import os
import numpy as np
from absl import logging
from functools import partial
import tensorflow as tf
# import tensorflow_datasets as tfds
from util.util import scale, make_tf_callable_generator
from util.setup import ancillary_path
import pickle

""" Dataset Handlers for ESRGAN """


def scale_down(method="bicubic", dimension=512, size=None, factor=4):
  """ Scales down function based on the parameters provided.
      Args:
        method (default: bicubic): Interpolation method to be used for Scaling down the image.
        dimension (default: 256): Dimension of the high resolution counterpart.
        size (default: None): [height, width] of the image.
        factor (default: 4): Factor by which the model enhances the low resolution image.
      Returns:
        tf.data.Dataset mappable python function based on the configuration.
  """
  if not size:
    size = (dimension, dimension)
  size_ = {"size": size}

  def scale_fn(image, *args, **kwargs):
    size = size_["size"]
    high_resolution = image
    # if not kwargs.get("no_random_crop", None):
    #   high_resolution = tf.image.random_crop(
    #       image, [size[0], size[1], image.shape[-1]])

    low_resolution = tf.image.resize(
        high_resolution,
        [size[0] // factor, size[1] // factor],
        method=method)
    low_resolution = tf.clip_by_value(low_resolution, 0, 255)
    high_resolution = tf.clip_by_value(high_resolution, 0, 255)
    return low_resolution, high_resolution

  scale_fn.size = size_["size"]
  return scale_fn


def augment_image(
        brightness_delta=0.05,
        contrast_factor=[0.7, 1.3],
        saturation=[0.6, 1.6]):
    """ Helper function used for augmentation of images in the dataset.
      Args:
        brightness_delta: maximum value for randomly assigning brightness of the image.
        contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast.
                          None, if not to be used.
        saturation: list / tuple of minimum and maximum value of factor to set random saturation.
                    None, if not to be used.
      Returns:
        tf.data.Dataset mappable function for image augmentation
    """

    def augment_fn(data, label, *args, **kwargs):
        # Augmenting data (~ 80%)

        def augment_steps_fn(data, label):
            # Randomly rotating image (~50%)
            def rotate_fn(data, label):
                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
                return (tf.image.rot90(data, times),
                        tf.image.rot90(label, times))

            data, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: rotate_fn(data, label),
                lambda: (data, label))

            # Randomly flipping image (~50%)
            def flip_fn(data, label):
                return (tf.image.flip_left_right(data),
                        tf.image.flip_left_right(label))

            data, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: flip_fn(data, label),
                lambda: (data, label))

            return data, label

        # Randomly returning unchanged data (~20%)
        return tf.cond(
            tf.less_equal(tf.random.uniform([]), 0.2),
            lambda: (data, label),
            partial(augment_steps_fn, data, label))

    return augment_fn


# def reform_dataset(dataset, types, size, num_elems=None):
#   """ Helper function to convert the output_dtype of the dataset
#       from (tf.float32, tf.uint8) to desired dtype
#       Args:
#         dataset: Source dataset(image-label dataset) to convert.
#         types: tuple / list of target datatype.
#         size: [height, width] threshold of the images.
#         num_elems: Number of Data points to store
#       Returns:
#         tf.data.Dataset with the images of dimension >= Args.size and types = Args.types
#   """
#   _carrier = {"num_elems": num_elems}
#
#   def generator_fn():
#     for idx, data in enumerate(dataset, 1):
#       if _carrier["num_elems"]:
#         if not idx % _carrier["num_elems"]:
#           raise StopIteration
#       if data[0].shape[0] >= size[0] and data[0].shape[1] >= size[1]:
#         yield data[0], data[1]
#       else:
#         continue
#   return tf.data.Dataset.from_generator(
#       generator_fn, types, (tf.TensorShape([None, None, 3]), tf.TensorShape(None)))


# setup scaling parameters dictionary
mean_std_dct = {}
mean_std_file = ancillary_path+'mean_std_lo_hi_l2.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l2 = pickle.load(f)
f.close()

mean_std_file = ancillary_path+'mean_std_lo_hi_l1b.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l1b = pickle.load(f)
f.close()

mean_std_dct.update(mean_std_dct_l1b)
mean_std_dct.update(mean_std_dct_l2)


class OpdNpyDataset:

    def __init__(self, filenames, hr_size, lr_size, batch_size=128):
        self.filenames = filenames
        self.num_files = len(filenames)
        self.hr_size = hr_size
        self.lr_size = lr_size

        # def integer_gen(limit):
        #     n = 0
        #     while n < limit:
        #         yield n
        #         n += 1
        #
        # num_gen = integer_gen(self.num_files)
        # gen = make_tf_callable_generator(num_gen)
        # dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
        # dataset = dataset.batch(batch_size)
        # dataset = dataset.map(self.data_function, num_parallel_calls=8)
        # # These execute w/o an iteration on dataset?
        # # dataset = dataset.map(scale_down(), num_parallel_calls=1)
        # # dataset = dataset.map(augment_image(), num_parallel_calls=1)
        #
        # dataset = dataset.cache()
        # dataset = dataset.prefetch(buffer_size=1)

        file_idxs = np.arange(self.num_files)
        dataset = tf.data.Dataset.from_tensor_slices(list(file_idxs))
        dataset = dataset.shuffle(2000, reshuffle_each_iteration=True)
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(self.data_function, num_parallel_calls=8)
        dataset = dataset.cache(filename='/ships22/cloud/scratch/Satellite_Output/GOES-16/global/NREL_2023/2023_east_cf/cld_opd_train/cld_opd_train')
        dataset = dataset.prefetch(buffer_size=1)

        self.dataset = dataset

    def read_numpy_file_s(self, f_idxs):
        opd_s = []
        refl_s = []
        for fi in f_idxs:
            fname = self.filenames[fi]
            data = np.load(fname)

            refl = data[:, 0, :, :]
            refl = scale(refl, 'refl_0_65um_nom', mean_std_dct)
            refl = refl.astype(np.float32)
            refl_s.append(refl)

            opd = data[:, 1, :, :]
            opd = scale(opd, 'cld_opd_dcomp', mean_std_dct)
            opd = opd.astype(np.float32)
            opd_s.append(opd)

        opd = np.concatenate(opd_s)
        refl = np.concatenate(refl_s)

        hr_image = np.stack([refl, opd], axis=3)
        hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, self.hr_size, self.hr_size)

        low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='bicubic')
        # low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='nearest')
        # low_resolution = hr_image[:, ::4, ::4, :]
        # low_resolution = low_resolution.copy()
        low_resolution = tf.math.multiply(low_resolution, 255.0)
        low_resolution = tf.clip_by_value(low_resolution, 0, 255.0)

        hr_image = tf.math.multiply(hr_image[:, :, :, 1], 255.0)
        high_resolution = tf.clip_by_value(hr_image, 0, 255.0)
        high_resolution = tf.expand_dims(high_resolution, axis=3)

        low_resolution, high_resolution = augment_image()(low_resolution, high_resolution)

        return low_resolution, high_resolution

    @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
    def data_function(self, indexes):
        out = tf.numpy_function(self.read_numpy_file_s, [indexes], [tf.float32, tf.float32])
        return out