import netCDF4 as nc import numpy as np import xarray as xr def split_dataset(input_file, output_pattern, dim_name, chunk_size): # Load the input dataset ds = nc.Dataset(input_file, 'r', format='NETCDF4') outer_dim_size = len(ds.dimensions[dim_name]) # Calculate the number of chunks num_chunks = int(np.ceil(outer_dim_size / chunk_size)) # Loop through each chunk for i in range(num_chunks-1): # Determine the start and end indices of this chunk start = i * chunk_size end = min((i + 1) * chunk_size, outer_dim_size) # Slicing along our dimension of interest slice_indices = slice(start, end) # Create a new output file for this chunk output_file = output_pattern.format(i) rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4') # Copy dimensions for name, dim in ds.dimensions.items(): # Adjust the dimension size for the split dimension if name == dim_name: dim_size = len(range(start, end)) else: dim_size = len(dim) if not dim.isunlimited() else None rootgrp.createDimension(name, dim_size) # Copy variables for name, var in ds.variables.items(): var.set_auto_maskandscale(False) outVar = rootgrp.createVariable(name, var.datatype, var.dimensions) outVar.set_auto_maskandscale(False) # Copy variable attributes if name != 'gs_1c_spect': # The original file has bad metadata for this, and possibly other fields outVar.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) # Divide variable data for the split dimension, keep others as is if dim_name in var.dimensions: outVar[:,] = var[slice_indices,] else: outVar[:,] = var[:,] # Copy global attributes rootgrp.setncatts({k: ds.getncattr(k) for k in ds.ncattrs()}) rootgrp.close() ds.close() def concatenate_nc4_files(nc_files, output_file, concat_dim_name='time'): datasets = [xr.open_dataset(nc_file) for nc_file in nc_files] combined = xr.concat(datasets, dim=concat_dim_name) combined.to_netcdf(output_file) print(f"All files combined and saved to {output_file}") # Call the function # split_dataset('input.nc', 'output_{}.nc', 'time', 10) def aggregate_output_files(nc_files, output_file, track_dim_name='nj', xtrack_dim_name='ni'): input_datasets = [nc.Dataset(nc_file, 'r', format='NETCDF4') for nc_file in nc_files] output_rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4') # calculate new track_dim length track_dim_len = 0 xtrack_dim_len = None for ds in input_datasets: for name, dim in ds.dimensions.items(): if name == track_dim_name: track_dim_len += len(dim) elif name == xtrack_dim_name and xtrack_dim_len is None: xtrack_dim_len = len(dim) # create the dimensions for the new aggregation target file output_rootgrp.createDimension('time', 1) output_rootgrp.createDimension(track_dim_name, track_dim_len) output_rootgrp.createDimension(xtrack_dim_name, xtrack_dim_len) # use the time value from the first file for aggregated time var = input_datasets[0].variables['time'] time_var = output_rootgrp.createVariable('time', var.datatype, 'time') time_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) # assign the value of var to time_var time_var[:] = var[:] # create lon, lat variables var = input_datasets[0].variables['lon'] lon_var = output_rootgrp.createVariable('lon', var.datatype, ['nj', 'ni']) lon_var.set_auto_maskandscale(False) lon_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) var = input_datasets[0].variables['lat'] lat_var = output_rootgrp.createVariable('lat', var.datatype, ['nj', 'ni']) lat_var.set_auto_maskandscale(False) lat_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) # create the other variables for name, var in input_datasets[0].variables.items(): if name not in ['time', 'lon', 'lat']: out_var = output_rootgrp.createVariable(name, var.datatype, ['time', track_dim_name, xtrack_dim_name]) out_var.set_auto_maskandscale(False) out_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) # copy from input files to the single output file start_idx = 0 for ds in input_datasets: track_len = len(ds.dimensions[track_dim_name]) lon_var[start_idx:start_idx + track_len, :] = ds.variables['lon'][:, :] lat_var[start_idx:start_idx + track_len, :] = ds.variables['lat'][:, :] start_idx += track_len start_idx = 0 for ds in input_datasets: track_len = len(ds.dimensions[track_dim_name]) for name, var in ds.variables.items(): if name not in ['time', 'lon', 'lat']: out_var = output_rootgrp.variables[name] out_var[0, start_idx:start_idx + track_len, :] = var[0, :, :] start_idx += track_len output_rootgrp.close() return