#!/usr/bin/env python
"""
================================================================================
:mod:`imageraster` -- Image raster widgets
================================================================================

.. module:: imageraster
   :synopsis: Image raster widgets

.. inheritance-diagram:: pyhmsa.gui.spec.datum.imageraster

"""

# Standard library modules.

# Third party modules.
from PySide.QtGui import \
    QSlider, QFormLayout, QHBoxLayout, QRadioButton, QSplitter, QSizePolicy
from PySide.QtCore import Qt, QAbstractTableModel, Signal

import matplotlib
from matplotlib.figure import Figure

import numpy as np

# Local modules.
from pyhmsa.gui.spec.datum.datum import \
    _DatumWidget, _DatumTableWidget, _DatumFigureWidget
from pyhmsa.gui.spec.datum.analysis import \
    Analysis1DTableWidget, Analysis1DGraphWidget

from pyhmsa.spec.datum.analysis import Analysis1D
from pyhmsa.spec.datum.imageraster import ImageRaster2D, ImageRaster2DSpectral

# Globals and constants variables.

class ImageRaster2DTableWidget(_DatumTableWidget):

    class _TableModel(QAbstractTableModel):

        def __init__(self, datum):
            QAbstractTableModel.__init__(self)
            self._datum = datum

        def rowCount(self, parent=None):
            return self._datum.y

        def columnCount(self, parent=None):
            return self._datum.x

        def data(self, index, role):
            if not index.isValid() or \
                    not (0 <= index.row() < self._datum.y) or \
                    not (0 <= index.column() < self._datum.x):
                return None
            if role != Qt.DisplayRole:
                return None

            row = index.row()
            column = index.column()
            return str(self._datum[column, row])

        def headerData(self, section , orientation, role):
            if role != Qt.DisplayRole:
                return None
            if orientation == Qt.Horizontal:
                return str(section + 1)
            elif orientation == Qt.Vertical:
                return str(section + 1)

    def __init__(self, datum=None, parent=None):
        _DatumTableWidget.__init__(self, ImageRaster2D, datum, parent)

    def _createModel(self, datum):
        return self._TableModel(datum)

class ImageRaster2DGraphWidget(_DatumFigureWidget):

    def __init__(self, datum=None, parent=None):
        _DatumFigureWidget.__init__(self, ImageRaster2D, datum, parent)

    def _createFigure(self):
        fig = Figure((5.0, 4.0), dpi=100)

        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        self._aximage = None

        return fig

    def _drawFigure(self, fig, datum):
        if self._aximage is None: # First draw
            self._aximage = fig.axes[0].imshow(datum.T)
            fig.colorbar(self._aximage, shrink=0.8)
        else:
            self._aximage.set_data(datum.T)
            self._aximage.autoscale()

class _ImageRaster2DSpectralWidget(_DatumWidget):

    valueSelected = Signal(int, int)

    def __init__(self, datum=None, parent=None):
        self._datum = None
        _DatumWidget.__init__(self, ImageRaster2DSpectral, datum, parent)

    def _initUI(self):
        # Widgets
        self._rdb_sum = QRadioButton("Sum")
        self._rdb_sum.setChecked(True)
        self._rdb_max = QRadioButton("Maximum")
        self._rdb_single = QRadioButton("Single")
        self._rdb_range = QRadioButton("Range")

        self._sld_start = QSlider(Qt.Orientation.Horizontal)
        self._sld_start.setTickPosition(QSlider.TicksBelow)
        self._sld_start.setEnabled(False)

        self._sld_end = QSlider(Qt.Orientation.Horizontal)
        self._sld_end.setTickPosition(QSlider.TicksBelow)
        self._sld_end.setEnabled(False)

        self._wdg_imageraster2d = self._createImageRaster2DWidget()
        self._wdg_analysis = self._createAnalysis1DWidget()

        # Layouts
        layout = _DatumWidget._initUI(self)

        sublayout = QHBoxLayout()
        sublayout.addWidget(self._rdb_sum)
        sublayout.addWidget(self._rdb_max)
        sublayout.addWidget(self._rdb_single)
        sublayout.addWidget(self._rdb_range)
        layout.addLayout(sublayout)

        sublayout = QFormLayout()
        sublayout.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow) # Fix for Mac OS
        sublayout.addRow('Channels (Start)', self._sld_start)
        sublayout.addRow('Channels (End)', self._sld_end)
        layout.addLayout(sublayout)

        splitter = QSplitter()
        splitter.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        splitter.addWidget(self._wdg_imageraster2d)
        splitter.addWidget(self._wdg_analysis)
        layout.addWidget(splitter)

        # Signals
        self._rdb_sum.toggled.connect(self._onModeSum)
        self._rdb_max.toggled.connect(self._onModeMax)
        self._rdb_single.toggled.connect(self._onModeSingle)
        self._rdb_range.toggled.connect(self._onModeRange)

        self._sld_start.valueChanged.connect(self._onSlideStart)
        self._sld_end.valueChanged.connect(self._onSlideEnd)

        self.valueSelected.connect(self._onValueSelected)

        return layout

    def _createAnalysis1DWidget(self):
        raise NotImplementedError

    def _createImageRaster2DWidget(self):
        raise NotImplementedError

    def _onModeSum(self, checked):
        self._sld_start.setEnabled(not checked)
        self._sld_end.setEnabled(not checked)
        self._updateModeSum()

    def _onModeMax(self, checked):
        self._sld_start.setEnabled(not checked)
        self._sld_end.setEnabled(not checked)
        self._updateModeMax()

    def _onModeSingle(self, checked):
        self._sld_start.setEnabled(checked)
        self._sld_end.setEnabled(not checked)
        self._updateModeSingle()

    def _onModeRange(self, checked):
        self._sld_start.setEnabled(checked)
        self._sld_end.setEnabled(checked)
        self._updateModeRange()

    def _updateData(self):
        if self._rdb_sum.isChecked():
            self._updateModeSum()
        elif self._rdb_max.isChecked():
            self._updateModeMax()
        elif self._rdb_single.isChecked():
            self._updateModeSingle()
        elif self._rdb_range.isChecked():
            self._updateModeRange()

    def _updateModeSum(self):
        if self._datum is None:
            return
        subdatum = np.sum(self._datum, 2)
        self._wdg_imageraster2d.setDatum(subdatum)

    def _updateModeMax(self):
        if self._datum is None:
            return
        subdatum = np.amax(self._datum, 2)
        self._wdg_imageraster2d.setDatum(subdatum)

    def _updateModeSingle(self):
        if self._datum is None:
            return
        channel = self._sld_start.value()
        subdatum = self._datum[:, :, channel]
        self._wdg_imageraster2d.setDatum(subdatum)

    def _updateModeRange(self):
        if self._datum is None:
            return
        start = self._sld_start.value()
        end = self._sld_end.value()
        start2 = min(start, end)
        end2 = max(start, end)
        subdatum = np.sum(self._datum[:, :, start2:end2 + 1], 2)
        self._wdg_imageraster2d.setDatum(subdatum)

    def _onSlideStart(self, channel):
        if self._rdb_single.isChecked():
            self._updateModeSingle()
        elif self._rdb_range.isChecked():
            self._updateModeRange()

    def _onSlideEnd(self, channel):
        self._updateModeRange()

    def _onValueSelected(self, x, y):
        if self._datum is None:
            return
        subdatum = self._datum[x, y]
        self._wdg_analysis.setDatum(subdatum.view(Analysis1D))
        self._updateData()

    def setDatum(self, datum):
        _DatumWidget.setDatum(self, datum)
        self._datum = datum

        self._sld_start.setMaximum(datum.channels - 1)
        self._sld_end.setMaximum(datum.channels - 1)

        self._updateData()

class ImageRaster2DSpectralTableWidget(_ImageRaster2DSpectralWidget):

    def _createImageRaster2DWidget(self):
        widget = ImageRaster2DTableWidget()
        widget._table.clicked.connect(self._onTableClicked)
        return widget

    def _createAnalysis1DWidget(self):
        return Analysis1DTableWidget()

    def _onTableClicked(self, index):
        self.valueSelected.emit(index.column(), index.row())

class ImageRaster2DSpectralGraphWidget(_ImageRaster2DSpectralWidget):

    def _createImageRaster2DWidget(self):
        widget = ImageRaster2DGraphWidget()
        widget._canvas.mpl_connect("button_release_event", self._onFigureClicked)
        return widget

    def _createAnalysis1DWidget(self):
        widget = Analysis1DGraphWidget()

        fig = widget._canvas.figure
        ax = fig.axes[0]

        color = matplotlib.rcParams['axes.color_cycle'][1]

        self._ax_single = ax.axvline(0, lw=3, color=color, zorder=3)
        self._ax_single.set_visible(False)

        self._ax_range = ax.axvspan(0, 0, alpha=0.5, facecolor=color, zorder=3)
        self._ax_range.set_visible(False)

        return widget

    def _onFigureClicked(self, event):
        if not event.inaxes:
            return
        self.valueSelected.emit(event.xdata, event.ydata)

    def _onModeSum(self, checked):
        _ImageRaster2DSpectralWidget._onModeSum(self, checked)
        self._ax_single.set_visible(not checked)
        self._ax_range.set_visible(not checked)
        self._wdg_analysis._canvas.draw()

    def _onModeMax(self, checked):
        _ImageRaster2DSpectralWidget._onModeMax(self, checked)
        self._ax_single.set_visible(not checked)
        self._ax_range.set_visible(not checked)
        self._wdg_analysis._canvas.draw()

    def _onModeSingle(self, checked):
        _ImageRaster2DSpectralWidget._onModeSingle(self, checked)
        self._ax_single.set_visible(checked)
        self._ax_range.set_visible(not checked)
        self._wdg_analysis._canvas.draw()

    def _onModeRange(self, checked):
        _ImageRaster2DSpectralWidget._onModeRange(self, checked)
        self._ax_single.set_visible(not checked)
        self._ax_range.set_visible(checked)
        self._wdg_analysis._canvas.draw()

    def _updateModeSingle(self):
        _ImageRaster2DSpectralWidget._updateModeSingle(self)

        axplot = self._wdg_analysis._axplot
        if axplot is None:
            return

        channel = self._sld_start.value()

        xs = axplot.get_xdata()
        ys = axplot.get_ydata()

        x = xs[channel]
        ymin = min(ys); ymax = max(ys)
        self._ax_single.set_xdata([x, x])
        self._ax_single.set_ydata([ymin, ymax])

        self._wdg_analysis._canvas.draw()

    def _updateModeRange(self):
        _ImageRaster2DSpectralWidget._updateModeRange(self)

        axplot = self._wdg_analysis._axplot
        if axplot is None:
            return

        start = self._sld_start.value()
        end = self._sld_end.value()

        xs = axplot.get_xdata()
        ys = axplot.get_ydata()

        xmin = xs[start]; xmax = xs[end]
        ymin = min(ys); ymax = max(ys)
        self._ax_range.set_xy([[xmin, ymin],
                               [xmin, ymax],
                               [xmax, ymax],
                               [xmax, ymin]])

        self._wdg_analysis._canvas.draw()


