Commit a47a2a73 authored by BorjaEst's avatar BorjaEst
Browse files

Reimplementation of skim with new modules

parent d6654942
......@@ -43,7 +43,7 @@ class Source:
logging.info("Load source '%s'", self.name)
for name, specifications in collections.items():
logging.info("Load model '%s'", name)
model = __load_model(**specifications)
model = _load_model(**specifications)
if model:
self._models[name] = model
......@@ -52,7 +52,7 @@ class Source:
@property
def models(self):
return self._models.keys()
return list(self._models.keys())
def skim(self, groupby=None):
"""Request to skim all source data into the current folder
......@@ -65,15 +65,15 @@ class Source:
os.makedirs(dirname, exist_ok=True)
logger.info("Skim data from '%s'", dirname)
with utils.cd(dirname):
Skimmed_ds = self[model].model.skim()
Skimmed_ds.to_netcdf(groupby)
_skim(self[model], delta=groupby)
@utils.return_on_failure("Error when loading model", default=None)
def __load_model(tco3_zm=None, vmro3_zm=None):
"""Loads and standarises a dataset using the specs."""
def _load_model(tco3_zm=None, vmro3_zm=None):
"""Loads a model merging standardized data from specified datasets."""
dataset = xr.Dataset()
if tco3_zm:
logger.debug("Loading tco3_zm into model")
with xr.open_mfdataset(tco3_zm['paths']) as load:
standardized = standardization.standardize_tco3(
dataset=load,
......@@ -81,6 +81,7 @@ def __load_model(tco3_zm=None, vmro3_zm=None):
coordinates=tco3_zm['coordinates'])
dataset = dataset.merge(standardized)
if vmro3_zm:
logger.debug("Loading vmro3_zm into model")
with xr.open_mfdataset(vmro3_zm['paths']) as load:
standardized = standardization.standardize_vmro3(
dataset=load,
......@@ -90,6 +91,37 @@ def __load_model(tco3_zm=None, vmro3_zm=None):
return dataset
def _skim(model, delta=None):
"""Skims model producing reduced dataset files"""
logger.debug("Skimming model with delta {}".format(delta))
skimmed = model.model.skim()
if delta == 'year':
def tco3_path(y): return "tco3_zm_{}-{}.nc".format(y, y + 1)
def vmro3_path(y): return "vmro3_zm_{}-{}.nc".format(y, y + 1)
groups = skimmed.model.groupby_year()
elif delta == 'decade':
def tco3_path(y): return "tco3_zm_{}-{}.nc".format(y, y + 10)
def vmro3_path(y): return "vmro3_zm_{}-{}.nc".format(y, y + 10)
groups = skimmed.model.groupby_year()
else:
def tco3_path(_): return "tco3_zm.nc"
def vmro3_path(_): return "vmro3_zm.nc"
groups = [(None, skimmed), ]
years, datasets = zip(*groups)
if skimmed.model.tco3:
logger.debug("Saving skimed tco3 into files")
xr.save_mfdataset(
datasets=[ds.model.tco3 for ds in datasets],
paths=[tco3_path(year) for year in years]
)
if skimmed.model.vmro3:
logger.debug("Saving skimed vmro3 into files")
xr.save_mfdataset(
datasets=[ds.model.vmro3 for ds in datasets],
paths=[vmro3_path(year) for year in years]
)
class TestsSource(unittest.TestCase):
name = "SourceTest"
......@@ -104,7 +136,7 @@ class TestsSource(unittest.TestCase):
self.assertEqual(expected, result)
def test_property_models(self):
expected = TestsSource.collections.keys()
expected = list(TestsSource.collections.keys())
result = self.source.models
self.assertEqual(expected, result)
......
......@@ -12,6 +12,7 @@ import numpy as np
from o3skim import standardization
logger = logging.getLogger('extended_xr')
mean_coord = 'lon'
@xr.register_dataset_accessor("model")
......@@ -22,28 +23,37 @@ class ModelAccessor:
@property
def tco3(self):
"""Return the total ozone column of this dataset."""
return self._model["tco3_zm"]
if "tco3_zm" in list(self._model.var()):
return self._model["tco3_zm"].to_dataset()
else:
return None
@property
def vmro3(self):
"""Return the ozone volume mixing ratio of this dataset."""
return self._model["vmro3_zm"]
if "vmro3_zm" in list(self._model.var()):
return self._model["vmro3_zm"].to_dataset()
else:
return None
def groupby_year(self):
"""Returns a grouped dataset by year"""
logger.debug("Performing group by year on model")
def delta_map(x): return x.year
years = self._model.indexes['time'].map(delta_map)
return self._model.groupby(xr.DataArray(years))
def groupby_decade(self):
"""Returns a grouped dataset by decade"""
logger.debug("Performing group by decade on model")
def delta_map(x): return x.year // 10 * 10
years = self._model.indexes['time'].map(delta_map)
return self._model.groupby(xr.DataArray(years))
def to_netcdf(self, delta=None):
""" """
pass #TODO
def skim(self):
"""Skims model producing reduced dataset"""
logger.debug("Skimming model")
return self._model.mean(mean_coord)
class Tests(unittest.TestCase):
......@@ -116,3 +126,21 @@ class Tests(unittest.TestCase):
for decade, dataset in groups:
self.assertIsInstance(decade, np.int64)
self.assertIsInstance(dataset, xr.Dataset)
def test_skimming(self):
result = self.ds.model.skim()
# Test general coordinates
self.assertIn('time', result.coords)
self.assertIn('lat', result.coords)
self.assertIn('plev', result.coords)
self.assertNotIn('lon', result.coords)
# Test tco3 coordinates
self.assertIn('time', result.model.tco3.coords)
self.assertIn('lat', result.model.tco3.coords)
self.assertNotIn('plev', result.model.tco3.coords)
self.assertNotIn('lon', result.model.tco3.coords)
# Test vmro3 coordinates
self.assertIn('time', result.model.vmro3.coords)
self.assertIn('lat', result.model.vmro3.coords)
self.assertIn('plev', result.model.vmro3.coords)
self.assertNotIn('lon', result.model.vmro3.coords)
......@@ -32,50 +32,48 @@ vmro3_standard_coordinates = [
default=xr.Dataset())
def standardize_tco3(dataset, variable, coordinates):
"""Standardizes a tco3 dataset"""
dataset = squeeze(dataset)
dataset = rename_tco3(dataset, variable, coordinates)
dataset = sort(dataset)
return dataset
array = dataset[variable]
array = squeeze(array)
array = rename_tco3(array, coordinates)
array = sort(array)
return array.to_dataset()
@utils.return_on_failure("Error when loading '{0}'".format(vmro3_standard_name),
default=xr.Dataset())
def standardize_vmro3(dataset, variable, coordinates):
"""Standardizes a vmro3 dataset"""
dataset = squeeze(dataset)
dataset = rename_vmro3(dataset, variable, coordinates)
dataset = sort(dataset)
return dataset
array = dataset[variable]
array = squeeze(array)
array = rename_vmro3(array, coordinates)
array = sort(array)
return array.to_dataset()
def rename_tco3(dataset, variable, coordinates):
"""Renames a tco3 dataset variable and coordinates"""
def rename_tco3(array, coordinates):
"""Renames a tco3 array variable and coordinates"""
logger.debug("Rename of '{0}' var and coords".format(tco3_standard_name))
return dataset.rename({
**{variable: tco3_standard_name},
**{coordinates[x]: x for x in tco3_standard_coordinates}
})
array.name = tco3_standard_name
return array.rename({coordinates[x]: x for x in tco3_standard_coordinates})
def rename_vmro3(dataset, variable, coordinates):
"""Renames a vmro3 dataset variable and coordinates"""
def rename_vmro3(array, coordinates):
"""Renames a vmro3 array variable and coordinates"""
logger.debug("Rename of '{0}' var and coords".format(vmro3_standard_name))
return dataset.rename({
**{variable: vmro3_standard_name},
**{coordinates[x]: x for x in vmro3_standard_coordinates}
})
array.name = vmro3_standard_name
return array.rename({coordinates[x]: x for x in vmro3_standard_coordinates})
def squeeze(dataset):
"""Squeezes the 1size dimensions on a dataset"""
def squeeze(array):
"""Squeezes the 1-size dimensions on an array"""
logger.debug("Squeezing coordinates in dataset")
return dataset.squeeze(drop=True)
return array.squeeze(drop=True)
def sort(dataset):
"""Sorts a dataset by coordinates"""
def sort(array):
"""Sorts an array by coordinates"""
logger.debug("Sorting coordinates in dataset")
return dataset.sortby(list(dataset.coords))
return array.sortby(list(array.coords))
class TestsTCO3(unittest.TestCase):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment