from functools import partial

import tensorflow as tf


def augment_image():
    """ Helper function used for augmentation of images in the dataset.
      Returns:
        tf.data.Dataset mappable function for image augmentation
    """

    def augment_fn(data, label, *args, **kwargs):
        # Augmenting data (~ 80%)

        def augment_steps_fn(data, label):
            # Randomly rotating image (~50%)
            def rotate_fn(data, label):
                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
                return (tf.image.rot90(data, times),
                        tf.image.rot90(label, times))

            data, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: rotate_fn(data, label),
                lambda: (data, label))

            # Randomly flipping image (~50%)
            def flip_fn(data, label):
                return (tf.image.flip_left_right(data),
                        tf.image.flip_left_right(label))

            data, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: flip_fn(data, label),
                lambda: (data, label))

            return data, label

        # Randomly returning unchanged data (~20%)
        return tf.cond(
            tf.less_equal(tf.random.uniform([]), 0.2),
            lambda: (data, label),
            partial(augment_steps_fn, data, label))

    return augment_fn


def augment_icing():
    """ Helper function used for augmentation of images in the dataset.
      Returns:
        tf.data.Dataset mappable function for image augmentation
    """

    def augment_fn(data, data_b, label, *args, **kwargs):
        # Augmenting data (~ 80%)

        def augment_steps_fn(data, data_b, label):
            # Randomly rotating image (~50%)
            def rotate_fn(data, data_b, label):
                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
                return (tf.image.rot90(data, times),
                        data_b,
                        label)

            data, data_b, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: rotate_fn(data, data_b, label),
                lambda: (data, data_b, label))

            # Randomly flipping image (~50%)
            def flip_fn(data, data_b, label):
                return (tf.image.flip_left_right(data),
                        data_b,
                        label)

            data, data_b, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: flip_fn(data, data_b, label),
                lambda: (data, data_b, label))

            return data, data_b, label

        # Randomly returning unchanged data (~20%)
        return tf.cond(
            tf.less_equal(tf.random.uniform([]), 0.2),
            lambda: (data, data_b, label),
            partial(augment_steps_fn, data, data_b, label))

    return augment_fn


def augment_image_3arg():
    """ Helper function used for augmentation of images in the dataset.
      Returns:
        tf.data.Dataset mappable function for image augmentation
    """

    def augment_fn(data, data_b, label, *args, **kwargs):
        # Augmenting data (~ 80%)

        def augment_steps_fn(data, data_b, label):
            # Randomly rotating image (~50%)
            def rotate_fn(data, data_b, label):
                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
                return (tf.image.rot90(data, times),
                        tf.image.rot90(data_b, times),
                        tf.image.rot90(label, times))

            data, data_b, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: rotate_fn(data, data_b, label),
                lambda: (data, data_b, label))

            # Randomly flipping image (~50%)
            def flip_fn(data, data_b, label):
                return (tf.image.flip_left_right(data),
                        tf.image.flip_left_right(data_b),
                        tf.image.flip_left_right(label))

            data, data_b, label = tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.5),
                lambda: flip_fn(data, data_b, label),
                lambda: (data, data_b, label))

            return data, data_b, label

        # Randomly returning unchanged data (~20%)
        return tf.cond(
            tf.less_equal(tf.random.uniform([]), 0.2),
            lambda: (data, data_b, label),
            partial(augment_steps_fn, data, data_b, label))

    return augment_fn