"""
This module defines ModelPredictiveObject class that wraps a GraphLab model
input a predictive object, which can be deployed into a Predicitve Service
"""

import graphlab
import logging
from _predictive_object import PredictiveObject

__logger__ = logging.getLogger(__name__)

class ModelPredictiveObject(PredictiveObject):
  '''Predicitve Object definition for GraphLab models

  Each ModelPredictiveObject wraps one model and provides 'query' interface for
  the model
  '''
  def __init__(self, model, description = ''):

    if not isinstance(model, graphlab.Model):
      raise TypeError('model must be a GraphLab model')

    super(ModelPredictiveObject, self).__init__(description)
    self._set_model(model)

  @property
  def model(self):
    '''Get GraphLab model of the predictive object'''
    if len(self.dependencies) == 1:
      m = self.dependencies['model']
      assert isinstance(m, graphlab.Model)
      return self.dependencies['model']
    elif len(self.dependencies) == 0:
      raise RuntimeError("Model predictive object does not have any dependency model set")
    else:
      raise RuntimeError("Model predictive object should have only one model dependency object")

  @model.setter
  def model(self, model):
    '''Set the model this Predictive Object encapsulates'''
    self._set_model(model)

  def post_load(self):
    """After a predictive object is loaded, do object specific post-load process
    Implemented by derived class
    """
    # get all model methods after loading
    self._set_model(self.model)

  def query(self, input):
    '''Query the model according to input

    Query the model according to the method and query data specified in the
    input.

    Parameters
    ----------
    input : dict
      a dictionary that needs to have the following two keys:
         'method': The method that is supported by given model. Refer to individual
            model for list of supported methods.
         'data' : the actual data that is used to query the model. In the form
            of dictionary, which matches the actual method signature

    Returns:
    --------
    out: dict | list
      Result from the model method. The result could be any type.
      If model method returns SFrame, out will be converted to a list of dictionaries,
      if model method returns SArray, out will be converted to a list of values
    '''
    PredictiveObject._validate_query_input(input)

    method = input['method']
    if not self._model_methods.has_key(method):
      raise ValueError("Method '%s' is not supported for current model" % method)

    data = input['data']
    if type(data) != dict:
      raise TypeError('"data" value has to be a dictionary')

    # do appropriate construction on SFrame or SArray depend on the method
    # definition
    method_description = self._model_methods[method]
    for (param_name, param_type) in method_description.iteritems():
      if not data.has_key(param_name):
        continue

      # do appropriate conversion to SFrame or SArray according to expected type
      if type(data[param_name]) is not list:
        data[param_name] = [data[param_name]]

      gl_data = graphlab.SArray(data[param_name])

      if param_type == 'sframe':
        value_types = set([type(v) for v in data[param_name]])
        if len(value_types) != 1 or value_types.pop() is not dict:
          raise TypeError('Expect all values of %s to be of type dict' % param_name)

        data[param_name] = gl_data.unpack(column_name_prefix=None)
      elif param_type == 'sarray':
        if type(data[param_name]) is not list:
          data[param_name] = [data[param_name]]

        data[param_name] = graphlab.SArray(data[param_name])
      else:
        raise RuntimeError('Unexpected paramter type %s for parameter %s' % (param_type, param_name))

    # call actual method
    func = getattr(self.model.__class__, method)
    result = func(self.model, **data)

    # convert GraphLab object to python data for ease of serialization
    return self._make_serializable(result)

  def _set_model(self, model):
    self.dependencies = {'model':model}
    self._model_methods = model._get_queryable_methods()
    if type(self._model_methods) != dict:
      raise RuntimeError("_get_queryable_methods for model %s should return a \
        dictionary" % model.__class__)

    for (method, description) in self._model_methods.iteritems():
      if type(description) != dict:
        raise RuntimeError("model %s _get_queryable_methods should use dict as method\
          description."% mode.__class__)

      for (param_name, param_type) in description.iteritems():
        if (param_type not in ['sframe', 'sarray']):
          raise RuntimeError("model %s _get_queryable_methods should only use \
            'sframe' or 'sarray' type. %s is not supported" % (model.__class__, param_type))

  def get_doc_string(self):
    '''Returns documentation for the predictive object query'''
    docstring_prefix = 'Note:\n'
    docstring_prefix += '    For input that expects "SFrame" type, you need to pass in a list of dictionaries,\n'
    docstring_prefix += '    for input that expects "SArray" type, you need to pass in a list of values.\n'
    docstring_prefix += '    Similarly, output of type SFrame will be converted to a list of dictionaries,\n'
    docstring_prefix += '    output of type SArray will be converted to a list of values.\n'
    docstring_prefix += '\n'
    docstring_prefix += 'The following methods are available for query for this predictive object:\n'
    docstring_prefix += '    %s' % (';'.join(self._model_methods))
    docstring_prefix += '\n'

    ret = docstring_prefix

    for method in self._model_methods:
      ret += '\n' + method + '\n'
      ret += getattr(self.model, method).__doc__

    return ret

