""" This is an addition to Enthoughts "quaterion" module.
	The Enthought module can be imported with 
	import enthought.mathematics.quaternion

    Quaternions are frequently used to avoid 'gimbal lock' and/or 'twisting'
    when rotating polygons in 3D space.
"""

# Major library imports.
import numpy as np
import matplotlib.pyplot as mp
#import math
#import enthought.mathematics.quaternion as eq


def qmult(q1, q2):
    """
    Multiply two quaternions.

    """

    s1, v1 = q1[0], q1[1:]
    s2, v2 = q2[0], q2[1:]
    scalar = s1*s2 - np.dot(v1, v2)
    vector = s2*v1 + s1*v2 + np.cross(v1, v2)
    return np.hstack((scalar, vector))

def vect2quat(vect):
    """
    Turn a 3d vector into a quaternion.
    """

    vectLength2 = np.dot(vect, vect)
    q0 = np.sqrt(1-vectLength2)
    return np.hstack((q0, vect))

def quat2vect(quat):
    """
    Turn a quaternion into a 3d vector .
    """
    return quat[1:]
    
def vel2quat(vel, q0, rate, CStype):
    '''
    Take an angular velocity (in deg/s), and convert it into the
    corresponding orientation quaternion.
    q0 ... vector-part of quaternion (!!)
    '''
    
    # Convert from deg/s to rad/s
    vel = vel * np.pi/180
    vel_t = np.sqrt(np.sum(vel**2, 1))
    vel_nonZero = vel_t>0
    
    # initialize the quaternion
    q_delta = np.zeros(np.shape(vel))
    q_pos = np.zeros((len(vel),4))
    q_pos[0,:] = vect2quat(q0)
    
    # magnitude of position steps
    dq_total = np.sin(vel_t[vel_nonZero]/(2*rate))
    
    q_delta[vel_nonZero,:] = vel[vel_nonZero,:] * np.tile(dq_total/vel_t[vel_nonZero], (3,1)).transpose()
    
    for ii in range(len(vel)-1):
        q1 = vect2quat(q_delta[ii,:])
        q2 = q_pos[ii,:]
        if CStype == 'sf':            
            qm = qmult(q1,q2)
        elif CStype == 'bf':
            qm = qmult(q2,q1)
        else:
            print 'I don''t know this type of coordinate system!'
        q_pos[ii+1,:] = qm
    
    return q_pos
    
def quat2deg(quat):
    return 2*np.arcsin(quat) * 180/np.pi
    
def deg2quat(deg):
    return np.sin(0.5*deg*np.pi/180)
    
def qinv(q):
    s, v = q[0], q[1:]
    return np.hstack((s, -v))

def rotate_vector(vector, q):
    qvector = np.hstack((0, vector))
    vRotated = qmult(q, qmult(qvector, qinv(q)))
    return vRotated[1:4]
    
    
if __name__=='__main__':
    v0 = [0., 0., 100.]
    vel = np.tile(v0, (1000,1))
    rate = 100
    
    out = vel2quat(vel, [0., 0., 0.], rate, 'sf')
    print out[-1:]
    mp.plot(out[:,1:4])
    mp.show()
