"""Surface plot."""
import pathlib
from typing import Optional, Union
import cf_pandas
import cf_xarray
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from pandas import DataFrame
from xarray import Dataset
import ocean_model_skill_assessor as omsa
fs = 14
fs_title = 16
[docs]
def plot(
obs: Union[DataFrame, Dataset],
model: Dataset,
xname: str,
yname: str,
zname: str,
suptitle: str,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
zlabel: Optional[str] = None,
model_title: str = "Model",
along_transect_distance: bool = False,
plot_on_map: bool = False,
proj=None,
extent=None,
kind="pcolormesh",
nsubplots: int = 3,
figname: Union[str, pathlib.Path] = "figure.png",
dpi: int = 100,
figsize=(15, 6),
return_plot: bool = False,
invert_yaxis: bool = False,
make_Z_negative=None,
**kwargs,
):
"""Plot scatter or surface plot.
For featuretype of trajectoryProfile or timeSeriesProfile.
Parameters
----------
obs: DataFrame, Dataset
Observation time series
mode: Dataset
Model time series to compare against obs
xname : str
Name of variable to plot on x-axis when interpreted with cf-xarray and cf-pandas
yname : str
Name of variable to plot on y-axis when interpreted with cf-xarray and cf-pandas
zname : str
Name of variable to plot with color when interpreted with cf-xarray and cf-pandas
suptitle: str, optional
Title for plot, over all the subplots.
xlabel: str, optional
Label for x-axis.
ylabel: str, optional
Label for y-axis.
zlabel: str, optional
Label for colorbar.
along_transect_distance:
Set to True to calculate the along-transect distance in km from the longitude and latitude, which must be interpretable through cf-pandas or cf-xarray as "longitude" and "latitude".
kind: str
Can be "pcolormesh" for surface plot or "scatter" for scatter plot.
nsubplots : int, optional
Number of subplots. Might always be 3, and that is the default.
figname: str
Filename for figure (as absolute or relative path).
dpi: int, optional
dpi for figure. Default is 100.
figsize : tuple, optional
Figsize to pass to `plt.figure()`. Default is (15,5).
return_plot : bool
If True, return plot. Use for testing.
"""
if "override_plot" in kwargs:
kwargs.pop("override_plot")
# want obs and data as DataFrames
if kind == "scatter":
if isinstance(obs, xr.Dataset):
obs = obs.to_dataframe()
if isinstance(model, xr.Dataset):
model = model.to_dataframe().reset_index()
if nsubplots == 3:
# using .values on obs prevents name clashes for time and depth
model["diff"] = obs.cf[zname].values - model.cf[zname]
# want obs and data as Datasets
elif kind == "pcolormesh":
if isinstance(obs, pd.DataFrame):
obs = obs.to_xarray()
obs = obs.assign_coords(
{obs.cf["T"].name: obs.cf["T"], model.cf["Z"].name: obs.cf["Z"]}
)
if isinstance(model, pd.DataFrame):
model = model.to_xarray()
if nsubplots == 3:
# using .values on obs prevents name clashes for time and depth
model["diff"] = obs.cf[zname].values - model.cf[zname]
# model["diff"] = obs.cf[zname].values - model.cf[zname]
model["diff"].attrs = {}
else:
raise ValueError("`kind` should be scatter or pcolormesh.")
if along_transect_distance:
obs["distance"] = omsa.utils.calculate_distance(
obs.cf["longitude"], obs.cf["latitude"]
)
if isinstance(model, xr.Dataset):
model["distance"] = (
model.cf["T"].name,
omsa.utils.calculate_distance(
model.cf["longitude"], model.cf["latitude"]
),
)
model = model.assign_coords({"distance": model["distance"]})
elif isinstance(model, pd.DataFrame):
model["distance"] = omsa.utils.calculate_distance(
model.cf["longitude"], model.cf["latitude"]
)
# diff = diff.assign_coords({"distance": distance})
# for first two plots
# vmin, vmax, cmap, extend, levels, norm
if nsubplots > 1:
cmap_params = xr.plot.utils._determine_cmap_params(
np.vstack((obs.cf[zname].values, model.cf[zname].values)), robust=True
)
else:
# I need to fix this in cf-xarray but haven't yet so instead this
# workaround for DataArray
# only works if the key is the actual variable
if isinstance(obs, xr.DataArray):
cmap_params = xr.plot.utils._determine_cmap_params(obs.values, robust=True)
else:
cmap_params = xr.plot.utils._determine_cmap_params(
obs.cf[zname].values, robust=True
)
if "vmin" in kwargs:
cmap_params.update({"vmin": kwargs["vmin"]})
if "vmax" in kwargs:
cmap_params.update({"vmax": kwargs["vmax"]})
if nsubplots == 3:
# including `center=0` forces this to return the diverging colormap option
cmap_params_diff = xr.plot.utils._determine_cmap_params(
model["diff"].values, robust=True, center=0
)
# sharex and sharey removed the y ticklabels so don't use.
# maybe don't work with layout="constrained"
if plot_on_map:
if proj is None:
import cartopy
proj = cartopy.crs.Mercator()
subplot_kw = dict(projection=proj, frameon=False)
else:
subplot_kw = {}
if make_Z_negative is not None:
if make_Z_negative == "obs":
if (obs[obs.cf["Z"].notnull()].cf["Z"] > 0).all():
obs[obs.cf["Z"].name] = -obs.cf["Z"]
elif make_Z_negative == "model":
if (model[model.cf["Z"].notnull()].cf["Z"] > 0).all():
model[model.cf["Z"].name] = -model.cf["Z"]
fig, axes = plt.subplots(
1,
nsubplots,
figsize=figsize,
layout="constrained",
subplot_kw=subplot_kw,
)
# sharex=True, sharey=True)
# setup
xarray_kwargs = dict(
add_labels=False,
add_colorbar=False,
)
pandas_kwargs = dict(colorbar=False)
kwargs.update({key: cmap_params.get(key) for key in ["vmin", "vmax", "cmap"]})
if nsubplots == 1:
ax = axes
else:
ax = axes[0]
if plot_on_map:
omsa.plot.map.setup_ax(
ax, left_labels=True, bottom_labels=True, top_labels=False, fontsize=12
)
kwargs["transform"] = omsa.plot.map.pc
if extent is not None:
ax.set_extent(extent)
if kind == "scatter":
obs.plot(
kind=kind,
x=obs.cf[xname].name,
y=obs.cf[yname].name,
c=obs.cf[zname].name,
ax=ax,
**kwargs,
**pandas_kwargs,
)
elif kind == "pcolormesh":
# I need to fix this in cf-xarray but haven't yet so instead this
# workaround for DataArray
# only works if the key is the actual variable
if isinstance(obs, xr.DataArray):
obs.cf.plot.pcolormesh(x=xname, y=yname, ax=ax, **kwargs, **xarray_kwargs)
else:
obs.cf[zname].cf.plot.pcolormesh(
x=xname, y=yname, ax=ax, **kwargs, **xarray_kwargs
)
ax.set_title("Observation", fontsize=fs_title)
ax.set_ylabel(ylabel, fontsize=fs)
ax.set_xlabel(xlabel, fontsize=fs)
ax.tick_params(axis="both", labelsize=fs)
if invert_yaxis:
ax.invert_yaxis()
if nsubplots > 1:
# plot model
if plot_on_map:
omsa.plot.map.setup_ax(
axes[1],
left_labels=False,
bottom_labels=True,
top_labels=False,
fontsize=12,
)
if extent is not None:
axes[1].set_extent(extent)
if kind == "scatter":
model.plot(
kind=kind,
x=model.cf[xname].name,
y=model.cf[yname].name,
c=model.cf[zname].name,
ax=axes[1],
**kwargs,
**pandas_kwargs,
)
elif kind == "pcolormesh":
model.cf[zname].cf.plot.pcolormesh(
x=xname, y=yname, ax=axes[1], **kwargs, **xarray_kwargs
)
axes[1].set_title(model_title, fontsize=fs_title)
axes[1].set_xlabel(xlabel, fontsize=fs)
axes[1].set_ylabel("")
axes[1].set_xlim(axes[0].get_xlim())
axes[1].set_ylim(axes[0].get_ylim())
# save space by not relabeling y axis
axes[1].set_yticklabels("")
axes[1].tick_params(axis="x", labelsize=fs)
if nsubplots > 2:
# plot difference (assume Dataset)
# for last (diff) plot
kwargs.update(
{key: cmap_params_diff.get(key) for key in ["vmin", "vmax", "cmap"]}
)
if plot_on_map:
omsa.plot.map.setup_ax(
axes[2],
left_labels=False,
bottom_labels=True,
top_labels=False,
fontsize=12,
)
if extent is not None:
axes[2].set_extent(extent)
if kind == "scatter":
model.plot(
kind=kind,
x=model.cf[xname].name,
y=model.cf[yname].name,
c="diff",
ax=axes[2],
**kwargs,
**pandas_kwargs,
)
elif kind == "pcolormesh":
model["diff"].cf.plot.pcolormesh(
x=xname, y=yname, ax=axes[2], **kwargs, **xarray_kwargs
)
# CAN SEE 3 PLOTS
axes[2].set_title("Obs - Model", fontsize=fs_title)
axes[2].set_xlabel(xlabel, fontsize=fs)
axes[2].set_ylabel("")
if not plot_on_map:
axes[2].set_xlim(axes[0].get_xlim())
axes[2].set_ylim(axes[0].get_ylim())
# commenting this out 9/13/24 bc it made the third plot not match the first 2
# axes[2].set_ylim(obs.cf[yname].min(), obs.cf[yname].max())
axes[2].set_yticklabels("")
axes[2].tick_params(axis="x", labelsize=fs)
# import pdb; pdb.set_trace()
if nsubplots > 2:
# two colorbars, 1 for obs and model and 1 for diff
# https://matplotlib.org/stable/tutorials/colors/colorbar_only.html#sphx-glr-tutorials-colors-colorbar-only-py
norm = mpl.colors.Normalize(vmin=cmap_params["vmin"], vmax=cmap_params["vmax"])
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_params["cmap"])
cbar1 = fig.colorbar(
mappable, ax=axes[:2], orientation="horizontal", shrink=0.5
)
cbar1.set_label(zlabel, fontsize=fs)
cbar1.ax.tick_params(axis="both", labelsize=fs)
norm = mpl.colors.Normalize(
vmin=cmap_params_diff["vmin"], vmax=cmap_params_diff["vmax"]
)
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_params_diff["cmap"])
cbar2 = fig.colorbar(
mappable, ax=axes[2], orientation="horizontal"
) # shrink=0.6)
cbar2.set_label(f"{zlabel} difference", fontsize=fs)
cbar2.ax.tick_params(axis="both", labelsize=fs)
elif nsubplots == 1:
norm = mpl.colors.Normalize(vmin=cmap_params["vmin"], vmax=cmap_params["vmax"])
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_params["cmap"])
cbar1 = fig.colorbar(mappable, ax=axes, orientation="horizontal", shrink=0.5)
cbar1.set_label(zlabel, fontsize=fs)
cbar1.ax.tick_params(axis="both", labelsize=fs)
fig.suptitle(suptitle, wrap=True, fontsize=fs_title) # , loc="left")
fig.savefig(figname, dpi=dpi) # , bbox_inches="tight")
if return_plot:
return fig