Commit 645ead77 authored by BorjaEst's avatar BorjaEst
Browse files

Refactor standarization step 1 - rename and sort

parent af9df0c3
"""Module in charge of dataset standardization when loading models."""
import logging
import unittest
import xarray as xr
import pandas as pd
import numpy as np
from o3skim import utils
logger = logging.getLogger('o3skim.standardization')
# tco3 standardization
tco3_standard_name = 'tco3_zm'
tco3_mean_coordinate = 'lon'
tco3_standard_coordinates = [
'time',
'lat',
......@@ -17,7 +20,6 @@ tco3_standard_coordinates = [
# vmro3 standardization
vmro3_standard_name = 'vmro3_zm'
vmro3_mean_coordinate = 'lon'
vmro3_standard_coordinates = [
'time',
'plev',
......@@ -26,64 +28,80 @@ vmro3_standard_coordinates = [
]
@utils.return_on_failure("Error when loading '{0}'".format(tco3_standard_name),
xr.Dataset())
def __load_tco3(name, paths, coordinates):
"""Loads and standarises the tco3 data"""
logger.debug("Standard loading of '{0}' data".format(tco3_standard_name))
with xr.open_mfdataset(paths) as dataset:
dataset = dataset.rename({
**{name: tco3_standard_name},
**{coordinates[x]: x for x in tco3_standard_coordinates}
})[tco3_standard_name].to_dataset()
return dataset.mean(dim=tco3_mean_coordinate)
@utils.return_on_failure("Error when loading '{0}'".format(vmro3_standard_name),
xr.Dataset())
def __load_vmro3(name, paths, coordinates):
"""Loads and standarises the vmro3 data"""
logger.debug("Standard loading of '{0}' data".format(vmro3_standard_name))
with xr.open_mfdataset(paths) as dataset:
dataset = dataset.rename({
**{name: vmro3_standard_name},
**{coordinates[x]: x for x in vmro3_standard_coordinates}
})[vmro3_standard_name].to_dataset()
return dataset.mean(dim=vmro3_mean_coordinate)
# Load case dictionary
__loads = {
tco3_standard_name: __load_tco3,
vmro3_standard_name: __load_vmro3
}
# Non existing variable exception
class UnknownVariable(Exception):
"""To raise if variable to treat is unknown"""
def __init__(self, variable, message="Unknown variable"):
self.variable = variable
self.message = message
super().__init__(self.message)
def load(variable, configuration):
"""Loads and standarises the variable using a specific
configuration.
:param variable: Loadable variable.
:type variable: str
:param configuration: Configuration to apply standardization.
:type configuration: dict
:return: A standardized dataset.
:rtype: xarray.Dataset
"""
try:
function = __loads[variable]
except KeyError:
raise UnknownVariable(variable)
return function(**configuration)
def standardize(dataset, tco3_zm=None, vmro3_zm=None):
pass # TODO
def standardize_vmro3(dataset, variable, coordinates):
"""Standardizes a vmro3 dataset"""
dataset = rename_vmro3(dataset, variable, coordinates)
dataset = sort(dataset)
return dataset
def rename_vmro3(dataset, variable, coordinates):
"""Renames a vmro3 dataset variable and coordinates"""
logger.debug("Rename of '{0}' var and coords".format(vmro3_standard_name))
return dataset.rename({
**{variable: vmro3_standard_name},
**{coordinates[x]: x for x in vmro3_standard_coordinates}
})
def sort(dataset):
"""Sorts a dataset by coordinates"""
logger.debug("Sorting coordinates in dataset")
return dataset.sortby(list(dataset.coords))
def squeeze(dataset):
"""Squeezes the 1size dimensions on a dataset"""
logger.debug("Squeezing coordinates in dataset")
pass # TODO
class TestsVMRO3(unittest.TestCase):
data_s = np.mgrid[1:3:3j, 1:3:3j, 1:4:4j, 1:3:3j][0]
data_r = np.mgrid[3:1:3j, 3:1:3j, 4:1:4j, 3:1:3j][0]
varname = "vmro3"
coords = {'lon': 'longit', 'lat': 'latitu',
'plev': 'level', 'time': 'time'}
@staticmethod
def non_standard_ds():
return xr.Dataset(
data_vars=dict(
vmro3=(["longit", "latitu", "level", "time"], TestsVMRO3.data_r)
),
coords=dict(
longit=[180, 0, -180],
latitu=[90, 0, -90],
level=[1000, 100, 10, 1],
time=pd.date_range("2000-01-03", periods=3, freq='-1d')
),
attrs=dict(description="Non standardized dataset")
)
@staticmethod
def standard_ds():
return xr.Dataset(
data_vars=dict(
vmro3_zm=(["lon", "lat", "plev", "time"], TestsVMRO3.data_s)
),
coords=dict(
time=pd.date_range("2000-01-01", periods=3, freq='1d'),
plev=[1, 10, 100, 1000],
lat=[-90, 0, 90],
lon=[-180, 0, 180]
),
attrs=dict(description="Standardized dataset")
)
def test_standardize(self):
standardized_vmro3 = standardize_vmro3(
dataset=TestsVMRO3.non_standard_ds(),
variable=TestsVMRO3.varname,
coordinates=TestsVMRO3.coords)
xr.testing.assert_equal(TestsVMRO3.standard_ds(), standardized_vmro3)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment