Skip to content
Snippets Groups Projects
split_nc4.py 2.35 KiB
Newer Older
tomrink's avatar
tomrink committed
import netCDF4 as nc
import numpy as np
tomrink's avatar
tomrink committed
import xarray as xr
tomrink's avatar
tomrink committed


def split_dataset(input_file, output_pattern, dim_name, chunk_size):
    # Load the input dataset
tomrink's avatar
tomrink committed
    ds = nc.Dataset(input_file, 'r', format='NETCDF4')
tomrink's avatar
tomrink committed
    outer_dim_size = len(ds.dimensions[dim_name])
tomrink's avatar
tomrink committed

    # Calculate the number of chunks
tomrink's avatar
tomrink committed
    num_chunks = int(np.ceil(outer_dim_size / chunk_size))
tomrink's avatar
tomrink committed

    # Loop through each chunk
    for i in range(num_chunks-1):
        # Determine the start and end indices of this chunk
        start = i * chunk_size
tomrink's avatar
tomrink committed
        end = min((i + 1) * chunk_size, outer_dim_size)
tomrink's avatar
tomrink committed

        # Slicing along our dimension of interest
        slice_indices = slice(start, end)

        # Create a new output file for this chunk
        output_file = output_pattern.format(i)
        rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4')

        # Copy dimensions
        for name, dim in ds.dimensions.items():
            # Adjust the dimension size for the split dimension
            if name == dim_name:
                dim_size = len(range(start, end))
            else:
                dim_size = len(dim) if not dim.isunlimited() else None

            rootgrp.createDimension(name, dim_size)

        # Copy variables
        for name, var in ds.variables.items():
            var.set_auto_maskandscale(False)
            outVar = rootgrp.createVariable(name, var.datatype, var.dimensions)
            outVar.set_auto_maskandscale(False)

            # Copy variable attributes
            if name != 'gs_1c_spect':  # The original file has bad metadata for this, and possibly other fields
                outVar.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

            # Divide variable data for the split dimension, keep others as is
            if dim_name in var.dimensions:
                outVar[:,] = var[slice_indices,]
            else:
                outVar[:,] = var[:,]

        # Copy global attributes
        rootgrp.setncatts({k: ds.getncattr(k) for k in ds.ncattrs()})
        rootgrp.close()

    ds.close()
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed

def concatenate_nc4_files(nc_files, output_file, concat_dim_name='time'):
    datasets = [xr.open_dataset(nc_file) for nc_file in nc_files]
    combined = xr.concat(datasets, dim=concat_dim_name)
    combined.to_netcdf(output_file)
    print(f"All files combined and saved to {output_file}")

tomrink's avatar
tomrink committed
# Call the function
# split_dataset('input.nc', 'output_{}.nc', 'time', 10)