'''
Functions for working with quaternions. Note that all the functions also
work on arrays, and can deal with full quaternions as well as with
quaternion vectors.

'''

'''
author: Thomas Haslwanter
date:   Jan-2014
ver:    2.4
'''

from numpy import sqrt, sum, r_, c_, hstack, cos, sin, atleast_2d, \
     zeros, shape, vstack, prod, min, arcsin, pi, tile, array, copysign, \
     reshape
import matplotlib.pyplot as plt

def deg2quat(inDeg):
    '''
    Convert axis-angles or plain degree into the corresponding quaternion values.
    Can be used with a plain number or with an axis angle.

    Parameters
    ----------
    inDeg : float or (N,3)
        quaternion magnitude or quaternion vectors.
    
    Returns
    -------
    outQuat : float or array (N,3)
        number or quaternion vector.
    
    Notes
    -----
    
    .. math::
        | \\vec{q} | = sin(\\theta/2)

    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> quat.deg2quat(array([[10,20,30], [20,30,40]]))
    array([[ 0.08715574,  0.17364818,  0.25881905],
       [ 0.17364818,  0.25881905,  0.34202014]])

    >>> quat.deg2quat(10)
    0.087155742747658166

    '''
    deg = (inDeg+180)%360-180
    return sin(0.5 * deg * pi/180)
    
def quatconj(q):
    ''' Conjugate quaternion 
    
    Parameters
    ----------
    q: array_like, shape ([3,4],) or (N,[3/4])
        quaternion or quaternion vectors
    
    Returns
    -------
    qconj : conjugate quaternion(s)
    
    
    Examples
    --------
    >>>  quat.quatconj([0,0,0.1])
    array([ 0., -0., -0., -1.])
    
    >>> quat.quatconj([[cos(0.1),0,0,sin(0.1)],
    >>>    [cos(0.2), 0, sin(0.2), 0]])
    array([[ 0.99500417, -0.        , -0.        , -0.09983342],
           [ 0.98006658, -0.        , -0.19866933, -0.        ]])
    
    '''
    
    q = atleast_2d(q)
    if q.shape[1]==3:
        q = vect2quat(q)

    qConj = q * r_[1, -1,-1,-1]

    if q.shape[0]==1:
        qConj=qConj.ravel()

    return qConj

def quatinv(q):
    ''' Quaternion inversion 

    Parameters
    ----------
    q: array_like, shape ([3,4],) or (N,[3/4])
        quaternion or quaternion vectors
    
    Returns
    -------
    qinv : inverse quaternion(s)
    
    Notes
    -----

    .. math::
          q^{-1} = \\frac{q_0 - \\vec{q}}{|q|^2}

    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>>  quat.quatinv([0,0,0.1])
    array([[-0. , -0. , -0.1]])
    
    >>> quat.quatinv([[cos(0.1),0,0,sin(0.1)],
    >>> [cos(0.2),0,sin(0.2),0]])
    array([[ 0.99500417, -0.        , -0.        , -0.09983342],
           [ 0.98006658, -0.        , -0.19866933, -0.        ]])
    '''
    
    q = atleast_2d(q)
    if q.shape[1]==3:
        return -q
    else:
        qLength = sum(q**2, 1)
        qConj = q * r_[1, -1,-1,-1]
        return (qConj.T / qLength).T

def quatmult(p,q):
    '''
    Quaternion multiplication: Calculates the product of two quaternions r = p * q
    If one of both of the quaterions have only three columns,
    the scalar component is calculated such that the length
    of the quaternion is one.
    The lengths of the quaternions have to match, or one of
    the two quaternions has to have the length one.
    If both p and q only have 3 components, the returned quaternion
    also only has 3 components (i.e. the quaternion vector)
    
    Parameters
    ----------
    p,q : array_like, shape ([3,4],) or (N,[3,4])
        quaternions or quaternion vectors
    
    Returns
    -------
    r : quaternion or quaternion vector (if both
        p and q are contain quaternion vectors).
    
    Notes
    -----

    .. math::
        q \\circ p = \\sum\\limits_{i=0}^3 {q_i I_i} * \\sum\\limits_{j=0}^3 \\
        {p_j I_j} = (q_0 p_0 - \\vec{q} \\cdot \\vec{p}) + (q_0 \\vec{p} + p_0 \\
        \\vec{q} + \\vec{q} \\times \\vec{p}) \\cdot \\vec{I}
    
    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> p = [cos(0.2), 0, 0, sin(0.2)]
    >>> q = [[0, 0, 0.1],
    >>>    [0, 0.1, 0]]
    >>> r = quat.quatmult(p,q)

    '''

    flag3D = False
    p = atleast_2d(p)
    q = atleast_2d(q)
    if p.shape[1]==3 & q.shape[1]==3:
        flag3D = True

    if len(p) != len(q):
        assert (len(p)==1 or len(q)==1), \
            'Both arguments in the quaternion multiplication must have the same number of rows, unless one has only one row.'

    p = vect2quat(p).T
    q = vect2quat(q).T
    
    if prod(shape(p)) > prod(shape(q)):
        r=zeros(shape(p))
    else:
        r=zeros(shape(q))

    r[0] = p[0]*q[0] - p[1]*q[1] - p[2]*q[2] - p[3]*q[3]
    r[1] = p[1]*q[0] + p[0]*q[1] + p[2]*q[3] - p[3]*q[2]
    r[2] = p[2]*q[0] + p[0]*q[2] + p[3]*q[1] - p[1]*q[3]
    r[3] = p[3]*q[0] + p[0]*q[3] + p[1]*q[2] - p[2]*q[1]

    if flag3D:
        # for rotations > 180 deg
        r[:,r[0]<0] = -r[:,r[0]<0]
        r = r[1:]

    r = r.T
    return r

def quat2deg(inQuat):
    '''Calculate the axis-angle corresponding to a given quaternion.
    
    Parameters
    ----------
    inQuat: float, or array_like, shape ([3/4],) or (N,[3/4])
        quaternion(s) or quaternion vector(s)
    
    Returns
    -------
    axAng : corresponding axis angle(s)
        float, or shape (3,) or (N,3)
    
    Notes
    -----

    .. math::
        | \\vec{q} | = sin(\\theta/2)

    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> quat.quat2deg(0.1)
    array([ 11.47834095])

    >>> quat.quat2deg([0.1, 0.1, 0])
    array([ 11.47834095,  11.47834095,   0.        ])

    >>> quat.quat2deg([cos(0.1), 0, sin(0.1), 0])
    array([  0.       ,  11.4591559,   0.       ])
    '''
    return 2 * arcsin(quat2vect(inQuat)) * 180 / pi

def quat2rotmat(inQuat):
    ''' Calculate the rotation matrix corresponding to the quaternion. If
    "inQuat" contains more than one quaternion, the matrix is flattened (to
    facilitate the work with rows of quaternions), and can be restored to
    matrix form by "reshaping" the resulting rows into a (3,3) shape.
    
    Parameters
    ----------
    inQuat : array_like, shape ([3,4],) or (N,[3,4])
        quaternions or quaternion vectors
    
    Returns
    -------
    rotMat : corresponding rotation matrix/matrices (flattened)
    
    Notes
    -----

    .. math::
        {\\bf{R}} = \\left( {\\begin{array}{*{20}{c}}
        {q_0^2 + q_1^2 - q_2^2 - q_3^2}&{2({q_1}{q_2} - {q_0}{q_3})}&{2({q_1}{q_3} + {q_0}{q_2})}\\\\
        {2({q_1}{q_2} + {q_0}{q_3})}&{q_0^2 - q_1^2 + q_2^2 - q_3^2}&{2({q_2}{q_3} - {q_0}{q_1})}\\\\
        {2({q_1}{q_3} - {q_0}{q_2})}&{2({q_2}{q_3} + {q_0}{q_1})}&{q_0^2 - q_1^2 - q_2^2 + q_3^2} \\\\
        \\end{array}} \\right)

    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> r = quat.quat2rotmat([0, 0, 0.1])
    >>> r.shape
    (1, 9)
    >>> r.reshape((3,3))
    array([[ 0.98      , -0.19899749,  0.        ],
        [ 0.19899749,  0.98      ,  0.        ],
        [ 0.        ,  0.        ,  1.        ]])
    '''
    
    q = vect2quat(inQuat).T
    
    R = zeros((9, q.shape[1]))
    R[0] = q[0]**2 + q[1]**2 - q[2]**2 - q[3]**2
    R[1] = 2*(q[1]*q[2] - q[0]*q[3])
    R[2] = 2*(q[1]*q[3] + q[0]*q[2])
    R[3] = 2*(q[1]*q[2] + q[0]*q[3])
    R[4] = q[0]**2 - q[1]**2 + q[2]**2 - q[3]**2
    R[5] = 2*(q[2]*q[3] - q[0]*q[1])
    R[6] = 2*(q[1]*q[3] - q[0]*q[2])
    R[7] = 2*(q[2]*q[3] + q[0]*q[1])
    R[8] = q[0]**2 - q[1]**2 - q[2]**2 + q[3]**2
    
    if R.shape[1] == 1:
        return reshape(R, (3,3))
    else:
        return R.T
    
def quat2vect(inQuat):
    '''
    Extract the quaternion vector from a full quaternion.

    Parameters
    ----------
    inQuat : array_like, shape ([3,4],) or (N,[3,4])
        quaternions or quaternion vectors.
    
    Returns
    -------
    vect : array, shape (3,) or (N,3)
        corresponding quaternion vectors
    
    Notes
    -----
    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> quat.quat2vect([[cos(0.2), 0, 0, sin(0.2)],[cos(0.1), 0, sin(0.1), 0]])
    array([[ 0.        ,  0.        ,  0.19866933],
           [ 0.        ,  0.09983342,  0.        ]])

    '''
    
    inQuat = atleast_2d(inQuat)
    if inQuat.shape[1] == 4:
        vect = inQuat[:,1:]
    else:
        vect = inQuat
    if min(vect.shape)==1:
        vect = vect.ravel()
    return vect

def rotmat2quat(rMat):
    '''
    Assumes that R has the shape (3,3), or the matrix elements in columns

    Parameters
    ----------
    rMat : array, shape (3,3) or (N,9)
        single rotation matrix, or matrix with rotation-matrix elements.
    
    Returns
    -------
    outQuat : array, shape (4,) or (N,4)
        corresponding quaternion vector(s)
    
    Notes
    -----

    .. math::
         \\vec q = 0.5*copysign\\left( {\\begin{array}{*{20}{c}}
        {\\sqrt {1 + {R_{11}} - {R_{22}} - {R_{33}}} ,}\\\\
        {\\sqrt {1 - {R_{11}} + {R_{22}} - {R_{33}}} ,}\\\\
        {\\sqrt {1 - {R_{11}} - {R_{22}} + {R_{33}}} ,}
        \\end{array}\\begin{array}{*{20}{c}}
        {{R_{32}} - {R_{23}}}\\\\
        {{R_{13}} - {R_{31}}}\\\\
        {{R_{21}} - {R_{12}}}
        \\end{array}} \\right) 
    
    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    
    >>> rotMat = array([[cos(alpha), -sin(alpha), 0],
    >>>    [sin(alpha), cos(alpha), 0],
    >>>    [0, 0, 1]])
    >>> quat.rotmat2quat(rotMat)
    array([[ 0.99500417,  0.        ,  0.        ,  0.09983342]])
    
    '''    
    
    if rMat.shape == (3,3) or rMat.shape == (9,):
        rMat=atleast_2d(rMat.ravel()).T
    else:
        rMat = rMat.T
    q = zeros((4, rMat.shape[1]))
    
    R11 = rMat[0]
    R12 = rMat[1]
    R13 = rMat[2]
    R21 = rMat[3]
    R22 = rMat[4]
    R23 = rMat[5]
    R31 = rMat[6]
    R32 = rMat[7]
    R33 = rMat[8]
    
    q[1] = 0.5 * copysign(sqrt(1+R11-R22-R33), R32-R23)
    q[2] = 0.5 * copysign(sqrt(1-R11+R22-R33), R13-R31)
    q[3] = 0.5 * copysign(sqrt(1-R11-R22+R33), R21-R12)
    q[0] = sqrt(1-(q[1]**2+q[2]**2+q[3]**2))
    
    return q.T
    
def vect2quat(inData):
    ''' Utility function, which turns a quaternion vector into a unit quaternion.

    Parameters
    ----------
    inData : array_like, shape (3,) or (N,3)
        quaternions or quaternion vectors
    
    Returns
    -------
    quats : array, shape (4,) or (N,4)
        corresponding unit quaternions.
    
    Notes
    -----
    More info under 
    http://en.wikipedia.org/wiki/Quaternion
    
    Examples
    --------
    >>> quats = array([[0,0, sin(0.1)],[0, sin(0.2), 0]])
    >>> quat.vect2quat(quats)
    array([[ 0.99500417,  0.        ,  0.        ,  0.09983342],
           [ 0.98006658,  0.        ,  0.19866933,  0.        ]])

    '''
    inData = atleast_2d(inData)
    (m,n) = inData.shape
    if (n!=3)&(n!=4):
        error('Quaternion must have 3 or 4 columns')
    if n == 3:
        qLength = 1-sum(inData**2,1)
        numLimit = 1e-12
        # Check for numerical problems
        if min(qLength) < -numLimit:
            error('Quaternion is too long!')
        else:
            # Correct for numerical problems
            qLength[qLength<0] = 0
        outData = hstack((c_[sqrt(qLength)], inData))
        
    else:
        outData = inData
        
    return outData


if __name__=='__main__':
    '''These are some simple tests to see if the functions produce the
    proper output.
    More extensive tests are found in tests/test_quat.py'''
    
    a = r_[cos(0.1), 0,0,sin(0.1)]
    b = r_[cos(0.1),0,sin(0.1), 0]
    c = vstack((a,b))
    d = r_[sin(0.1), 0, 0]
    e = r_[2, 0, sin(0.1), 0]

    print(quatmult(a,a))
    print(quatmult(a,b))
    print(quatmult(c,c))
    print(quatmult(c,a))
    print(quatmult(d,d))

    print('The inverse of {0} is {1}'.format(a, quatinv(a)))
    print('The inverse of {0} is {1}'.format(d, quatinv(d)))
    print('The inverse of {0} is {1}'.format(e, quatinv(e)))
    print(quatmult(e, quatinv(e)))

    print(quat2vect(a))
    print('{0} is {1} degree'.format(a, quat2deg(a)))
    print('{0} is {1} degree'.format(c, quat2deg(c)))
    print(quat2deg(0.2))
    x = r_[1,0,0]
    vNull = r_[0,0,0]
    print(rotate_vector(x, a))

    v0 = [0., 0., 100.]
    vel = tile(v0, (1000,1))
    rate = 100
    
    out = vel2quat(vel, [0., 0., 0.], rate, 'sf')
    print(out[-1:])
    plt.plot(out[:,1:4])
    plt.show()
    
    print(deg2quat(15))
    print(deg2quat(quat2deg(a)))
    
    q = array([[0, 0, sin(0.1)],
               [0, sin(0.01), 0]])
    rMat = quat2rotmat(q)
    print(rMat[1].reshape((3,3)))
    qNew = rotmat2quat(rMat)
    print(qNew)
