Skip to content
Snippets Groups Projects
split_nc4.py 5.11 KiB
import netCDF4 as nc
import numpy as np
import xarray as xr


def split_dataset(input_file, output_pattern, dim_name, chunk_size):
    # Load the input dataset
    ds = nc.Dataset(input_file, 'r', format='NETCDF4')
    outer_dim_size = len(ds.dimensions[dim_name])

    # Calculate the number of chunks
    num_chunks = int(np.ceil(outer_dim_size / chunk_size))

    # Loop through each chunk
    for i in range(num_chunks-1):
        # Determine the start and end indices of this chunk
        start = i * chunk_size
        end = min((i + 1) * chunk_size, outer_dim_size)

        # Slicing along our dimension of interest
        slice_indices = slice(start, end)

        # Create a new output file for this chunk
        output_file = output_pattern.format(i)
        rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4')

        # Copy dimensions
        for name, dim in ds.dimensions.items():
            # Adjust the dimension size for the split dimension
            if name == dim_name:
                dim_size = len(range(start, end))
            else:
                dim_size = len(dim) if not dim.isunlimited() else None

            rootgrp.createDimension(name, dim_size)

        # Copy variables
        for name, var in ds.variables.items():
            var.set_auto_maskandscale(False)
            outVar = rootgrp.createVariable(name, var.datatype, var.dimensions)
            outVar.set_auto_maskandscale(False)

            # Copy variable attributes
            if name != 'gs_1c_spect':  # The original file has bad metadata for this, and possibly other fields
                outVar.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

            # Divide variable data for the split dimension, keep others as is
            if dim_name in var.dimensions:
                outVar[:,] = var[slice_indices,]
            else:
                outVar[:,] = var[:,]

        # Copy global attributes
        rootgrp.setncatts({k: ds.getncattr(k) for k in ds.ncattrs()})
        rootgrp.close()

    ds.close()


def concatenate_nc4_files(nc_files, output_file, concat_dim_name='time'):
    datasets = [xr.open_dataset(nc_file) for nc_file in nc_files]
    combined = xr.concat(datasets, dim=concat_dim_name)
    combined.to_netcdf(output_file)
    print(f"All files combined and saved to {output_file}")

# Call the function
# split_dataset('input.nc', 'output_{}.nc', 'time', 10)


def aggregate_output_files(nc_files, output_file, track_dim_name='nj', xtrack_dim_name='ni'):
    input_datasets = [nc.Dataset(nc_file, 'r', format='NETCDF4') for nc_file in nc_files]
    output_rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4')

    # calculate new track_dim length
    track_dim_len = 0
    xtrack_dim_len = None

    for ds in input_datasets:
        for name, dim in ds.dimensions.items():
            if name == track_dim_name:
                track_dim_len += len(dim)
            elif name == xtrack_dim_name and xtrack_dim_len is None:
                xtrack_dim_len = len(dim)

    # create the dimensions for the new aggregation target file
    output_rootgrp.createDimension('time', 1)
    output_rootgrp.createDimension(track_dim_name, track_dim_len)
    output_rootgrp.createDimension(xtrack_dim_name, xtrack_dim_len)

    # use the time value from the first file for aggregated time
    var = input_datasets[0].variables['time']
    time_var = output_rootgrp.createVariable('time', var.datatype, 'time')
    time_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

    # assign the value of var to time_var
    time_var[:] = var[:]

    # create lon, lat variables
    var = input_datasets[0].variables['lon']
    lon_var = output_rootgrp.createVariable('lon', var.datatype, ['nj', 'ni'])
    lon_var.set_auto_maskandscale(False)
    lon_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

    var = input_datasets[0].variables['lat']
    lat_var = output_rootgrp.createVariable('lat', var.datatype, ['nj', 'ni'])
    lat_var.set_auto_maskandscale(False)
    lat_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

    # create the other variables
    for name, var in input_datasets[0].variables.items():
        if name not in ['time', 'lon', 'lat']:
            out_var = output_rootgrp.createVariable(name, var.datatype, ['time', track_dim_name, xtrack_dim_name])
            out_var.set_auto_maskandscale(False)
            out_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

    # copy from input files to the single output file
    start_idx = 0
    for ds in input_datasets:
        track_len = len(ds.dimensions[track_dim_name])
        lon_var[start_idx:start_idx + track_len, :] = ds.variables['lon'][:, :]
        lat_var[start_idx:start_idx + track_len, :] = ds.variables['lat'][:, :]
        start_idx += track_len

    start_idx = 0
    for ds in input_datasets:
        track_len = len(ds.dimensions[track_dim_name])
        for name, var in ds.variables.items():
            if name not in ['time', 'lon', 'lat']:
                out_var = output_rootgrp.variables[name]
                out_var[0, start_idx:start_idx + track_len, :] = var[0, :, :]
        start_idx += track_len

    output_rootgrp.close()

    return