Skip to content

Commit

Permalink
Merge pull request #126 from dngoldberg/branch_obs_sensitivity
Browse files Browse the repository at this point in the history
i did pytests and all seemed ok
  • Loading branch information
dngoldberg authored May 7, 2024
2 parents ee3d698 + c0b5e6a commit 7171a83
Show file tree
Hide file tree
Showing 8 changed files with 419 additions and 57 deletions.
10 changes: 5 additions & 5 deletions example_cases/ismipc_30x30/ismipc_30x30.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ cp output_momsolve/ismipc_U_obs.h5 input/

#Run each phase of the model in turn
RUN_DIR=$FENICS_ICE_BASE_DIR/runs/
python $RUN_DIR/run_inv.py ismipc_30x30.toml
python $RUN_DIR/run_forward.py ismipc_30x30.toml
python $RUN_DIR/run_eigendec.py ismipc_30x30.toml
python $RUN_DIR/run_errorprop.py ismipc_30x30.toml
python $RUN_DIR/run_invsigma.py ismipc_30x30.toml
#python $RUN_DIR/run_inv.py ismipc_30x30.toml
#python $RUN_DIR/run_forward.py ismipc_30x30.toml
#python $RUN_DIR/run_eigendec.py ismipc_30x30.toml
#python $RUN_DIR/run_errorprop.py ismipc_30x30.toml
mpirun -n 4 python $RUN_DIR/run_obs_sens_prop.py ismipc_30x30.toml
13 changes: 10 additions & 3 deletions example_cases/ismipc_30x30/ismipc_30x30.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ random_seed = 0
[mesh]

mesh_filename = "ismip_mesh.xml"
periodic_bc = true
periodic_bc = false

[obs]

Expand Down Expand Up @@ -67,7 +67,8 @@ sliding_law = 'linear' #budd, linear
[momsolve.picard_params]
nonlinear_solver = "newton"
[momsolve.picard_params.newton_solver]
linear_solver = "umfpack"
linear_solver = "cg"
preconditioner = "hypre_amg"
maximum_iterations = 200
absolute_tolerance = 1.0e-0
relative_tolerance = 1.0e-3
Expand All @@ -77,7 +78,9 @@ error_on_nonconvergence = false
[momsolve.newton_params]
nonlinear_solver = "newton"
[momsolve.newton_params.newton_solver]
linear_solver = "umfpack"
#linear_solver = "umfpack"
linear_solver = "cg"
preconditioner = "hypre_amg"
maximum_iterations = 25
absolute_tolerance = 1.0e-7
relative_tolerance = 1.0e-8
Expand Down Expand Up @@ -128,6 +131,10 @@ qoi = 'h2' #or 'vaf'
name = "Bottom Edge"
id = 4

[mass_solve]

use_cg_thickness = true

[testing]

expected_init_alpha = 531.6114524861194
Expand Down
43 changes: 40 additions & 3 deletions fenics_ice/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ def parse(self):
except KeyError:
pass

try:
obs_sens_dict = self.config_dict['obs_sens']
except KeyError:
obs_sens_dict = {}
self.obs_sens = ObsSensCfg(**obs_sens_dict)

try:
mass_solve_dict = self.config_dict['mass_solve']
except KeyError:
mass_solve_dict = {}
self.mass_solve = MassSolveCfg(**mass_solve_dict)


def check_dirs(self):
"""
Check input directory exists & create output dir if necessary.
Expand All @@ -138,13 +151,15 @@ def check_dirs(self):
self.time.phase_name,
self.eigendec.phase_name,
self.error_prop.phase_name,
self.inv_sigma.phase_name]
self.inv_sigma.phase_name,
self.obs_sens.phase_name]

ph_suffix = [self.inversion.phase_suffix,
self.time.phase_suffix,
self.eigendec.phase_suffix,
self.error_prop.phase_suffix,
self.inv_sigma.phase_suffix]
self.inv_sigma.phase_suffix,
self.obs_sens.phase_suffix]

for ph, suff in zip(ph_names, ph_suffix):
out_dir = (outdir / ph / suff)
Expand Down Expand Up @@ -253,6 +268,17 @@ class ErrorPropCfg(ConfigPrinter):
phase_name: str = 'error_prop'
phase_suffix: str = ''


@dataclass(frozen=True)
class ObsSensCfg(ConfigPrinter):
"""
Configuration related to observation sensitivities
"""
qoi: str = 'vaf'
phase_name: str = 'obs_sens'
phase_suffix: str = ''


@dataclass(frozen=True)
class SampleCfg(ConfigPrinter):
"""
Expand Down Expand Up @@ -394,6 +420,18 @@ def __post_init__(self):
assert self.min_thickness >= 0.0


@dataclass(frozen=True)
class MassSolveCfg(ConfigPrinter):
"""
Options for mass balance solver
"""

use_cg_thickness: bool = False

def __post_init__(self):
""" """


@dataclass(frozen=True)
class MomsolveCfg(ConfigPrinter):
"""
Expand Down Expand Up @@ -513,7 +551,6 @@ def __post_init__(self):

for fname in fname_default_suff:
self.set_default_filename(fname, fname_default_suff[fname])
#embed()

@dataclass(frozen=True)
class TimeCfg(ConfigPrinter):
Expand Down
1 change: 0 additions & 1 deletion fenics_ice/inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def write_variable(var, params, name=None, outdir=None, phase_name='', phase_suf
outvar.rename(name, "")
# Prefix the run name
outfname = Path(outdir) / phase_name / phase_suffix / "_".join((params.io.run_name+phase_suffix, name))
#embed()

# Write out output according to user specified format in toml
output_var_format = params.io.output_var_format
Expand Down
74 changes: 41 additions & 33 deletions fenics_ice/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,43 @@
from pathlib import Path
from numpy.random import randn
import logging
from IPython import embed

log = logging.getLogger("fenics_ice")

# Functions for repeated ungridded interpolation
# TODO - this will not handle extrapolation/missing data
# nicely - unfound simplex are returned '-1' which takes the last
# tri.simplices...

# at the moment i have moved these from vel_obs_from_data, so they
# can be called directly from a run script.
# the ismipc test, which calls this function, still seems to perform fine
# but this refactoring may make things less efficient.
def interp_weights(xy, uv, periodic_bc, d=2):
"""Compute the nearest vertices & weights (for reuse)"""
from scipy.spatial import Delaunay
tri = Delaunay(xy)
simplex = tri.find_simplex(uv)

if not np.all(simplex >= 0):
if not periodic_bc:
log.warning("Some points missing in interpolation "
"of velocity obs to function space.")
else:
log.warning("Some points missing in interpolation "
"of velocity obs to function space.")

vertices = np.take(tri.simplices, simplex, axis=0)
temp = np.take(tri.transform, simplex, axis=0)
delta = uv - temp[:, d]
bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
return vertices, np.hstack((bary, 1 - bary.sum(axis=1,
keepdims=True)))

def interpolate(values, vtx, wts):
"""Bilinear interpolation, given vertices & weights above"""
return np.einsum('nj,nj->n', np.take(values, vtx), wts)

class model:
"""
The 'model' object is the core of any fenics_ice simulation. It handles loading input
Expand Down Expand Up @@ -122,7 +155,10 @@ def init_fields_from_data(self):
self.bed = self.field_from_data("bed", self.Q)
self.bmelt = self.field_from_data("bmelt", self.M, default=0.0)
self.smb = self.field_from_data("smb", self.M, default=0.0)
self.H_np = self.field_from_data("thick", self.M, min_val=min_thick)
if (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
self.H_np = self.field_from_data("thick", self.Qp, min_val=min_thick)
else:
self.H_np = self.field_from_data("thick", self.M, min_val=min_thick)

if self.params.melt.use_melt_parameterisation:

Expand Down Expand Up @@ -308,44 +344,16 @@ def vel_obs_from_data(self):
use_cloud_point=self.params.inversion.use_cloud_point_velocities)
else:
inout.read_vel_obs(infile, model=self)
# Functions for repeated ungridded interpolation
# TODO - this will not handle extrapolation/missing data
# nicely - unfound simplex are returned '-1' which takes the last
# tri.simplices...
def interp_weights(xy, uv, d=2):
"""Compute the nearest vertices & weights (for reuse)"""
from scipy.spatial import Delaunay
tri = Delaunay(xy)
simplex = tri.find_simplex(uv)

if not np.all(simplex >= 0):
if not self.params.mesh.periodic_bc:
log.warning("Some points missing in interpolation "
"of velocity obs to function space.")
else:
log.warning("Some points missing in interpolation "
"of velocity obs to function space.")

vertices = np.take(tri.simplices, simplex, axis=0)
temp = np.take(tri.transform, simplex, axis=0)
delta = uv - temp[:, d]
bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
return vertices, np.hstack((bary, 1 - bary.sum(axis=1,
keepdims=True)))

def interpolate(values, vtx, wts):
"""Bilinear interpolation, given vertices & weights above"""
return np.einsum('nj,nj->n', np.take(values, vtx), wts)

# Grab coordinates of both Lagrangian & DG function spaces
# and compute (once) the interpolating arrays
Q_coords = self.Q.tabulate_dof_coordinates()
M_coords = self.M.tabulate_dof_coordinates()

vtx_Q, wts_Q = interp_weights(self.vel_obs['uv_comp_pts'],
Q_coords)
Q_coords, self.params.mesh.periodic_bc)
vtx_M, wts_M = interp_weights(self.vel_obs['uv_comp_pts'],
M_coords)
M_coords, self.params.mesh.periodic_bc)

# Define new functions to hold results
self.u_obs_Q = Function(self.Q, name="u_obs")
Expand Down Expand Up @@ -378,7 +386,7 @@ def interpolate(values, vtx, wts):
# We need to do the same as above but for cloud point data
# so we can write out a nicer output in the mesh coordinates
vtx_Q_c, wts_Q_c = interp_weights(self.vel_obs['uv_obs_pts'],
Q_coords)
Q_coords, self.params.mesh.periodic_bc)

# Define new functions to hold results
self.u_cloud_Q = Function(self.Q, name="u_obs_cloud")
Expand Down
56 changes: 44 additions & 12 deletions fenics_ice/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import time
import ufl
import weakref
from IPython import embed

log = logging.getLogger("fenics_ice")

Expand Down Expand Up @@ -77,11 +76,26 @@ def interpolation_matrix(x_coords, y_space):
return x_local, P


def Amat_obs_action(P, Rvec, vec_cg, dg_space):
# This function implements the Rvec*P*D action on a P1 function
# where D is a projection into DG space
#
# to be called for each component of velocity
#

test, trial = TestFunction(dg_space), TrialFunction(dg_space)
vec_dg = Function(dg_space)
solve(inner(trial, test) * dx == inner(vec_cg, test) * dx,
vec_dg, solver_parameters={"linear_solver": "lu"})

return Rvec * (P @ vec_dg.vector().get_local())


class ssa_solver:
"""
The ssa_solver object is currently the only kind of fenics_ice 'solver' available.
"""
def __init__(self, model, mixed_space=False):
def __init__(self, model, mixed_space=False, obs_sensitivity=False):

# Enable aggressive compiler options
parameters["form_compiler"]["optimize"] = False
Expand All @@ -93,6 +107,7 @@ def __init__(self, model, mixed_space=False):
self.model.solvers.append(self)
self.params = model.params
self.mixed_space = mixed_space
self.obs_sensitivity = obs_sensitivity

# Mesh/Function Spaces
self.mesh = model.mesh
Expand Down Expand Up @@ -146,10 +161,15 @@ def __init__(self, model, mixed_space=False):
self.U = Function(self.V, name="U")
self.U_np = Function(self.V, name="U_np")
self.Phi = TestFunction(self.V)
self.Ksi = TestFunction(self.M)
self.pTau = TestFunction(self.Qp)

self.trial_H = TrialFunction(self.M)
if not (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
self.trial_H = TrialFunction(self.M)
self.Ksi = TestFunction(self.M)
else:
self.trial_H = TrialFunction(self.Qp)
self.Ksi = TestFunction(self.Qp)


# Facets
self.ff = model.ff
Expand Down Expand Up @@ -607,21 +627,26 @@ def def_thickadv_eq(self):
+ inner(jump(Ksi), jump(0.5 * (dot(U_np, nm) + abs(dot(U_np, nm))) * trial_H))
* dS

# Outflow at boundaries
+ conditional(dot(U_np, nm) > 0, 1.0, 0.0)*inner(Ksi, dot(U_np * trial_H, nm))
* ds

# Inflow at boundaries
+ conditional(dot(U_np, nm) < 0, 1.0, 0.0)*inner(Ksi, dot(U_np * H_init, nm))
* ds

# basal melting
+ bmelt*Ksi*dx

# surface mass balance
- smb*Ksi*dx
)


if not (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
self.thickadv = self.thickadv + (

# Outflow at boundaries
+ conditional(dot(U_np, nm) > 0, 1.0, 0.0)*inner(Ksi, dot(U_np * trial_H, nm))
* ds

# Inflow at boundaries
+ conditional(dot(U_np, nm) < 0, 1.0, 0.0)*inner(Ksi, dot(U_np * H_init, nm))
* ds
)

# # Forward euler
# self.thickadv = (inner(Ksi, ((trial_H - H_np) / dt)) * dx
# - inner(grad(Ksi), U_np * H_np) * dx
Expand Down Expand Up @@ -1303,9 +1328,16 @@ def comp_J_inv(self, verbose=False):
J_v_obs, op=MPI.SUM)
J_v_obs, = J_v_obs

u_std_local = u_std[obs_local]
v_std_local = v_std[obs_local]

self._cached_J_mismatch_data \
= (interp_space,
u_PRP, v_PRP, l_u_obs, l_v_obs, J_u_obs, J_v_obs)
if (self.obs_sensitivity):
self._cached_Amat_vars = \
(P, u_std_local, v_std_local, obs_local, interp_space)

(interp_space,
u_PRP, v_PRP, l_u_obs, l_v_obs, J_u_obs, J_v_obs) = \
self._cached_J_mismatch_data
Expand Down
1 change: 1 addition & 0 deletions runs/run_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_forward(config_file):

# Run the forward model
Q = slvr.timestep(adjoint_flag=1, qoi_func=qoi_func)

# Run the adjoint model, computing gradient of Qoi w.r.t cntrl
dQ_ts = compute_gradient(Q, cntrl) # Isaac 27

Expand Down
Loading

0 comments on commit 7171a83

Please sign in to comment.