diff --git a/modules/util/split_nc4.py b/modules/util/split_nc4.py index f2636a24679ddd4e8c585e3dc4dca4f6cf1c8c79..0063ac7a27807104265a10ad2bd9217ac7aae542 100644 --- a/modules/util/split_nc4.py +++ b/modules/util/split_nc4.py @@ -64,4 +64,72 @@ def concatenate_nc4_files(nc_files, output_file, concat_dim_name='time'): print(f"All files combined and saved to {output_file}") # Call the function -# split_dataset('input.nc', 'output_{}.nc', 'time', 10) \ No newline at end of file +# split_dataset('input.nc', 'output_{}.nc', 'time', 10) + + +def aggregate_output_files(nc_files, output_file, track_dim_name='nj', xtrack_dim_name='ni'): + input_datasets = [nc.Dataset(nc_file, 'r', format='NETCDF4') for nc_file in nc_files] + output_rootgrp = nc.Dataset(output_file, 'w', format='NETCDF4') + + # calculate new track_dim length + track_dim_len = 0 + xtrack_dim_len = None + + for ds in input_datasets: + for name, dim in ds.dimensions.items(): + if name == track_dim_name: + track_dim_len += len(dim) + elif name == xtrack_dim_name and xtrack_dim_len is None: + xtrack_dim_len = len(dim) + + # create the dimensions for the new aggregation target file + output_rootgrp.createDimension('time', 1) + output_rootgrp.createDimension(track_dim_name, track_dim_len) + output_rootgrp.createDimension(xtrack_dim_name, xtrack_dim_len) + + # use the time value from the first file for aggregated time + var = input_datasets[0].variables['time'] + time_var = output_rootgrp.createVariable('time', var.datatype, 'time') + time_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) + + # assign the value of var to time_var + time_var[:] = var[:] + + # create lon, lat variables + var = input_datasets[0].variables['lon'] + lon_var = output_rootgrp.createVariable('lon', var.datatype, ['nj', 'ni']) + lon_var.set_auto_maskandscale(False) + lon_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) + + var = input_datasets[0].variables['lat'] + lat_var = output_rootgrp.createVariable('lat', var.datatype, ['nj', 'ni']) + lat_var.set_auto_maskandscale(False) + lat_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) + + # create the other variables + for name, var in input_datasets[0].variables.items(): + if name not in ['time', 'lon', 'lat']: + out_var = output_rootgrp.createVariable(name, var.datatype, ['time', track_dim_name, xtrack_dim_name]) + out_var.set_auto_maskandscale(False) + out_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()}) + + # copy from input files to the single output file + start_idx = 0 + for ds in input_datasets: + track_len = len(ds.dimensions[track_dim_name]) + lon_var[start_idx:start_idx + track_len, :] = ds.variables['lon'][:, :] + lat_var[start_idx:start_idx + track_len, :] = ds.variables['lat'][:, :] + start_idx += track_len + + start_idx = 0 + for ds in input_datasets: + track_len = len(ds.dimensions[track_dim_name]) + for name, var in ds.variables.items(): + if name not in ['time', 'lon', 'lat']: + out_var = output_rootgrp.variables[name] + out_var[0, start_idx:start_idx + track_len, :] = var[0, :, :] + start_idx += track_len + + output_rootgrp.close() + + return