#!/usr/bin/env python
# encoding: utf-8
"""
plotting_util.py

Purpose: Utility methods for plotting, not linked to input data formatting.

Created by Eva Schiffer <eva.schiffer@ssec.wisc.edu> on 2017-04-05.
Copyright (c) 2017 University of Wisconsin Regents. All rights reserved.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

from pylab import *

import matplotlib.pyplot as plt
import matplotlib.cm as cm

import logging
import numpy
from   numpy import ma

import cartopy
import cartopy.crs as crs
import cartopy.feature as cfeature

from distutils.version import StrictVersion

from constants import *

LOG = logging.getLogger(__name__)

# check whether we have an old cartopy or one that's new enough to handle the sweep direction correctly for GOES-R and later
have_new_cartopy = False
if StrictVersion(cartopy.__version__) >= StrictVersion("0.16.0") :
    have_new_cartopy = True
else :
    LOG.warn("Older version of cartopy (pre 0.16.0) detected. The sweep parameter can not be set "
             "correctly in this version which may lead to slight navigational misalignment in the "
             "resulting quicklooks images.")

def _add_colorbar(axes, plot_object, place_str="right", size_percent_str="5%", pad_inches=0.05,
                  units=None, valid_range=None, data_shape=None, tick_positions=None, tick_labels=None,
                  extend_constant=None, ):
    """
    Add a colorbar to our image next to the plot in axes.

    :param axes: The axes for the plot we want a colorbar for.
    :param place_str: Which side the colorbar should be on as a string.
    :param size_percent_str: The percentage size the colorbar should be out of the size of the original plot.
    :param pad_inches: How far the colorbar should be from the plot in axes.
    :return:
    """

    # if we were given a data shape, override the place_str and try to pick a place for the axes intelligently
    which_side = place_str
    if data_shape is not None :
        if data_shape[0] < data_shape[1] :
            which_side = "bottom"
        else :
            which_side = "right"

    # if we're on the bottom or left of the figure, fix the padding to miss the bottom labels
    pad_to_use = pad_inches
    if which_side == "left" or which_side == "bottom" :
        pad_to_use += 0.25

    # figure out the orientation
    orientation_temp = 'horizontal'
    if which_side == "right" or which_side == "left" :
        orientation_temp = 'vertical'

    # make the colorbar
    if extend_constant is not None :
        cbar = plt.colorbar(plot_object, ax=axes, orientation=orientation_temp, extend=extend_constant, )
    else :
        cbar = plt.colorbar(plot_object, ax=axes, orientation=orientation_temp, )

    # set various tick mark related stuff on the colorbar
    if ( (valid_range is not None and (valid_range[1] - valid_range[0] > 5000.0) and (orientation_temp == 'horizontal')) or
         ((tick_labels is not None) and (len(tick_labels) > 3)) or
         (tick_labels is not None and orientation_temp == 'vertical')):
        cbar.ax.tick_params(labelsize=6)
    if tick_positions is not None :
        cbar.set_ticks(tick_positions)
    if tick_labels is not None :
        cbar.set_ticklabels(tick_labels)
        cbar.ax.tick_params(axis="both", which="both", length=0)

    # add the units to the colorbar
    if (units is not None) and (str.lower(str(units)) != "none") and (str.lower(str(units)) != "1"):
        temp_units = "[" + units + "]"
        # if we have a very long units text, make it smaller sized text so it will hopefully fit
        if len(temp_units) > 20 :
            cbar.set_label(temp_units, fontstyle='italic', fontsize=8,)
        else :
            cbar.set_label(temp_units, fontstyle='italic',)

    return cbar

# Note: we are using this for our category data, but it only works because our category data
# is very limited in it's range and values
def _cmap_discretize(cmap, N):
    """Return a discrete colormap from the continuous colormap cmap.

        cmap: colormap instance, eg. cm.jet.
        N: number of colors.

    Example
        x = resize(arange(100), (5,100))
        djet = cmap_discretize(cm.jet, 5)
        imshow(x, cmap=djet)

    Note: this function came from the scipy cookbook
    http://scipy-cookbook.readthedocs.io/items/Matplotlib_ColormapTransformations.html
    """

    if type(cmap) == str:
        cmap = get_cmap(cmap)
    colors_i = concatenate((linspace(0, 1., N), (0., 0., 0., 0.)))
    colors_rgba = cmap(colors_i)
    indices = linspace(0, 1., N + 1)
    cdict = {}
    for ki, key in enumerate(('red', 'green', 'blue')):
        cdict[key] = [(indices[i], colors_rgba[i - 1, ki], colors_rgba[i, ki]) for i in range(N + 1)]
    # Return colormap object.
    return matplotlib.colors.LinearSegmentedColormap(cmap.name + "_%d" % N, cdict, 1024)

def _cap_extent (my_extent, projection_object, ) :
    """
    Make sure my extent doesn't go beyond the available extent for the projection.

    :param my_extent:
    :param projection_object:
    :return:
    """

    global_bounds = projection_object.boundary.bounds
    g_extent = (global_bounds[0], global_bounds[2], global_bounds[1], global_bounds[3])

    return [max(my_extent[0], g_extent[0]), min(my_extent[1], g_extent[1]), max(my_extent[2], g_extent[2]), min(my_extent[3], g_extent[3]),]

def _position_category_labels(flag_values, ):
    """
    Given a list of flag values that are an increasing list of integers that starts with 0,
    position the labels so the first and last ticks are not at the ends, but rather in the
    middle of their segments.

    Given something like [0, 1, 2, 3,] we want to put each of the ticks in the "middle" of their
    numbered segment. So there are 4 categories in 3 segments of space. The width of each segment is
    (3 - 0) / 4 and we need to space the first and last labels in 1/2 of that amount so we'd end up
    with [3/8, 3/8 + 3/4, 3/8 + 6/4, 3 - 3/8 or 3/8 + 9/4,]

    :param flag_values: a list of flag values that are an increasing list of integers starting with 0
    :return:
    """

    # we don't strictly need to subtract the flag_values[0] here (since it should always be 0),
    # but a good reminder if this ever becomes more general
    category_size = float(flag_values[-1] - flag_values[0]) / float(len(flag_values))
    half_cat_size = category_size / 2.0

    # note: this is not a very general solution, if we ever need to plot more general sets it will need work
    new_positions = [half_cat_size + (x * category_size) for x in flag_values]

    return new_positions

def close_figure(figure_obj) :

    plt.close(figure_obj)

def _setup_kwargs ( input_kwargs, masked_data, colormap, flag_values, range_to_use, scale_const, ) :

    # get the min and max data values for convenience
    data_min = numpy.min(masked_data)
    data_max = numpy.max(masked_data)

    # when a log scale is being used the range can't include zero, if it does, use a regular linear scale instead of log
    if scale_const == LOG_SCALE and (range_to_use[0] <= 0.0) and (range_to_use[-1] >= 0.0) :
        LOG.warn("Detected attempt to use log scale with a data range that includes zero. "
                 "Turning off log scaling for this variable.")
        scale_const = REGULAR_SCALE

    LOG.debug("Data min / max: " + str(data_min) + " / " + str(data_max))

    # if we've got a color map, pass it to the list of things we want to tell the plotting function
    if colormap is not None:
        input_kwargs['cmap'] = colormap
    if flag_values is not None:
        cmap_temp = input_kwargs['cmap'] if 'cmap' in input_kwargs and input_kwargs['cmap'] is not None else cm.jet
        input_kwargs['cmap'] = _cmap_discretize(cmap_temp, len(flag_values))

    # add the overflow colors and display if needed
    extend_const = None
    if input_kwargs['cmap'] is not None and range_to_use is not None:
        # if the low end of the range is not as far down as the smallest data
        # Future, when matplotlib can handle the power norm and min-extend together, change this logic
        # see also: https://github.com/matplotlib/matplotlib/issues/11486
        if range_to_use[0] > data_min :
            if scale_const == SQ_ROOT_SCALE :
                LOG.warn("Square root scaling is configured for this product and the product data overflows "
                         "below the lower limit of the range to be plotted. But matplotlib currently cannot "
                         "display colorbars with an arrow indicating minimum overflow due to a bug. Therefore "
                         "the minimum overflow indication will not be displayed in the plot for this data.")
                input_kwargs['cmap'].set_under(input_kwargs['cmap'](0.0))  # get the min color from the colormap for the min overflow
            else :
                extend_const = "min"
            #input_kwargs['cmap'].set_under(input_kwargs['cmap'](0.0))  # get the min color from the colormap for the min overflow
        # if the high end of the range is not as far up as the largest data
        if range_to_use[-1] < data_max :
            extend_const = "max" if extend_const is None else "both"
            #input_kwargs['cmap'].set_over(input_kwargs['cmap'](1.0))  # get the max color from the colormap for the max overflow
        input_kwargs['vmin'] = range_to_use[0]
        input_kwargs['vmax'] = range_to_use[-1]

    # get the scale normalization we want to use
    norm_object = None
    if scale_const is not None:
        min_to_use = data_min if range_to_use is None else range_to_use[0]
        max_to_use = data_max if range_to_use is None else range_to_use[-1]
        if scale_const == LOG_SCALE:
            norm_object = matplotlib.colors.LogNorm(vmin=min_to_use, vmax=max_to_use, )
        elif scale_const == SQ_ROOT_SCALE:
            norm_object = matplotlib.colors.PowerNorm(gamma=0.5, vmin=min_to_use, vmax=max_to_use, )

    return extend_const, input_kwargs, norm_object

def create_mapped_figure(data, boundingAxes, title,
                         invalidMask=None, colorMap=None, backColor=None, dataRanges=None, units=None,
                         valid_range=None, range_to_use=None, scale_const=REGULAR_SCALE,
                         flag_values=None, flag_names=None, version_txt=None,
                         expected_longitude=None, expected_sat_height=None,):
    """
    create a figure including our data on a map

    :param data:                The raw data to plot
    :param boundingAxes:        The bounding axes to select the plot view in terms of x and y
    :param title:               The plot title

    :param invalidMask:         A mask of any invalid or missing data
    :param colorMap:            The colormap to use for plotting the data
    :param backColor:           The background color to use where there isn't any data
    :param dataRanges:          If the data has discrete ranges, the list of ranges to use
    :param units:               The units that will be displayed on the colorbar

    :param valid_range:         The potential valid range of the data as defined in the file
    :param range_to_use:        The range of the data that should be plotted, pinning overflows to the min and max as needed.
                                If passed as None, use the existing range of the data.
    :param scale_const:         The type of scale to use on the colorbar. Possible values are defined by constants in the constants module

    :param flag_values:         If this is a flag or category variable, what are the values we need to plot and label
    :param flag_names:          If this is a flag or category value, what are the display names of the values

    :param version_txt:         Informational text describing the versions used to make and plot the data to be displayed on the plot
    :param expected_longitude:  The central longitude for the projection this will be plotted in.
    :param expected_sat_height: The satellite height for the projection this will be plotted in.

    :return:                The figure containing the data plotted on a map
    """

    # make a clean version of data
    dataClean = ma.array(data, mask=invalidMask)

    # make our cartopy projection object                     Note: height is in meters, not km
    #proj4_str = "+proj=geos +h=35786023.44 +lon_0=-75.0 +sweep=x" # this is right, but cartopy won't take it in this format
    # note: the preliminary central longitude was -89.5 before Dec. 2017, then it moved to -75.0
    if have_new_cartopy :
        proj_object = crs.Geostationary(central_longitude=expected_longitude, satellite_height=expected_sat_height, sweep_axis='x', )
    else :
        # in the old version of cartopy there is no way to set the sweep direction to x, and unfortunately the default is y
        proj_object = crs.Geostationary(central_longitude=expected_longitude, satellite_height=expected_sat_height, )

    # build the plot
    figure = plt.figure()
    axes = figure.add_subplot(111, projection=proj_object)

    # build extra info to go to the map plotting function
    kwargs = {}

    # if we have a category variable, override the range we use to make the colorbar look ok
    temp_valid_range = range_to_use if flag_values is None else [flag_values[0] - 0.5, flag_values[-1] + 0.5, ]

    # setup stuff like the colormap and range related kwargs and get the extend constant for the colorbar
    extend_const, kwargs, norm_object = _setup_kwargs(kwargs, dataClean, colorMap, flag_values, range_to_use, scale_const,)

    LOG.debug("plotting with bounding axes: " + str(boundingAxes))

    # cap the extents so cartopy won't freak out about them and set the extent of the plot
    extent_list = _cap_extent(boundingAxes, proj_object)
    # for some reason cartopy can't handle setting the extents for full disk, so don't do that
    if extent_list[1] - extent_list[0] < 10000000 :
        axes.set_extent(extent_list, crs=proj_object)

    # draw some features on our map
    # FUTURE: this is kind of a hack way to get the line colors right, we should probably actually pass this into the method instead
    line_color = 'black' if (backColor is None) or (backColor == 'white') else 'violet'
    # draw some background if we need to; note: cartopy refuses to let me change the image background any other way
    if backColor is not None :
        axes.add_feature(cartopy.feature.LAND,  facecolor=backColor,)
        axes.add_feature(cartopy.feature.OCEAN, facecolor=backColor,)
    # draw some lat/lon gridlines
    axes.gridlines(linestyle=':', zorder=2, edgecolor=line_color,)
    # FUTURE, right now cartopy can't label grid lines on a Geostationary plot
    axes.coastlines(resolution='50m',  zorder=3, color=line_color,)
    axes.add_feature(cfeature.BORDERS, linewidth=0.5, zorder=4, edgecolor=line_color,) # this makes country borders
    axes.add_feature(cfeature.LAKES, linestyle=':', linewidth=0.5, facecolor='none', edgecolor=line_color, zorder=5,)
    # add states and provinces
    # Create a feature for States/Admin 1 regions at 1:50m from Natural Earth
    states_provinces = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        facecolor='none')
    axes.add_feature(states_provinces, edgecolor=line_color, linewidth=0.25, zorder=6,)

    # plot our data with a map
    temp_plt = plt.imshow(dataClean,
                          transform=proj_object, extent=extent_list,
                          origin='upper', interpolation='nearest',
                          norm=norm_object,
                          zorder=1, **kwargs)

    # FUTURE, if we ever need to go back to one of these other plotting methods, here is how to call them
    # Note: these could be called using the extent_list rather than lat/lon, but these calls would need revision
    # plt.contourf(longitudeClean, latitudeClean, dataClean, transform=crs.PlateCarree(), zorder=1, **kwargs)
    # plt.scatter(longitudeClean, latitudeClean, c=dataClean, s=0.2, transform=crs.PlateCarree(), zorder=1, **kwargs)
    # plt.pcolormesh(longitudeClean, latitudeClean, dataClean, transform=crs.PlateCarree(), zorder=1, **kwargs)

    # add the version message if we have one
    if version_txt is not None:
        plt.figtext(.02, .02, version_txt, size=6,)

    # set the title of the plot
    axes.set_title(title)

    # show a colorbar
    if data is not None:

        # if we have a category variable, override tick info we use to make the colorbar look ok.
        # Note: right now we are relying on the fact that all of the variables we are plotting have categories that are
        # increasing integers (so stuff like [0,1,] or [0,1,2,3,4,5,]). In FUTURE we should try to handle this in
        # a more flexible manner (this strategy definitely wouldn't work for things like plotting the QA flags).
        tick_positions = None if flag_values is None else _position_category_labels(flag_values)
        tick_labels = None if flag_names is None else flag_names.split(" ")

        cbar = _add_colorbar(axes, temp_plt, units=units, valid_range=temp_valid_range, data_shape=data.shape,
                             tick_positions=tick_positions, tick_labels=tick_labels, extend_constant=extend_const,)

    return figure

def create_simple_figure(data, figureTitle, invalidMask=None, colorMap=None, backColor=None, units=None, drawColorbar=True,
                         valid_range=None, range_to_use=None, scale_const=REGULAR_SCALE,
                         flag_values=None, flag_names=None, version_txt=None,):
    """
    create a simple figure showing the data given

    :param data:            The data to show
    :param figureTitle:     The title of the figure
    :param invalidMask:     A mask of any invalid or missing data in the data set
    :param colorMap:        The colormap to use when plotting the data (if None the default for imshow will be used)
    :param backColor:       The background color to use where there isn't any data
    :param units:           The units that will be displayed on the colorbar
    :param drawColorbar:    Whether or not to draw a colorbar at all

    :param valid_range:     The potential valid range of the data as defined in the file
    :param range_to_use:    The range of the data that should be plotted, pinning overflows to the min and max as needed.
                            If passed as None, use the existing range of the data.
    :param scale_const:     The type of scale to use on the colorbar. Possible values are defined by constants in the constants module

    :param flag_values:     If this is a flag or category variable, what are the values we need to plot and label
    :param flag_names:      If this is a flag or category value, what are the display names of the values
    :param version_txt:     Informational text describing the versions used to make and plot the data to be displayed on the plot

    :return:                The figure containing the image of the data
    """

    cleanData = ma.array(data, mask=invalidMask)

    # build the plot
    figure = plt.figure()
    axes = figure.add_subplot(111, facecolor=backColor,)

    # build extra info to go to the map plotting function
    kwargs = {}

    # setup stuff like the colormap and range related kwargs and get the extend constant for the colorbar
    extend_const, kwargs, norm_object = _setup_kwargs(kwargs, cleanData, colorMap, flag_values, range_to_use, scale_const,)

    if (data is not None) and (numpy.sum(invalidMask) < invalidMask.size):
        # draw our data
        im = imshow(cleanData, norm=norm_object, **kwargs)

        if drawColorbar :

            # if we have a category variable, override tick info we use to make the colorbar look ok.
            # Note: right now we are relying on the fact that all of the variables we are plotting have categories that are
            # increasing integers (so stuff like [0,1,] or [0,1,2,3,4,5,]). In FUTURE we should try to handle this in
            # a more flexible manner (this strategy definitely wouldn't work for things like plotting the QA flags).
            temp_valid_range = range_to_use if flag_values is None else [flag_values[0] - 0.5, flag_values[-1] + 0.5,]
            tick_positions = None if flag_values is None else _position_category_labels(flag_values)
            tick_labels = None if flag_names is None else flag_names.split(" ")

            cbar = _add_colorbar(axes, im, units=units,
                                 valid_range=temp_valid_range, data_shape=data.shape,
                                 tick_positions=tick_positions, tick_labels=tick_labels,
                                 extend_constant=extend_const, )

    # add the version message if we have one
    if version_txt is not None:
        plt.figtext(.02, .02, version_txt, size=6,)

    # and some informational stuff
    axes.set_title(figureTitle)

    return figure

def create_nodata_figure(figureTitle, version_txt=None,):
    """
    create a blank figure to convey that we had no data to plot

    :param figureTitle:     The title of the figure

    :return: The figure containing the title and an explicit message saying we had no data
    """

    # build the plot
    figure = plt.figure()
    axes = figure.add_subplot(111)

    # add an explicit message saying we had no data
    plt.text(0.35, 0.4, "No Data\nAvailable", size=40, color='red',)

    # add the version message if we have one
    if version_txt is not None :
        plt.figtext(.02, .02, version_txt, size=6,)

    # set the title
    axes.set_title(figureTitle)

    return figure