Source code for ocean_model_skill_assessor.main

"""
Main run functions.
"""

import logging
import mimetypes
import pathlib
import warnings

from collections.abc import Sequence
from glob import glob
from pathlib import Path, PurePath
from typing import Any, Dict, List, Optional, Tuple, Union

import cf_xarray
import extract_model as em
import extract_model.accessor
import intake
import intake_xarray
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import shapely.wkt
import xarray as xr
import yaml

from cf_pandas import Vocab, astype
from cf_pandas import set_options as cfp_set_options
from cf_xarray import set_options as cfx_set_options
from datetimerange import DateTimeRange
from intake.catalog import Catalog
from intake.catalog.local import LocalCatalogEntry
from pandas import DataFrame, to_datetime
from shapely.geometry import Point
from xgcm import Grid

# from ocean_model_skill_assessor.plot import map
import ocean_model_skill_assessor.plot as plot

from .featuretype import ftconfig
from .paths import Paths
from .stats import compute_stats, save_stats
from .utils import (
    check_catalog,
    check_dataframe,
    check_dataset,
    coords1Dto2D,
    find_bbox,
    fix_dataset,
    get_mask,
    kwargs_search_from_model,
    open_catalogs,
    open_vocab_labels,
    open_vocabs,
    read_model_file,
    read_processed_data_file,
    save_processed_files,
    set_up_logging,
    shift_longitudes,
)


# turn off annoying warning in cf-xarray
cfx_set_options(warn_on_missing_variables=False)


[docs] def make_local_catalog( filenames: List[str], filetype: Optional[str] = None, name: str = "local_catalog", description: str = "Catalog of user files.", metadata: dict = None, metadata_catalog: dict = None, skip_entry_metadata: bool = False, skip_strings: Optional[list] = None, kwargs_open: Optional[Dict] = None, logger=None, ) -> Catalog: """Make an intake catalog from specified data files, including model output locations. Pass keywords for xarray for model output into the catalog through ``kwargs_xarray``. ``kwargs_open`` and ``metadata`` must be the same for all filenames. If it is not, make multiple catalogs and you can input them individually into the run command. Parameters ---------- filenames : list of paths Where to find dataset(s) from which to make local catalog. filetype : str, optional Type of the input filenames, if you don't want the function to try to guess. Must be in the form that can go into intake as f"open_{filetype}". name : str, optional Name for catalog. description : str, optional Description for catalog. metadata : dict, optional Metadata for individual source. If input dataset does not include the longitude and latitude position(s), you will need to include it in the metadata as keys `minLongitude`, `minLatitude`, `maxLongitude`, `maxLatitude`. metadata_catalog : dict, optional Metadata for catalog. skip_entry_metadata : bool, optional This is useful for testing in which case we don't want to actually read the file. If you are making a catalog file for a model, you may want to set this to `True` to avoid reading it all in for metadata. skip_strings : list of strings, optional If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. kwargs_open : dict, optional Keyword arguments to pass on to the appropriate ``intake`` ``open_*`` call for model or dataset. Returns ------- Catalog Intake catalog with an entry for each dataset represented by a filename. Examples -------- Make catalog to represent local or remote files with specific locations: >>> make_local_catalog([filename1, filename2]) Make catalog to represent model output: >>> make_local_catalog([model output location], skip_entry_metadata=True, kwargs_open={"drop_variables": "tau"}) """ metadata = metadata or {} metadata_catalog = metadata_catalog or {} kwargs_open = kwargs_open or {} # if any of kwargs_open came in with "None" instead of None because of CLI, change back to None kwargs_open.update({key: None for key, val in kwargs_open.items() if val == "None"}) sources = [] for filename in filenames: mtype = mimetypes.guess_type(filename)[0] if filetype is not None: source = getattr(intake, f"open_{filetype}")( filename, **kwargs_open ) # , csv_kwargs=kwargs_open) elif ( (mtype is not None and ("csv" in mtype or "text" in mtype)) or "csv" in filename or "text" in filename ): source = getattr(intake, "open_csv")(filename, csv_kwargs=kwargs_open) elif ("thredds" in filename and "dodsC" in filename) or "dods" in filename: # use netcdf4 engine if not input in kwargs_xarray kwargs_open.setdefault("engine", "netcdf4") source = getattr(intake, "open_opendap")(filename, **kwargs_open) elif ( (mtype is not None and "netcdf" in mtype) or "netcdf" in filename or ".nc" in filename ): source = getattr(intake, "open_netcdf")(filename, **kwargs_open) # combine input metadata with source metadata source.metadata.update(metadata) sources.append(source) # create dictionary of catalog entries entries = { PurePath(source.urlpath).stem: LocalCatalogEntry( name=PurePath(source.urlpath).stem, description=source.description if source.description is not None else "", driver=source._yaml()["sources"][source.name]["driver"], args=source._yaml()["sources"][source.name]["args"], metadata=source.metadata, direct_access="allow", ) for i, source in enumerate(sources) } # create catalog cat = Catalog.from_dict( entries, name=name, description=description, metadata=metadata_catalog, ) # now that catalog is made, go through sources and add metadata for source in list(cat): if not skip_entry_metadata: dd = cat[source].read() # only read lon/lat from file if didn't input lon/lat info if cat[source].metadata.keys() >= { "maxLongitude", "minLongitude", "minLatitude", "maxLatitude", }: dd["longitude"] = cat[source].metadata["minLongitude"] dd["latitude"] = cat[source].metadata["minLatitude"] cat[source].metadata = { "minLongitude": cat[source].metadata["minLongitude"], "minLatitude": cat[source].metadata["minLatitude"], "maxLongitude": cat[source].metadata["minLongitude"], "maxLatitude": cat[source].metadata["minLatitude"], } else: metadata = { "minLongitude": float(dd.cf["longitude"].min()), "minLatitude": float(dd.cf["latitude"].min()), "maxLongitude": float(dd.cf["longitude"].max()), "maxLatitude": float(dd.cf["latitude"].max()), } # set up some basic metadata for each source if isinstance(dd, pd.DataFrame): dd[dd.cf["T"].name] = to_datetime(dd.cf["T"]) dd.set_index(dd.cf["T"], inplace=True) if dd.index.tz is not None: # logger is already defined in other function if logger is not None: logger.warning( # type: ignore "Dataset %s had a timezone %s which is being removed. Make sure the timezone matches the model output.", source, str(dd.index.tz), ) dd.index = dd.index.tz_convert(None) dd.cf["T"] = dd.index metadata.update( { "minTime": str(dd.cf["T"].values.min()), # works for df and ds! "maxTime": str(dd.cf["T"].values.max()), # works for df and ds! } ) cat[source].metadata.update(metadata) cat[source]._entry._metadata.update(metadata) # create dictionary of catalog entries sources = [cat[source] for source in list(cat)] entries = { PurePath(source.urlpath).stem: LocalCatalogEntry( name=PurePath(source.urlpath).stem, description=source.description if source.description is not None else "", driver=source._yaml()["sources"][source.name]["driver"], args=source._yaml()["sources"][source.name]["args"], metadata=source.metadata, direct_access="allow", ) for i, source in enumerate(sources) } # create catalog cat = Catalog.from_dict( entries, name=name, description=description, metadata=metadata_catalog, ) # this allows for not checking a model catalog if not skip_entry_metadata: check_catalog(cat, skip_strings=skip_strings) return cat
[docs] def make_catalog( catalog_type: str, project_name: str, catalog_name: Optional[str] = None, description: Optional[str] = None, metadata: Optional[dict] = None, kwargs: Optional[Dict[str, Any]] = None, kwargs_search: Optional[Dict[str, Union[str, int, float]]] = None, kwargs_open: Optional[Dict] = None, skip_strings: Optional[list] = None, vocab: Optional[Union[Vocab, str, PurePath]] = None, return_cat: bool = True, save_cat: bool = False, verbose: bool = True, mode: str = "w", testing: bool = False, cache_dir: Optional[Union[str, PurePath]] = None, ): """Make a catalog given input selections. Parameters ---------- catalog_type : str Which type of catalog to make? Options are "erddap", "axds", or "local". project_name : str Subdirectory in cache dir to store files associated together. catalog_name : str, optional Catalog name, with or without suffix of yaml. Otherwise a default name based on the catalog type will be used. description : str, optional Description for catalog. metadata : dict, optional Catalog metadata. kwargs : dict, optional Available keyword arguments for catalog types. Find more information about options in the original docs for each type. Some inputs might be required, depending on the catalog type. kwargs_search : dict, optional Keyword arguments to input to search on the server before making the catalog. These are not used with ``make_local_catalog()``; only for catalog types "erddap" and "axds". Options are: * to search by bounding box: include all of min_lon, max_lon, min_lat, max_lat: (int, float). Longitudes must be between -180 to +180. * to search within a datetime range: include both of min_time, max_time: interpretable datetime string, e.g., "2021-1-1" * to search using a textual keyword: include `search_for` as a string. * model_name can be input in place of either the spatial box or the time range or both in which case those values will be found from the model output. model_name should match a catalog file in the directory described by project_name. kwargs_open : dict, optional Keyword arguments to save into local catalog for model to pass on to ``xr.open_mfdataset`` call or ``pandas`` ``open_csv``. Only for use with ``catalog_type=local``. skip_strings : list of strings, optional If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. vocab : str, Vocab, Path, optional Way to find the criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. return_cat : bool, optional Return catalog. For when using as a Python package instead of with command line. save_cat: bool, optional Save catalog to disk into project directory under `catalog_name`. verbose : bool, optional Print useful runtime commands to stdout if True as well as save in log, otherwise silently save in log. mode : str, optional mode for logging file. Default is to overwrite an existing logfile, but can be changed to other modes, e.g. "a" to instead append to an existing log file. testing : boolean, optional Set to True if testing so warnings come through instead of being logged. cache_dir: str, Path Pass on to omsa.paths to set cache directory location if you don't want to use the default. Good for testing. """ paths = Paths(project_name, cache_dir=cache_dir) logger = set_up_logging(verbose, paths=paths, mode=mode, testing=testing) if kwargs_search is not None and catalog_type == "local": warnings.warn( "`kwargs_search` were input but will not be used since `catalog_type=='local'`.", UserWarning, ) if kwargs_open is not None and catalog_type != "local": warnings.warn( f"`kwargs_open` were input but will not be used since `catalog_type=={catalog_type}`.", UserWarning, ) kwargs = kwargs or {} kwargs_search = kwargs_search or {} # get spatial and/or temporal search terms from model if desired kwargs_search.update({"project_name": project_name}) if catalog_type != "local": kwargs_search = kwargs_search_from_model(kwargs_search, paths) if vocab is not None: if isinstance(vocab, str): vocab = Vocab(paths.VOCAB_PATH(vocab)) elif isinstance(vocab, PurePath): vocab = Vocab(vocab) elif isinstance(vocab, Vocab): pass else: raise ValueError("Vocab should be input as string, Path, or Vocab object.") if description is None: description = f"Catalog of type {catalog_type}." if catalog_type == "local": catalog_name = "local_cat" if catalog_name is None else catalog_name if "filenames" not in kwargs: raise ValueError("For `catalog_type=='local'`, must input `filenames`.") filenames = kwargs["filenames"] kwargs.pop("filenames") cat = make_local_catalog( astype(filenames, list), name=catalog_name, description=description, metadata=metadata, kwargs_open=kwargs_open, skip_strings=skip_strings, logger=logger, **kwargs, ) elif catalog_type == "erddap": catalog_name = "erddap_cat" if catalog_name is None else catalog_name if "server" not in kwargs: raise ValueError("For `catalog_type=='erddap'`, must input `server`.") if vocab is not None: with cfp_set_options(custom_criteria=vocab.vocab): cat = intake.open_erddap_cat( kwargs_search=kwargs_search, name=catalog_name, description=description, metadata=metadata, **kwargs, ) else: cat = intake.open_erddap_cat( kwargs_search=kwargs_search, name=catalog_name, description=description, metadata=metadata, **kwargs, ) elif catalog_type == "axds": catalog_name = "axds_cat" if catalog_name is None else catalog_name if vocab is not None: with cfp_set_options(custom_criteria=vocab.vocab): cat = intake.open_axds_cat( kwargs_search=kwargs_search, name=catalog_name, description=description, metadata=metadata, **kwargs, ) else: cat = intake.open_axds_cat( kwargs_search=kwargs_search, name=catalog_name, description=description, metadata=metadata, **kwargs, ) # this allows for not checking a model catalog if "skip_entry_metadata" in kwargs and not kwargs["skip_entry_metadata"]: check_catalog(cat, skip_strings=skip_strings) if save_cat: # save cat to file cat.save(paths.CAT_PATH(catalog_name)) logger.info( f"Catalog saved to {paths.CAT_PATH(catalog_name)} with {len(list(cat))} entries." ) if return_cat: return cat
[docs] def _initial_model_handling( model_name: Union[str, Catalog], paths: Paths, override_chunks: dict, model_source_name: Optional[str] = None, ) -> xr.Dataset: """Initial model handling. cf-xarray needs to be able to identify Z, T, longitude, latitude coming out of here. Parameters ---------- model_name : str, Catalog Name of catalog for model output, created with ``make_catalog`` call, or Catalog instance. paths : Paths Paths object for finding paths to use. override_chunks : dict Override chunks for model output. Might be empty dict. model_source_name : str, optional Use this to access a specific source in the input model_catalog instead of otherwise just using the first source in the catalog. Returns ------- Dataset Dataset pointing to model output. """ # read in model output model_cat = open_catalogs(model_name, paths, skip_check=True)[0] model_source_name = model_source_name or list(model_cat)[0] if isinstance(model_cat[model_source_name], intake_xarray.netcdf.NetCDFSource): dsm = model_cat[model_source_name].read() else: dsm = model_cat[model_source_name].read(chunks=override_chunks) # the main preprocessing happens later, but do a minimal job here # so that cf-xarray can be used hopefully dsm = em.preprocess(dsm, kwargs=dict(find_depth_coords=False)) check_dataset(dsm) return dsm, model_source_name
[docs] def _narrow_model_time_range( dsm: xr.Dataset, user_min_time: pd.Timestamp, user_max_time: pd.Timestamp, model_min_time: pd.Timestamp, model_max_time: pd.Timestamp, data_min_time: pd.Timestamp, data_max_time: pd.Timestamp, ) -> xr.Dataset: """Narrow the model time range to approximately what is needed, to save memory. If user_min_time and user_max_time were input and are not null values and are narrower than the model time range, use those to control time range. Otherwise use data_min_time and data_max_time to narrow the time range, but add 1 model timestep on either end to make sure to have extra model output if need to interpolate in that range. Do not deal with time in detail here since that will happen when the model and data are "aligned" a little later. For now, just return a slice of model times, outside of the extract_model code since not interpolating yet. not dealing with case that data is available before or after model but overlapping rename dsm since it has fewer times now and might need them for the other datasets Parameters ---------- dsm: xr.Dataset model dataset user_min_time : pd.Timestamp If this is input, it will be used as the min time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). user_max_time : pd.Timestamp If this is input, it will be used as the max time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). model_min_time : pd.Timestamp Min model time step model_max_time : pd.Timestamp Max model time step data_min_time : pd.Timestamp The min time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_min_time, then the constraint time. data_max_time : pd.Timestamp The max time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_max_time, then the constraint time. Returns ------- xr.Dataset Model dataset, but narrowed in time. """ # calculate delta time for model dt = pd.Timestamp(dsm.cf["T"][1].values) - pd.Timestamp(dsm.cf["T"][0].values) if ( pd.notnull(user_min_time) and pd.notnull(user_max_time) and (model_min_time.date() <= user_min_time.date()) and (model_max_time.date() >= user_max_time.date()) ): dsm2 = dsm.cf.sel(T=slice(user_min_time, user_max_time)) if dsm2.cf["T"].size == 0: raise ValueError( "accidentally narrowed time range too far; no more times left!" ) # always take an extra timestep just in case else: dsm2 = dsm.cf.sel( T=slice( data_min_time - dt, data_max_time + dt, ) ) if dsm2.cf["T"].size == 0: # this means the data is in a hole in the model output dsm2 = None return dsm2
[docs] def _find_data_time_range(cat: Catalog, source_name: str) -> tuple: """Determine min and max data times. Parameters ---------- cat : Catalog Catalog that contains dataset source_name from which to find data time range. source_name : str Name of dataset within cat to examine. Returns ------- data_min_time : pd.Timestamp The min time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_min_time, then the constraint time. If "Z" is present to indicate UTC timezone, it is removed. data_max_time : pd.Timestamp The max time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_max_time, then the constraint time. If "Z" is present to indicate UTC timezone, it is removed. """ # Do min and max separately. if "minTime" in cat[source_name].metadata: data_min_time = cat[source_name].metadata["minTime"] # use kwargs_search min/max times if available elif ( "kwargs_search" in cat.metadata and "min_time" in cat.metadata["kwargs_search"] ): data_min_time = cat.metadata["kwargs_search"]["min_time"] else: raise KeyError("Need a way to input min time desired.") # max if "maxTime" in cat[source_name].metadata: data_max_time = cat[source_name].metadata["maxTime"] # use kwargs_search min/max times if available elif ( "kwargs_search" in cat.metadata and "max_time" in cat.metadata["kwargs_search"] ): data_max_time = cat.metadata["kwargs_search"]["max_time"] else: raise KeyError("Need a way to input max time desired.") # remove "Z" from min_time, max_time if present since assuming all in UTC data_min_time = pd.Timestamp(data_min_time.replace("Z", "")) data_max_time = pd.Timestamp(data_max_time.replace("Z", "")) # take time constraints as min/max if available and more constricting # constraints could be in the metadata or in the kwargs_search # depending on intake v1 or v2 if cat.version == 2: if ( "constraints" in cat[source_name].kwargs and "time>=" in cat[source_name].kwargs["constraints"] ): constrained_min_time = pd.Timestamp( cat[source_name].kwargs["constraints"]["time>="].replace("Z", "") ) if constrained_min_time > data_min_time: data_min_time = constrained_min_time if ( "constraints" in cat[source_name].kwargs and "time<=" in cat[source_name].kwargs["constraints"] ): constrained_max_time = pd.Timestamp( cat[source_name].kwargs["constraints"]["time<="].replace("Z", "") ) if constrained_max_time < data_max_time: data_max_time = constrained_max_time else: # intake v1 if ( "constraints" in cat[source_name].describe()["args"] and "time>=" in cat[source_name].describe()["args"]["constraints"] ): constrained_min_time = pd.Timestamp( cat[source_name] .describe()["args"]["constraints"]["time>="] .replace("Z", "") ) if constrained_min_time > data_min_time: data_min_time = constrained_min_time if ( "constraints" in cat[source_name].describe()["args"] and "time<=" in cat[source_name].describe()["args"]["constraints"] ): constrained_max_time = pd.Timestamp( cat[source_name] .describe()["args"]["constraints"]["time<="] .replace("Z", "") ) if constrained_max_time < data_max_time: data_max_time = constrained_max_time return data_min_time, data_max_time
[docs] def _match_depth_sign(dd, model_depth_attr_positive): """Match depth sign to model depth direction attribute""" if isinstance(dd, (xr.DataArray, xr.Dataset)): attrs = dd[dd.cf["Z"].name].attrs if hasattr(dd[dd.cf["Z"].name], "encoding"): encoding = dd[dd.cf["Z"].name].encoding if model_depth_attr_positive == "up": dd[dd.cf["Z"].name] = np.negative(np.absolute(dd.cf["Z"])) else: dd[dd.cf["Z"].name] = np.positive(dd.cf["Z"]) dd.cf["Z"].attrs = attrs if hasattr(dd[dd.cf["Z"].name], "encoding"): dd.cf["Z"].encoding = encoding elif isinstance(dd, (pd.DataFrame, pd.Series)): if model_depth_attr_positive == "up": dd.cf["Z"] = np.negative(abs(dd.cf["Z"])) else: dd.cf["Z"] = np.positive(abs(dd.cf["Z"])) return dd
[docs] def _choose_depths( dd: Union[pd.DataFrame, xr.Dataset], model_depth_attr_positive: str, no_Z: bool, want_vertical_interp: bool, key_variables: list, dam: xr.DataArray, mask: xr.DataArray, lon: float, lat: float, logger=None, ) -> tuple: """Determine depths to interpolate to, if any. This assumes the data container does not have indices, or at least no depth indices. Parameters ---------- dd: DataFrame or Dataset Data container model_depth_attr_positive: str result of model.cf["Z"].attrs["positive"]: "up" or "down", from model no_Z : bool If True, set Z=None so no vertical interpolation or selection occurs. Do this if your variable has no concept of depth, like the sea surface height. want_vertical_interp: bool This is False unless the user wants to specify that vertical interpolation should happen. This is used in only certain cases but in those cases it is important so that it is known to interpolate instead of try to figure out a vertical level index (which is not possible currently). key_variables: list List of key variables in the dataset. This is used to determine if ssh is available. dam: DataArray Model array variable. This is used if we need to find the depth index, for example for a mooring that is below the surface. mask: DataArray Mask for the model output. This is used to find the depth index, for example for a mooring that is below the surface. lon: float Longitude of the data point. This is used to find the depth index, for example for a mooring that is below the surface. lat: float Latitude of the data point. This is used to find the depth index, for example for a mooring that is below the surface. logger : logger, optional Logger for messages. Returns ------- dd Possibly modified Dataset with sign of depths to match model Z Depths to interpolate to with sign that matches the model depths. vertical_interp Flag, True if we should interpolate vertically, False if not. iZ Depth index to use for model extraction """ Z = None iZ = None # Catch case in which z column represent fixed depth and SSH is water column above depth # such that SSH gives information to z dd = _change_z_by_ssh(dd, model_depth_attr_positive, key_variables, logger) # sort out depths between model and data # 1 location: interpolate or nearest neighbor horizontally # have it figure out depth if ("Z" not in dd.cf.axes) or no_Z: Z = None vertical_interp = False if logger is not None: logger.info( f"Will not perform vertical interpolation and there is no concept of depth for this variable." ) # if depth varies in time and will interpolate to match depths or elif ( (dd.cf["Z"].size > 1) and (dd.cf["Z"] != dd.cf["Z"][0]).any() and want_vertical_interp ): # elif (dd.cf["Z"] != dd.cf["Z"][0]).any() or ((dd.cf["Z"] == 0).all() and ("ssh" in key_variables) and (dd.cf["ssh"].name == 'sea_surface_height_above_sea_level_geoid_local_station_datum')) and want_vertical_interp: # elif (dd.cf["Z"] == 0).all() and ("ssh" in key_variables) and (dd.cf["ssh"].name == 'sea_surface_height_above_sea_level_geoid_local_station_datum') and want_vertical_interp: # if the model depths are positive up/negative down, make sure the data match dd = _match_depth_sign(dd, model_depth_attr_positive) Z = dd.cf["Z"].values vertical_interp = True if logger is not None: logger.info(f"Will perform vertical interpolation, to depths {Z}.") # if depth is constant in time and will not interpolate # this will work if xgcm grid is found but might take forever, commenting out for now elif (dd.cf["Z"].size == 1) or (dd.cf["Z"] == dd.cf["Z"][0]).all(): if dd.cf["Z"].size == 1: Z = float(dd.cf["Z"]) else: Z = float( dd.cf["Z"][0] ) # do nearest depth to the one depth represented in dataset vertical_interp = False if logger is not None: logger.info( f"Will not perform vertical interpolation and will find nearest depth to {Z}." ) # if depth varies in time and need to determine depth index else: if logger is not None: logger.info("Need to find depth index") # if slug == "moorings_kbnerr_historical": # not erddap # dd = cat[source_name].read() # else: # dd = cat[source_name](cache_kwargs={}).read() dd = _match_depth_sign(dd, model_depth_attr_positive) mean_depth = dd.cf["Z"].mean() if "ssh" in key_variables: mean_ssh = dd.cf["ssh"].mean() else: mean_ssh = 0 if (mean_depth == 0) and (mean_ssh > 3): mean = mean_ssh else: mean = mean_depth # find ix, iy for location var_out, kwargs = em.sel2dcf( dam, longitude=lon, latitude=lat, mask=mask, return_info=True, use_xoak=False, ) # distances, ie, ix = em.sel2d(dam, mask, lon, lat) # distances, ie, ix = utils.sel2d(mask, lon, lat) # need xgcm grid to have vertical coords for this inds = ..., kwargs[dam.cf["Y"].name], kwargs[dam.cf["X"].name] # depths = dam.cf["vertical"][inds].squeeze().load() mid = int(np.floor(dam.cf["T"].size / 2)) izs = [] for i in [0, mid, -1]: depths = dam.cf["vertical"][inds].squeeze().isel(ocean_time=i).load() # find the index of the mean data depth iz = depths.values.tolist().index(depths.cf.sel(Z=mean, method="nearest")) # iz = int(np.absolute(np.absolute(depths) - np.absolute(mean)).argmin().values) # import pdb; pdb.set_trace() izs.append(iz) # also haven't tested it but this camef rom the model caching script which worked. from scipy.stats import mode iZ = mode(np.array(izs))[0] vertical_interp = False if logger is not None: logger.info( f"Will not perform vertical interpolation and will use depth index {iZ}." ) # raise NotImplementedError( # "Method to find index for depth not at surface not available yet." # ) return dd, Z, vertical_interp, iZ
[docs] def _dam_from_dsm( dsm2: xr.Dataset, key_variable: Union[str, dict], key_variable_data: str, source_metadata: dict, no_Z: bool, logger=None, ) -> xr.DataArray: """Select or calculate variable from Dataset. cf-xarray needs to work for Z, T, longitude, latitude after this Parameters ---------- dsm2 : Dataset Dataset containing model output. If this is being run from `main`, the model output has already been narrowed to the relevant time range. key_variable : str, dict Information to select variable from Dataset. Will be a dict if something needs to be calculated or accessed. In the more simple case will be a string containing the key variable name that can be interpreted with cf-xarray to access the variable of interest from the Dataset. key_variable_data : str A string containing the key variable name that can be interpreted with cf-xarray to access the variable of interest from the Dataset. source_metadata : dict Metadata for dataset source. Accessed by `cat[source_name].metadata`. no_Z : bool If True, set Z=None so no vertical interpolation or selection occurs. Do this if your variable has no concept of depth, like the sea surface height. logger : logger, optional Logger for messages. Returns ------- DataArray: Single variable DataArray from Dataset. """ if isinstance(key_variable, dict): if "add_to_inputs" in key_variable: new_input_key = key_variable["add_to_inputs"] # new_input_key = list(key_variable["add_to_inputs"].keys())[0] new_input_val = source_metadata[new_input_key] # new_input_val = source_metadata[ # list(key_variable["add_to_inputs"].values())[0] # ] key_variable["inputs"].update({new_input_key: new_input_val}) # e.g. ds.xroms.east_rotated(angle=-90, reference="compass", isradians=False, name="along_channel") function_or_property = getattr( getattr(dsm2, key_variable["accessor"]), key_variable["function"], ) # if it is a property can't call it like a function # if isinstance(getattr(type(dsm2.xroms), "east"), property): if isinstance(function_or_property, (property, xr.DataArray)): dam = function_or_property elif "inputs" in key_variable: dam = function_or_property(**key_variable["inputs"]) else: dam = function_or_property() else: dam = dsm2.cf[key_variable_data] # # this is the case in which need to find the depth index # # swap z_rho and z_rho0 in order to do this # # doing this here since now we know the variable and have a DataArray # if Z is not None and Z != 0 and not vertical_interp: # zkey = dam.cf["vertical"].name # zkey0 = f"{zkey}0" # if zkey0 not in dsm2.coords: # raise KeyError("missing time-invariant version of z coordinates.") # if zkey0 not in dam.coords: # dam[zkey0] = dsm[zkey0] # dam[zkey0].attrs = dam[zkey].attrs # dam = dam.drop(zkey) # if hasattr(dam, "encoding") and "coordinates" in dam.encoding: # dam.encoding["coordinates"] = dam.encoding["coordinates"].replace(zkey,zkey0) check_dataset(dam, no_Z=no_Z) # if dask-backed, read into memory if dam.cf["longitude"].chunks is not None: dam[dam.cf["longitude"].name] = dam.cf["longitude"].load() if dam.cf["latitude"].chunks is not None: dam[dam.cf["latitude"].name] = dam.cf["latitude"].load() # if vertical isn't present either the variable doesn't have the concept, like ssh, or it is missing if "vertical" not in dam.cf.coordinates: if logger is not None: logger.warning( "the 'vertical' key cannot be identified in dam by cf-xarray. Maybe you need to include the xgcm grid and vertical metrics for xgcm grid, but maybe your variable does not have a vertical axis." ) # raise KeyError("the 'vertical' key cannot be identified in dam by cf-xarray. Maybe you need to include the xgcm grid and vertical metrics for xgcm grid.") return dam
[docs] def _processed_file_names( fname_processed_orig: Union[str, pathlib.Path], dfd_type: type, user_min_time: pd.Timestamp, user_max_time: pd.Timestamp, paths: Paths, ts_mods: list, logger=None, ) -> Tuple[pathlib.Path, pathlib.Path, pathlib.Path, pathlib.Path]: """Determine file names for base of stats and figure names and processed data and model names fname_processed_orig: no info about time modifications fname_processed: fully specific name fname_processed_data: processed data file fname_processed_model: processed model file Parameters ---------- fname_processed_orig : str Filename based but without modification if user_min_time and user_max_time were input. Does include info about ts_mods if present. dfd_type : type pd.DataFrame or xr.Dataset depending on the data container type. user_min_time : pd.Timestamp If this is input, it will be used as the min time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). user_max_time : pd.Timestamp If this is input, it will be used as the max time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). paths : Paths Paths object for finding paths to use. ts_mods : list list of time series modifications to apply to data and model. Can be an empty list if no modifications to apply. logger : logger, optional Logger for messages. Returns ------- tuple of Paths * fname_processed: base to be used for stats and figure * fname_processed_data: file name for processed data * fname_processed_model: file name for processed model * model_file_name: (unprocessed) model output """ if pd.notnull(user_min_time) and pd.notnull(user_max_time): fname_processed_orig = f"{fname_processed_orig}_{str(user_min_time.date())}_{str(user_max_time.date())}" fname_processed_orig = paths.PROCESSED_CACHE_DIR / fname_processed_orig assert isinstance(fname_processed_orig, pathlib.Path) # also for ts_mods fnamemods = "" for mod in ts_mods: fnamemods += f"_{mod['name_mod']}" fname_processed = fname_processed_orig.with_name( fname_processed_orig.stem + fnamemods ).with_suffix(fname_processed_orig.suffix) if dfd_type == pd.DataFrame: fname_processed_data = ( fname_processed.parent / (fname_processed.stem + "_data") ).with_suffix(".csv") elif dfd_type == xr.Dataset: fname_processed_data = ( fname_processed.parent / (fname_processed.stem + "_data") ).with_suffix(".nc") else: raise TypeError("object is neither DataFrame nor Dataset.") fname_processed_model = ( fname_processed.parent / (fname_processed.stem + "_model") ).with_suffix(".nc") # use same file name as for processed but with different path base and # make sure .nc model_file_name: pathlib.Path = ( paths.MODEL_CACHE_DIR / fname_processed_orig.stem ).with_suffix(".nc") if logger is not None: logger.info(f"Processed data file name is {fname_processed_data}.") logger.info(f"Processed model file name is {fname_processed_model}.") logger.info(f"model file name is {model_file_name}.") return fname_processed, fname_processed_data, fname_processed_model, model_file_name
[docs] def _check_prep_narrow_data( dd: Union[pd.DataFrame, xr.Dataset], key_variable_data: str, source_name: str, maps: list, vocab: Vocab, user_min_time: pd.Timestamp, user_max_time: pd.Timestamp, data_min_time: pd.Timestamp, data_max_time: pd.Timestamp, logger=None, ) -> Tuple[Union[pd.DataFrame, xr.Dataset], list]: """Check, prep, and narrow the data in time range. Parameters ---------- dd : Union[pd.DataFrame, xr.Dataset] Dataset. key_variable_data : str Name of variable to access from dataset. source_name : str Name of dataset we are accessing from the catalog. maps : list Each entry is a list of information about a dataset; the last entry is for the present source_name or dataset. Each entry contains [min_lon, max_lon, min_lat, max_lat, source_name] and possibly an additional element containing "maptype". vocab : Vocab Way to find the criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. user_min_time : pd.Timestamp If this is input, it will be used as the min time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). user_max_time : pd.Timestamp If this is input, it will be used as the max time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). data_min_time : pd.Timestamp The min time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_min_time, then the constraint time. data_max_time : pd.Timestamp The max time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_max_time, then the constraint time. logger : optional logger, by default None Returns ------- tuple * dd: data container that has been checked and processed. Will be None if a problem has been detected. * maps: list of data information. If there was a problem with this dataset, the final entry in `maps` representing the dataset will have been deleted. """ if isinstance(dd, DataFrame) and key_variable_data not in dd.cf: msg = f"Key variable {key_variable_data} cannot be identified in dataset {source_name}. Skipping dataset.\n" logger.warning(msg) maps.pop(-1) return None, maps elif isinstance(dd, (xr.DataArray, xr.Dataset)) and vocab is not None: try: dd.cf[key_variable_data] except KeyError: msg = f"Key variable {key_variable_data} cannot be identified in dataset {source_name}. Skipping dataset.\n" logger.warning(msg) maps.pop(-1) return None, maps # see if more than one column of data is being identified as key_variable_data # if more than one, log warning and then choose first # variable might be calculated later if key_variable_data in dd.cf and isinstance(dd.cf[key_variable_data], DataFrame): msg = f"More than one variable ({dd.cf[key_variable_data].columns}) have been matched to input variable {key_variable_data}. The first {dd.cf[key_variable_data].columns[0]} is being selected. To change this, modify the vocabulary so that the two variables are not both matched, or change the input data catalog." logger.warning(msg) # remove other data columns for col in dd.cf[key_variable_data].columns[1:]: dd.drop(col, axis=1, inplace=True) if isinstance(dd, pd.DataFrame): # shouldn't need to deal with multi-indices anymore # deal with possible time zone # if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): # index = dd.index.get_level_values(dd.cf["T"].name) # else: # index = dd.index # if hasattr(index, "tz") and index.tz is not None: if dd.cf["T"].dt.tz is not None: logger.warning( "Dataset %s had a timezone %s which is being removed. Make sure the timezone matches the model output.", source_name, str(dd.cf["T"].dt.tz), ) # remove time zone dd.cf["T"] = dd.cf["T"].dt.tz_convert(None) # if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): # # loop over levels in index so we know which level to replace # inds = [] # for lev in range(dd.index.nlevels): # ind = dd.index.get_level_values(lev) # if dd.index.names[lev] == dd.cf["T"].name: # ind = ind.tz_convert(None) # inds.append(ind) # dd = dd.set_index(inds) # # ilev = dd.index.names.index(index.name) # # dd.index = dd.index.set_levels(index, level=ilev) # # # dd.index.set_index([]) # else: # dd.index = index # dd.index.tz_convert(None) # dd.cf["T"] = index # dd.index # # make sure index is sorted ascending so time goes forward # dd = dd.sort_index() # This is meant to limit the data range when user has input time range # for limiting time range of long datasets if ( pd.notnull(user_min_time) and pd.notnull(user_max_time) # and (data_min_time.date() <= user_min_time.date()) # and (data_max_time.date() >= user_max_time.date()) ): # if ( # pd.notnull(user_min_time) # and pd.notnull(user_max_time) # and (data_min_time.date() <= user_min_time.date()) # and (data_max_time.date() >= user_max_time.date()) # ): # if pd.notnull(user_min_time) and pd.notnull(user_max_time) and (abs(data_min_time - user_min_time) <= pd.Timedelta("1 day")) and (abs(data_max_time - user_max_time) >= pd.Timedelta("1 day")): # if pd.notnull(user_min_time) and pd.notnull(user_max_time) and (data_min_time <= user_min_time) and (data_max_time >= user_max_time): # if data_time_range.encompass(model_time_range): if isinstance(dd, pd.DataFrame): dd = ( dd.set_index(dd.cf["T"]) .loc[user_min_time:user_max_time] .reset_index(drop=True) ) elif isinstance(dd, xr.Dataset): dd = dd.cf.sel(T=slice(user_min_time, user_max_time)) else: dd = dd # check if all of variable is nan # variable might be calculated later if key_variable_data in dd.cf and dd.cf[key_variable_data].isnull().all(): msg = f"All values of key variable {key_variable_data} are nan in dataset {source_name}. Skipping dataset.\n" logger.warning(msg) maps.pop(-1) return None, maps return dd, maps
[docs] def _check_time_ranges( source_name: str, data_min_time: pd.Timestamp, data_max_time: pd.Timestamp, model_min_time: pd.Timestamp, model_max_time: pd.Timestamp, user_min_time: pd.Timestamp, user_max_time: pd.Timestamp, maps, logger=None, ) -> Tuple[bool, list]: """Compare time ranges in case should skip dataset source_name. Parameters ---------- source_name : str Name of dataset we are accessing from the catalog. data_min_time : pd.Timestamp The min time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_min_time, then the constraint time. data_max_time : pd.Timestamp The max time in the dataset catalog metadata, or if there is a constraint in the metadata such as an ERDDAP catalog allows, and it is more constrained than data_max_time, then the constraint time. user_min_time : pd.Timestamp If this is input, it will be used as the min time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). user_max_time : pd.Timestamp If this is input, it will be used as the max time for the model. At this point in the code, it will be a pandas Timestamp though could be "NaT" (a null time value). model_min_time : pd.Timestamp Min model time step model_max_time : pd.Timestamp Max model time step maps : list Each entry is a list of information about a dataset; the last entry is for the present source_name or dataset. Each entry contains [min_lon, max_lon, min_lat, max_lat, source_name] and possibly an additional element containing "maptype". logger : logger, optional Logger for messages. Returns ------- tuple * skip_dataset: bool that is True if this dataset should be skipped * maps: list of dataset information with the final entry (representing the present dataset) removed if skip_dataset is True. """ if logger is not None: min_lon, max_lon, min_lat, max_lat = maps[-1][:4] logger.info( f""" User time range: {user_min_time} to {user_max_time}. Model time range: {model_min_time} to {model_max_time}. Data time range: {data_min_time} to {data_max_time}. Data lon range: {min_lon} to {max_lon}. Data lat range: {min_lat} to {max_lat}.""" ) data_time_range = DateTimeRange(data_min_time, data_max_time) model_time_range = DateTimeRange(model_min_time, model_max_time) user_time_range = DateTimeRange(user_min_time, user_max_time) if not data_time_range.is_intersection(model_time_range): msg = f"Time range of dataset {source_name} and model output do not overlap. Skipping dataset.\n" if logger is not None: logger.warning(msg) maps.pop(-1) return True, maps if ( pd.notnull(user_min_time) and pd.notnull(user_max_time) and not data_time_range.is_intersection(user_time_range) ): msg = f"Time range of dataset {source_name} and user-input time range do not overlap. Skipping dataset.\n" if logger is not None: logger.warning(msg) maps.pop(-1) return True, maps # in certain cases, the user input time range might be outside of the model availability if ( pd.notnull(user_min_time) and pd.notnull(user_max_time) and not model_time_range.is_intersection(user_time_range) ): if logger is not None: logger.warning( "User-input time range is outside of model availability, so moving on..." ) return True, maps return False, maps
[docs] def _return_p1( paths: Paths, dsm: xr.Dataset, mask: Union[xr.DataArray, None], alpha: int, dd: int, logger=None, ) -> shapely.Polygon: """Find and return the model domain boundary. Parameters ---------- paths : Paths _description_ dsm : xr.Dataset _description_ mask : xr.DataArray or None Values are 1 for active cells and 0 for inactive grid cells in the model dsm. alpha: int, optional Number for alphashape to determine what counts as the convex hull. Larger number is more detailed, 1 is a good starting point. dd: int, optional Number to decimate model output lon/lat, as a stride. skip_mask : bool Allows user to override mask behavior and keep it as None. Good for testing. Default False. logger : _type_, optional _description_, by default None Returns ------- shapely.Polygon Model domain boundary """ if not paths.ALPHA_PATH.is_file(): # let it find a mask _, _, _, p1 = find_bbox( dsm, paths=paths, mask=mask, alpha=alpha, dd=dd, save=True, ) if logger is not None: logger.info("Calculating numerical domain boundary.") else: if logger is not None: logger.info("Using existing numerical domain boundary.") with open(paths.ALPHA_PATH) as f: p1wkt = f.readlines()[0] p1 = shapely.wkt.loads(p1wkt) return p1
[docs] def _return_data_locations( maps: list, dd: Union[pd.DataFrame, xr.Dataset], featuretype: str, logger=None ) -> Tuple[Union[float, np.array], Union[float, np.array]]: """Return lon, lat locations from dataset. Parameters ---------- maps : list Each entry is a list of information about a dataset; the last entry is for the present source_name or dataset. Each entry contains [min_lon, max_lon, min_lat, max_lat, source_name] and possibly an additional element containing "maptype". dd : Union[pd.DataFrame, xr.Dataset] Dataset featuretype : str NCEI feature type for dataset logger : optional logger, by default None Returns ------- tuple * lons: float or array of floats * lats: float or array of floats """ min_lon, max_lon, min_lat, max_lat, source_name = maps[-1][:5] # logic for one or multiple lon/lat locations if ( min_lon != max_lon or min_lat != max_lat or featuretype == "trajectory" or featuretype == "trajectoryProfile" ): if logger is not None: logger.info( f"Source {source_name} is not stationary so using multiple locations." ) lons, lats = ( dd.cf["longitude"].values, dd.cf["latitude"].values, ) else: lons, lats = min_lon, max_lat return lons, lats
[docs] def _is_outside_boundary( p1: shapely.Polygon, lon: float, lat: float, source_name: str, logger=None ) -> bool: """Checks point to see if is outside model domain. This currently assumes that the dataset is fixed in space. Parameters ---------- p1 : shapely.Polygon Model domain boundary lon : float Longitude of point to compare with model domain boundary lat : float Latitude of point to compare with model domain boundary source_name : str Name of dataset within cat to examine. logger : optional logger, by default None Returns ------- bool True if lon, lat point is outside the model domain boundary, otherwise False. """ # BUT — might want to just use nearest point so make this optional point = Point(lon, lat) if not p1.contains(point): msg = f"Dataset {source_name} at lon {lon}, lat {lat} not located within model domain. Skipping dataset.\n" if logger is not None: logger.warning(msg) return True else: return False
[docs] def _process_model( dsm2: xr.Dataset, preprocess: bool, need_xgcm_grid: bool, kwargs_xroms: dict, logger=None, ) -> Tuple[xr.Dataset, Grid, bool]: """Process model output a second time, possibly. Parameters ---------- dsm2 : xr.Dataset Model output Dataset, already narrowed in time. preprocess : bool True to preprocess. need_xgcm_grid : bool True if need to find `xgcm` grid object. kwargs_xroms : dict Keyword arguments to pass to xroms. logger : optional logger, by default None Returns ------- tuple * dsm2: Model output, possibly modified * grid: xgcm grid object or None * preprocessed: bool that is True if model output was processed in this function """ preprocessed = False # process model output without using open_mfdataset # vertical coords have been an issue for ROMS and POM, related to dask and OFS models if preprocess and need_xgcm_grid: # if em.preprocessing.guess_model_type(dsm) in ["ROMS", "POM"]: # kwargs_pp = {"interp_vertical": False} # else: # kwargs_pp = {} # dsm = em.preprocess(dsm, kwargs=kwargs_pp) # if em.preprocessing.guess_model_type(dsm) in ["ROMS"]: # grid = em.preprocessing.preprocess_roms_grid(dsm) # else: # grid = None # dsm = em.preprocess(dsm, kwargs=dict(grid=grid)) if em.preprocessing.guess_model_type(dsm2) in ["ROMS"]: if need_xgcm_grid: import xroms if logger is not None: logger.info( "setting up for model output with xroms, might take a few minutes..." ) dsm2, grid = xroms.roms_dataset(dsm2, **kwargs_xroms) dsm2.xroms.set_grid(grid) check_dataset(dsm2) # now has been preprocessed preprocessed = True else: grid = None return dsm2, grid, preprocessed
[docs] def _return_mask( mask: xr.DataArray, dsm: xr.Dataset, lon_name: str, wetdry: bool, key_variable_data: str, paths: Paths, logger=None, ) -> xr.DataArray: """Find or calculate and check mask. Parameters ---------- mask : xr.DataArray or None Values are 1 for active cells and 0 for inactive grid cells in the model dsm. dsm : xr.Dataset Model output Dataset lon_name : str variable name for longitude in dsm. wetdry : bool Adjusts the logic in the search for mask such that if True, selected mask must include "wetdry" in name and will use first time step. key_variable_data : str Key name of variable paths : Paths Paths to files and directories for this project. logger optional Returns ------- DataArray Mask """ # take out relevant variable and identify mask if available (otherwise None) # this mask has to match dam for em.select() if mask is None: if paths.MASK_PATH(key_variable_data).is_file(): if logger is not None: logger.info( f"Using cached mask from {paths.MASK_PATH(key_variable_data)}." ) mask = xr.open_dataarray(paths.MASK_PATH(key_variable_data)) else: if logger is not None: logger.info( f"Finding and saving mask to cache to {paths.MASK_PATH(key_variable_data)}." ) # # dam variable might not be in Dataset itself, but its coordinates probably are. # mask = get_mask(dsm, dam.name) mask = get_mask(dsm, lon_name, wetdry=wetdry) assert mask is not None mask.to_netcdf(paths.MASK_PATH(key_variable_data)) # there should not be any nans in the mask! if mask.isnull().any(): raise ValueError( f"""there are nans in your mask — better fix something. The cached version is at {paths.MASK_PATH(key_variable_data)}. """ ) return mask
[docs] def _select_process_save_model( select_kwargs: dict, source_name: str, model_source_name: str, model_file_name: pathlib.Path, save_horizontal_interp_weights: bool, key_variable_data: str, maps: list, paths: Paths, logger=None, ) -> Tuple[xr.Dataset, bool, list]: """Select model output, process, and save to file Parameters ---------- select_kwargs : dict Keyword arguments to send to `em.select()` for model extraction source_name : str Name of dataset within cat to examine. model_source_name : str Source name for model in the model catalog model_file_name : pathlib.Path Path to where to save model output save_horizontal_interp_weights : bool Default True. Whether or not to save horizontal interp info like Delaunay triangulation to file. Set to False to not save which is useful for testing. key_variable_data : str Name of variable to select, to be interpreted with cf-xarray maps : list Each entry is a list of information about a dataset; the last entry is for the present source_name or dataset. Each entry contains [min_lon, max_lon, min_lat, max_lat, source_name] and possibly an additional element containing "maptype". paths : Paths Paths object for finding paths to use. logger : logger, optional Logger for messages. Returns ------- tuple * model_var: xr.Dataset with selected model output * skip_dataset: True if we should skip this dataset due to checks in this function * maps: Same as input except might be missing final entry if skipping this dataset """ dam = select_kwargs.pop("dam") skip_dataset = False # use pickle of triangulation from project dir if available tri_name = paths.PROJ_DIR / "tri.pickle" if ( select_kwargs["horizontal_interp"] and select_kwargs["horizontal_interp_code"] == "delaunay" and tri_name.is_file() ): import pickle if logger is not None: logger.info( f"Using previously-calculated Delaunay triangulation located at {tri_name}." ) with open(tri_name, "rb") as handle: tri = pickle.load(handle) else: tri = None # add tri to select_kwargs to use in em.select select_kwargs["triangulation"] = tri if logger is not None: logger.info( f"Selecting model output at locations to match dataset {source_name}." ) model_var, kwargs_out = em.select(dam, **select_kwargs) # save pickle of triangulation to project dir if ( select_kwargs["horizontal_interp"] and select_kwargs["horizontal_interp_code"] == "delaunay" and not tri_name.is_file() and save_horizontal_interp_weights ): import pickle with open(tri_name, "wb") as handle: pickle.dump( kwargs_out["tri"], handle, protocol=pickle.HIGHEST_PROTOCOL, ) msg = f""" Model coordinates found are {model_var.coords}. """ if select_kwargs["horizontal_interp"]: msg += f""" Interpolation coordinates used for horizontal interpolation are {kwargs_out["interp_coords"]}.""" else: msg += f""" Output information from finding nearest neighbors to requested points are {kwargs_out}.""" if logger is not None: logger.info(msg) # Use distances from xoak to give context to how far the returned model points might be from # the data locations if not select_kwargs["horizontal_interp"]: distance = kwargs_out["distances"] if (distance > 5).any(): if logger is not None: logger.warning( "Distance between nearest model location and data location for source %s is over 5 km with a distance of %s", source_name, str(float(distance)), ) elif (distance > 100).any(): msg = f"Distance between nearest model location and data location for source {source_name} is over 100 km with a distance of {float(distance)}. Skipping dataset.\n" if logger is not None: logger.warning(msg) maps.pop(-1) skip_dataset = True if model_var.cf["T"].size == 0: # model output isn't available to match data # data must not be in the space/time range of model maps.pop(-1) if logger is not None: logger.warning( "Model output is not present to match dataset %s.", source_name, ) skip_dataset = True # this is trying to drop z_rho type coordinates to not save an extra time series # do need to use "vertical" here instead of "Z" since "Z" will be s_rho and we want # to keep that if ( select_kwargs["Z"] is not None and not select_kwargs["vertical_interp"] and "vertical" in model_var.cf.coordinates ): if logger is not None: logger.info("Trying to drop vertical coordinates time series") if model_var.cf["vertical"].ndim > 2: model_var = model_var.drop_vars(model_var.cf["vertical"].name) # try rechunking to avoid killing kernel if model_var.dims == (model_var.cf["T"].name,): # for simple case of only time, just rechunk into pieces if no chunks if model_var.chunks == ((model_var.size,),): if logger is not None: logger.info(f"Rechunking model output...") model_var = model_var.chunk({model_var.cf["T"].name: 1}) if logger is not None: logger.info(f"Loading model output...") model_var = model_var.compute() # depths shouldn't need to be saved if interpolated since then will be a dimension if select_kwargs["Z"] is not None and not select_kwargs["vertical_interp"]: # find Z index if "Z" in dam.cf.axes: zkey = dam.cf["Z"].name iz = list(dam.cf["Z"].values).index(model_var[zkey].values) model_var[f"{zkey}_index"] = iz # if we chose an index maybe there is no vertical? experimental if "vertical" not in model_var.cf: model_var[f"{zkey}_index"].attrs["positive"] = dam.cf["vertical"].attrs[ "positive" ] else: raise KeyError("Z missing from dam axes") if not select_kwargs["horizontal_interp"]: if len(distance) > 1: model_var["distance"] = ( model_var.cf["T"].name, distance, ) # if more than one distance, it is array else: model_var["distance"] = float(distance) model_var["distance"].attrs["units"] = "km" # model_var.attrs["distance_from_location_km"] = float(distance) else: # when lons/lats are function of time, add them back in if "longitude" not in model_var.cf: # if dam.cf["longitude"].name not in model_var.coords: # if model_var.ndim == 1 and len(model_var[model_var.dims[0]]) == lons.size: if isinstance(select_kwargs["longitude"], (float, int)): attrs = dict( axis="X", units="degrees_east", standard_name="longitude", ) model_var[dam.cf["longitude"].name] = select_kwargs["longitude"] model_var[dam.cf["longitude"].name].attrs = attrs elif ( model_var.ndim == 1 and len(model_var[model_var.dims[0]]) == select_kwargs["longitude"].size ): attrs = dict( axis="X", units="degrees_east", standard_name="longitude", ) model_var[dam.cf["longitude"].name] = ( model_var.dims[0], select_kwargs["longitude"], attrs, ) if "latitude" not in model_var.cf: # if dam.cf["latitude"].name not in model_var.dims: if isinstance(select_kwargs["latitude"], (float, int)): model_var[dam.cf["latitude"].name] = select_kwargs["latitude"] attrs = dict( axis="Y", units="degrees_north", standard_name="latitude", ) model_var[dam.cf["latitude"].name].attrs = attrs elif ( model_var.ndim == 1 and len(model_var[model_var.dims[0]]) == select_kwargs["latitude"].size ): attrs = dict( axis="Y", units="degrees_north", standard_name="latitude", ) model_var[dam.cf["latitude"].name] = ( model_var.dims[0], select_kwargs["latitude"], attrs, ) attrs = { "key_variable": key_variable_data, "vertical_interp": str(select_kwargs["vertical_interp"]), "interpolate_horizontal": str(select_kwargs["horizontal_interp"]), "model_source_name": model_source_name, "source_name": source_name, } if select_kwargs["horizontal_interp"]: attrs.update( { "horizontal_interp_code": select_kwargs["horizontal_interp_code"], } ) model_var.attrs.update(attrs) if select_kwargs["Z"] is None: no_Z = True else: no_Z = False model_var = model_var.cf.guess_coord_axis() try: check_dataset(model_var, no_Z=no_Z) except KeyError: # see if I can fix it model_var = fix_dataset(model_var, dam) check_dataset(model_var, no_Z=no_Z) if logger is not None: logger.info(f"Saving model output to file...") model_var.to_netcdf(model_file_name) return model_var, skip_dataset, maps
[docs] def _change_z_by_ssh(dd, model_depth_attr_positive, key_variables, logger): """Possibly modify depth by ssh. Catch case in which z column represent fixed depth and SSH is water column above depth such that SSH gives information to z """ if ( dd.cf["Z"].size > 1 and (dd.cf["Z"] == dd.cf["Z"][0]).all() and "ssh" in key_variables and dd.cf["ssh"].name == "sea_surface_height_above_sea_level_geoid_local_station_datum" ): constant_z = dd.cf["Z"][0] mean_ssh = dd.cf["ssh"].mean() # want to use the ssh as the effective depth dd[dd.cf["Z"].name] = mean_ssh dd = _match_depth_sign(dd, model_depth_attr_positive) if logger is not None: logger.info( f"Changed constant Z {constant_z} to mean SSH value that is measured relative to that fixed depth which gives {dd[dd.cf['Z'].name].mean()} instead." ) return dd
[docs] def run( catalogs: Union[str, Catalog, Sequence], project_name: str, key_variable: Union[str, dict], model_name: Union[str, Catalog], vocabs: Optional[Union[str, Vocab, Sequence, PurePath]] = None, vocab_labels: Optional[Union[str, Path, dict]] = None, ndatasets: Optional[int] = None, kwargs_map: Optional[Dict] = None, verbose: bool = True, mode: str = "w", testing: bool = False, alpha: int = 5, dd: int = 2, preprocess: bool = False, need_xgcm_grid: bool = False, xcmocean_options: Optional[dict] = None, kwargs_xroms: Optional[dict] = None, locstream: bool = True, interpolate_horizontal: bool = True, horizontal_interp_code="delaunay", save_horizontal_interp_weights: bool = True, want_vertical_interp: bool = False, want_locstreamZ: bool = False, extrap: bool = False, model_source_name: Optional[str] = None, override_chunks: Optional[Dict] = None, catalog_source_names=None, user_min_time: Optional[Union[str, pd.Timestamp]] = None, user_max_time: Optional[Union[str, pd.Timestamp]] = None, check_in_boundary: bool = True, tidal_filtering: Optional[Dict[str, bool]] = None, ts_mods: Optional[list] = None, model_only: bool = False, plot_map: bool = True, no_Z: bool = False, skip_mask: bool = False, override_mask_lon: Optional[str] = None, known_model_depth_attr_positive: Optional[str] = None, wetdry: bool = False, plot_count_title: bool = True, cache_dir: Optional[Union[str, PurePath]] = None, return_fig: bool = False, override_model: bool = False, override_processed: bool = False, override_stats: bool = False, override_plot: bool = False, plot_description: Optional[str] = None, kwargs_plot: Optional[Dict] = None, skip_key_variable_check: bool = False, **kwargs, ): """Run the model-data comparison. Note that timezones are assumed to match between the model output and data. To avoid calculating a mask you need to input `skip_mask=True`, `check_in_boundary=False`, and `plot_map=False`. Parameters ---------- catalogs : str, list, Catalog Catalog name(s) or list of names, or catalog object or list of catalog objects. Datasets will be accessed from catalog entries. project_name : str Subdirectory in cache dir to store files associated together. key_variable : str, dict Key in vocab(s) representing variable to compare between model and datasets. model_name : str, Catalog Name of catalog for model output, created with ``make_catalog`` call, or Catalog instance. vocabs : str, list, Vocab, PurePath, optional Criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. This input is for the name of one or more existing vocabularies which are stored in a user application cache. This should be supplied, however it is made optional because it could be provided by setting it outside of the OMSA code. vocab_labels : str, dict, Path, optional Ultimately a dictionary whose keys match the input vocab and values have strings to be used in plot labels, such as "Sea water temperature [C]" for the key "temp". They can be input from a stored file or as a dict. ndatasets : int, optional Max number of datasets from each input catalog to use. kwargs_map : dict, optional Keyword arguments to pass on to ``omsa.plot.map.plot_map`` call. verbose : bool, optional Print useful runtime commands to stdout if True as well as save in log, otherwise silently save in log. mode : str, optional mode for logging file. Default is to overwrite an existing logfile, but can be changed to other modes, e.g. "a" to instead append to an existing log file. testing : boolean, optional Set to True if testing so warnings come through instead of being logged. alpha : int parameter for alphashape. 0 returns qhull, and higher values make a tighter polygon around the points. dd : int number to decimate model points by when calculating model boundary with alphashape. input 1 to not decimate. preprocess : bool, optional If True, use function from ``extract_model`` to preprocess model output. need_xgcm_grid: bool If True, try to set up xgcm grid for run, which will be used for the variable calculation for the model. kwargs_xroms : dict Optional keyword arguments to pass to xroms.open_dataset locstream: boolean, optional Which type of interpolation to do, passed to em.select(): * False: 2D array of points with 1 dimension the lons and the other dimension the lats. * True: lons/lats as unstructured coordinate pairs (in xESMF language, LocStream). interpolate_horizontal : bool, optional If True, interpolate horizontally. Otherwise find nearest model points. horizontal_interp_code: str Default "xesmf" to use package ``xESMF`` for horizontal interpolation, which is probably better if you need to interpolate to many points. To use ``xESMF`` you have install it as an optional dependency. Input "tree" to use BallTree to find nearest 3 neighbors and interpolate using barycentric coordinates. This has been tested for interpolating to 3 locations so far. Input "delaunay" to use a delaunay triangulation to find the nearest triangle points and interpolate the same as with "tree" using barycentric coordinates. This should be faster when you have more points to interpolate to, especially if you save and reuse the triangulation. save_horizontal_interp_weights : bool Default True. Whether or not to save horizontal interp info like Delaunay triangulation to file. Set to False to not save which is useful for testing. want_vertical_interp: bool This is False unless the user wants to specify that vertical interpolation should happen. This is used in only certain cases but in those cases it is important so that it is known to interpolate instead of try to figure out a vertical level index (which is not possible currently). want_locstreamZ: bool This is False unless the user wants to specify that advanced indexing should happen in extract_model. There is logic built in for cases in which this will always be necessary, but there are cases in which it needs to be specified by the user to occur, such as a mooring at depth, in which case the featureType is "timeSeries" but the model will be interpolated in depth and will need to use advanced indexing to extract a time series ultimately. extrap: bool Passed to `extract_model.select()`. Defaults to False. Pass True to extrapolate outside the model domain. model_source_name : str, optional Use this to access a specific source in the input model_catalog instead of otherwise just using the first source in the catalog. override_chunks : dict, optional If input, override the chunks in the model catalog with this dict. This is useful if you want to change the chunking of the model output to make it more efficient for your use case, or if you need it to be different due to the interpolation that will be occurring for your use case. This is a dict with keys that are the dimension names and values that are the chunk sizes. For example, ``{"time": 1}`` would make the time dimension have chunks of size 1. catalog_source_names user_min_time : str, optional If this is input, it will be used as the min time for the model user_max_time : str, optional If this is input, it will be used as the max time for the model check_in_boundary : bool If True, station location will be compared against model domain polygon to check if inside domain. Set to False to skip this check which might be desirable if you want to just compare with the closest model point. tidal_filtering: dict, ``tidal_filtering["model"]=True`` to tidally filter modeling output after em.select() is run, and ``tidal_filtering["data"]=True`` to tidally filter data. ts_mods : list list of time series modifications to apply to data and model. model_only: bool If True, reads in model output and saves to cache, then stops. Default False. plot_map : bool If False, don't plot map no_Z : bool If True, set Z=None so no vertical interpolation or selection occurs. Do this if your variable has no concept of depth, like the sea surface height. skip_mask : bool Allows user to override mask behavior and keep it as None. Good for testing. Default False. Also skips mask in p1 calculation and map plotting if set to False and those are set to True. override_mask_lon : str If input, override the mask lon name with this string. Do this when you know the name of the horizontal grid you want associated with the variable you are using and the code is struggling to figure it out with "lonname". This may occur because key_variable is not a variable in the model output but is calculated. known_model_depth_attr_positive : str If input, override the model depth attribute positive value with this string. This is useful if you know the model depth attribute positive value and it is not available in the cached model output files. wetdry : bool If True, insist that masked used has "wetdry" in the name and then use the first time step of that mask. plot_count_title : bool If True, have a count to match the map of the station number in the title, like "0: [station name]". Otherwise skip count. cache_dir: str, Path Pass on to omsa.paths to set cache directory location if you don't want to use the default. Good for testing. vocab_labels: dict, optional dict with keys that match input vocab for putting labels with units on the plots. User has to make sure they match both the data and model; there is no unit handling. return_fig: bool Set to True to return all outputs from this function. Use for testing. Only works if using a single source. override_model : bool Flag to force-redo model selection. Default False. override_processed : bool Flag to force-redo model and data processing. Default False. override_stats : bool Flag to force-redo stats calculation. Default False. override_plot : bool Flag to force-redo plot. If True, only redos plot itself if other files are already available. If False, only redos the plot not the other files. Default False. kwargs_plot : dict to pass to omsa plot selection and then through the omsa plot selection to the subsequent plot itself for source. If you need more fine options, run the run function per source. skip_key_variable_check : bool If True, don't check for key_variable name being in catalog source metadata. """ paths = Paths(project_name, cache_dir=cache_dir) logger = set_up_logging(verbose, paths=paths, mode=mode, testing=testing) logger.info(f"Input parameters: {locals()}") kwargs_map = kwargs_map or {} kwargs_plot = kwargs_plot or {} kwargs_xroms = kwargs_xroms or {} ts_mods = ts_mods or [] override_chunks = override_chunks or {} # add override_plot to kwargs_plot in case the fignames are changed later and should be checked there instead kwargs_plot.update({"override_plot": override_plot}) mask = None # After this, we have a single Vocab object with vocab stored in vocab.vocab if vocabs is not None: vocab = open_vocabs(vocabs, paths) # now we shouldn't need to worry about this for the rest of the run right? cfp_set_options(custom_criteria=vocab.vocab) cfx_set_options(custom_criteria=vocab.vocab) else: vocab = None # After this, we have None or a dict with key, values of vocab keys, string description for plot labels if vocab_labels is not None: vocab_labels = open_vocab_labels(vocab_labels, paths) # Open and check catalogs. cats = open_catalogs(catalogs, paths, skip_strings=["_base", "_all", "_tidecons"]) # Warning about number of datasets ndata = np.sum([len(list(cat)) for cat in cats]) if ndatasets is not None: logger.info( f"Note that we are using {ndatasets} datasets of {ndata} datasets. This might take awhile." ) else: logger.info( f"Note that there are {ndata} datasets to use. This might take awhile." ) # initialize model Dataset as None to compare with later # don't open model output at all if not needed (because it has already been saved, for example) dsm = None preprocessed = False p1 = None # have to save this because of my poor variable naming at the moment as I make a list possible key_variable_orig = key_variable # loop over catalogs and sources to pull out lon/lat locations for plot maps = [] count = 0 # track datasets since count is used to match on map for cat in cats: logger.info(f"Catalog {cat}.") if catalog_source_names is not None: source_names = catalog_source_names else: source_names = list(cat) for i, source_name in enumerate(source_names[:ndatasets]): skip_dataset = False if ndatasets is None: msg = ( f"\nsource name: {source_name} ({i+1} of {ndata} for catalog {cat}." ) else: msg = f"\nsource name: {source_name} ({i+1} of {ndatasets} for catalog {cat}." logger.info(msg) # this check doesn't work if key_data is a dict since too hard to figure out what to check then # change to iterable key_variable_list = cf_xarray.utils.always_iterable(key_variable_orig) # import pdb; pdb.set_trace() if ( "key_variables" in cat[source_name].metadata and all( [ key not in cat[source_name].metadata["key_variables"] for key in key_variable_list ] ) # and key_variable_list not in cat[source_name].metadata["key_variables"] # and not isinstance(key_variable_list, dict) and all([not isinstance(key, dict) for key in key_variable_list]) and not skip_key_variable_check ): logger.info( f"no `key_variables` key found in source metadata or at least not {key_variable}" ) skip_dataset = True continue min_lon = cat[source_name].metadata["minLongitude"] max_lon = cat[source_name].metadata["maxLongitude"] min_lat = cat[source_name].metadata["minLatitude"] max_lat = cat[source_name].metadata["maxLatitude"] new_map = [min_lon, max_lon, min_lat, max_lat, source_name] # include maptype if available if "maptype" in cat[source_name].metadata: new_map += [cat[source_name].metadata["maptype"]] maps.append(new_map) # first loop dsm should be None # this is just a simple connection, no extra processing etc if dsm is None: dsm, model_source_name = _initial_model_handling( model_name, paths, override_chunks, model_source_name ) assert isinstance(model_source_name, str) # for mypy # Determine data min and max times user_min_time, user_max_time = pd.Timestamp(user_min_time), pd.Timestamp( user_max_time ) model_min_time = pd.Timestamp(str(dsm.cf["T"][0].values)) model_max_time = pd.Timestamp(str(dsm.cf["T"][-1].values)) data_min_time, data_max_time = _find_data_time_range(cat, source_name) # skip this dataset if times between data and model don't align skip_dataset, maps = _check_time_ranges( source_name, data_min_time, data_max_time, model_min_time, model_max_time, user_min_time, user_max_time, maps, logger, ) if skip_dataset: continue # key_variable could be a list of strings or dicts and here we loop over them if so obss, models, statss, key_variable_datas = [], [], [], [] for key_variable in key_variable_list: # allow for possibility that key_variable is a dict with more complicated usage than just a string if isinstance(key_variable, dict): key_variable_data = key_variable["data"] else: key_variable_data = key_variable logger.info( f"running {source_name} for key_variable(s) {key_variable_data} from key_variable_list {key_variable_list}\n" ) # # Combine and align the two time series of variable # with cfp_set_options(custom_criteria=vocab.vocab): try: dfd = cat[source_name].read() if isinstance(dfd, pd.DataFrame): dfd = check_dataframe(dfd, no_Z) except requests.exceptions.HTTPError as e: logger.warning(str(e)) msg = f"Data cannot be loaded for dataset {source_name}. Skipping dataset.\n" logger.warning(msg) maps.pop(-1) skip_dataset = True continue except Exception as e: logger.warning(str(e)) msg = f"Data cannot be loaded for dataset {source_name}. Skipping dataset.\n" logger.warning(msg) maps.pop(-1) skip_dataset = True continue # check for already-aligned model-data file cat_name = cat.name or cat.metadata["name"] source_name_use = source_name.replace(".", "_").replace(" ", "_") fname_processed_orig = ( f"{cat_name}_{source_name_use}_{key_variable_data}" ) ( fname_processed, fname_processed_data, fname_processed_model, model_file_name, ) = _processed_file_names( fname_processed_orig, type(dfd), user_min_time, user_max_time, paths, ts_mods, logger, ) figname = (paths.OUT_DIR / f"{fname_processed.stem}").with_suffix( ".png" ) # in case there are multiple key_variables in key_variable_list which will be joined # for the figure, renamed including both names # import pdb; pdb.set_trace() if len(key_variable_list) > 1: figname = pathlib.Path( str(figname).replace( key_variable_data, "_".join(key_variable_list) ) ) logger.info(f"Figure name is {figname}.") if figname.is_file() and not override_plot: logger.info(f"plot already exists so skipping dataset.") continue # read in previously-saved processed model output and obs. if ( not override_processed and fname_processed_data.is_file() and fname_processed_model.is_file() ): logger.info( "Reading previously-processed model output and data for %s.", source_name, ) obs = read_processed_data_file(fname_processed_data, no_Z) model = read_model_file(fname_processed_model, no_Z, dsm) else: logger.info( "No previously processed model output and data available for %s, so setting up now.", source_name, ) # take out relevant variable and identify mask if available (otherwise None) # this mask has to match dam for em.select() if not skip_mask: # first discover if key_variable_data is either directly in dsm.data_vars # or is a cf-xarray/pandas custom name, to get the long coord name # dsm.cf[key_variable_data] works whether key_variable_data is a # direct variable name or a custom name if override_mask_lon is not None: lonname = override_mask_lon else: try: lonname = dsm.cf[key_variable_data].cf["longitude"].name except KeyError: lonname = dsm.cf.coordinates["longitude"][0] mask = _return_mask( mask, dsm, lonname, wetdry, key_variable_data, paths, logger, ) # if key_variable_data in dsm.data_vars: # lonname = dsm[key_variable_data].cf["longitude"].name # else: # try: # dsm.cf[key_variable_data] # if key_variable_data in dsm.data_vars: # mask = _return_mask( # mask, # dsm, # lonname, # wetdry, # key_variable_data, # paths, # logger, # ) # else: # mask = _return_mask( # mask, # dsm, # dsm.cf.coordinates["longitude"][ # 0 # ], # using the first longitude key is adequate # wetdry, # key_variable_data, # paths, # logger, # ) # I think these should always be true together if skip_mask: assert mask is None # Calculate boundary of model domain to compare with data locations and for map # don't need p1 if check_in_boundary False and plot_map False if (check_in_boundary or plot_map) and p1 is None: p1 = _return_p1(paths, dsm, mask, alpha, dd, logger) # see if data location is inside alphashape-calculated polygon of model domain if check_in_boundary: if _is_outside_boundary( p1, min_lon, min_lat, source_name, logger ): maps.pop(-1) continue # Check, prep, and possibly narrow data time range dfd, maps = _check_prep_narrow_data( dfd, key_variable_data, source_name, maps, vocab, user_min_time, user_max_time, data_min_time, data_max_time, logger, ) # if there were any issues in the last function, dfd should be None and we should # skip this dataset if dfd is None: skip_dataset = True continue # Read in model output from cache if possible. if not override_model and model_file_name.is_file(): logger.info("Reading model output from file.") model_var = read_model_file(model_file_name, no_Z, dsm) if not interpolate_horizontal: distance = model_var["distance"] # Update data depth to account for ssh if necessary # this happens in _choose_depths if the model output is read in model_depth_attr_positive = ( dsm[dsm.cf.axes["Z"][0]].attrs["positive"] or known_model_depth_attr_positive ) dfd = _change_z_by_ssh( dfd, model_depth_attr_positive, cat[source_name].metadata["key_variables"], logger, ) # Is this necessary? It removes `s_rho_index` when present which causes an issue # since it is "vertical" for cf # model_var = model_var.cf[key_variable_data] # if model_only: # logger.info("Running model only so moving on to next source...") # continue # have to read in the model output else: # lons, lats might be one location or many lons, lats = _return_data_locations( maps, dfd, cat[source_name].metadata["featuretype"], logger ) # narrow time range to limit how much model output to deal with dsm2 = _narrow_model_time_range( dsm, user_min_time, user_max_time, model_min_time, model_max_time, data_min_time, data_max_time, ) if dsm2 is None: skip_dataset = True logger.info( "Data range was within a gap in the model output, so skipping data source." ) continue # more processing opportunity and chance to use xroms if needed dsm2, grid, preprocessed = _process_model( dsm2, preprocess, need_xgcm_grid, kwargs_xroms, logger ) # Narrow model from Dataset to DataArray here # key_variable = ["xroms", "ualong", "theta"] # and all necessary steps to get there will happen # key_variable = {"accessor": "xroms", "function": "ualong", "inputs": {"theta": theta}} # # HOW TO GET THETA IN THE DICT? # dam might be a Dataset but it has to be on a single grid, that is, e.g., all variable on the ROMS rho grid. # well, that is only partially true. em.select requires DataArrays for certain operations like vertical # interpolation. dam = _dam_from_dsm( dsm2, key_variable, key_variable_data, cat[source_name].metadata, no_Z, logger, ) # shift if 0 to 360 dam = shift_longitudes(dam) # this is fast if not needed # expand 1D coordinates to 2D, so all models dealt with in OMSA are treated with 2D coords. # if your model is too large to be treated with this way, subset the model first. dam = coords1Dto2D(dam) # this is fast if not needed # if locstreamT then want to keep all the data times (like a CTD transect) # if not, just want the unique values (like a CTD profile) locstreamT = ftconfig[cat[source_name].metadata["featuretype"]][ "locstreamT" ] locstreamZ = ftconfig[cat[source_name].metadata["featuretype"]][ "locstreamZ" ] if locstreamT: T = [pd.Timestamp(date) for date in dfd.cf["T"].values] else: T = [ pd.Timestamp(date) for date in np.unique(dfd.cf["T"].values) ] # Need to have this here because if model file has previously been read in but # aligned file doesn't exist yet, this needs to run to update the sign of the # data depths in certain cases. zkeym = dsm.cf.axes["Z"][0] dfd, Z, vertical_interp, iZ = _choose_depths( dfd, dsm[zkeym].attrs["positive"], no_Z, want_vertical_interp, cat[source_name].metadata["key_variables"], dam, mask, lons, lats, logger, ) select_kwargs = dict( dam=dam, longitude=lons, latitude=lats, # T=slice(user_min_time, user_max_time), # T=np.unique(dfd.cf["T"].values), # works for Datasets # T=np.unique(dfd.cf["T"].values).tolist(), # works for DataFrame # T=list(np.unique(dfd.cf["T"].values)), # might work for both # T=[pd.Timestamp(date) for date in np.unique(dfd.cf["T"].values)], T=T, # # works for both # T=None, # changed this because wasn't working with CTD profiles. Time interpolation happens during _align. Z=Z, vertical_interp=vertical_interp, iT=None, iZ=iZ, extrap=extrap, extrap_val=None, locstream=locstream, locstreamT=locstreamT, locstreamZ=locstreamZ or want_locstreamZ, # locstream_dim="z_rho", weights=None, mask=mask, use_xoak=False, horizontal_interp=interpolate_horizontal, horizontal_interp_code=horizontal_interp_code, xgcm_grid=grid, return_info=True, ) model_var, skip_dataset, maps = _select_process_save_model( select_kwargs, source_name, model_source_name, model_file_name, save_horizontal_interp_weights, key_variable_data, maps, paths, logger, ) if skip_dataset: continue if model_only: logger.info("Running model only so moving on to next source...") continue # opportunity to modify time series data # fnamemods = "" from copy import deepcopy ts_mods_copy = deepcopy(ts_mods) # ts_mods_copy = ts_mods.copy() # otherwise you modify ts_mods when adding data for mod in ts_mods_copy: logger.info( f"Apply a time series modification called {mod['function']}." ) if isinstance(dfd, pd.DataFrame): dfd.set_index(dfd.cf["T"], inplace=True) # this is how you include the dataset in the inputs if ( "include_data" in mod["inputs"] and mod["inputs"]["include_data"] ): mod["inputs"].update({"dd": dfd}) mod["inputs"].pop("include_data") # apply ts_mod to full dataset instead of just one variable since might want # to use more than one of the variables # also need to overwrite Dataset since the shape of the variables might change here dfd = mod["function"](dfd, **mod["inputs"]) # dfd[dfd.cf[key_variable_data].name] = mod["function"]( # dfd.cf[key_variable_data], **mod["inputs"] # ) if isinstance(dfd, pd.DataFrame): if dfd.cf["T"].name in dfd.columns: drop = True else: drop = False dfd = dfd.reset_index(drop=drop) model_var = mod["function"](model_var, **mod["inputs"]) # check model output for nans ind_keep = np.arange(0, model_var.cf["T"].size)[ model_var.cf["T"].notnull() ] if model_var.cf["T"].name in model_var.dims: model_var = model_var.isel({model_var.cf["T"].name: ind_keep}) # there could be a small mismatch in the length of time if times were pulled # out separately if np.unique(model_var.cf["T"]).size != np.unique(dfd.cf["T"]).size: logger.info("Changing the timing of the model or data.") # if model_var.cf["T"].size != np.unique(dfd.cf["T"]).size: # if (isinstance(dfd, pd.DataFrame) and model_var.cf["T"].size != dfd.cf["T"].unique().size) or (isinstance(dfd, xr.Dataset) and model_var.cf["T"].size != dfd.cf["T"].drop_duplicates(dim=dfd.cf["T"].name).size): # if len(model_var.cf["T"]) != len(dfd.cf["T"]): # timeSeries stime = pd.Timestamp( max(dfd.cf["T"].values[0], model_var.cf["T"].values[0]) ) etime = pd.Timestamp( min(dfd.cf["T"].values[-1], model_var.cf["T"].values[-1]) ) if stime != etime: model_var = model_var.cf.sel({"T": slice(stime, etime)}) if isinstance(dfd, pd.DataFrame): dfd = dfd.set_index(dfd.cf["T"].name) dfd = dfd.loc[stime:etime] # interpolate data to model times # Times between data and model should already match from em.select # except in the case that model output was cached in convenient time series # in which case the times aren't already matched. For this case, the data # also might be missing the occasional data points, and want # the data index to match the model index since the data resolution might be very high. # get combined index of model and obs to first interpolate then reindex obs to model # otherwise only nan's come through # accounting for known issue for interpolation after sampling if indices changes # https://github.com/pandas-dev/pandas/issues/14297 # this won't run for single ctd profiles if len(dfd.cf["T"].unique()) > 1: model_index = model_var.cf["T"].to_pandas().index model_index.name = dfd.index.name ind = model_index.union(dfd.index) dfd = ( dfd.reindex(ind) .interpolate(method="time", limit=3) .reindex(model_index) ) dfd = dfd.reset_index() elif isinstance(dfd, xr.Dataset): # interpolate data to model times # model_index = model_var.cf["T"].to_pandas().index # ind = model_index.union(dfd.cf["T"].to_pandas().index) dfd = dfd.interp( {dfd.cf["T"].name: model_var.cf["T"].values} ) # dfd = dfd.cf.sel({"T": slice(stime, etime)}) # change names of model to match data so that stats will calculate without adding variables # not necessary if dfd is DataFrame (i think) if isinstance(dfd, (xr.Dataset, xr.DataArray)): rename = {} for model_dim in model_var.squeeze().dims: matching_dim = [ data_dim for data_dim in dfd.dims if dfd[data_dim].size == model_var[model_dim].size ][0] rename.update({model_dim: matching_dim}) # rename = {model_var.cf[key].name: dfd.cf[key].name for key in ["T","Z","latitude","longitude"]} model_var = model_var.rename(rename) # Save processed data and model files save_processed_files( dfd, fname_processed_data, model_var, fname_processed_model ) obs = read_processed_data_file(fname_processed_data, no_Z) model = read_model_file(fname_processed_model, no_Z, dsm) logger.info(f"model file name is {model_file_name}.") if not override_model and model_file_name.is_file(): logger.info("Reading model output from file.") model = read_model_file(fname_processed_model, no_Z, dsm) if not interpolate_horizontal: distance = model["distance"] else: raise ValueError( f"If the processed files are available need model file {model_file_name} too." ) # make sure that obs depths were modified if necessary, to match model model_depth_attr_positive = ( known_model_depth_attr_positive or model[model.cf.axes["Z"][0]].attrs["positive"] ) obs = _match_depth_sign(obs, model_depth_attr_positive) if model_only: logger.info("Running model only so moving on to next source...") continue stats_fname = (paths.OUT_DIR / f"{fname_processed.stem}").with_suffix( ".yaml" ) if not override_stats and stats_fname.is_file(): logger.info("Reading from previously-saved stats file.") with open(stats_fname, "r") as stream: stats = yaml.safe_load(stream) else: logger.info(f"Calculating stats for {key_variable_data}.") stats = compute_stats( obs.cf[key_variable_data], model.cf[key_variable_data].squeeze() ) # stats = obs.omsa.compute_stats # add distance in if not interpolate_horizontal: stats["dist"] = float(distance) # save stats save_stats( source_name, stats, key_variable_data, paths, filename=stats_fname, ) logger.info("Saved stats file.") # this conflicts with quiver plots over time... # # If stats are 2D, assume I want to actually plot them instead of # # obs and model, so replace obs/model with their respective stats # # I guess this only works if you have # if isinstance(stats["ss"]["value"], (xr.DataArray,xr.Dataset)) and stats["ss"]["value"].ndim > 1: # missing_dims = set(obs.cf.axes) - set(stats["ss"]["value"].cf.axes) # if len(missing_dims) > 0: # stats["ss"]["value"][obs.cf["T"].name] = obs.cf["T"][0] # # this is not a very good solution but it works for now # # check for axes with multiple values and drop the first one # # this messes up the plot title later otherwise # too_many_dict = {key: value for key, value in stats["ss"]["value"].cf.axes.items() if len(value) > 1} # to_drop = list(too_many_dict.values())[0][0] # del stats["ss"]["value"][to_drop] # # use key_variable_data name so plot works # # import pdb; pdb.set_trace() # obs = stats["ss"]["value"].rename(key_variable_data) # # obs = stats["ss"]["value"]#.to_dataset(name=key_variable_data) # # obs = stats["ss"]["value"].rename(key_variable_data).to_dataset(name="skill_score") # model = None # # import pdb; pdb.set_trace() # logger.info("Stats are 2D, so replacing obs with stats and model with None.") # Combine across key_variable in case there was a list of inputs obss.append(obs) models.append(model) statss.append(stats) key_variable_datas.append(key_variable_data) # combine list of outputs in the case there is more than one key variable if len(obss) > 1: # import pdb; pdb.set_trace() # if both key variables are in the dataset both times just take one # or could check to see if both key variables are in the first dataset if obss[0].equals(obss[1]): obs = obss[0] else: if isinstance(obs, xr.Dataset): obs = xr.merge(obss) if models[0].equals(models[1]): model = models[0] else: if isinstance(model, xr.Dataset): model = xr.merge(models) # # assume one key variable in each model output # if all( # [ # len(cf_xarray.accessor._get_all(model, key)) > 0 # for model, key in zip(models, key_variable_list) # ] # ): # # if len(cf_xarray.accessor._get_all(models[0], key_variable_list[0])) > 0 and : # model = xr.merge(models) # else: # raise NotImplementedError # leave stats as a list stats = statss # if there was always just one key variable for this run, do nothing since the variables are # already available correctly named else: pass # # currently title is being set in plot.selection # if plot_count_title: # title = f"{count}: {source_name}" # else: # title = f"{source_name}" if ( not skip_dataset and (not figname.is_file() or override_plot) and not model_only ): fig = plot.selection( obs, model, cat[source_name].metadata["featuretype"], key_variable_datas, source_name, stats, figname, plot_description, vocab_labels, xcmocean_options=xcmocean_options, **kwargs_plot, ) msg = f"Made plot for {source_name}\n." logger.info(msg) count += 1 # map of model domain with data locations if plot_map: if len(maps) > 0: try: figname = paths.OUT_DIR / "map.png" plot.map.plot_map(np.asarray(maps), figname, p=p1, **kwargs_map) except ModuleNotFoundError: pass else: logger.warning("Not plotting map since no datasets to plot.") logger.info( "Finished analysis. Find plots, stats summaries, and log in %s.", str(paths.PROJ_DIR), ) # just have option for returning info for testing and if dealing with # a single source if len(maps) == 1 and return_fig: # model output, processed data, processed model, stats, fig return fig
# else: # plt.close(fig)