#!/usr/bin/env python
# encoding: utf-8

import os

import numpy as np
from pyhdf import SD

SRF_DIR = os.path.join(os.path.dirname(__file__), "modis_bright")
AQUA_SRF  = SRF_DIR + "/modis_aqua.srf.nc"
TERRA_SRF = SRF_DIR + "/modis_terra.srf.nc"


# speed of light; m * s^-1
c = 2.99792458e8

# Planck constant; J * s
h = 6.62606957e-34

# Boltzmann constant; J * K^-1
k_B = 1.3806488e-23

# Planck function coefficients
c1 = 2.0 * h * c**2
c2 = h * c / k_B

def modis_bright(platform, band, R):

    """Convert MODIS radiances to brightness temperatures

    platform: 'aqua' or 'terra'
    band: an integer emissive band number 20-36 excluding 26
    R: a numpy array of radiance values in W per m^2 per sr per micron
    """

    return all_srfs[platform][band].bright(R)

class Srf(object):

    """An instrument spectral response function"""

    def __init__(self, lambda_, F):

        self.lambda_ = lambda_
        self.F = F / F.sum()

    def apply(self, lambda_, R):

        """Compute the radiance seen via this SRF for a given spectrum"""

        # interpolate our SRF domain onto the given domain
        F = np.interp(lambda_, self.lambda_, self.F, 0.0, 0.0)

        # compute the integral of F * R divided by the integral of
        # F using the trapezoidal method (dmh = doubled mean height)
        FR = F * R
        widths = lambda_[1:] - lambda_[:-1]
        dmh_FR = FR[...,1:] + FR[...,:-1]
        dmh_F = F[1:] + F[:-1]
        return (widths * dmh_FR).sum(axis=-1) / (widths * dmh_F).sum(axis=-1)

    def bright(self, R):

        # FIXME: needs better way for dealing with degree, T_min,
        # and T_max

        if not hasattr(self, 'monochrom_coeffs'):
            self.monochrom_coeffs = self.monochrom(180.0, 340.0, 2)

        T_b_mono = inv_planck_wl(self.centroid, R)
        T_b = np.zeros(R.shape)
        for coeff in self.monochrom_coeffs:
            T_b = T_b_mono * T_b + coeff

        return T_b

    @property
    def centroid(self):

        """The centroid wavelength value of this SRF"""

        if not hasattr(self, '_centroid'):
            self._centroid = (self.lambda_ * self.F).sum()
        return self._centroid

    def monochrom(self, T_min, T_max, degree):

        """Compute monochromatic correction coefficients

        T_min and T_max define the temperature range over which to
        train the correction function. degree specifies the degree of
        polynomial to use.
        """

        # adapted from bndfit.f

        # establish 100 point equal spacing in radiance between T_min
        # and T_max and convert these radiances to brightness temperatures
        B_in = np.linspace(planck_wl(self.centroid, T_min),
                           planck_wl(self.centroid, T_max),
                          100)
        T_in = inv_planck_wl(self.centroid, B_in)

        # compute table of monochromatic radiances for each of the
        # temperatures in T_in; B_mono will have dimensions
        # [temperature, wavelength (or wavenumber)]
        B_mono = planck_wl(self.lambda_, T_in[:,np.newaxis])

        # apply this SRF to the monochromatic radiances for each
        # temperature; this gives us a value for what an instrument
        # with this SRF would measure when looking at a blackbody
        # at each of the temperatures in T_in
        B_obs = self.apply(self.lambda_, B_mono)

        # the "uncorrrected" brightness temperatures are what we'd get
        # from applying the monochromatic inverse planck function using
        # our centroid wavelength
        T_uncorrected = inv_planck_wl(self.centroid, B_obs)

        # fit a line to give us an approximation for T_in as a function
        # of T_uncorrected
        return np.polyfit(T_uncorrected, T_in, degree)

def planck_wl(lambda_, T):

    """Planck function for lambda_ in microns, T in Kelvin

    Output in W per m^2 per sr per micron.
    """

    lambda_ = lambda_ * 1e-6
    return c1 / lambda_**5 / (np.exp(c2 / (lambda_ * T)) - 1.0) * 1e-6

def inv_planck_wl(lambda_, B):

    """Inverse planck routine for wavelength in microns

    B is radiance in W per m^2 per sr per micron.
    """

    lambda_ = lambda_ * 1e-6
    B = B * 1e6
    return c2 / lambda_ / np.log(c1 / (lambda_**5 * B) + 1.0)

def wn2wl(nu):

    """Convert from wavenumber to wavelength.

    nu given in inverse centimeters. Result in microns.
    """

    return 1e4 / nu

def read_modis_srfs(filename):

    """Return MODIS emissive band SRFs into dict or Srf objects indexed by band"""

    sd = SD.SD(filename)

    # work around a weird issue with modis_terra.srf.nc. indexing the
    # SDS with [:] returns [20 0 0 ...], but doing it this way works
    # fine. no idea why
    band = list(sd.select('channel_list'))
    nu_begin = sd.select('begin_frequency')[:]
    nu_end = sd.select('end_frequency')[:]

    srfs = {}
    for band, nu_begin, nu_end in zip(band, nu_begin, nu_end):
        F = sd.select('channel_%s_response' % band)[:][::-1]
        lambda_ = wn2wl(np.linspace(nu_begin, nu_end, len(F)))[::-1]
        srfs[band] = Srf(lambda_, F)
    sd.end()

    return srfs

all_srfs = {}
all_srfs['aqua'] = read_modis_srfs(AQUA_SRF)
all_srfs['terra'] = read_modis_srfs(TERRA_SRF)