'''
Utilities for movements in 3D space
'''

'''
Author: Thomas Haslwanter
Version: 1.2
Date: Jan-2014
'''

from numpy.linalg import inv
import numpy as np
from numpy import r_, sum, pi, sqrt, zeros, shape, sin, tile, arange, float, array, nan, ones_like, rad2deg, deg2rad
from scipy.integrate import cumtrapz
from thLib import quat, vector

def analyze3Dmarkers(MarkerPos, ReferencePos):
    '''
    Take recorded positions from 3 markers, and calculate center-of-mass (COM) and orientation
    Can be used e.g. for the analysis of Optotrac data.

    Parameters
    ----------
    MarkerPos : ndarray, shape (N,9)
        x/y/z coordinates of 3 markers

    ReferencePos : ndarray, shape (1,9)
        x/y/z coordinates of markers in the reference position

    Returns
    -------
    Position : ndarray, shape (N,3)
        x/y/z coordinates of COM, relative to the reference position
    Orientation : ndarray, shape (N,3)
        Orientation relative to reference orientation, expressed as quaternion

    Example
    -------
    >>> (PosOut, OrientOut) = analyze3Dmarkers(MarkerPos, ReferencePos)


    '''

    # Specify where the x-, y-, and z-position are located in the data
    cols = {'x' : r_[(0,3,6)]} 
    cols['y'] = cols['x'] + 1
    cols['z'] = cols['x'] + 2    

    # Calculate the position
    cog = np.vstack(( sum(MarkerPos[:,cols['x']], axis=1),
                      sum(MarkerPos[:,cols['y']], axis=1),
                      sum(MarkerPos[:,cols['z']], axis=1) )).T/3.

    cog_ref = np.vstack(( sum(ReferencePos[:,cols['x']]),
                          sum(ReferencePos[:,cols['y']]),
                          sum(ReferencePos[:,cols['z']]) )).T/3.                      

    position = cog - cog_ref    

    # Calculate the orientation    
    numPoints = len(MarkerPos)
    orientation = np.ones((numPoints,3))

    refOrientation = vector.plane_orientation(ReferencePos[:3], ReferencePos[3:6], ReferencePos[6:])

    for ii in range(numPoints):
        '''The three points define a triangle. The first rotation is such
        that the orientation of the reference-triangle, defined as the
        direction perpendicular to the triangle, is rotated along the shortest
        path to the current orientation.
        In other words, this is a rotation outside the plane spanned by the three
        marker points.'''

        curOrientation = vector.plane_orientation( MarkerPos[ii,:3], MarkerPos[ii,3:6], MarkerPos[ii,6:])
        alpha = vector.angle(refOrientation, curOrientation)        

        if alpha > 0:
            n1 = np.cross(refOrientation, curOrientation)
            n1 = n1/np.linalg.norm(n1)
            q1 = n1 * sin(alpha/2)
        else:
            q1 = r_[0,0,0]

        # Now rotate the triangle into this orientation ...
        refPos_after_q1 = vector.rotate_vector(np.reshape(ReferencePos,(3,3)), q1)

        '''Find which further rotation in the plane spanned by the three marker points
	is necessary to bring the data into the measured orientation.'''

        Marker_0 = MarkerPos[ii,:3]
        Marker_1 = MarkerPos[ii,3:6]
        Vector10 = Marker_0 - Marker_1
        vector10_ref = refPos_after_q1[0]-refPos_after_q1[1]
        beta = vector.angle(Vector10, vector10_ref)

        q2 = curOrientation * np.sin(beta/2)

        if np.cross(vector10_ref,Vector10).dot(curOrientation)<=0:
            q2 = -q2
        orientation[ii,:] = quat.quatmult(q2, q1)

    return (position, orientation)

def movement_from_markers(r0, Position, Orientation):
    '''
    Movement trajetory of a point on an object, from the position and
    orientation of a set of markers, and the relative position of the
    point at t=0.
    
    Parameters
    ----------
    r0 : ndarray (3,)
        Position of point relative to center of markers, when the object is
        in the reference position.
    Position : ndarray, shape (N,3)
        x/y/z coordinates of COM, relative to the reference position
    Orientation : ndarray, shape (N,3)
        Orientation relative to reference orientation, expressed as quaternion
    
    Returns
    -------
    mov : ndarray, shape (N,3)
        x/y/z coordinates of the position on the object, relative to the
        reference position of the markers
    
    Notes
    ----- 
    
      .. math::

          \\vec C(t) = \\vec M(t) + \\vec r(t) = \\vec M(t) +
          {{\\bf{R}}_{mov}}(t) \\cdot \\vec r({t_0})
          
    Examples
    --------
    >>> t = np.arange(0,10,0.1)
    >>> translation = (np.c_[[1,1,0]]*t).T
    >>> M = np.empty((3,3))
    >>> M[0] = np.r_[0,0,0]
    >>> M[1]= np.r_[1,0,0]
    >>> M[2] = np.r_[1,1,0]
    >>> M -= np.mean(M, axis=0) 
    >>> q = np.vstack((np.zeros_like(t), np.zeros_like(t),quat.deg2quat(100*t))).T
    >>> M0 = vector.rotate_vector(M[0], q) + translation
    >>> M1 = vector.rotate_vector(M[1], q) + translation
    >>> M2 = vector.rotate_vector(M[2], q) + translation
    >>> data = np.hstack((M0,M1,M2))
    >>> (pos, ori) = signals.analyze3Dmarkers(data, data[0])
    >>> r0 = np.r_[1,2,3]
    >>> movement = movement_from_markers(r0, pos, ori)
    
    '''

    return Position + vector.rotate_vector(r0, Orientation)
    
    
def vel2quat(vel, q0, rate, CStype):
    '''
    Take an angular velocity (in deg/s), and convert it into the
    corresponding orientation quaternion.
    
    Parameters
    ----------
    vel : array, shape (3,) or (N,3)
        angular velocity [deg/s].
    q0 : array (3,)
        vector-part of quaternion (!!)
    rate : float
        sampling rate (in [Hz])
    CStype:  string
        coordinate_system, space-fixed ("sf") or body_fixed ("bf")
    
    Returns
    -------
    quats : array, shape (N,4)
        unit quaternion vectors.
    
    Notes
    -----
    For angular velocity with respect to space ("sf"), the orientation is given by

      .. math::
          q(t) = \\Delta q(t_n) \\circ \\Delta q(t_{n-1}) \\circ ... \\circ \\Delta q(t_2) \\circ \\Delta q(t_1) \\circ q(t_0)
    
      .. math::
        \\Delta \\vec{q_i} = \\vec{n(t)}\\sin (\\frac{\\Delta \\phi (t_i)}{2}) = \\frac{\\vec \\omega (t_i)}{\\left| {\\vec \\omega (t_i)} \\right|}\\sin \\left( \\frac{\\left| {\\vec \\omega ({t_i})} \\right|\\Delta t}{2} \\right)

    For angular velocity with respect to the body ("bf"), the sequence of quaternions is inverted.
        
    Take care that you choose a high enough sampling rate!
    
    Examples
    --------
    >>> v0 = [0., 0., 100.] * np.pi/180.
    >>> vel = tile(v0, (1000,1))
    >>> rate = 100
    >>> out = quat.vel2quat(vel, [0., 0., 0.], rate, 'sf')
    >>> out[-1:]
    array([[-0.76040597,  0.        ,  0.        ,  0.64944805]])

    '''
    
    # convert from deg/s to rad/s
    vel = deg2rad(vel)
    
    vel_t = sqrt(sum(vel**2, 1))
    vel_nonZero = vel_t>0
    
    # initialize the quaternion
    q_delta = zeros(shape(vel))
    q_pos = zeros((len(vel),4))
    q_pos[0,:] = quat.vect2quat(q0)
    
    # magnitude of position steps
    dq_total = sin(vel_t[vel_nonZero]/(2*rate))
    
    q_delta[vel_nonZero,:] = vel[vel_nonZero,:] * tile(dq_total/vel_t[vel_nonZero], (3,1)).T
    
    for ii in range(len(vel)-1):
        q1 = quat.vect2quat(q_delta[ii,:])
        q2 = q_pos[ii,:]
        if CStype == 'sf':            
            qm = quat.quatmult(q1,q2)
        elif CStype == 'bf':
            qm = quat.quatmult(q2,q1)
        else:
            print('I don''t know this type of coordinate system!')
        q_pos[ii+1,:] = qm
    
    return q_pos

def reconstruct_movement(omega, accMeasured, initialPosition, R_initialOrientation, rate):
    ''' From the measured data, reconstruct the movement.
    Assume a start in a stationary position. No compensation for drift.
    
    Parameters
    ----------
    omega : ndarray(N,3)
        Angular velocity, in [rad/s]
    accMeasured : ndarray(N,3)
        Linear acceleration, in [m/s^2]
    initialPosition : ndarray(3,)
        initial Position, in [m]
    R_initialOrientation: ndarray(3,3)
        Rotation matrix describing the initial orientation of the sensor,
        except a mis-orienation with respect to gravity
    rate : float
        sampling rate, in [Hz]
        
    Returns
    -------
    q : ndarray(N,3)
        Orientation, expressed as a quaternion vector
    pos : ndarray(N,3)
        Position in space [m]
        
    Example
    -------
    >>> q1, pos1 = reconstruct_movement(omega, acc, initialPosition, R_initialOrientation, rate)
    
    
    '''
    
    # Transform recordings to angVel/acceleration in space --------------
    
    # Orientation of \vec{g} with the sensor in the "R_initialOrientation"
    g0 = inv(R_initialOrientation).dot(r_[0,0,g])
    
    # for the remaining deviation, assume the shortest rotation to there
    q0 = qrotate(accMeasured[0],g0)    
    R0 = quat2rotmat(q0)
    
    # combine the two, to form a reference orientation. Note that the sequence
    # is very important!
    R_ref = R_initialOrientation.dot(R0)
    q_ref = rotmat2quat(R_ref)
    
    # Calculate orientation q by "integrating" omega -----------------
    q = vel2quat(rad2deg(omega), q_ref, rate, 'bf')
    
    # Acceleration, velocity, and position ----------------------------
    # From q and the measured acceleration, get the \frac{d^2x}{dt^2}
    g_v = r_[0, 0, g] 
    accReSensor = accMeasured - rotate_vector(g_v, quatinv(q))
    accReSpace = rotate_vector(accReSensor, q)
    
    # Make the first position the reference position
    q = quatmult(q, quatinv(q[0]))
    
    # compensate for drift
    #drift = np.mean(accReSpace, 0)
    #accReSpace -= drift*0.7
    
    # Position and Velocity through integration, assuming 0-velocity at t=0
    vel = nan*ones_like(accReSpace)
    pos = nan*ones_like(accReSpace)
    
    for ii in range(accReSpace.shape[1]):
        vel[:,ii] = cumtrapz(accReSpace[:,ii], dx=1./rate, initial=0)
        pos[:,ii] = cumtrapz(vel[:,ii],        dx=1./rate, initial=initialPosition[ii])
        
    return (q, pos)

    
    for ii in range(accReSpace.shape[1]):
        vel[:,ii] = cumtrapz(accReSpace[:,ii], dx=1./rate, initial=0)
        pos[:,ii] = cumtrapz(vel[:,ii],        dx=1./rate, initial=initialPosition[ii])
        
    return (q, pos)

if __name__ == '__main__':
    print('Done')
