###############################################################
# Tools used to compute the reflectance spectra for
# the data (tiles or pixels)
#
# Authors:
#  NOVELTIS: Cedric Bacour / Ivan Price
###############################################################

#from __init__ import *

import numpy as np


import process_brdf







# ===================================================================
# calculate the reflectance OVER THE WHOLE SPECTRAL DOMAIN
#
# COMMENTS
#  fait pour chaque mois
# => ref : nlon, nlat, nbands
# => ws  : nlon, nlat
# => chl  : nlon, nlat
#
def main(job, do_compute_error = False):
    ''' calculate the reflectance

    ##
    NOTE THIS RELIES ON THE cfg.idx_* being pre-calculated using
    job.calculate_surface_masks()
    ^^

     - dimension order:
#       - weighting coefficient for land tiles: [lon,lat,pft]
       - other variables: [lon,lat]

     Inputs
      - data dictionnary containing the variables contained in the tile file
#        - land_cover              scalar or array [ nlon, nlat]
#        - weighting coefficients  vector [pft] or array [ nlon, nlat, pft]
        - wind_speed              scalar or array [ nlon, nlat]
        - chloro_conc             scalar or array [ nlon, nlat]


     Output
      - add the elements 'reflectance' and 'mask' to the input data dictionnary
        array [nlon, nlat, nlmbd]


     #TODO: Need improvement for the processing of pixels with water fraction
     For the current version : ref_ocean = cst

    '''

    # lower and upper wavebands
#    ilmin = np.where(cfg.lmbd==lmbd[0])[0][0]
#    ilmax = np.where(cfg.lmbd==lmbd[-1])[0][0] + 1


    # -- Compute reflectance
#    # weighting coefficients
#    weighting_coeff = data[cfg.vars_land[1]]
#    land_cover = data[cfg.vars_land[0]]
#    wc_land = weighting_coeff[:,:,1:]  # remove the first dimension corresponding to water
#    original_shape = wc_land.shape
    data = job.data
    ref_land       = data['ref_land']
    ref_land_covar = data['ref_land_covar']
    chloro_conc    = data['chloro_conc']
    wind_speed     = data['wind_speed']
    original_shape = ref_land.shape


    # flatten to 2d and 1d arrays
    ref_land = ref_land.reshape(ref_land.shape[0]*ref_land.shape[1], ref_land.shape[2])
    ref_land_covar = ref_land_covar.reshape(ref_land_covar.shape[0]*ref_land_covar.shape[1], ref_land_covar.shape[2])
    chloro_conc = chloro_conc.flatten()
    wind_speed = wind_speed.flatten()


    # - initialise the variable to be returned
    reflectance = np.zeros((ref_land.shape[0], len(job.cfg.lmbd)), np.float32) + job.cfg.missval_2b
    err_reflectance_land = np.empty((ref_land.shape[0], len(job.cfg.lmbd)), np.float32)
    err_reflectance_land[:] = np.NaN

    ##err_reflectance_land = np.array([])


    # - for each pixel in the land pixels collection
    if len(data['idx_land']) > 0:

        #err_reflectance_land = np.zeros((len(data['idx_land']), len(cfg.lmbd)), np.float32) + cfg.missval_2b

        #ipix = 0
        for idx in data['idx_land']:
            #print 'perform pixel calc for land: %s' % idx
            # here is all the calc time, removing this for loop and
            # the one in reflectance_land function may be good

            ans = reflectance_spectrum_land( ref_land[idx,:],
                                                            job.cfg.lmbd,
                                                            job.cfg
                                                          )
            reflectance[idx,:] = ans


            # Now compute error on reflectance
            if do_compute_error:
                covar = np.zeros([7,7],np.float64)
                icnt = 0
                for i in range(7):
                    for j in range(7-i):
                        covar[i,j+i] = ref_land_covar[idx,icnt]
                        covar[j+i,i] = ref_land_covar[idx,icnt]
                        icnt = icnt+1

                err_reflectance_land[idx,:] = reflectance_spectrum_land(covar,
                                                                       job.cfg.lmbd,
                                                                       job.cfg,
                                                                       error = 1
                                                                       )
            ##ipix = ipix+1

    # - snow (PRELIMINARY VERSION / TO BE UBDATED)
    if len(data['idx_snow']) > 0:

        for idx in data['idx_snow']:
            ans = process_brdf.function_brdf_snow( job.cfg,
                                                   sza  = job.cfg.sza_std,
                                                   vza  = job.cfg.vza_std,
                                                   phi  = job.cfg.phi_std,
                                                   lmbd = job.cfg.lmbd)
            reflectance[idx,:] = ans.ravel()

    # -
    # - Ocean processing
    # -
    if len(data['idx_ocean']) > 0:


        ref_ocean = reflectance_spectrum_ocean(chloro_conc, job.cfg.lmbd, job.cfg)
        for idx in data['idx_ocean']:
            reflectance[idx,:] = ref_ocean[idx,:]


    # -
    # - Reshape the reflectance matrix back to the original (grid selection) dimensions
    # -
    reflectance = np.reshape(reflectance, (original_shape[0], original_shape[1], len(job.cfg.lmbd)))
    err_reflectance_land = np.reshape(err_reflectance_land, (original_shape[0], original_shape[1], len(job.cfg.lmbd)))

    ans = [[],[]]
    ans[0] = reflectance
    ans[1] = err_reflectance_land

    return ans
    return [ reflectance, err_reflectance_land ]

# END calculate_reflectance
# ===================================================================


def calculate_ndvi(ref, cfg, lmbd=None, domains=None):

    if lmbd is None:
        lmbd = cfg.lmbd


    ref_MODIS = spectral_selection(ref, lmbd, domains=domains)

    #TODO: these need to be better labelled
    xx = ref_MODIS[:,:,cfg.land.chanPIR]-ref_MODIS[:,:,cfg.land.chanR]
    yy = ref_MODIS[:,:,cfg.land.chanPIR]+ref_MODIS[:,:,cfg.land.chanR]

    #job.data['NDVI'] = (ref_MODIS[:,:,Cland.chanPIR]-ref_MODIS[:,:,Cland.chanR]) / (ref_MODIS[:,:,Cland.chanPIR]+ref_MODIS[:,:,Cland.chanR])
    ans = xx / yy

    return ans



# ===================================================================
def spectral_selection(ref, lmbd, domains=None):
    ''' Reflectance averaging by spectral domains

     Inputs:
      - wavelength         vector [nlmbd]
      - reflectance        vector [nlmbd] or array  [nlon, nlat, nlmbd]
      - wavelength domains :
        either a list of list of spectral domains defined by [lmbd_min, lmbd_max] => averaging over spectral domain
        or a list of wavebands => get the reflectance for these wavebands

     Outputs:
      - reflectance      vector[ndomains] or array [nlon, nlat, ndomains]
    '''


    ndomains = len(domains)
    case = None
    try:
        len(domains[0])
        case = 'averaging'
    except:
        case = 'indices'


    # -- Averaging over spectral domains
    if case == 'averaging':

        # if domains != None:
        #     ndomains = len(domains)
        # print 'len %s' % ndomains
        # # integration over the whole spectral domain
        # if ndomains == 1:
        #     # One pixel
        #     if len(ref.shape) == 1:
        #       ans = np.average( ref )
        #     else:
        #       ans = np.average( ref, axis = 2)

        # # integration by spectral domain
        # else:
        # allocation
        if len(ref.shape) == 1:
            ans = np.zeros(ndomains, np.float32)
        else:
            ans = np.zeros((ref.shape[0],ref.shape[1],ndomains), np.float32)

        # computation per domain
        for i in range(ndomains):
            mask = np.ma.masked_outside(lmbd, domains[i][0], domains[i][1])
            indices = np.ma.nonzero(mask.mask == False)[0]
            indices = indices.tolist()

            # One pixel
            if len(ref.shape) == 1:
                ans[i] = np.average( np.take(ref, indices) )
            else:
                ans[:,:,i] = np.average( np.take(ref, indices, axis = 2), axis = 2)

    # -- Extract the required wavebands
    if case == 'indices':
        indices = []
        for wl in domains:
            ind = np.ma.nonzero(lmbd == wl)[0]
            indices.extend(ind.tolist())
        if len(ref.shape) == 1:
            ans =  np.take(ref, indices)
        else:
            ans =  np.take(ref, indices, axis = 2)

    # -- Return
    return ans
#END spectral selection
# ===================================================================


# ===================================================================
# Reflectance averaging by spectral domains
#
# def reflectance_averaging(lmbd, ref, domains=None):
#     ''' Reflectance averaging by spectral domains

#      Inputs:
#       - wavelength         vector [nlmbd]
#       - reflectance        vector [nlmbd] or array  [nlon, nlat, nlmbd]
#       - wavelength domains (optional) defined in the config.py file
#       if no domain is provided, the averaging is performed over the
#       300-4000 nm domain

#      Outputs:
#       - reflectance      vector[ndomains] or array [nlon, nlat, ndomains]
#     '''

#     ndomains = 1
#     if domains != None:
#         ndomains = len(domains)
#     print 'len %s' % ndomains
#     # integration over the whole spectral domain
#     if ndomains == 1:
#         # One pixel
#         if len(ref.shape) == 1:
#             ans = np.average( ref )
#         else:
#             ans = np.average( ref, axis = 2)

#     # integration by spectral domain
#     else:
#         # allocation
#         if len(ref.shape) == 1:
#             ans = np.zeros(ndomains, np.float32)
#         else:
#             ans = np.zeros((ref.shape[0],ref.shape[1],ndomains), np.float32)

#         # computation per domain
#         for i in range(ndomains):
#             mask = np.ma.masked_outside(lmbd, domains[i][0], domains[i][1])
#             indices = np.ma.nonzero(mask.mask == False)[0]
#             indices = indices.tolist()

#             # One pixel
#             if len(ref.shape) == 1:
#                 ans[i] = np.average( np.take(ref, indices) )
#             else:
#                 ans[:,:,i] = np.average( np.take(ref, indices, axis = 2), axis = 2)
#     return ans
# END reflectance_averaging
# ===================================================================



# ===================================================================
# Compute reflectance spectrum over Land on a pixel basis from normalized land surface
# reflectance in N wavebands
#
def reflectance_spectrum_land(data_in, lmbd, cfg, error = None):
    '''Compute reflectance over Land on a pixel basis from

     Inputs
      - normalized land surface reflectances in N bands :     vector
       OR covariance matrix                             :     array

      - Land class containing the parameters to compute the reflectance spectrum : class
      - requested wavelength : vector


     Outputs
      - the reflectance spectrum from 300 to 4000 nm.
    '''


    # Compute the reflectance spectrum
    if error == None:

        ref_out = data_in - cfg.land.ACP_ObsMean
        # buf = np.dot(Cland.ACP_prod,ref_out)
        # ans = np.dot(Cland.ACP_eigenvectors,buf) + Cland.ACP_SpecMean

        # HERE IS ALL THE TIME FOR CALCULATION
        ans = np.dot(cfg.land.ACP_mat,ref_out) + cfg.land.ACP_SpecMean
        #

        # reflectance values must be > 0.005
        mask_low = np.ma.masked_less(ans,0.005)
        ans = np.array(np.ma.filled(mask_low,0.005))


    # Compute the error on the reflectance spectrum (return standard deviation)
    else:
        ans = np.zeros(lmbd.shape,np.float64)
        for i in range((lmbd.shape)[0]):
            buf = np.dot(cfg.land.ACP_mat[i,:],np.array(data_in))
            ans[i] = np.sqrt(np.dot(buf,np.transpose(cfg.land.ACP_mat[i,:])))


    # keep only the requested wavebands
#    mask_lmbd = np.ma.masked_inside(Cland.ACP_lmbd,lmbd[0],lmbd[-1])
#    idxGOOD = np.ma.nonzero(mask_lmbd.mask == True)
#    ref_out = fitted[idxGOOD]

    # return value
    return ans
# END reflectance_spectrum_land
# ===================================================================



# ===================================================================
# Compute reflectance over all ocean pixels betwen lmin and lmax
#
def reflectance_spectrum_ocean(chl, lmbd, cfg):
    ''' Compute reflectance over Ocean on a pixel basis

     Inputs
      - chlorophyll content

     Outputs
      - #TODO: this function is a stub of the real thing

    '''

    lmin = lmbd[0]
    lmax = lmbd[-1]

    ref_ocean = np.zeros([len(chl),len(lmbd)])

    # Mask depending on the values of the chlorophyll content
    ind_in = np.ma.masked_inside(chl, cfg.ocean.chl[0], cfg.ocean.chl[-1])
    ind_in = np.ma.where(ind_in.mask == True)

    ind_inf = np.ma.masked_less(chl, cfg.ocean.chl[0])
    ind_inf = np.ma.where(ind_inf.mask == True)

    ind_sup = np.ma.masked_greater(chl, cfg.ocean.chl[-1])
    ind_sup = np.ma.where(ind_sup.mask == True)

    # indices of lmin and lmax in the domain 300-4000 nm
    lmbd_tmp = np.arange(4000-300+1)+300
    idx800 = 500
    imin = np.ma.masked_less_equal(lmbd_tmp,lmin)
    imin = list(np.ma.where(imin.mask == True))
    imax = np.ma.masked_greater_equal(lmbd_tmp,lmax)
    imax = list(np.ma.where(imax.mask == True))
    if len(imin) > 0: imin=[imin[0][-1]]
    if len(imax) > 0: imax=[imax[0][0]+1]

    # Computation for pixels which chl is within tabulated chlorphyll content min and max
    if len(ind_in) > 0:
        ind_in = ind_in[0]
        for ind in ind_in:
            diff = cfg.ocean.chl - chl[ind]
            buf = np.ma.masked_less(diff,0)
            #TODO: check if buffer is entirely masked..what to do in this case ?
            if buf.mask.all():
                print 'ERROR buf all masked in proc_reflectance !'
            else:
                ind_min = (np.ma.where(buf == np.ma.minimum.reduce(buf)))[0]-1
                x = np.log(cfg.ocean.chl[ind_min+1]/chl[ind])/np.log(cfg.ocean.chl[ind_min+1]/cfg.ocean.chl[ind_min])
                ref_interp = (x*cfg.ocean.ref_chl_std[ind_min,:] + (1-x)*cfg.ocean.ref_chl_std[ind_min+1,:]).ravel()
                ref_ocean[ind,0:idx800+1] = ref_interp

    # Computation for pixels which chl is lower than tabulated chlorphyll content min
    if len(ind_inf) > 0:
        ind_inf = ind_inf[0]
        ref_ocean[ind_inf,0:idx800+1] = cfg.ocean.ref_chl_std[0,:]

    # Computation for pixels which chl is greater than tabulated chlorphyll content max
    if len(ind_sup) > 0:
        ind_sup = ind_sup[0]
        ref_ocean[ind_sup,0:idx800+1] = cfg.ocean.ref_chl_std[-1,:]


    return ref_ocean

# END reflectance_spectrum_ocean
# ===================================================================


# ===================================================================
# Compute reflectance over all ocean pixels betwen lmin and lmax
#
# def reflectance_spectrum_ocean(chl, Cocean, lmbd):
#     ''' Compute reflectance over Ocean on a pixel basis

#      Inputs
#       - chlorophyll content

#      Outputs
#       - #TODO: this function is a stub of the real thing

#     '''

#     lmin = lmbd[0]
#     lmax = lmbd[-1]

#     # Mask depending on the values of the chlorophyll content
#     ind_in = np.ma.masked_inside(chl,Cocean.chl[0],Cocean.chl[-1])
#     ind_in = np.ma.where(ind_in.mask == True)

#     ind_inf = np.ma.masked_less(chl,Cocean.chl[0])
#     ind_inf = np.ma.where(ind_inf.mask == True)

#     ind_sup = np.ma.masked_greater(chl,Cocean.chl[-1])
#     ind_sup = np.ma.where(ind_sup.mask == True)

#     # Reflectance in the tabulated wavebands
#     nlmbd = len(Cocean.lmbd_chl_std)
# #    RefCalc = np.zeros([len(chl),nlmbd])

#     # Output reflectance and wavebands
#     ref_ocean = np.zeros([len(chl),lmax-lmin+1])
# #    lmbd_ocean = np.arange(lmax-lmin+1)+lmin


#    # indices in lmbd_std
#     imin = np.ma.masked_less_equal(Cocean.lmbd_chl_std,lmin)
#     imin = list(np.ma.where(imin.mask == True))
#     imax = np.ma.masked_greater_equal(Cocean.lmbd_chl_std,lmax)
#     imax = list(np.ma.where(imax.mask == True))
#     # indices in lmbd requested
#  #   jmin = np.ma.masked_less_equal(lmbd_ocean,300)
#     jmin = np.ma.masked_less_equal(lmbd,300)
#     jmin = np.ma.where(jmin.mask == True)
#  #   jmax = np.ma.masked_greater_equal(lmbd_ocean,800)
#     jmax = np.ma.masked_greater_equal(lmbd,800)
#     jmax = np.ma.where(jmax.mask == True)

#     if len(imin) > 0: imin=[imin[0][-1]]
#     if len(imax) > 0: imax=[imax[0][0]+1]
#     if len(jmin) > 0: jmin=[jmin[0][-1]]
#     if len(jmax) > 0: jmax=[jmax[0][0]+1]

#     # Computation for pixels which chl is within tabulated chlorphyll content min and max
#     if len(ind_in) > 0:
#       ind_in = ind_in[0]
#       for i in range(len(ind_in)):

#           diff = Cocean.chl - chl[i]
#           buf = np.ma.masked_less(diff,0)
#           ind_min = (np.ma.where(buf == np.ma.minimum.reduce(buf)))[0]-1
#           x = np.log(Cocean.chl[ind_min+1]/chl[i])/np.log(Cocean.chl[ind_min+1]/Cocean.chl[ind_min])
#           ref_interp = (x*Cocean.ref_chl_std[ind_min,:] + (1-x)*Cocean.ref_chl_std[ind_min+1,:]).ravel()
# #         RefCalc[ind_in[i],:] = ref_interp

#           if len(imin) > 0 and len(imax) > 0:
#               ref_ocean[ind_in[i],:] = ref_interp[imin[0]:imax[0]]
#           if len(imin) == 0 and len(imax) > 0:
#               ref_ocean[ind_in[i],:jmin[0]] = ref_interp[0]
#               ref_ocean[ind_in[i],jmin[0]:] = ref_interp[:imax[0]]
#           if len(imin) > 0 and len(imax) ==  0:
#               ref_ocean[ind_in[i],:jmax[0]] = ref_interp[imin[0]:]
#               #ref_ocean[ind_in[i],jmax[0]:] = ref_interp[-1]
#           if len(imin) == 0 and len(imax) ==  0:
#               ref_ocean[ind_in[i],jmin[0]:jmax[0]] = ref_interp[:]
#               ref_ocean[ind_in[i],:jmin[0]] = ref_interp[0]

#     # Computation for pixels which chl is lower than tabulated chlorphyll content min
#     if len(ind_inf) > 0:
#       ind_inf = ind_inf[0]
#       ref_interp = Cocean.ref_chl_std[0,:]
# #     RefCalc[ind_inf,:] = Cocean.ref_chl_std[0,:]
#       if len(imin) > 0 and len(imax) > 0:
#           ref_ocean[ind_inf,:] = ref_interp[imin[0]:imax[0]]
#       if len(imin) == 0 and len(imax) > 0:
#           ref_ocean[ind_inf,:jmin[0]] = ref_interp[0]
#           ref_ocean[ind_inf,jmin[0]:] = ref_interp[:imax[0]]
#       if len(imin) > 0 and len(imax) ==  0:
#           ref_ocean[ind_inf,:jmax[0]] = ref_interp[imin[0]:]
#               #ref_ocean[ind_in[i],jmax[0]:] = ref_interp[-1]
#       if len(imin) == 0 and len(imax) ==  0:
#           ref_ocean[ind_inf,jmin[0]:jmax[0]] = ref_interp[:]
#           ref_ocean[ind_inf,:jmin[0]] = ref_interp[0]

#     # Computation for pixels which chl is greater than tabulated chlorphyll content max
#     if len(ind_sup) > 0:
#       ind_sup = ind_sup[0]
# #     RefCalc[ind_sup,:] = Cocean.ref_chl_std[-1,:]
#       ref_interp = Cocean.ref_chl_std[-1,:]
#       if len(imin) > 0 and len(imax) > 0:
#           ref_ocean[ind_sup,:] = ref_interp[imin[0]:imax[0]]
#       if len(imin) == 0 and len(imax) > 0:
#           ref_ocean[ind_sup,:jmin[0]] = ref_interp[0]
#           ref_ocean[ind_sup,jmin[0]:] = ref_interp[:imax[0]]
#       if len(imin) > 0 and len(imax) ==  0:
#           ref_ocean[ind_sup,:jmax[0]] = ref_interp[imin[0]:]
#               #ref_ocean[ind_in[i],jmax[0]:] = ref_interp[-1]
#       if len(imin) == 0 and len(imax) ==  0:
#           ref_ocean[ind_sup,jmin[0]:jmax[0]] = ref_interp[:]
#           ref_ocean[ind_sup,:jmin[0]] = ref_interp[0]


#     return ref_ocean

# END reflectance_spectrum_ocean
# ===================================================================




# THIS FUNCTION IS NEVER CALLED AND IS HENCE COMMENTED OUT
#
# ===================================================================
# Create a RGB composite image
#
# Inputs:
#  - reflectance [nlon, nlat, nlmbd]
#  - indices of the lmbd vector for the R, G, and B, channels
#  - wavelength domains (optional) defined in the config.py file
#
# Outputs:
#  - reflectance      array [nlon, nlat, 3]
# ===================================================================
#def reflectance_rgb(ref, ilmbd_R = None, ilmbd_G = None, ilmbd_B = None,
#                             missval = None, fill_missval = 0):
#
#
#
#    ans = np.zeros((ref.shape[0],ref.shape[1],3), np.float32)
#
#    ref_map = ref[:]
#    ref_map = np.ma.masked_equal(ref_map,missval)
#    ref_map = np.ma.filled(ref_map,fill_missval)
#    print "ref_map : ", np.minimum.reduce(ref_map.ravel()),np.maximum.reduce(ref_map.ravel())
#    print "missval = ", missval
#    print "fill_missval", fill_missval
#
#
#    # linear stretching of the reflectance dynamics
#    ref_R = ref_map[:,:,ilmbd_R]
#    ref_G = ref_map[:,:,ilmbd_G]
#    ref_B = ref_map[:,:,ilmbd_B]
#
#    pc = 10
#
#    mnR = percentile.get_percentile(ref_R.ravel(),pc)
#    mxR = percentile.get_percentile(ref_R.ravel(),100-pc)
#    mnG = percentile.get_percentile(ref_G.ravel(),pc)
#    mxG = percentile.get_percentile(ref_G.ravel(),100-pc)
#    mnB = percentile.get_percentile(ref_B.ravel(),pc)
#    mxB = percentile.get_percentile(ref_B.ravel(),100-pc)
#
#    ref_R = (ref_R-mnR)/(mxR-mnR)
#    ref_G = (ref_G-mnG)/(mxG-mnG)
#    ref_B = (ref_B-mnB)/(mxB-mnB)
#
#
#    print "R : ", np.minimum.reduce(ref_R.ravel()),np.maximum.reduce(ref_R.ravel())
#    print "G : ", np.minimum.reduce(ref_G.ravel()),np.maximum.reduce(ref_G.ravel())
#    print "B : ", np.minimum.reduce(ref_B.ravel()),np.maximum.reduce(ref_B.ravel())
#
#
#    ans[:,:,0]  = ref_G
#    ans[:,:,1]  = ref_B
#    ans[:,:,2]  = ref_R
#
#
#    #fig = plt.figure(1,frameon=False)
#    #plt.clf()
#    #plt.imshow(ans, aspect='equal')
#    #plt.show()
#
#
#    return ans
# END compo_RGB
# ===================================================================
