"""
Methods for creating and using a Vowpal Wabbit model. This package provides the
ability to fit models using `Vowpal Wabbit
<https://github.com/JohnLangford/vowpal_wabbit>`_, an open source project meant
for large-scale online learning. This includes implementations of a variety of
models -- linear and logistic regression and others. Importantly it is
straightforward to flexibly use a large number of unique features, e.g. words in
a document, through `hashing <http://github.com/JohnLangford/vowpal_wabbit/wiki/
Feature-Hashing-and-Extraction>`_.

While there are a variety of Python wrappers for Vowpal Wabbit, this one is
directly integrated with our disk-backed :py:class:`~graphlab.SFrame`; this can
make it easier to interactively create new features for Vowpal Wabbit models.
"""

import graphlab as _graphlab
import graphlab.connect as _mt
from graphlab.toolkits._model import Model as _Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray
from graphlab.deps import pandas as _pandas, HAS_PANDAS as _HAS_PANDAS
import time as _time
from graphlab.toolkits._main import ToolkitError as _ToolkitError
from graphlab.toolkits._model import _get_default_options_wrapper as \
                                                  __get_default_options_wrapper


get_default_options = __get_default_options_wrapper(
                          'vowpal_wabbit', 
                          'vowpal_wabbit', 
                          'VowpalWabbitModel')

def create(dataset, target,
           loss_function='squared',
           quadratic=[],
           l1_penalty=0.0, l2_penalty=0.0,
           bigram=False,
           step_size=0.5, num_bits=18, verbose=False,
           max_iterations=1,
           command_line_args=''):
    """
    create(dataset, target, loss_function='squared', quadratic=list(), 
    l1_penalty=0.0, l2_penalty=0.0, bigram=False, step_size=0.5, num_bits=18, 
    verbose=False, max_iterations=1, command_line_args='')

    Learn a large linear model using Vowpal Wabbit.

    Parameters
    ----------
    dataset : SFrame
        A data set. Due to the way Vowpal Wabbit creates features from each
        entry, ':' and '|' characters are not allowed in any columns containing
        strings. Each row of the dataset is translated into a string and passed
        to Vowpal Wabbit. Currently, the upper bound on the size of the string
        is 1MB. Based on the type of the SArray column, the values are passed in
        the following ways.

        - *integer* or *float*: the value is passed directly to VW.

        - *str*: the name of the column is used as the namespace, followed by
          the entire string.

        - *dict*: the name of the column is used as the namespace, and each
          key-value pair is a feature. The keys of the dictionary must be string
          or numeric and the values must be numeric (integer or float).

        - *array*: the name of the column is used as the namespace, the index of
          the array element is used as the name of the feature, and only numeric
          elements in the array are passed to VW.

        - *list (recursive type)*: the name of the column is used as the
          namespace, the index of the list element is used as the name of the
          feature, and currently only numeric elements (integer or float) are
          passed to VW.

        See the `VW input format guidelines
        <https://github.com/JohnLangford/vowpal_wabbit/wiki/Input-format>`_ for
        more details.

    target : string
        The name of the column in ``dataset`` that is the prediction target.
        This column must have a numeric type.

    loss_function : {'squared', 'hinge', 'logistic', 'quantile'}, optional
        This defines the `loss function
        <http://en.wikipedia.org/wiki/Loss_function>`_ used during optimization.
        Typical choices:

        - *real-valued target*: `squared error loss
          <http://en.wikipedia.org/wiki/Mean_squared_error>`_.

        - *binary target*: `logistic
          <http://en.wikipedia.org/wiki/Logistic_regression>`_. The target
          column must only contain -1 or 1.

        The `hinge loss <http://en.wikipedia.org/wiki/Hinge_loss>`_ is also used
        for classification, while `quantile loss
        <http://en.wikipedia.org/wiki/Quantile_regression>`_ can be good when
        one aims to predict quantities other than the mean.

    quadratic : list of pairs, optional
        This will add `interaction terms
        <http://en.wikipedia.org/wiki/Interaction_(statistics)>`_ to a linear
        model between a pair of columns. Quadratic terms add a parameter in the
        model for the product of two features, i.e. if we include an interaction
        between :math:`x_1` and :math:`x_2`, we can add a parameter :math:`b_3`.

            .. math:: y_i =  a + b_1 * x_{i1} + b_2 * x_{i2} + b_3 * x_{i1} * x_{i2}

        Multiple quadratic terms can be added by including multiple pairs, e.g.
        ``quadratic = [('a', 'b'), ('b', 'c')]`` would add interaction terms
        between columns names 'a' and 'b' as well as terms for interactions
        between 'b' and 'c'. Including ':' as one of the items in the pairs is a
        shortcut for adding quadratic terms for all pairs of features. Due to
        Vowpal Wabbit's implementation, quadratic terms are determined by the
        first letter of the column name.

    l1_penalty : float, optional
        This defines how strongly you want to keep parameters to be zero.

    l2_penalty : float, optional
        This defines how strongly you want to keep parameters near zero.
        Specifically it adds a penalty of :math:`.5 * \lambda * |w|_2^2` to the
        weight vector w, where lambda is the provided regularization value.

    bigram : bool, optional
        Add bigram features. For columns containing the text "my name is bob"
        this will add bigram features for "my name", "name is", "is bob".

    step_size : float, optional
        Set the learning rate for online learning.

    verbose : bool, optional
        Print first 10 rows as they are seen by VowpalWabbit.
        This is useful for debugging.

    max_iterations : int, optional
        Number of passes to take over the data set.

    command_line_args : string, optional
        Additional arguments to pass to Vowpal Wabbit, just as one would use
        when using VW via the command line.

    Returns
    -------
    out : VowpalWabbitModel
        A model that can be used for predicting new cases.

    See Also
    --------
    VowpalWabbitModel.predict, VowpalWabbitModel.evaluate

    Notes
    -----
    - Other desired command line arguments can be provided manually through the
      command_line_args keyword argument. See the `VW documentation <http://gith
      ub.com/JohnLangford/vowpal_wabbit/wiki/Command-line-arguments>`_ for more
      details.

    - Several Vowpal Wabbit features are not yet supported, including importance
      weighted learning.

    Examples
    --------
    >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
    >>> data['price'] = data['price'].apply(lambda x: 1 if x > 30000 else -1)
    >>> m = graphlab.vowpal_wabbit.create(data, 'price')

    To add quadratic terms between 'user' and 'movie' columns:

    >>> m = graphlab.vowpal_wabbit.create(sf, 'rating', quadratic=[('user', 'movie')])

    If a column contains text, each space-separated word is used as a
    unique feature. Often times it is useful to also include bigrams as
    features. This can be done easily with the ``bigram`` argument:

    >>> m = graphlab.vowpal_wabbit.create(sf, 'rating', bigram=True)
    """

    _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.create')

    if not (isinstance(dataset, _SFrame)):
        raise TypeError("Input 'dataset' must be an SFrame")

    if type(dataset) != _SFrame:
        dataset = _SFrame(dataset)

    assert target in dataset.column_names(), "No target provided."

    quadratic_command = ''
    for (feature_a, feature_b) in quadratic:
        # VW uses first letter to describe namespace
        quadratic_command += ' -q ' + feature_a[0] + feature_b[0]

    opts = {'verbose': verbose,
            'target': target,
            'loss_function': loss_function,
            'quadratic': quadratic_command,
            'step_size': step_size,
            'l1_penalty': l1_penalty,
            'l2_penalty': l2_penalty,
            'num_bits' : num_bits,
            'max_iterations': max_iterations,
            'bigram': bigram,
            'extra_command_line_args': command_line_args}

    # Initialize the model with basic parameters
    response = _graphlab.toolkits._main.run("vw_init", opts)
    m = VowpalWabbitModel(response['model'])

    # Train the model on the given data set and retrieve predictions
    opts = {'model': m.__proxy__,
            'data': dataset}
    response = _graphlab.toolkits._main.run("vw_train", opts)
    m = VowpalWabbitModel(response['model'])

    yhat = _SArray(None, _proxy=response['predictions'])

    # Evaluate model
    start_time = _time.time()
    y = dataset[target]

    if loss_function == 'logistic':
        is_one_or_neg_one = y.apply(lambda x: x == 1 or x == -1)
        if not all(is_one_or_neg_one):
            raise TypeError('When using `logistic` as a loss function, the target column must contain only 1\'s and -1\'s.')
        y = y.apply(lambda x: int(x*.5 + .5))
        m = m._set('training_accuracy', _graphlab.evaluation.accuracy(y, yhat))
    else:
        m = m._set('training_rmse', _graphlab.evaluation.rmse(y, yhat))
    return m


class VowpalWabbitModel(_Model):
    """
    Wrapper around Vowpal Wabbit.
    """

    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy
        self.__name__ = 'vowpal_wabbit'

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return VowpalWabbitModel(model_proxy)
        return model_wrapper

    def list_fields(self):
        """
        List the fields stored in the model, including data, model, and training
        options. Each field can be queried with the ``get`` method.

        Returns
        -------
        out : list
            List of fields queryable with the ``get`` method.

        See Also
        --------
        get
 

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
        >>> data['price'] = data['price'].apply(lambda x: 1 if x > 30000 else -1)
        >>> m = graphlab.vowpal_wabbit.create(data, 'price')
        >>> m.list_fields()
        ['extra_command_line_args',
        'bigram',
        'target',
        'step_size',
        'l1_penalty',
        'elapsed_time',
        'cache_file',
        'final_regressor_file',
        'num_bits',
        'training_rmse',
        'loss_function',
        'quadratic',
        'l2_penalty',
        'verbose',
        'initial_regressor_file',
        'command_line_args',
        'max_iterations']
        """

        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.list_fields')
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        response = _graphlab.toolkits._main.run('vw_list_keys', opts)
        return sorted(response.keys())

    def get(self, field):
        """
        Return the value of a given field. The list of all queryable fields can
        be obtained with the
        :func:`~graphlab.vowpal_wabbit.VowpalWabbitModel.list_fields` method.

        Parameters
        ----------
        field : string
            Name of the field to be retrieved.

        Returns
        -------
        out
            Value of the requested field.

        See Also
        --------
        list_fields

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
        >>> data['price'] = data['price'].apply(lambda x: 1 if x > 30000 else -1)
        >>> m = graphlab.vowpal_wabbit.create(data, 'price')
        >>> m.get('step_size')
        .5
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.get')

        opts = {'model': self.__proxy__,
                'model_name': self.__name__,
                'field': field}
        response = _graphlab.toolkits._main.run('vw_get_value',
                                               opts)
        return response['value']

        opts = {'model': self.__proxy__, 'model_name': 'vowpal_wabbit'}
        response = _graphlab.toolkits._main.run("vw_get_value", opts)
        fields = response
        return fields[field]

    def _set(self, field, value):
        opts = {'model': self.__proxy__,
                'key': field,
                'value': value}
        response = _graphlab.toolkits._main.run("vw_set_model_description", opts)
        return VowpalWabbitModel(response['model'])

    def __str__(self):
        """
        Return a string description of the model to the ``print`` method.

        Returns
        -------
        out : string
            A description of the NearestNeighborsModel.
        """

        opts = {'model': self.__proxy__,
                'model_name': self.__name__}
        fields = _graphlab.toolkits._main.run("vw_get_current_options", opts)

        model_fields = [
            ("Target column", 'target'),
            ("Loss function", 'loss_function'),
            ("Step size",     'step_size'),
            ("L1 penalty",    'l1_penalty'),
            ("L2 penalty",    'l2_penalty'),
            ("Verbose",       'verbose'),
            ("Number of bits",'num_bits'),
            ("Max iterations", 'max_iterations')]

        if 'training_rmse' in fields.keys():
            model_fields.append(("Training RMSE", 'training_rmse'))
        if 'training_accuracy' in fields.keys():
            model_fields.append(("Training accuracy", 'training_accuracy'))
        return _graphlab.toolkits._internal_utils._toolkit_repr_print(self, \
                                                [model_fields], width=24)


    def __repr__(self):
        """
        Print a string description of the model when the model name is entered
        in the terminal.
        """

        return self.__str__()

    def summary(self):
        """
        Display a summary of the NearestNeighborsModel.

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
        >>> data['price'] = data['price'].apply(lambda x: 1 if x > 30000 else -1)
        >>> m = graphlab.vowpal_wabbit.create(data, 'price')
        >>> m.summary()
        Vowpal Wabbit Model:
            target column:  price
            loss function:  squared
            step size:      0.5
            L1 penalty:     0.0
            L2 penalty:     0.0
            verbose:        0
            bits:           18
            max iterations: 1
            training_rmse: 0.414910189668
        """
        opts = {'model': self.__proxy__,
                'model_name': self.__name__}

        fields = _graphlab.toolkits._main.run("vw_get_current_options", opts)

        print ""
        print "                    Model summary                       "
        print "--------------------------------------------------------"

        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.summary')
        print self.__repr__()


    def _training_stats(self):
        """
        Get information about model creation, e.g. time elapsed during model
        fitting, data loading, and more.

        Returns
        -------
        out : dict
            Statistics about model training, e.g. runtime.

        See Also
        --------
        summary

        Examples
        --------
        >>> data =  graphlab.SFrame('http://s3.amazonaws.com/GraphLab-Datasets/regression/houses.csv')
        >>> data['price'] = data['price'].apply(lambda x: 1 if x > 30000 else -1)
        >>> m = graphlab.vowpal_wabbit.create(data, 'price')
        >>> m._training_stats()
        {'elapsed_time': 0.006756,
        'training_rmse': 0.4149101896680863}
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit._training_stats')
        opts = {'model': self.__proxy__, 'model_name': self.__name__}
        response = _graphlab.toolkits._main.run("vw_training_stats", opts)
        return response

    def predict(self, dataset):
        """
        Use the trained :class:`~graphlab.vowpal_wabbit.VowpalWabbitModel` to make
        predictions about the target column that was provided during
        :func:`~graphlab.vowpal_wabbit.create`.

        Parameters
        ----------
        dataset : SFrame
            A dataset that has the same columns that were used during training.
            If the target column exists in ``dataset`` it will be ignored while
            making predictions.

        Returns
        -------
        out : SArray
            Predicted target value for each example (i.e. row) in the dataset.

        See Also
        --------
        evaluate
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.predict')

        opts = {'model': self.__proxy__,
                'data': dataset}
        response = _graphlab.toolkits._main.run("vw_predict", opts)

        # Convert predictions to an SArray
        return _SArray(None, _proxy=response['predictions'])

    def evaluate(self, dataset):
        """
        Evaluate the model by making predictions of target values and comparing
        these to actual values. Currently, this method only supports vw models
        trained with ``squared`` or ``logistic`` loss.

        If the model is trained with ``squared`` loss, the evaluation metrics
        are root-mean-squared error (RMSE) and the absolute value of the maximum
        error between the actual and predicted values. Let :math:`y` and
        :math:`\hat{y}` denote vectors of length :math:`N` (number of examples)
        with actual and predicted values. The RMSE is defined as:

        .. math::

            RMSE = \sqrt{\\frac{1}{N} \sum_{i=1}^N (\widehat{y}_i - y_i)^2}.

        The max-error is defined as

        .. math::

            max\_error = \max_{i=1}^N \|\widehat{y}_i - y_i\| .

        If the model is trained with ``logistic`` loss, then the model is
        evaluated as a classifier with a decision threshold of 0. The metrics
        are classification accuracy and confusion matrix.  Classification
        accuracy is the fraction of examples whose predicted and actual classes
        match. The confusion matrix contains the cross-tabulation of actual and
        predicted classes for the target variable.

        Parameters
        ----------
        dataset : SFrame
            Dataset in the same format as the SFrame used to train the model.
            The columns names and types must be the same as that used in
            training, including the target column.

        Returns
        -------
        out : dict
            Dictionary of evaluation results. For ``squared`` loss, the
            dictionary keys are *rmse* and *max_error*.  For ``logistic`` loss,
            the dictionary keys are *accuracy* and *confusion_matrix*.

        See Also
        --------
        predict

        References
        ----------
        - `Wikipedia - confusion matrix
          <http://en.wikipedia.org/wiki/Confusion_matrix>`_
        - `Wikipedia - root-mean-square deviation
          <http://en.wikipedia.org/wiki/Root-mean-square_deviation>`_
        """
        _mt._get_metric_tracker().track('toolkit.vowpal_wabbit.evaluate')

        target_column = self.get('target')
        if target_column not in dataset.column_names():
            raise _ToolkitError, \
                "Input dataset must contain a target column for " \
                "evaluation of prediction quality."

        targets = dataset[target_column]
        predictions = self.predict(dataset)

        loss = self.get('loss_function')
        if loss == 'squared':
            rmse = _graphlab.evaluation.rmse(targets, predictions)
            max_error = _graphlab.evaluation.max_error(targets, predictions)
            return {'rmse': rmse,
                    'max_error': max_error}
        elif loss == 'logistic':
            accuracy = _graphlab.evaluation.accuracy(targets, predictions)
            confusion_matrix = _graphlab.evaluation.confusion_matrix(targets, \
                                  predictions)
            return {'accuracy' : accuracy,
                    'confusion_matrix' : confusion_matrix }
        else:
            raise _ToolkitError, "VW evaluate currently only supports models trained with squared or logistic loss."

    @classmethod
    def _get_queryable_methods(cls):
        '''Returns a list of method names that are queryable through Predictive
        Service'''
        return {'predict':{'dataset':'sframe'}}

