#!/usr/bin/env python
"""
================================================================================
:mod:`analysis` -- Analysis widgets
================================================================================

.. module:: analysis
   :synopsis: Analysis widgets

.. inheritance-diagram:: pyhmsa.gui.spec.datum.analysis

"""

# Standard library modules.

# Third party modules.
from PySide.QtCore import Qt, QAbstractTableModel

from matplotlib.figure import Figure

# Local modules.
from pyhmsa.gui.spec.datum.datum import _DatumTableWidget, _DatumFigureWidget

from pyhmsa.spec.datum.analysis import Analysis0D, Analysis1D, Analysis2D
from pyhmsa.spec.condition.detector import DetectorSpectrometer

# Globals and constants variables.

class Analysis0DTableWidget(_DatumTableWidget):

    class _TableModel(QAbstractTableModel):

        def __init__(self, datum):
            QAbstractTableModel.__init__(self)
            self._datum = datum

        def rowCount(self, parent=None):
            return 1

        def columnCount(self, parent=None):
            return 1

        def data(self, index, role):
            if not index.isValid() or not (0 <= index.row() < 1):
                return None
            if role != Qt.DisplayRole:
                return None

            return str(self._datum)

        def headerData(self, section , orientation, role):
            if role != Qt.DisplayRole:
                return None
            if orientation == Qt.Horizontal:
                return 'Value'
            elif orientation == Qt.Vertical:
                return str(section + 1)

    def __init__(self, datum=None, parent=None):
        _DatumTableWidget.__init__(self, Analysis0D, datum, parent)

    def _createModel(self, datum):
        return self._TableModel(datum)

class Analysis1DTableWidget(_DatumTableWidget):

    class _TableModel(QAbstractTableModel):

        def __init__(self, datum):
            QAbstractTableModel.__init__(self)
            self._datum = datum

            conditions = datum.conditions.findvalues(DetectorSpectrometer)
            if conditions:
                self._calibration = next(iter(conditions)).calibration
            else:
                self._calibration = None

        def rowCount(self, parent=None):
            return self._datum.channels

        def columnCount(self, parent=None):
            return 2 if self._calibration is not None else 1

        def data(self, index, role):
            if not index.isValid() or not (0 <= index.row() < self._datum.channels):
                return None
            if role != Qt.DisplayRole:
                return None

            row = index.row()
            column = index.column()
            if self._calibration is not None:
                if column == 0:
                    return str(self._calibration(row))
                elif column == 1:
                    return str(self._datum[row])
            else:
                return str(self._datum[row])

        def headerData(self, section , orientation, role):
            if role != Qt.DisplayRole:
                return None
            if orientation == Qt.Horizontal:
                if self._calibration is not None:
                    if section == 0:
                        return '%s (%s)' % (self._calibration.quantity,
                                            self._calibration.unit)
                    elif section == 1:
                        return 'Value'
                else:
                    return 'Value'
            elif orientation == Qt.Vertical:
                return str(section + 1)

    def __init__(self, datum=None, parent=None):
        _DatumTableWidget.__init__(self, Analysis1D, datum, parent)

    def _createModel(self, datum):
        return self._TableModel(datum)

class Analysis1DGraphWidget(_DatumFigureWidget):

    def __init__(self, datum=None, parent=None):
        _DatumFigureWidget.__init__(self, Analysis1D, datum, parent)

    def _createFigure(self):
        fig = Figure((5.0, 4.0), dpi=100)

        fig.add_subplot("111")

        self._axplot = None

        return fig

    def _drawFigure(self, fig, datum):
        # Extract data and labels
        xlabel, ylabel, xy = datum.get_xy(with_labels=True)

        # Draw
        ax = fig.axes[0]

        if self._axplot is None:
            self._axplot = ax.plot(xy[:, 0], xy[:, 1], zorder=1)[0]
        else:
            self._axplot.set_data((xy[:, 0], xy[:, 1]))
            ax.relim()
            ax.autoscale(tight=True)

        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

class Analysis2DTableWidget(_DatumTableWidget):

    class _TableModel(QAbstractTableModel):

        def __init__(self, datum):
            QAbstractTableModel.__init__(self)
            self._datum = datum

        def rowCount(self, parent=None):
            return self._datum.v

        def columnCount(self, parent=None):
            return self._datum.u

        def data(self, index, role):
            if not index.isValid() or \
                    not (0 <= index.row() < self._datum.v) or \
                    not (0 <= index.column() < self._datum.u):
                return None
            if role != Qt.DisplayRole:
                return None

            row = index.row()
            column = index.column()
            return str(self._datum[column, row])

        def headerData(self, section , orientation, role):
            if role != Qt.DisplayRole:
                return None
            if orientation == Qt.Horizontal:
                return str(section + 1)
            elif orientation == Qt.Vertical:
                return str(section + 1)

    def __init__(self, datum=None, parent=None):
        _DatumTableWidget.__init__(self, Analysis2D, datum, parent)

    def _createModel(self, datum):
        return self._TableModel(datum)

class Analysis2DGraphWidget(_DatumFigureWidget):

    def __init__(self, datum=None, parent=None):
        _DatumFigureWidget.__init__(self, Analysis2D, datum, parent)

    def _createFigure(self):
        fig = Figure((5.0, 4.0), dpi=100)

        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        self._aximage = None

        return fig

    def _drawFigure(self, fig, datum):
        if self._aximage is None: # First draw
            self._aximage = fig.axes[0].imshow(datum.T)
            fig.colorbar(self._aximage, shrink=0.8)
        else:
            self._aximage.set_data(datum.T)
            self._aximage.autoscale()
