diff --git a/modules/util/split_nc4.py b/modules/util/split_nc4.py index 4fcfaacc610e6090c48229afc568a90b2553704d..9480d0384cd64d336b02ec569400cbdc45fa3f40 100644 --- a/modules/util/split_nc4.py +++ b/modules/util/split_nc4.py @@ -4,51 +4,56 @@ import numpy as np def split_dataset(input_file, output_pattern, dim_name, chunk_size): # Load the input dataset - with nc.Dataset(input_file, 'r') as ds: - dim_size = len(ds.dimensions[dim_name]) - - # Calculate the number of chunks - num_chunks = int(np.ceil(dim_size / chunk_size)) - - # Loop through each chunk - for i in range(num_chunks-1): - # Determine the start and end indices of this chunk - start = i * chunk_size - end = min((i + 1) * chunk_size, dim_size) - - # Slicing along our dimension of interest - slice_indices = slice(start, end) - - # Create a new output file for this chunk - output_file = output_pattern.format(i) - - with nc.Dataset(output_file, 'w') as ds_out: - # Copy dimensions - for name, dim in ds.dimensions.items(): - # Adjust the dimension size for the split dimension - if name == dim_name: - dim_size = len(range(start, end)) - else: - dim_size = len(dim) if not dim.isunlimited() else None - - ds_out.createDimension(name, dim_size) - - # Copy variables - for name, var in ds.variables.items(): - outVar = ds_out.createVariable(name, var.datatype, var.dimensions) - - # Copy variable attributes - outVar.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) - - # Divide variable data for the split dimension, keep others as is - if dim_name in var.dimensions: - print(name, outVar.shape, var.shape) - outVar[:,] = var[slice_indices,] - else: - outVar[:,] = var[:,] - - # Copy global attributes - ds_out.setncatts({k: ds.getncattr(k) for k in ds.ncattrs()}) + ds = nc.Dataset(input_file, 'r', format='NETCDF4') + dim_size = len(ds.dimensions[dim_name]) + + # Calculate the number of chunks + num_chunks = int(np.ceil(dim_size / chunk_size)) + + # Loop through each chunk + for i in range(num_chunks-1): + # Determine the start and end indices of this chunk + start = i * chunk_size + end = min((i + 1) * chunk_size, dim_size) + + # Slicing along our dimension of interest + slice_indices = slice(start, end) + + # Create a new output file for this chunk + output_file = output_pattern.format(i) + rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4') + + # Copy dimensions + for name, dim in ds.dimensions.items(): + # Adjust the dimension size for the split dimension + if name == dim_name: + dim_size = len(range(start, end)) + else: + dim_size = len(dim) if not dim.isunlimited() else None + + rootgrp.createDimension(name, dim_size) + + # Copy variables + for name, var in ds.variables.items(): + var.set_auto_maskandscale(False) + outVar = rootgrp.createVariable(name, var.datatype, var.dimensions) + outVar.set_auto_maskandscale(False) + + # Copy variable attributes + if name != 'gs_1c_spect': # The original file has bad metadata for this, and possibly other fields + outVar.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) + + # Divide variable data for the split dimension, keep others as is + if dim_name in var.dimensions: + outVar[:,] = var[slice_indices,] + else: + outVar[:,] = var[:,] + + # Copy global attributes + rootgrp.setncatts({k: ds.getncattr(k) for k in ds.ncattrs()}) + rootgrp.close() + + ds.close() # Call the function # split_dataset('input.nc', 'output_{}.nc', 'time', 10) \ No newline at end of file