Commit c9c58dfa authored by rink's avatar rink
Browse files

Merge remote-tracking branch 'origin/master'

parents b793b773 90d4171e
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (modules)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.7 (modules)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/modules.iml" filepath="$PROJECT_DIR$/.idea/modules.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>
\ No newline at end of file
This diff is collapsed.
......@@ -25,7 +25,7 @@ gfs_date_format = '%y%m%d'
h4_to_h5_path = home_dir + '/h4toh5convert'
data_dir = '/data1/rink'
data_dir = '/home/rink/data'
converted_file_dir = data_dir + '/gfs_h5'
CACHE_GFS = True
......@@ -549,6 +549,7 @@ def get_bounding_gfs_files(timestamp):
farr = np.array(filelist)
farr = farr[sidxs]
ftimes = tarr[sidxs]
idxs = np.arange(len(filelist))
above = ftimes >= timestamp
if not above.any():
......@@ -559,16 +560,18 @@ def get_bounding_gfs_files(timestamp):
if not below.any():
return None, None, None, None
tL = ftimes[below].max()
iL = np.searchsorted(ftimes, tL, 'left')
iL = idxs[below].max()
iR = iL + 1
fList = farr.tolist()
return fList[iL], ftimes[iL], fList[iR], ftimes[iR]
if timestamp == ftimes[iL]:
return fList[iL], ftimes[iL], None, None
else:
return fList[iL], ftimes[iL], fList[iR], ftimes[iR]
def get_profile(xr_dataset, fld_name, lons, lats, lon360=True):
def get_profile(xr_dataset, fld_name, lons, lats, lon360=True, do_norm=False):
if lon360:
lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360
......@@ -583,6 +586,9 @@ def get_profile(xr_dataset, fld_name, lons, lats, lon360=True):
dim1 = xr.DataArray(lats, dims='k')
intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method='linear')
intrp_fld = intrp_fld.values
if do_norm:
intrp_fld = normalize(intrp_fld, fld_name)
return intrp_fld
......
......@@ -8,7 +8,7 @@ import xarray as xr
import pickle
from deeplearning.amv_raob import get_bounding_gfs_files, convert_file, get_images, get_interpolated_profile, \
split_matchup, shuffle_dict, get_interpolated_scalar, get_num_samples
split_matchup, shuffle_dict, get_interpolated_scalar, get_num_samples, get_time_tuple_utc, get_profile
LOG_DEVICE_PLACEMENT = False
......@@ -273,23 +273,31 @@ class CloudHeightNN:
label.append(tup[2])
sfc.append(tup[3])
continue
print('not found in cache, processing key: ', key)
obs = self.matchup_dict.get(key)
if obs is None:
print('no entry for: ', key)
timestamp = obs[0][0]
print('not found in cache, processing key: ', key, get_time_tuple_utc(timestamp)[0])
gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
if (gfs_0 is None) or (gfs_1 is None):
print('no GFS for: ', timestamp)
if (gfs_0 is None) and (gfs_1 is None):
print('no GFS for: ', get_time_tuple_utc(timestamp)[0])
continue
try:
gfs_0 = convert_file(gfs_0)
if gfs_1 is not None:
gfs_1 = convert_file(gfs_1)
except Exception as exc:
print(get_time_tuple_utc(timestamp)[0])
print(exc)
continue
gfs_0 = convert_file(gfs_0)
gfs_1 = convert_file(gfs_1)
ds_1 = None
try:
ds_0 = xr.open_dataset(gfs_0)
ds_1 = xr.open_dataset(gfs_1)
if gfs_1 is not None:
ds_1 = xr.open_dataset(gfs_1)
except Exception as exc:
print(exc)
continue
......@@ -334,20 +342,26 @@ class CloudHeightNN:
lons = lons[common_idxs]
lats = lats[common_idxs]
ndb = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'temperature', timestamp, lons, lats, do_norm=True)
if ds_1 is not None:
ndb = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'temperature', timestamp, lons, lats, do_norm=True)
else:
ndb = get_profile(ds_0, 'temperature', lons, lats, do_norm=True)
if ndb is None:
continue
ndf = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'rh', timestamp, lons, lats, do_norm=False)
if ds_1 is not None:
ndf = get_interpolated_profile(ds_0, ds_1, time_0, time_1, 'rh', timestamp, lons, lats, do_norm=False)
else:
ndf = get_profile(ds_0, 'rh', lons, lats, do_norm=False)
if ndf is None:
continue
ndf /= 100.0
ndb = np.stack((ndb, ndf), axis=2)
ndd = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'MSL pressure', timestamp, lons, lats, do_norm=False)
ndd /= 1000.0
#ndd = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'MSL pressure', timestamp, lons, lats, do_norm=False)
#ndd /= 1000.0
nde = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'surface temperature', timestamp, lons, lats, do_norm=True)
#nde = get_interpolated_scalar(ds_0, ds_1, time_0, time_1, 'surface temperature', timestamp, lons, lats, do_norm=True)
# label/truth
# Level of best fit (LBF)
......@@ -369,7 +383,8 @@ class CloudHeightNN:
images.append(nda)
vprof.append(ndb)
label.append(ndc)
nds = np.stack([ndd, nde], axis=1)
# nds = np.stack([ndd, nde], axis=1)
nds = np.zeros((len(lons), 2))
sfc.append(nds)
if not CACHE_GFS:
......@@ -379,7 +394,8 @@ class CloudHeightNN:
self.in_mem_data_cache[key] = (nda, ndb, ndc, nds)
ds_0.close()
ds_1.close()
if ds_1 is not None:
ds_1.close()
images = np.concatenate(images)
......@@ -781,15 +797,15 @@ class CloudHeightNN:
print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
ckpt_manager.save()
if DISK_CACHE and epoch == 0:
f = open(cachepath, 'wb')
pickle.dump(self.in_mem_data_cache, f)
f.close()
print('total time: ', total_time)
self.writer_train.close()
self.writer_valid.close()
if DISK_CACHE:
f = open(cachepath, 'wb')
pickle.dump(self.in_mem_data_cache, f)
f.close()
def build_model(self):
flat = self.build_cnn()
flat_1d = self.build_1d_cnn()
......
......@@ -71,6 +71,7 @@ def get_bounding_gfs_files(timestamp):
farr = np.array(filelist)
farr = farr[sidxs]
ftimes = tarr[sidxs]
idxs = np.arange(len(filelist))
above = ftimes >= timestamp
if not above.any():
......@@ -82,7 +83,7 @@ def get_bounding_gfs_files(timestamp):
return None, None, None, None
tL = ftimes[below].max()
iL = np.searchsorted(ftimes, tL, 'left')
iL = idxs[below].max()
iR = iL + 1
fList = farr.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment