Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analytic sampling for conditional posterior instances trained with MDNs. #458

Merged
merged 22 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aaf7316
Added analytic sampling of conditional posterior instances trained wi…
jnsbck Mar 29, 2021
0600280
Adressing comments in #458: reformated code snippets in docstrings. R…
jnsbck Mar 30, 2021
0978d7d
Adressing comments in #458: Sampling nowvia Cholesky decomposition.
jnsbck Mar 30, 2021
630b7ca
Added Missing bracket.
jnsbck Mar 30, 2021
d74fdae
Intermediate step: sample_mog and log_prob now compatible with MDNPos…
jnsbck Mar 31, 2021
501bf68
Replaced with .
jnsbck Apr 5, 2021
2910bf0
Replaced with mulnormpdf with log_prob_mog.
jnsbck Apr 5, 2021
aa4e1c5
Merge branch 'analytic_mdn_conditioning' of https://github.com/JBeckU…
jnsbck Apr 5, 2021
06c3dc5
Conditional sampling and evaluation for MDNs integrated into DirectPo…
jnsbck Apr 18, 2021
383c803
Removed accidentally added files from pull request
jnsbck Apr 18, 2021
a720a93
Small rewrites. Got rid of cholesky decomps.
jnsbck Apr 24, 2021
a0e8800
renamed vars.
jnsbck Apr 24, 2021
660a9ce
Adressing critique: Fixed Docstrings. Restructure. Comments. Warnings.
jnsbck May 13, 2021
c7e0904
looking for bug.
jnsbck May 13, 2021
ab20b6b
small fix.
jnsbck May 13, 2021
794a74c
BUGFIX import of DirectPosterior Type caused Problems.
jnsbck Jun 13, 2021
293418f
Added unittest. Works, but not sure how to integrate it properly.
jnsbck Jun 13, 2021
449f1b7
fixed docstring formatting. removed obsolete line.
jnsbck Jun 13, 2021
8d63816
ran black, fixed formating, moved test to linearGaussian_snpe_test.
jnsbck Jul 6, 2021
a39d186
Merge branch 'main' into analytic_mdn_conditioning
jnsbck Jul 18, 2021
43f0c51
forgot arg for snpe_method in test_mdn_conditional_density.
jnsbck Jul 18, 2021
f989443
fixed line that was accidently overwritten during resolve of conflicts.
jnsbck Jul 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 95 additions & 18 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

import numpy as np
import torch
from torch import Tensor, log, nn
from warnings import warn

from sbi import utils as utils
from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.types import ScalarFloat, Shape
from sbi.utils import del_entries
Expand All @@ -16,6 +18,7 @@
ensure_theta_batched,
ensure_x_batched,
)
from sbi.utils.conditional_density import extract_and_transform_mog, condition_mog


class DirectPosterior(NeuralPosterior):
Expand Down Expand Up @@ -414,17 +417,96 @@ def sample_conditional(
Returns:
Samples from conditional posterior.
"""
if type(self.net._distribution) is mdn:
num_samples = torch.Size(sample_shape).numel()

return super().sample_conditional(
PotentialFunctionProvider(),
sample_shape,
condition,
dims_to_sample,
x,
show_progress_bars,
mcmc_method,
mcmc_parameters,
)
logits, means, precfs, _ = extract_and_transform_mog(self, x)
logits, means, precfs, _ = condition_mog(
self._prior, condition, dims_to_sample, logits, means, precfs
)

# Currently difficult to integrate `sample_posterior_within_prior`
warn(
"Sampling MoG analytically. Some of the samples might not be within the prior support!"
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
)
samples = mdn.sample_mog(num_samples, logits, means, precfs)
return samples.detach().reshape((*sample_shape, -1))

else:
return super().sample_conditional(
PotentialFunctionProvider(),
sample_shape,
condition,
dims_to_sample,
x,
show_progress_bars,
mcmc_method,
mcmc_parameters,
)

def log_prob_conditional(
self,
theta: Tensor,
condition: Tensor,
dims_to_evaluate: List[int],
x: Optional[Tensor] = None,
) -> Tensor:
"""Evaluates the conditional posterior probability of a MDN at a context x for a value theta given a condition.

This function only works for MDN based posteriors, becuase evaluation is done analytically. For all other density estimators a `NotImplementedError` will be
raised!
jnsbck marked this conversation as resolved.
Show resolved Hide resolved

Args:
theta: Parameters $\theta$.
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
i.e. it could e.g. be a sample from the posterior distribution.
The entries at all `dims_to_sample` will be ignored.
dims_to_evaluate: Which dimensions to evaluate the sample for.
The dimensions not specified in `dims_to_evaluate` will be fixed to values given in `condition`.
x: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x` passed to `set_default_x()`.

Returns:
log_prob: `(len(θ),)`-shaped normalized (!) log posterior probability
$\log p(\theta|x) for θ in the support of the prior, -∞ (corresponding
to 0 probability) outside.
"""

if type(self.net._distribution) == mdn:
logits, means, precfs, sumlogdiag = extract_and_transform_mog(self, x)
logits, means, precfs, sumlogdiag = condition_mog(
self._prior, condition, dims_to_evaluate, logits, means, precfs
)

batch_size, dim = theta.shape
prec = precfs.transpose(3, 2) @ precfs

self.net.eval() # leakage correction requires eval mode

if dim != len(dims_to_evaluate):
X = X[:, dims_to_evaluate]

jnsbck marked this conversation as resolved.
Show resolved Hide resolved
# Implementing leakage correction is difficult for conditioned MDNs,
# because samples from self i.e. the full posterior are used rather
# then from the new, conditioned posterior.
warn("Probabilities are not adjusted for leakage.")

log_prob = mdn.log_prob_mog(
theta,
logits.repeat(batch_size, 1),
means.repeat(batch_size, 1, 1),
prec.repeat(batch_size, 1, 1, 1),
sumlogdiag.repeat(batch_size, 1),
)

self.net.train(True)
return log_prob.detach()

else:
raise NotImplementedError(
"This functionality is only available for MDN based posteriors."
)

def map(
self,
Expand Down Expand Up @@ -509,11 +591,7 @@ class PotentialFunctionProvider:
"""

def __call__(
self,
prior,
posterior_nn: nn.Module,
x: Tensor,
mcmc_method: str,
self, prior, posterior_nn: nn.Module, x: Tensor, mcmc_method: str
) -> Callable:
"""Return potential function.

Expand Down Expand Up @@ -549,8 +627,7 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat:

with torch.set_grad_enabled(False):
target_log_prob = self.posterior_nn.log_prob(
inputs=theta.to(self.x.device),
context=x_repeated,
inputs=theta.to(self.x.device), context=x_repeated
)
is_within_prior = torch.isfinite(self.prior.log_prob(theta))
target_log_prob[~is_within_prior] = -float("Inf")
Expand Down
2 changes: 2 additions & 0 deletions sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sbi.utils.conditional_density import (
conditional_corrcoeff,
eval_conditional_density,
extract_and_transform_mog,
condition_mog,
)
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.io import get_data_root, get_log_root, get_project_root
Expand Down
138 changes: 137 additions & 1 deletion sbi/utils/conditional_density.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.


from typing import Any, Callable, List, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor

from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.torchutils import BoxUniform


def eval_conditional_density(
Expand Down Expand Up @@ -321,3 +322,138 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor:
"""
limits_diff = torch.prod(limits[:, 1] - limits[:, 0])
return probs * probs.numel() / limits_diff / torch.sum(probs)


def extract_and_transform_mog(
posterior: "DirectPosterior", context: Tensor = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Extracts the Mixture of Gaussians (MoG) parameters
from an MDN based DirectPosterior at either the default x or input x.

Args:
posterior: DirectPosterior instance.
context: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x` passed to `set_default_x()`.

Returns:
norm_logits: Normalised log weights of the underyling MoG.
(batch_size, n_mixtures)
means_transformed: Recentred and rescaled means of the underlying MoG
(batch_size, n_mixtures, n_dims)
precfs_transformed: Rescaled precision factors of the underlying MoG.
(batch_size, n_mixtures, n_dims, n_dims)
sumlogdiag: Sum of the log of the diagonal of the precision factors
of the new conditional distribution. (batch_size, n_mixtures)
"""

# extract and rescale means, mixture componenets and covariances
nn = posterior.net
dist = nn._distribution

if context == None:
encoded_x = nn._embedding_net(posterior.default_x)
else:
encoded_x = nn._embedding_net(context)

logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x)
norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)

scale = nn._transform._transforms[0]._scale
shift = nn._transform._transforms[0]._shift

means_transformed = (means - shift) / scale

A = scale * torch.eye(means_transformed.shape[2])
precfs_transformed = A @ precfs

sumlogdiag = torch.sum(
torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2
)

return norm_logits, means_transformed, precfs_transformed, sumlogdiag


def condition_mog(
prior: "Prior",
condition: Tensor,
dims: List[int],
logits: Tensor,
means: Tensor,
precfs: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Finds the conditional distribution p(X|Y) for a GMM.
jnsbck marked this conversation as resolved.
Show resolved Hide resolved

Args:
prior: Prior Distribution. Used to check if condition within support.
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
i.e. it could e.g. be a sample from the posterior distribution.
The entries at all `dims_to_sample` will be ignored.
dims_to_sample: Which dimensions to sample from. The dimensions not
specified in `dims_to_sample` will be fixed to values given in
`condition`.
logits: Log weights of the MoG. (batch_size, n_mixtures)
means: Means of the MoG. (batch_size, n_mixtures, n_dims)
precfs: Precision factors of the MoG.
(batch_size, n_mixtures, n_dims, n_dims)

Raises:
ValueError: The chosen condition is not within the prior support.

Returns:
logits: Log weights of the conditioned MoG. (batch_size, n_mixtures)
means: Means of the conditioned MoG. (batch_size, n_mixtures, n_dims)
precfs_xx: Precision factors of the MoG.
(batch_size, n_mixtures, n_dims, n_dims)
sumlogdiag: Sum of the log of the diagonal of the precision factors
of the new conditional distribution. (batch_size, n_mixtures)
"""

support = prior.support

n_mixtures, n_dims = means.shape[1:]

mask = torch.zeros(n_dims, dtype=bool)
mask[dims] = True

# check whether the condition is within the prior bounds
if type(prior) is torch.distributions.uniform.Uniform or type(prior) is BoxUniform:
cond_ubound = support.upper_bound[~mask]
cond_lbound = support.lower_bound[~mask]
within_support = torch.logical_and(
cond_lbound <= condition[:, ~mask], cond_ubound >= condition[:, ~mask]
)
if ~torch.all(within_support):
raise ValueError("The chosen condition is not within the prior support.")

y = condition[:, ~mask]
mu_x = means[:, :, mask]
mu_y = means[:, :, ~mask]

precfs_xx = precfs[:, :, mask]
precfs_xx = precfs_xx[:, :, :, mask]
precs_xx = precfs_xx.transpose(3, 2) @ precfs_xx

precfs_yy = precfs[:, :, ~mask]
precfs_yy = precfs_yy[:, :, :, ~mask]
precs_yy = precfs_yy.transpose(3, 2) @ precfs_yy

precs = precfs.transpose(3, 2) @ precfs
precs_xy = precs[:, :, mask]
precs_xy = precs_xy[:, :, :, ~mask]

means = mu_x - (
torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_mixtures, -1, 1)
).view(1, n_mixtures, -1)

diags = torch.diagonal(precfs_yy, dim1=2, dim2=3)
sumlogdiag_yy = torch.sum(torch.log(diags), dim=2)
log_prob = mdn.log_prob_mog(y, torch.zeros((1, 1)), mu_y, precs_yy, sumlogdiag_yy)

# Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist.
new_mcs = torch.exp(logits + log_prob)
new_mcs = new_mcs / new_mcs.sum()
logits = torch.log(new_mcs)

sumlogdiag = torch.sum(torch.log(torch.diagonal(precfs_xx, dim1=2, dim2=3)), dim=2)
return logits, means, precfs_xx, sumlogdiag
Loading