Commit b0b75bdd authored by BorjaEst's avatar BorjaEst
Browse files

Implement forwarding of attrs

parent 112a9839
...@@ -9,7 +9,8 @@ import unittest ...@@ -9,7 +9,8 @@ import unittest
import xarray as xr import xarray as xr
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from o3skim import standardization from xarray.core.dataset import Dataset
from o3skim import standardization, utils
logger = logging.getLogger('extended_xr') logger = logging.getLogger('extended_xr')
mean_coord = 'lon' mean_coord = 'lon'
...@@ -19,12 +20,15 @@ mean_coord = 'lon' ...@@ -19,12 +20,15 @@ mean_coord = 'lon'
class ModelAccessor: class ModelAccessor:
def __init__(self, xarray_obj): def __init__(self, xarray_obj):
self._model = xarray_obj self._model = xarray_obj
self._metadata = {}
@property @property
def tco3(self): def tco3(self):
"""Return the total ozone column of this dataset.""" """Return the total ozone column of this dataset."""
if "tco3_zm" in list(self._model.var()): if "tco3_zm" in list(self._model.var()):
return self._model["tco3_zm"].to_dataset() dataset = self._model["tco3_zm"].to_dataset()
dataset.attrs = self._model.attrs
return dataset
else: else:
return None return None
...@@ -32,17 +36,24 @@ class ModelAccessor: ...@@ -32,17 +36,24 @@ class ModelAccessor:
def vmro3(self): def vmro3(self):
"""Return the ozone volume mixing ratio of this dataset.""" """Return the ozone volume mixing ratio of this dataset."""
if "vmro3_zm" in list(self._model.var()): if "vmro3_zm" in list(self._model.var()):
return self._model["vmro3_zm"].to_dataset() dataset = self._model["vmro3_zm"].to_dataset()
dataset.attrs = self._model.attrs
return dataset
else: else:
return None return None
@property @property
def metadata(self): def metadata(self):
"""Return the ozone volume mixing ratio of this dataset.""" """Returns the metadata property"""
result = self._model.attrs return self._metadata
for var in self._model.var():
result = {**result, var: self._model[var].attrs} def add_metadata(self, metadata):
return result """Merges the input metadata with the model metadata"""
utils.mergedicts(self._metadata, metadata)
def set_metadata(self, metadata):
"""Sets the metadata to the input variable."""
self._metadata = metadata
def groupby_year(self): def groupby_year(self):
"""Returns a grouped dataset by year""" """Returns a grouped dataset by year"""
...@@ -61,5 +72,8 @@ class ModelAccessor: ...@@ -61,5 +72,8 @@ class ModelAccessor:
def skim(self): def skim(self):
"""Skims model producing reduced dataset""" """Skims model producing reduced dataset"""
logger.debug("Skimming model") logger.debug("Skimming model")
return self._model.mean(mean_coord) skimmed = self._model.mean(mean_coord)
skimmed.attrs = self._model.attrs
for var in self._model:
skimmed[var].attrs = self._model[var].attrs
return skimmed
...@@ -115,25 +115,28 @@ def _load_model(tco3_zm=None, vmro3_zm=None, metadata={}): ...@@ -115,25 +115,28 @@ def _load_model(tco3_zm=None, vmro3_zm=None, metadata={}):
:return: Dataset with specified variables. :return: Dataset with specified variables.
:rtype: xarray.Dataset :rtype: xarray.Dataset
""" """
dataset = xr.Dataset(attrs=metadata) dataset = xr.Dataset()
dataset.model.set_metadata(metadata)
def conflict(d1, d2): raise Exception(
"Conflict merging {}, {}".format(d1, d2))
if tco3_zm: if tco3_zm:
logger.debug("Loading tco3_zm into model") logger.debug("Loading tco3_zm into model")
with xr.open_mfdataset(tco3_zm['paths']) as load: with xr.open_mfdataset(tco3_zm['paths']) as load:
standardized = standardization.standardize_tco3( dataset['tco3_zm'] = standardization.standardize_tco3(
dataset=load, array=load[tco3_zm['name']],
variable=tco3_zm['name'],
coordinates=tco3_zm['coordinates']) coordinates=tco3_zm['coordinates'])
dataset = dataset.merge(standardized) metadata = {'tco3_zm': tco3_zm.get('metadata', {})}
dataset.tco3_zm.attrs = tco3_zm.get('metadata', {}) dataset.model.add_metadata(metadata)
utils.mergedicts(dataset.attrs, load.attrs, if_conflict=conflict)
if vmro3_zm: if vmro3_zm:
logger.debug("Loading vmro3_zm into model") logger.debug("Loading vmro3_zm into model")
with xr.open_mfdataset(vmro3_zm['paths']) as load: with xr.open_mfdataset(vmro3_zm['paths']) as load:
standardized = standardization.standardize_vmro3( dataset['vmro3_zm'] = standardization.standardize_vmro3(
dataset=load, array=load[vmro3_zm['name']],
variable=vmro3_zm['name'],
coordinates=vmro3_zm['coordinates']) coordinates=vmro3_zm['coordinates'])
dataset = dataset.merge(standardized) metadata = {'vmro3_zm': vmro3_zm.get('metadata', {})}
dataset.vmro3_zm.attrs = vmro3_zm.get('metadata', {}) dataset.model.add_metadata(metadata)
utils.mergedicts(dataset.attrs, load.attrs, if_conflict=conflict)
return dataset return dataset
......
...@@ -13,27 +13,23 @@ logger = logging.getLogger('o3skim.standardization') ...@@ -13,27 +13,23 @@ logger = logging.getLogger('o3skim.standardization')
@utils.return_on_failure("Error when loading '{0}'".format('tco3_zm'), @utils.return_on_failure("Error when loading '{0}'".format('tco3_zm'),
default=xr.Dataset()) default=xr.Dataset())
def standardize_tco3(dataset, variable, coordinates): def standardize_tco3(array, coordinates):
"""Standardizes a tco3 dataset. """Standardizes a tco3 dataset.
:param dataset: Dataset to standardize. :param array: DataArray to standardize.
:type dataset: xarray.Dataset :type array: xarray.DataArray
:param variable: Variable name for the tco3 on the original dataset.
:type variable: str
:param coordinates: Coordinates map for tco3 variable. :param coordinates: Coordinates map for tco3 variable.
:type coordinates: {'lon':str, 'lat':str, 'time':str} :type coordinates: {'lon':str, 'lat':str, 'time':str}
:return: Standardized dataset. :return: Standardized DataArray.
:rtype: xarray.Dataset :rtype: xarray.DataArray
""" """
array = dataset[variable]
array.name = 'tco3_zm' array.name = 'tco3_zm'
array = squeeze(array) array = squeeze(array)
array = rename_coords_tco3(array, **coordinates) array = rename_coords_tco3(array, **coordinates)
array = sort(array) array = sort(array)
return array.to_dataset() return array
def rename_coords_tco3(array, time, lat, lon): def rename_coords_tco3(array, time, lat, lon):
...@@ -44,27 +40,23 @@ def rename_coords_tco3(array, time, lat, lon): ...@@ -44,27 +40,23 @@ def rename_coords_tco3(array, time, lat, lon):
@utils.return_on_failure("Error when loading '{0}'".format('vmro3_zm'), @utils.return_on_failure("Error when loading '{0}'".format('vmro3_zm'),
default=xr.Dataset()) default=xr.Dataset())
def standardize_vmro3(dataset, variable, coordinates): def standardize_vmro3(array, coordinates):
"""Standardizes a vmro3 dataset. """Standardizes a vmro3 dataset.
:param dataset: Dataset to standardize. :param array: DataArray to standardize.
:type dataset: xarray.Dataset :type array: xarray.DataArray
:param variable: Variable name for the vmro3 on the original dataset.
:type variable: str
:param coordinates: Coordinates map for vmro3 variable. :param coordinates: Coordinates map for vmro3 variable.
:type coordinates: {'lon':str, 'lat':str, 'plev':str, 'time':str} :type coordinates: {'lon':str, 'lat':str, 'plev':str, 'time':str}
:return: Standardized dataset. :return: Standardized DataArray.
:rtype: xarray.Dataset :rtype: xarray.DataArray
""" """
array = dataset[variable]
array.name = 'vmro3_zm' array.name = 'vmro3_zm'
array = squeeze(array) array = squeeze(array)
array = rename_coords_vmro3(array, **coordinates) array = rename_coords_vmro3(array, **coordinates)
array = sort(array) array = sort(array)
return array.to_dataset() return array
def rename_coords_vmro3(array, time, plev, lat, lon): def rename_coords_vmro3(array, time, plev, lat, lon):
......
...@@ -56,16 +56,32 @@ class Tests(unittest.TestCase): ...@@ -56,16 +56,32 @@ class Tests(unittest.TestCase):
def test_tco3_property(self): def test_tco3_property(self):
expected = tco3_datarray.to_dataset(name="tco3_zm") expected = tco3_datarray.to_dataset(name="tco3_zm")
xr.testing.assert_equal(dataset.model.tco3, expected) xr.testing.assert_equal(dataset.model.tco3, expected)
self.assertEqual(dataset.model.tco3.attrs, dataset.attrs)
def test_vmro3_property(self): def test_vmro3_property(self):
expected = vmro3_datarray.to_dataset(name="vmro3_zm") expected = vmro3_datarray.to_dataset(name="vmro3_zm")
xr.testing.assert_equal(dataset.model.vmro3, expected) xr.testing.assert_equal(dataset.model.vmro3, expected)
self.assertEqual(dataset.model.vmro3.attrs, dataset.attrs)
def test_metadata_property(self): def test_metadata_property(self):
meta = dataset.model.metadata model = dataset.copy().model
self.assertEqual(meta["description"], "Test dataset") self.assertEqual(model.metadata, {})
self.assertEqual(meta["tco3_zm"]["description"], "Test tco3 xarray")
self.assertEqual(meta["vmro3_zm"]["description"], "Test vmro3 xarray") def test_add_metadata(self):
model = dataset.copy().model
self.assertEqual(model.metadata, {})
model.add_metadata(dict(d1=1))
self.assertEqual(model.metadata, dict(d1=1))
model.add_metadata(dict(d2=2))
self.assertEqual(model.metadata, dict(d1=1, d2=2))
def test_set_metadata(self):
model = dataset.copy().model
self.assertEqual(model.metadata, {})
model.set_metadata(dict(d1=1))
self.assertEqual(model.metadata, dict(d1=1))
model.set_metadata(dict(d2=2))
self.assertEqual(model.metadata, dict(d2=2))
def test_groupby_year(self): def test_groupby_year(self):
groups = dataset.model.groupby_year() groups = dataset.model.groupby_year()
...@@ -101,3 +117,9 @@ class Tests(unittest.TestCase): ...@@ -101,3 +117,9 @@ class Tests(unittest.TestCase):
self.assertIn('lat', result.model.vmro3.coords) self.assertIn('lat', result.model.vmro3.coords)
self.assertIn('plev', result.model.vmro3.coords) self.assertIn('plev', result.model.vmro3.coords)
self.assertNotIn('lon', result.model.vmro3.coords) self.assertNotIn('lon', result.model.vmro3.coords)
def test_skimming_attrs(self):
skimmed = dataset.model.skim()
self.assertEqual(skimmed.attrs, dataset.attrs)
for var in dataset:
self.assertEqual(skimmed[var].attrs, dataset[var].attrs)
...@@ -9,99 +9,82 @@ standardize_tco3 = standardization.standardize_tco3 ...@@ -9,99 +9,82 @@ standardize_tco3 = standardization.standardize_tco3
standardize_vmro3 = standardization.standardize_vmro3 standardize_vmro3 = standardization.standardize_vmro3
tco3_data_s = np.mgrid[1:3:3j, 1:3:3j, 1:3:3j][0] tco3_nonstd = xr.DataArray(
tco3_data_r = np.mgrid[3:1:3j, 3:1:3j, 1:1:1j, 3:1:3j][0] data=np.mgrid[3:1:3j, 3:1:3j, 1:1:1j, 3:1:3j][0],
tco3_varname = "tco3" dims=["long", "latd", "high", "time"],
tco3_coords = {'lon': 'long', 'lat': 'latd', 'time': 'time'}
tco3_nonstd = xr.Dataset(
data_vars=dict(
tco3=(["long", "latd", "high", "time"], tco3_data_r)
),
coords=dict( coords=dict(
long=[180, 0, -180], long=[180, 0, -180],
latd=[90, 0, -90], latd=[90, 0, -90],
high=[1], high=[1],
time=pd.date_range("2000-01-03", periods=3, freq='-1d') time=pd.date_range("2000-01-03", periods=3, freq='-1d')
), ),
attrs=dict(description="Non standardized dataset") attrs=dict(description="Non standardized DataArray")
) )
tco3_standard = xr.Dataset( tco3_standard = xr.DataArray(
data_vars=dict( data=np.mgrid[1:3:3j, 1:3:3j, 1:3:3j][0],
tco3_zm=(["lon", "lat", "time"], tco3_data_s) dims=["lon", "lat", "time"],
),
coords=dict( coords=dict(
time=pd.date_range("2000-01-01", periods=3, freq='1d'), time=pd.date_range("2000-01-01", periods=3, freq='1d'),
lat=[-90, 0, 90], lat=[-90, 0, 90],
lon=[-180, 0, 180] lon=[-180, 0, 180]
), ),
attrs=dict(description="Standardized dataset") attrs=dict(description="Standardized DataArray")
) )
class TestsTCO3(unittest.TestCase): class TestsTCO3(unittest.TestCase):
def test_standardize(self): def test_standardize(self):
coords = {'lon': 'long', 'lat': 'latd', 'time': 'time'}
standardized_tco3 = standardize_tco3( standardized_tco3 = standardize_tco3(
dataset=tco3_nonstd, array=tco3_nonstd,
variable=tco3_varname, coordinates=coords)
coordinates=tco3_coords)
xr.testing.assert_equal(tco3_standard, standardized_tco3) xr.testing.assert_equal(tco3_standard, standardized_tco3)
def test_fail_returns_empty_dataset(self): def test_fail_returns_empty_dataset(self):
empty_dataset = standardize_tco3( empty_dataset = standardize_tco3(
dataset=tco3_nonstd, array=tco3_nonstd,
variable="badVariable", coordinates="badCoords")
coordinates=tco3_coords)
xr.testing.assert_equal(xr.Dataset(), empty_dataset) xr.testing.assert_equal(xr.Dataset(), empty_dataset)
vmro3_data_s = np.mgrid[1:3:3j, 1:3:3j, 1:4:4j, 1:3:3j][0] vmro3_nonstd = xr.DataArray(
vmro3_data_r = np.mgrid[3:1:3j, 3:1:3j, 4:1:4j, 3:1:3j][0] data=np.mgrid[3:1:3j, 3:1:3j, 4:1:4j, 3:1:3j][0],
vmro3_varname = "vmro3" dims=["lo", "la", "lv", "t"],
vmro3_coords = {'lon': 'longit', 'lat': 'latitu',
'plev': 'level', 'time': 't'}
vmro3_nonstd = xr.Dataset(
data_vars=dict(
vmro3=(["longit", "latitu", "level", "t"], vmro3_data_r)
),
coords=dict( coords=dict(
longit=[180, 0, -180], lo=[180, 0, -180],
latitu=[90, 0, -90], la=[90, 0, -90],
level=[1000, 100, 10, 1], lv=[1000, 100, 10, 1],
t=pd.date_range("2000-01-03", periods=3, freq='-1d') t=pd.date_range("2000-01-03", periods=3, freq='-1d')
), ),
attrs=dict(description="Non standardized dataset") attrs=dict(description="Non standardized DataArray")
) )
vmro3_standard = xr.Dataset( vmro3_standard = xr.DataArray(
data_vars=dict( data=np.mgrid[1:3:3j, 1:3:3j, 1:4:4j, 1:3:3j][0],
vmro3_zm=(["lon", "lat", "plev", "time"], vmro3_data_s) dims=["lon", "lat", "plev", "time"],
),
coords=dict( coords=dict(
time=pd.date_range("2000-01-01", periods=3, freq='1d'), time=pd.date_range("2000-01-01", periods=3, freq='1d'),
plev=[1, 10, 100, 1000], plev=[1, 10, 100, 1000],
lat=[-90, 0, 90], lat=[-90, 0, 90],
lon=[-180, 0, 180] lon=[-180, 0, 180]
), ),
attrs=dict(description="Standardized dataset") attrs=dict(description="Standardized DataArray")
) )
class TestsVMRO3(unittest.TestCase): class TestsVMRO3(unittest.TestCase):
def test_standardize(self): def test_standardize(self):
coords = {'lon': 'lo', 'lat': 'la', 'plev': 'lv', 'time': 't'}
standardized_vmro3 = standardize_vmro3( standardized_vmro3 = standardize_vmro3(
dataset=vmro3_nonstd, array=vmro3_nonstd,
variable=vmro3_varname, coordinates=coords)
coordinates=vmro3_coords)
xr.testing.assert_equal(vmro3_standard, standardized_vmro3) xr.testing.assert_equal(vmro3_standard, standardized_vmro3)
def test_fail_returns_empty_dataset(self): def test_fail_returns_empty_dataset(self):
empty_dataset = standardize_vmro3( empty_dataset = standardize_vmro3(
dataset=vmro3_nonstd, array=vmro3_nonstd,
variable="badVariable", coordinates="badCoords")
coordinates=vmro3_coords)
xr.testing.assert_equal(xr.Dataset(), empty_dataset) xr.testing.assert_equal(xr.Dataset(), empty_dataset)
import copy import copy
import unittest import unittest
from unittest.case import expectedFailure
from o3skim import utils from o3skim import utils
...@@ -27,3 +28,9 @@ class Tests_mergedict(unittest.TestCase): ...@@ -27,3 +28,9 @@ class Tests_mergedict(unittest.TestCase):
self.assertEqual(dict_3['b'], 2) self.assertEqual(dict_3['b'], 2)
self.assertEqual(dict_3['c'], 0) self.assertEqual(dict_3['c'], 0)
self.assertEqual(dict_3['z'], {'a': 1, 'b': 2, 'c': 0}) self.assertEqual(dict_3['z'], {'a': 1, 'b': 2, 'c': 0})
def test_merge_with_exception(self):
def raise_exception(x, y): raise Exception(x, y)
with self.assertRaises(Exception) as cm:
utils.mergedicts({'a': 1}, {'a': 2}, raise_exception)
self.assertEqual(cm.exception.args, (1, 2))
...@@ -91,7 +91,7 @@ def save(file_name, metadata): ...@@ -91,7 +91,7 @@ def save(file_name, metadata):
yaml.dump(metadata, ymlfile, allow_unicode=True) yaml.dump(metadata, ymlfile, allow_unicode=True)
def mergedicts(d1, d2): def mergedicts(d1, d2, if_conflict=lambda _, d: d):
"""Merges dict d2 in dict d2 recursively. If two keys exist in """Merges dict d2 in dict d2 recursively. If two keys exist in
both dicts, the value in d1 is superseded by the value in d2. both dicts, the value in d1 is superseded by the value in d2.
...@@ -102,7 +102,12 @@ def mergedicts(d1, d2): ...@@ -102,7 +102,12 @@ def mergedicts(d1, d2):
:type d2: dict :type d2: dict
""" """
for key in d2: for key in d2:
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict): if key in d1:
mergedicts(d1[key], d2[key]) if isinstance(d1[key], dict) and isinstance(d2[key], dict):
mergedicts(d1[key], d2[key], if_conflict)
elif d1[key] == d2[key]:
pass # same leaf value
else:
d1[key] = if_conflict(d1[key], d2[key])
else: else:
d1[key] = d2[key] d1[key] = d2[key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment