Commit b0b75bdd authored by BorjaEst's avatar BorjaEst
Browse files

Implement forwarding of attrs

parent 112a9839
......@@ -9,7 +9,8 @@ import unittest
import xarray as xr
import pandas as pd
import numpy as np
from o3skim import standardization
from xarray.core.dataset import Dataset
from o3skim import standardization, utils
logger = logging.getLogger('extended_xr')
mean_coord = 'lon'
......@@ -19,12 +20,15 @@ mean_coord = 'lon'
class ModelAccessor:
def __init__(self, xarray_obj):
self._model = xarray_obj
self._metadata = {}
@property
def tco3(self):
"""Return the total ozone column of this dataset."""
if "tco3_zm" in list(self._model.var()):
return self._model["tco3_zm"].to_dataset()
dataset = self._model["tco3_zm"].to_dataset()
dataset.attrs = self._model.attrs
return dataset
else:
return None
......@@ -32,17 +36,24 @@ class ModelAccessor:
def vmro3(self):
"""Return the ozone volume mixing ratio of this dataset."""
if "vmro3_zm" in list(self._model.var()):
return self._model["vmro3_zm"].to_dataset()
dataset = self._model["vmro3_zm"].to_dataset()
dataset.attrs = self._model.attrs
return dataset
else:
return None
@property
def metadata(self):
"""Return the ozone volume mixing ratio of this dataset."""
result = self._model.attrs
for var in self._model.var():
result = {**result, var: self._model[var].attrs}
return result
"""Returns the metadata property"""
return self._metadata
def add_metadata(self, metadata):
"""Merges the input metadata with the model metadata"""
utils.mergedicts(self._metadata, metadata)
def set_metadata(self, metadata):
"""Sets the metadata to the input variable."""
self._metadata = metadata
def groupby_year(self):
"""Returns a grouped dataset by year"""
......@@ -61,5 +72,8 @@ class ModelAccessor:
def skim(self):
"""Skims model producing reduced dataset"""
logger.debug("Skimming model")
return self._model.mean(mean_coord)
skimmed = self._model.mean(mean_coord)
skimmed.attrs = self._model.attrs
for var in self._model:
skimmed[var].attrs = self._model[var].attrs
return skimmed
......@@ -115,25 +115,28 @@ def _load_model(tco3_zm=None, vmro3_zm=None, metadata={}):
:return: Dataset with specified variables.
:rtype: xarray.Dataset
"""
dataset = xr.Dataset(attrs=metadata)
dataset = xr.Dataset()
dataset.model.set_metadata(metadata)
def conflict(d1, d2): raise Exception(
"Conflict merging {}, {}".format(d1, d2))
if tco3_zm:
logger.debug("Loading tco3_zm into model")
with xr.open_mfdataset(tco3_zm['paths']) as load:
standardized = standardization.standardize_tco3(
dataset=load,
variable=tco3_zm['name'],
dataset['tco3_zm'] = standardization.standardize_tco3(
array=load[tco3_zm['name']],
coordinates=tco3_zm['coordinates'])
dataset = dataset.merge(standardized)
dataset.tco3_zm.attrs = tco3_zm.get('metadata', {})
metadata = {'tco3_zm': tco3_zm.get('metadata', {})}
dataset.model.add_metadata(metadata)
utils.mergedicts(dataset.attrs, load.attrs, if_conflict=conflict)
if vmro3_zm:
logger.debug("Loading vmro3_zm into model")
with xr.open_mfdataset(vmro3_zm['paths']) as load:
standardized = standardization.standardize_vmro3(
dataset=load,
variable=vmro3_zm['name'],
dataset['vmro3_zm'] = standardization.standardize_vmro3(
array=load[vmro3_zm['name']],
coordinates=vmro3_zm['coordinates'])
dataset = dataset.merge(standardized)
dataset.vmro3_zm.attrs = vmro3_zm.get('metadata', {})
metadata = {'vmro3_zm': vmro3_zm.get('metadata', {})}
dataset.model.add_metadata(metadata)
utils.mergedicts(dataset.attrs, load.attrs, if_conflict=conflict)
return dataset
......
......@@ -13,27 +13,23 @@ logger = logging.getLogger('o3skim.standardization')
@utils.return_on_failure("Error when loading '{0}'".format('tco3_zm'),
default=xr.Dataset())
def standardize_tco3(dataset, variable, coordinates):
def standardize_tco3(array, coordinates):
"""Standardizes a tco3 dataset.
:param dataset: Dataset to standardize.
:type dataset: xarray.Dataset
:param variable: Variable name for the tco3 on the original dataset.
:type variable: str
:param array: DataArray to standardize.
:type array: xarray.DataArray
:param coordinates: Coordinates map for tco3 variable.
:type coordinates: {'lon':str, 'lat':str, 'time':str}
:return: Standardized dataset.
:rtype: xarray.Dataset
:return: Standardized DataArray.
:rtype: xarray.DataArray
"""
array = dataset[variable]
array.name = 'tco3_zm'
array = squeeze(array)
array = rename_coords_tco3(array, **coordinates)
array = sort(array)
return array.to_dataset()
return array
def rename_coords_tco3(array, time, lat, lon):
......@@ -44,27 +40,23 @@ def rename_coords_tco3(array, time, lat, lon):
@utils.return_on_failure("Error when loading '{0}'".format('vmro3_zm'),
default=xr.Dataset())
def standardize_vmro3(dataset, variable, coordinates):
def standardize_vmro3(array, coordinates):
"""Standardizes a vmro3 dataset.
:param dataset: Dataset to standardize.
:type dataset: xarray.Dataset
:param variable: Variable name for the vmro3 on the original dataset.
:type variable: str
:param array: DataArray to standardize.
:type array: xarray.DataArray
:param coordinates: Coordinates map for vmro3 variable.
:type coordinates: {'lon':str, 'lat':str, 'plev':str, 'time':str}
:return: Standardized dataset.
:rtype: xarray.Dataset
:return: Standardized DataArray.
:rtype: xarray.DataArray
"""
array = dataset[variable]
array.name = 'vmro3_zm'
array = squeeze(array)
array = rename_coords_vmro3(array, **coordinates)
array = sort(array)
return array.to_dataset()
return array
def rename_coords_vmro3(array, time, plev, lat, lon):
......
......@@ -56,16 +56,32 @@ class Tests(unittest.TestCase):
def test_tco3_property(self):
expected = tco3_datarray.to_dataset(name="tco3_zm")
xr.testing.assert_equal(dataset.model.tco3, expected)
self.assertEqual(dataset.model.tco3.attrs, dataset.attrs)
def test_vmro3_property(self):
expected = vmro3_datarray.to_dataset(name="vmro3_zm")
xr.testing.assert_equal(dataset.model.vmro3, expected)
self.assertEqual(dataset.model.vmro3.attrs, dataset.attrs)
def test_metadata_property(self):
meta = dataset.model.metadata
self.assertEqual(meta["description"], "Test dataset")
self.assertEqual(meta["tco3_zm"]["description"], "Test tco3 xarray")
self.assertEqual(meta["vmro3_zm"]["description"], "Test vmro3 xarray")
model = dataset.copy().model
self.assertEqual(model.metadata, {})
def test_add_metadata(self):
model = dataset.copy().model
self.assertEqual(model.metadata, {})
model.add_metadata(dict(d1=1))
self.assertEqual(model.metadata, dict(d1=1))
model.add_metadata(dict(d2=2))
self.assertEqual(model.metadata, dict(d1=1, d2=2))
def test_set_metadata(self):
model = dataset.copy().model
self.assertEqual(model.metadata, {})
model.set_metadata(dict(d1=1))
self.assertEqual(model.metadata, dict(d1=1))
model.set_metadata(dict(d2=2))
self.assertEqual(model.metadata, dict(d2=2))
def test_groupby_year(self):
groups = dataset.model.groupby_year()
......@@ -101,3 +117,9 @@ class Tests(unittest.TestCase):
self.assertIn('lat', result.model.vmro3.coords)
self.assertIn('plev', result.model.vmro3.coords)
self.assertNotIn('lon', result.model.vmro3.coords)
def test_skimming_attrs(self):
skimmed = dataset.model.skim()
self.assertEqual(skimmed.attrs, dataset.attrs)
for var in dataset:
self.assertEqual(skimmed[var].attrs, dataset[var].attrs)
......@@ -9,99 +9,82 @@ standardize_tco3 = standardization.standardize_tco3
standardize_vmro3 = standardization.standardize_vmro3
tco3_data_s = np.mgrid[1:3:3j, 1:3:3j, 1:3:3j][0]
tco3_data_r = np.mgrid[3:1:3j, 3:1:3j, 1:1:1j, 3:1:3j][0]
tco3_varname = "tco3"
tco3_coords = {'lon': 'long', 'lat': 'latd', 'time': 'time'}
tco3_nonstd = xr.Dataset(
data_vars=dict(
tco3=(["long", "latd", "high", "time"], tco3_data_r)
),
tco3_nonstd = xr.DataArray(
data=np.mgrid[3:1:3j, 3:1:3j, 1:1:1j, 3:1:3j][0],
dims=["long", "latd", "high", "time"],
coords=dict(
long=[180, 0, -180],
latd=[90, 0, -90],
high=[1],
time=pd.date_range("2000-01-03", periods=3, freq='-1d')
),
attrs=dict(description="Non standardized dataset")
attrs=dict(description="Non standardized DataArray")
)
tco3_standard = xr.Dataset(
data_vars=dict(
tco3_zm=(["lon", "lat", "time"], tco3_data_s)
),
tco3_standard = xr.DataArray(
data=np.mgrid[1:3:3j, 1:3:3j, 1:3:3j][0],
dims=["lon", "lat", "time"],
coords=dict(
time=pd.date_range("2000-01-01", periods=3, freq='1d'),
lat=[-90, 0, 90],
lon=[-180, 0, 180]
),
attrs=dict(description="Standardized dataset")
attrs=dict(description="Standardized DataArray")
)
class TestsTCO3(unittest.TestCase):
def test_standardize(self):
coords = {'lon': 'long', 'lat': 'latd', 'time': 'time'}
standardized_tco3 = standardize_tco3(
dataset=tco3_nonstd,
variable=tco3_varname,
coordinates=tco3_coords)
array=tco3_nonstd,
coordinates=coords)
xr.testing.assert_equal(tco3_standard, standardized_tco3)
def test_fail_returns_empty_dataset(self):
empty_dataset = standardize_tco3(
dataset=tco3_nonstd,
variable="badVariable",
coordinates=tco3_coords)
array=tco3_nonstd,
coordinates="badCoords")
xr.testing.assert_equal(xr.Dataset(), empty_dataset)
vmro3_data_s = np.mgrid[1:3:3j, 1:3:3j, 1:4:4j, 1:3:3j][0]
vmro3_data_r = np.mgrid[3:1:3j, 3:1:3j, 4:1:4j, 3:1:3j][0]
vmro3_varname = "vmro3"
vmro3_coords = {'lon': 'longit', 'lat': 'latitu',
'plev': 'level', 'time': 't'}
vmro3_nonstd = xr.Dataset(
data_vars=dict(
vmro3=(["longit", "latitu", "level", "t"], vmro3_data_r)
),
vmro3_nonstd = xr.DataArray(
data=np.mgrid[3:1:3j, 3:1:3j, 4:1:4j, 3:1:3j][0],
dims=["lo", "la", "lv", "t"],
coords=dict(
longit=[180, 0, -180],
latitu=[90, 0, -90],
level=[1000, 100, 10, 1],
lo=[180, 0, -180],
la=[90, 0, -90],
lv=[1000, 100, 10, 1],
t=pd.date_range("2000-01-03", periods=3, freq='-1d')
),
attrs=dict(description="Non standardized dataset")
attrs=dict(description="Non standardized DataArray")
)
vmro3_standard = xr.Dataset(
data_vars=dict(
vmro3_zm=(["lon", "lat", "plev", "time"], vmro3_data_s)
),
vmro3_standard = xr.DataArray(
data=np.mgrid[1:3:3j, 1:3:3j, 1:4:4j, 1:3:3j][0],
dims=["lon", "lat", "plev", "time"],
coords=dict(
time=pd.date_range("2000-01-01", periods=3, freq='1d'),
plev=[1, 10, 100, 1000],
lat=[-90, 0, 90],
lon=[-180, 0, 180]
),
attrs=dict(description="Standardized dataset")
attrs=dict(description="Standardized DataArray")
)
class TestsVMRO3(unittest.TestCase):
def test_standardize(self):
coords = {'lon': 'lo', 'lat': 'la', 'plev': 'lv', 'time': 't'}
standardized_vmro3 = standardize_vmro3(
dataset=vmro3_nonstd,
variable=vmro3_varname,
coordinates=vmro3_coords)
array=vmro3_nonstd,
coordinates=coords)
xr.testing.assert_equal(vmro3_standard, standardized_vmro3)
def test_fail_returns_empty_dataset(self):
empty_dataset = standardize_vmro3(
dataset=vmro3_nonstd,
variable="badVariable",
coordinates=vmro3_coords)
array=vmro3_nonstd,
coordinates="badCoords")
xr.testing.assert_equal(xr.Dataset(), empty_dataset)
import copy
import unittest
from unittest.case import expectedFailure
from o3skim import utils
......@@ -27,3 +28,9 @@ class Tests_mergedict(unittest.TestCase):
self.assertEqual(dict_3['b'], 2)
self.assertEqual(dict_3['c'], 0)
self.assertEqual(dict_3['z'], {'a': 1, 'b': 2, 'c': 0})
def test_merge_with_exception(self):
def raise_exception(x, y): raise Exception(x, y)
with self.assertRaises(Exception) as cm:
utils.mergedicts({'a': 1}, {'a': 2}, raise_exception)
self.assertEqual(cm.exception.args, (1, 2))
......@@ -91,7 +91,7 @@ def save(file_name, metadata):
yaml.dump(metadata, ymlfile, allow_unicode=True)
def mergedicts(d1, d2):
def mergedicts(d1, d2, if_conflict=lambda _, d: d):
"""Merges dict d2 in dict d2 recursively. If two keys exist in
both dicts, the value in d1 is superseded by the value in d2.
......@@ -102,7 +102,12 @@ def mergedicts(d1, d2):
:type d2: dict
"""
for key in d2:
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict):
mergedicts(d1[key], d2[key])
if key in d1:
if isinstance(d1[key], dict) and isinstance(d2[key], dict):
mergedicts(d1[key], d2[key], if_conflict)
elif d1[key] == d2[key]:
pass # same leaf value
else:
d1[key] = if_conflict(d1[key], d2[key])
else:
d1[key] = d2[key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment