Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sa_ct_interp and tracer_ct_interp functions to GSW-Python #185

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions gsw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._fixed_wrapped_ufuncs import * # noqa
from .conversions import t90_from_t68
from .geostrophy import * # noqa
from .interpolation import * # noqa
from .stability import * # noqa
from .utility import * # noqa

Expand Down
4 changes: 2 additions & 2 deletions gsw/geostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def geo_strf_dyn_height(SA, CT, p, p_ref=0, axis=0, max_dp=1.0,
If any pressure interval in the input p exceeds max_dp, the dynamic
height will be calculated after interpolating to a grid with this
spacing.
interp_method : string {'pchip', 'linear'}
interp_method : string {'mrst', 'pchip', 'linear'}
Interpolation algorithm.

Returns
Expand All @@ -48,7 +48,7 @@ def geo_strf_dyn_height(SA, CT, p, p_ref=0, axis=0, max_dp=1.0,
in an isobaric surface, relative to the reference surface.

"""
interp_methods = {'pchip' : 2, 'linear' : 1}
interp_methods = {'mrst' : 3, 'pchip' : 2, 'linear' : 1}
if interp_method not in interp_methods:
raise ValueError(f'interp_method must be one of {interp_methods.keys()}')
if SA.shape != CT.shape:
Expand Down
177 changes: 177 additions & 0 deletions gsw/interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Functions for vertical interpolation.
"""

import numpy as np

from . import _gsw_ufuncs
from ._utilities import indexer, match_args_return

__all__ = ['sa_ct_interp',
'tracer_ct_interp',
]

@match_args_return
def sa_ct_interp(SA, CT, p, p_i, axis=0):
"""
Interpolates vertical casts of values of Absolute Salinity
and Conservative Temperature to the arbitrary pressures p_i.

Parameters
----------
SA : array-like
Absolute Salinity, g/kg
CT : array-like
Conservative Temperature (ITS-90), degrees C
p : array-like
Sea pressure (absolute pressure minus 10.1325 dbar), dbar
p_i : array-like
Sea pressure to interpolate on, dbar
axis : int, optional, default is 0
The index of the pressure dimension in SA and CT.


Returns
-------
SA_i : array
Values of SA interpolated to p_i along the specified axis.
CT_i : array
Values of CT interpolated to p_i along the specified axis.

"""
if SA.shape != CT.shape:
raise ValueError(f'Shapes of SA and CT must match; found {SA.shape} and {CT.shape}')
if p.ndim != p_i.ndim:
raise ValueError(f'p and p_i must have the same number of dimensions;\n'
f' found {p.ndim} versus {p_i.ndim}')
if p.ndim == 1 and SA.ndim > 1:
if len(p) != SA.shape[axis]:
raise ValueError(
f'With 1-D p, len(p) must be SA.shape[axis];\n'
f' found {len(p)} versus {SA.shape[axis]} on specified axis, {axis}'
)
ind = [np.newaxis] * SA.ndim
ind[axis] = slice(None)
p = p[tuple(ind)]
p_i = p_i[tuple(ind)]
elif p.ndim > 1:
if p.shape != SA.shape:
raise ValueError(f'With {p.ndim}-D p, shapes of p and SA must match;\n'
f'found {p.shape} and {SA.shape}')
if any([p.shape[i] != p_i.shape[i] for i in range(p.ndim) if i != axis]):
raise ValueError(f'With {p.ndim}-D p, p and p_i must have the same dimensions outside of axis {axis};\n'
f' found {p.shape} versus {p_i.shape}')
with np.errstate(invalid='ignore'):
# The need for this context seems to be a bug in np.ma.any.
if np.ma.any(np.ma.diff(np.ma.masked_invalid(p_i), axis=axis) <= 0) \
or np.ma.any(np.ma.diff(np.ma.masked_invalid(p), axis=axis) <= 0):
raise ValueError('p and p_i must be increasing along the specified axis')
p = np.broadcast_to(p, SA.shape)
goodmask = ~(np.isnan(SA) | np.isnan(CT) | np.isnan(p))
SA_i = np.empty(p_i.shape, dtype=float)
CT_i = np.empty(p_i.shape, dtype=float)
SA_i.fill(np.nan)
CT_i.fill(np.nan)

try:
order = 'F' if SA.flags.fortran else 'C'
except AttributeError:
order = 'C' # e.g., xarray DataArray doesn't have flags
for ind in indexer(SA.shape, axis, order=order):
# this is needed to support xarray inputs for numpy < 1.23
igood = np.asarray(goodmask[ind])
pgood = p[ind][igood]
pi = p_i[ind]
# There must be at least 2 non-NaN values for interpolation
if len(pgood) > 2:
sa = SA[ind][igood]
ct = CT[ind][igood]
sai, cti = _gsw_ufuncs.sa_ct_interp(sa, ct, pgood, pi)
SA_i[ind] = sai
CT_i[ind] = cti

return (SA_i, CT_i)

@match_args_return
def tracer_ct_interp(tracer, CT, p, p_i, factor=9., axis=0):
"""
Interpolates vertical casts of values of a tracer
and Conservative Temperature to the arbitrary pressures p_i.

Parameters
----------
tracer : array-like
tracer
CT : array-like
Conservative Temperature (ITS-90), degrees C
p : array-like
Sea pressure (absolute pressure minus 10.1325 dbar), dbar
p_i : array-like
Sea pressure to interpolate on, dbar
factor: float, optional, default is 9.
Ratio between the ranges of Conservative Temperature
and tracer in the world ocean.
axis : int, optional, default is 0
The index of the pressure dimension in tracer and CT.


Returns
-------
tracer_i : array
Values of tracer interpolated to p_i along the specified axis.
CT_i : array
Values of CT interpolated to p_i along the specified axis.

"""
if tracer.shape != CT.shape:
raise ValueError(f'Shapes of tracer and CT must match; found {tracer.shape} and {CT.shape}')
if p.ndim != p_i.ndim:
raise ValueError(f'p and p_i must have the same number of dimensions;\n'
f' found {p.ndim} versus {p_i.ndim}')
if p.ndim == 1 and tracer.ndim > 1:
if len(p) != tracer.shape[axis]:
raise ValueError(
f'With 1-D p, len(p) must be tracer.shape[axis];\n'
f' found {len(p)} versus {tracer.shape[axis]} on specified axis, {axis}'
)
ind = [np.newaxis] * tracer.ndim
ind[axis] = slice(None)
p = p[tuple(ind)]
p_i = p_i[tuple(ind)]
elif p.ndim > 1:
if p.shape != tracer.shape:
raise ValueError(f'With {p.ndim}-D p, shapes of p and tracer must match;\n'
f'found {p.shape} and {tracer.shape}')
if any([p.shape[i] != p_i.shape[i] for i in range(p.ndim) if i != axis]):
raise ValueError(f'With {p.ndim}-D p, p and p_i must have the same dimensions outside of axis {axis};\n'
f' found {p.shape} versus {p_i.shape}')
with np.errstate(invalid='ignore'):
# The need for this context seems to be a bug in np.ma.any.
if np.ma.any(np.ma.diff(np.ma.masked_invalid(p_i), axis=axis) <= 0) \
or np.ma.any(np.ma.diff(np.ma.masked_invalid(p), axis=axis) <= 0):
raise ValueError('p and p_i must be increasing along the specified axis')
p = np.broadcast_to(p, tracer.shape)
goodmask = ~(np.isnan(tracer) | np.isnan(CT) | np.isnan(p))
tracer_i = np.empty(p_i.shape, dtype=float)
CT_i = np.empty(p_i.shape, dtype=float)
tracer_i.fill(np.nan)
CT_i.fill(np.nan)

try:
order = 'F' if tracer.flags.fortran else 'C'
except AttributeError:
order = 'C' # e.g., xarray DataArray doesn't have flags
for ind in indexer(tracer.shape, axis, order=order):
# this is needed to support xarray inputs for numpy < 1.23
igood = np.asarray(goodmask[ind])
pgood = p[ind][igood]
pi = p_i[ind]
# There must be at least 2 non-NaN values for interpolation
if len(pgood) > 2:
tr = tracer[ind][igood]
ct = CT[ind][igood]
tri, cti = _gsw_ufuncs.tracer_ct_interp(tr, ct, pgood, pi, factor)
tracer_i[ind] = tri
CT_i[ind] = cti

return (tracer_i, CT_i)
13 changes: 12 additions & 1 deletion gsw/tests/test_geostrophy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import numpy as np
from numpy.testing import assert_almost_equal, assert_array_equal
from numpy.testing import assert_almost_equal, assert_array_equal, assert_allclose

import gsw
from gsw._utilities import Bunch
Expand Down Expand Up @@ -102,3 +102,14 @@ def test_pz_roundtrip():
zz = gsw.z_from_p(p, 30, 0.5, 0.25)
assert_almost_equal(z, zz)

def test_dyn_height_mrst():
"""
Tests the MRST-PCHIP interpolation method.
"""
p = cv.p_chck_cast
CT = cv.CT_chck_cast
SA = cv.SA_chck_cast
pr = cv.pr
strf = gsw.geo_strf_dyn_height(SA, CT, p, p_ref=pr, interp_method='mrst')

assert_allclose(strf, cv.geo_strf_dyn_height, rtol=0, atol=cv.geo_strf_dyn_height_ca)
29 changes: 29 additions & 0 deletions gsw/tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

import numpy as np
from numpy.testing import assert_allclose

import gsw
from gsw._utilities import Bunch

root_path = os.path.abspath(os.path.dirname(__file__))

cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz')))

def test_sa_ct_interp():
p = cv.p_chck_cast
CT = cv.CT_chck_cast
SA = cv.SA_chck_cast
p_i = np.repeat(cv.p_i[:, np.newaxis], p.shape[1], axis=1)
SA_i, CT_i = gsw.sa_ct_interp(SA, CT, p, p_i)
assert_allclose(SA_i, cv.SAi_SACTinterp, rtol=0, atol=cv.SAi_SACTinterp_ca)
assert_allclose(CT_i, cv.CTi_SACTinterp, rtol=0, atol=cv.CTi_SACTinterp_ca)

def test_tracer_ct_interp():
p = cv.p_chck_cast
CT = cv.CT_chck_cast
tracer = cv.SA_chck_cast
p_i = np.repeat(cv.p_i[:, np.newaxis], p.shape[1], axis=1)
tracer_i, CT_i = gsw.tracer_ct_interp(tracer, CT, p, p_i)
assert_allclose(tracer_i, cv.traceri_tracerCTinterp, rtol=0, atol=cv.traceri_tracerCTinterp_ca)
assert_allclose(CT_i, cv.CTi_SACTinterp, rtol=0, atol=cv.CTi_SACTinterp_ca)
Loading
Loading