Source code for cmip6_preprocessing.grids

import warnings

import numpy as np
import pkg_resources
import xarray as xr
import yaml

from xgcm import Grid
from xgcm.autogenerate import generate_grid_ds


path = "specs/staggered_grid_config.yaml"  # always use slash
grid_spec = pkg_resources.resource_filename(__name__, path)


def _parse_bounds_vertex(da, dim="bnds", position=[0, 1]):
    """Convenience function to extract positions from bounds/verticies"""
    return tuple([da.isel({dim: i}).load().data for i in position])


def _interp_vertex_to_bounds(da, orientation):
    """
    Convenience function to average 4 vertex points into two bound points.
    Helpful to recreate e.g. the latitude at the `lon_bounds` points.
    """
    if orientation == "x":
        datasets = [
            da.isel(vertex=[0, 1]).mean("vertex"),
            da.isel(vertex=[3, 2]).mean("vertex"),
        ]
    elif orientation == "y":
        datasets = [
            da.isel(vertex=[0, 3]).mean("vertex"),
            da.isel(vertex=[1, 2]).mean("vertex"),
        ]

    return xr.concat(datasets, dim="bnds")


[docs]def distance_deg(lon0, lat0, lon1, lat1): """Calculate the distance in degress longitude and latitude between two points Parameters ---------- lon0 : np.array Longitude of first point lat0 : np.array Latitude of first point lon1 : np.array Longitude of second point lat1 : np.array Latitude of second point """ delta_lon = lon1 - lon0 delta_lat = lat1 - lat0 # very small differences can end up negative, so zero them out based on a simple # criterion # this should work for CMIP6 (no 1/1 deg models) but should be based on actual grid # info in the future small_crit = 1 / 10 delta_lon = np.where( abs(delta_lon) < small_crit, 0.0, delta_lon ) # , np.nan, delta_lon) delta_lat = np.where( abs(delta_lat) < small_crit, 0.0, delta_lat ) # , np.nan, delta_lat) # # some bounds are wrapped aroud the lon discontinuty. delta_lon = np.where(delta_lon < (-small_crit * 2), 360 + delta_lon, delta_lon) # delta_lon = np.where( delta_lon > (360 + small_crit * 2), -360 + delta_lon, delta_lon ) return delta_lon, delta_lat
[docs]def distance(lon0, lat0, lon1, lat1): """Calculate the distance in m between two points on a spherical globe Parameters ---------- lon0 : np.array Longitude of first point lat0 : np.array Latitude of first point lon1 : np.array Longitude of second point lat1 : np.array Latitude of second point """ Re = 6.378e6 delta_lon, delta_lat = distance_deg(lon0, lat0, lon1, lat1) dy = Re * (np.pi * delta_lat / 180) dx = Re * (np.pi * delta_lon / 180) * np.cos(np.pi * lat0 / 180) return np.sqrt(dx**2 + dy**2)
[docs]def recreate_metrics(ds, grid): """Recreate a full set of horizontal distance metrics. Calculates distances between points in lon/lat coordinates The naming of the metrics is as follows: [metric_axis]_t : metric centered at tracer point [metric_axis]_gx : metric at the cell face on the x-axis. For instance `dx_gx` is the x distance centered on the eastern cell face if the shift is `right` [metric_axis]_gy : As above but along the y-axis [metric_axis]_gxgy : The metric located at the corner point. For example `dy_dxdy` is the y distance on the south-west corner if both axes as shifted left. Parameters ---------- ds : xr.Dataset Input dataset. grid : xgcm.Grid xgcm Grid object matching `ds` Returns ------- xr.Dataset, dict Dataset with added metrics as coordinates and dictionary that can be passed to xgcm.Grid to recognize new metrics """ ds = ds.copy() # Since this puts out numpy arrays, the arrays need to be transposed correctly transpose_dims = ["y", "x"] dims = [di for di in ds.dims if di not in transpose_dims] ds = ds.transpose(*tuple(transpose_dims + dims)) # is the vel point on left or right? axis_vel_pos = { axis: list(set(grid.axes[axis].coords.keys()) - set(["center"]))[0] for axis in ["X", "Y"] } # determine the appropriate vertex position for the north/south and east/west edge, # based on the grid config if axis_vel_pos["Y"] in ["left"]: ns_vertex_idx = [0, 3] ns_bound_idx = [0] elif axis_vel_pos["Y"] in ["right"]: ns_vertex_idx = [1, 2] ns_bound_idx = [1] if axis_vel_pos["X"] in ["left"]: ew_vertex_idx = [0, 1] ew_bound_idx = [0] elif axis_vel_pos["X"] in ["right"]: ew_vertex_idx = [3, 2] ew_bound_idx = [1] # infer dx at tracer points if "lon_bounds" in ds.coords and "lat_verticies" in ds.coords: lon0, lon1 = _parse_bounds_vertex(ds["lon_bounds"]) lat0, lat1 = _parse_bounds_vertex( _interp_vertex_to_bounds(ds["lat_verticies"], "x") ) dist = distance(lon0, lat0, lon1, lat1) ds.coords["dx_t"] = xr.DataArray(dist, coords=ds.lon.coords) # infer dy at tracer points if "lat_bounds" in ds.coords and "lon_verticies" in ds.coords: lat0, lat1 = _parse_bounds_vertex(ds["lat_bounds"]) lon0, lon1 = _parse_bounds_vertex( _interp_vertex_to_bounds(ds["lon_verticies"], "y") ) dist = distance(lon0, lat0, lon1, lat1) ds.coords["dy_t"] = xr.DataArray(dist, coords=ds.lon.coords) if "lon_verticies" in ds.coords and "lat_verticies" in ds.coords: # infer dx at the north/south face lon0, lon1 = _parse_bounds_vertex( ds["lon_verticies"], dim="vertex", position=ns_vertex_idx ) lat0, lat1 = _parse_bounds_vertex( ds["lat_verticies"], dim="vertex", position=ns_vertex_idx ) dist = distance(lon0, lat0, lon1, lat1) ds.coords["dx_gy"] = xr.DataArray( dist, coords=grid.interp(ds.lon, "Y", boundary="extrapolate").coords ) # infer dy at the east/west face lon0, lon1 = _parse_bounds_vertex( ds["lon_verticies"], dim="vertex", position=ew_vertex_idx ) lat0, lat1 = _parse_bounds_vertex( ds["lat_verticies"], dim="vertex", position=ew_vertex_idx ) dist = distance(lon0, lat0, lon1, lat1) ds.coords["dy_gx"] = xr.DataArray( dist, coords=grid.interp(ds.lon, "X", boundary="extrapolate").coords ) # for the distances that dont line up with the cell boundaries we need some different logic boundary = "extend" # TODO: This should be removed once we have the default boundary merged in xgcm # infer dx at eastern/western bound from tracer points lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs( ds.lon.load(), axis_vel_pos["X"] ) lat0, lat1 = grid.axes["X"]._get_neighbor_data_pairs( ds.lat.load(), axis_vel_pos["X"] ) dx = distance(lon0, lat0, lon1, lat1) ds.coords["dx_gx"] = xr.DataArray( dx, coords=grid.interp(ds.lon, "X", boundary=boundary).coords ) # infer dy at northern bound from tracer points lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs( ds.lat.load(), axis_vel_pos["Y"], boundary=boundary ) lon0, lon1 = grid.axes["Y"]._get_neighbor_data_pairs( ds.lon.load(), axis_vel_pos["Y"], boundary=boundary ) dy = distance(lon0, lat0, lon1, lat1) ds.coords["dy_gy"] = xr.DataArray( dy, coords=grid.interp(ds.lat, "Y", boundary=boundary).coords ) # infer dx at the corner point lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs( _interp_vertex_to_bounds(ds.lon_verticies.load(), "y") .isel(bnds=ns_bound_idx) .squeeze(), axis_vel_pos["X"], ) lat0, lat1 = grid.axes["X"]._get_neighbor_data_pairs( ds.lat_bounds.isel(bnds=ns_bound_idx).squeeze().load(), axis_vel_pos["X"] ) dx = distance(lon0, lat0, lon1, lat1) ds.coords["dx_gxgy"] = xr.DataArray( dx, coords=grid.interp( grid.interp(ds.lon, "X", boundary=boundary), "Y", boundary=boundary ).coords, ) # infer dy at the corner point lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs( _interp_vertex_to_bounds(ds.lat_verticies.load(), "x") .isel(bnds=ew_bound_idx) .squeeze(), axis_vel_pos["Y"], ) lon0, lon1 = grid.axes["Y"]._get_neighbor_data_pairs( ds.lon_bounds.isel(bnds=ew_bound_idx).squeeze().load(), axis_vel_pos["Y"] ) dy = distance(lon0, lat0, lon1, lat1) ds.coords["dy_gxgy"] = xr.DataArray( dy, coords=grid.interp( grid.interp(ds.lon, "X", boundary=boundary), "Y", boundary=boundary ).coords, ) # infer dz at tracer point if "lev_bounds" in ds.coords: ds = ds.assign_coords( dz_t=("lev", ds["lev_bounds"].diff("bnds").squeeze(drop=True).data) ) metrics_dict = { "X": [co for co in ["dx_t", "dx_gy", "dx_gx"] if co in ds.coords], "Y": [co for co in ["dy_t", "dy_gy", "dy_gx"] if co in ds.coords], "Z": [co for co in ["dz_t"] if co in ds.coords], } # # only put out axes that have entries metrics_dict = {k: v for k, v in metrics_dict.items() if len(v) > 0} return ds, metrics_dict
[docs]def detect_shift(ds_base, ds, axis): """Detects the shift of `ds` relative to `ds` on logical grid axes, using lon and lat positions. Parameters ---------- ds_base : xr.Dataset Reference ('base') dataset to compare to. Assumed that this is located at the 'center' coordinate. ds : xr.Dataset Comparison dataset. The resulting shift will be computed as this dataset relative to `ds_base` axis : str xgcm logical axis on which to detect the shift Returns ------- str Shift string output, in xgcm conventions. """ ds_base = ds_base.copy() ds = ds.copy() axis = axis.lower() axis_coords = {"x": "lon", "y": "lat"} # check the shift only for one point, somewhat in the center to avoid the # distorted polar regions check_point = {"x": len(ds_base.x) // 2, "y": len(ds_base.y) // 2} check_point_diff = {k: [v, v + 1] for k, v in check_point.items()} shift = ( ds.isel(**check_point)[axis_coords[axis]].load().data - ds_base.isel(**check_point)[axis_coords[axis]].load().data ) diff = ds[axis].isel({axis: check_point_diff[axis]}).diff(axis).data.tolist()[0] threshold = 0.1 # the fraction of full cell distance, that a point has to be shifted in order to # be recognized. # This avoids detection of shifts for very small differences that sometimes happen # if the coordinates were written e.g. by different modulel of a model axis_shift = "center" if shift > (diff * threshold): axis_shift = "right" elif shift < -(diff * threshold): axis_shift = "left" return axis_shift
[docs]def create_full_grid(base_ds, grid_dict=None): """Generate a full xgcm-compatible dataset from a reference datasets `base_ds`. This dataset should be representing a tracer fields, e.g. the cell center. Parameters ---------- base_ds : xr.Dataset The reference ('base') datasets, assumed to be at the tracer position/cell center grid_dict : dict, optional Dictionary with info about the grid staggering. Must be encoded using the base_ds attrs (e.g. {'model_name':{'axis_shift':{'X':'left',...}}}). If deactivated (default), will load from the internal database for CMIP6 models, by default None Returns ------- xr.Dataset xgcm compatible dataset """ # load dict with grid shift info for each axis if grid_dict is None: ff = open(grid_spec, "r") grid_dict = yaml.safe_load(ff) ff.close() source_id = base_ds.attrs["source_id"] grid_label = base_ds.attrs["grid_label"] # if source_id not in dict, and grid label is gn, warn and ask to submit an issue try: axis_shift = grid_dict[source_id][grid_label]["axis_shift"] except KeyError: warnings.warn( f"Could not find the source_id/grid_label ({source_id}/{grid_label}) combo in `grid_dict`, returning `None`. Please submit an issue to github: https://github.com/jbusecke/cmip6_preprocessing/issues" ) return None position = {k: ("center", axis_shift[k]) for k in axis_shift.keys()} axis_dict = {"X": "x", "Y": "y"} ds_grid = generate_grid_ds( base_ds, axis_dict, position=position, boundary_discontinuity={"X": 360} ) # TODO: man parse lev and lev_bounds as center and outer dims. # I should also be able to do this with `generate_grid_ds`, but here we # have the `lev_bounds` with most models, so that is probably more reliable. # cheapest solution right now if "lev" in ds_grid.dims: ds_grid["lev"].attrs["axis"] = "Z" return ds_grid
[docs]def combine_staggered_grid( ds_base, other_ds=None, recalculate_metrics=False, grid_dict=None, **kwargs ): """Combine a reference datasets with a list of other datasets to a full xgcm-compatible staggered grid datasets. Parameters ---------- ds_base : xr.Dataset The reference ('base') datasets, assumed to be at the tracer position/cell center other_ds : list,xr.Dataset, optional List of datasets representing different variables. Their grid position will be automatically detected relative to `ds_base`. Coordinates and attrs of these added datasets will be lost , by default None recalculate_metrics : bool, optional nables the reconstruction of grid metrics usign simple spherical geometry, by default False !!! Check your results carefully when using reconstructed values, these might differe substantially if the grid geometry is complicated. grid_dict : dict, optional Dictionary for staggered grid setup. See `create_full_grid` for detauls If None (default), will load staggered grid info from internal database, by default None Returns ------- xr.Dataset Single xgcm-compatible dataset, containing all variables on their respective staggered grid position. """ ds_base = ds_base.copy() if isinstance(other_ds, xr.Dataset): other_ds = [other_ds] ds_g = create_full_grid(ds_base, grid_dict=grid_dict) if ds_g is None: warnings.warn("Staggered Grid creation failed. Returning `None`") return None, None # save attrs out for later (something during alignment destroys them) dim_attrs_dict = {} for di in ds_g.dims: dim_attrs_dict[di] = ds_g[di].attrs # TODO: metrics and interpolation of metrics if they are parsed # parse other variables if other_ds is not None: for ds_new in other_ds: ds_new = ds_new.copy() # strip everything but the variable_id (perhaps I would want to # loosen this in the future) ds_new = ds_new[ds_new.attrs["variable_id"]] if not all( [ len(ds_new[di]) == len(ds_g[di]) for di in ds_new.dims if di not in ["member_id", "time"] ] ): warnings.warn( f"Could not parse `{ds_new.name}`, due to a size mismatch. If this is the MRI model, the grid convention is currently not supported." ) else: # detect shift and rename accordingly rename_dict = {} for axis in ["X", "Y"]: shift = detect_shift(ds_base, ds_new, axis) if shift != "center": rename_dict[axis.lower()] = axis.lower() + "_" + shift ds_new = ds_new.rename(rename_dict) ds_new = ds_new.reset_coords(drop=True) # TODO: This needs to be coded more generally, for now hardcode x and y force_align_dims = [di for di in ds_new.dims if "x" in di or "y" in di] _, ds_new = xr.align( ds_g.copy(), ds_new, join="override", exclude=[di for di in ds_new.dims if di not in force_align_dims], ) additional_dims = [di for di in ds_new.dims if di not in ds_g.dims] if len(additional_dims) > 0: raise RuntimeError( f"While trying to parse `{ds_new.name}`, detected dims that are not in the base dataset:[{additional_dims}]" ) ds_g[ds_new.name] = ds_new # Restore dims attrs from the beginning for di in ds_g.dims: ds_g.coords[di].attrs.update(dim_attrs_dict[di]) grid_kwargs = {"periodic": ["X"]} grid_kwargs.update(kwargs) grid = Grid(ds_g, grid_kwargs) # if activated calculate metrics if recalculate_metrics: grid_kwargs.pop( "metrics", None ) # remove any passed metrics when recalculating them # I might be able to refine this more to e.g. allow axes that are not recreated. ds_g, metrics_dict = recreate_metrics(ds_g, grid) # this might fail in circumstances, where the grid_kwargs["metrics"] = metrics_dict grid = Grid(ds_g, **grid_kwargs) return grid, ds_g