diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 43b7f56cd..d91afd752 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -1,16 +1,19 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . + from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple from warnings import warn import numpy as np import torch import torch.distributions.transforms as torch_tf from torch import Tensor, log, nn +from warnings import warn from sbi import utils as utils +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.types import Shape from sbi.utils import del_entries, mcmc_transform, rejection_sample, within_support @@ -19,6 +22,7 @@ batched_first_of_batch, ensure_theta_batched, ) +from sbi.utils.conditional_density import extract_and_transform_mog, condition_mog class DirectPosterior(NeuralPosterior): @@ -471,19 +475,101 @@ def sample_conditional( Returns: Samples from conditional posterior. """ + if type(self.net._distribution) is mdn: + num_samples = torch.Size(sample_shape).numel() - return super().sample_conditional( - PotentialFunctionProvider(), - sample_shape, - condition, - dims_to_sample, - x, - sample_with, - show_progress_bars, - mcmc_method, - mcmc_parameters, - rejection_sampling_parameters, - ) + logits, means, precfs, _ = extract_and_transform_mog(self, x) + logits, means, precfs, _ = condition_mog( + self._prior, condition, dims_to_sample, logits, means, precfs + ) + + # Currently difficult to integrate `sample_posterior_within_prior` + warn( + "Sampling MoG analytically. " + "Some of the samples might not be within the prior support!" + ) + samples = mdn.sample_mog(num_samples, logits, means, precfs) + return samples.detach().reshape((*sample_shape, -1)) + + else: + return super().sample_conditional( + PotentialFunctionProvider(), + sample_shape, + condition, + dims_to_sample, + x, + sample_with, + show_progress_bars, + mcmc_method, + mcmc_parameters, + rejection_sampling_parameters, + ) + + def log_prob_conditional( + self, + theta: Tensor, + condition: Tensor, + dims_to_evaluate: List[int], + x: Optional[Tensor] = None, + ) -> Tensor: + """Evaluates the conditional posterior probability of a MDN at a context x for + a value theta given a condition. + + This function only works for MDN based posteriors, becuase evaluation is done + analytically. For all other density estimators a `NotImplementedError` will be + raised! + + Args: + theta: Parameters $\theta$. + condition: Parameter set that all dimensions not specified in + `dims_to_sample` will be fixed to. Should contain dim_theta elements, + i.e. it could e.g. be a sample from the posterior distribution. + The entries at all `dims_to_sample` will be ignored. + dims_to_evaluate: Which dimensions to evaluate the sample for. + The dimensions not specified in `dims_to_evaluate` will be fixed to values given in `condition`. + x: Conditioning context for posterior $p(\theta|x)$. If not provided, + fall back onto `x` passed to `set_default_x()`. + + Returns: + log_prob: `(len(θ),)`-shaped normalized (!) log posterior probability + $\log p(\theta|x) for θ in the support of the prior, -∞ (corresponding + to 0 probability) outside. + """ + + if type(self.net._distribution) == mdn: + logits, means, precfs, sumlogdiag = extract_and_transform_mog(self, x) + logits, means, precfs, sumlogdiag = condition_mog( + self._prior, condition, dims_to_evaluate, logits, means, precfs + ) + + batch_size, dim = theta.shape + prec = precfs.transpose(3, 2) @ precfs + + self.net.eval() # leakage correction requires eval mode + + if dim != len(dims_to_evaluate): + X = X[:, dims_to_evaluate] + + # Implementing leakage correction is difficult for conditioned MDNs, + # because samples from self i.e. the full posterior are used rather + # then from the new, conditioned posterior. + warn("Probabilities are not adjusted for leakage.") + + log_prob = mdn.log_prob_mog( + theta, + logits.repeat(batch_size, 1), + means.repeat(batch_size, 1, 1), + prec.repeat(batch_size, 1, 1, 1), + sumlogdiag.repeat(batch_size, 1), + ) + + self.net.train(True) + return log_prob.detach() + + else: + raise NotImplementedError( + "This functionality is only available for MDN based posteriors." + ) def map( self, diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 40eef9c00..db8978119 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -2,6 +2,8 @@ from sbi.utils.conditional_density import ( conditional_corrcoeff, eval_conditional_density, + extract_and_transform_mog, + condition_mog, ) from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn from sbi.utils.io import get_data_root, get_log_root, get_project_root diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 07b194ef8..54f26cd0c 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -7,7 +7,9 @@ import torch from torch import Tensor +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn from sbi.utils.torchutils import ensure_theta_batched +from sbi.utils.torchutils import BoxUniform def eval_conditional_density( @@ -330,3 +332,138 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) + + +def extract_and_transform_mog( + posterior: "DirectPosterior", context: Tensor = None +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Extracts the Mixture of Gaussians (MoG) parameters + from an MDN based DirectPosterior at either the default x or input x. + + Args: + posterior: DirectPosterior instance. + context: Conditioning context for posterior $p(\theta|x)$. If not provided, + fall back onto `x` passed to `set_default_x()`. + + Returns: + norm_logits: Normalised log weights of the underyling MoG. + (batch_size, n_mixtures) + means_transformed: Recentred and rescaled means of the underlying MoG + (batch_size, n_mixtures, n_dims) + precfs_transformed: Rescaled precision factors of the underlying MoG. + (batch_size, n_mixtures, n_dims, n_dims) + sumlogdiag: Sum of the log of the diagonal of the precision factors + of the new conditional distribution. (batch_size, n_mixtures) + """ + + # extract and rescale means, mixture componenets and covariances + nn = posterior.net + dist = nn._distribution + + if context == None: + encoded_x = nn._embedding_net(posterior.default_x) + else: + encoded_x = nn._embedding_net(context) + + logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x) + norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + scale = nn._transform._transforms[0]._scale + shift = nn._transform._transforms[0]._shift + + means_transformed = (means - shift) / scale + + A = scale * torch.eye(means_transformed.shape[2]) + precfs_transformed = A @ precfs + + sumlogdiag = torch.sum( + torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2 + ) + + return norm_logits, means_transformed, precfs_transformed, sumlogdiag + + +def condition_mog( + prior: "Prior", + condition: Tensor, + dims: List[int], + logits: Tensor, + means: Tensor, + precfs: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Finds the conditional distribution p(X|Y) for a MoG. + + Args: + prior: Prior Distribution. Used to check if condition within support. + condition: Parameter set that all dimensions not specified in + `dims_to_sample` will be fixed to. Should contain dim_theta elements, + i.e. it could e.g. be a sample from the posterior distribution. + The entries at all `dims_to_sample` will be ignored. + dims_to_sample: Which dimensions to sample from. The dimensions not + specified in `dims_to_sample` will be fixed to values given in + `condition`. + logits: Log weights of the MoG. (batch_size, n_mixtures) + means: Means of the MoG. (batch_size, n_mixtures, n_dims) + precfs: Precision factors of the MoG. + (batch_size, n_mixtures, n_dims, n_dims) + + Raises: + ValueError: The chosen condition is not within the prior support. + + Returns: + logits: Log weights of the conditioned MoG. (batch_size, n_mixtures) + means: Means of the conditioned MoG. (batch_size, n_mixtures, n_dims) + precfs_xx: Precision factors of the MoG. + (batch_size, n_mixtures, n_dims, n_dims) + sumlogdiag: Sum of the log of the diagonal of the precision factors + of the new conditional distribution. (batch_size, n_mixtures) + """ + + support = prior.support + + n_mixtures, n_dims = means.shape[1:] + + mask = torch.zeros(n_dims, dtype=bool) + mask[dims] = True + + # check whether the condition is within the prior bounds + if type(prior) is torch.distributions.uniform.Uniform or type(prior) is BoxUniform: + cond_ubound = support.upper_bound[~mask] + cond_lbound = support.lower_bound[~mask] + within_support = torch.logical_and( + cond_lbound <= condition[:, ~mask], cond_ubound >= condition[:, ~mask] + ) + if ~torch.all(within_support): + raise ValueError("The chosen condition is not within the prior support.") + + y = condition[:, ~mask] + mu_x = means[:, :, mask] + mu_y = means[:, :, ~mask] + + precfs_xx = precfs[:, :, mask] + precfs_xx = precfs_xx[:, :, :, mask] + precs_xx = precfs_xx.transpose(3, 2) @ precfs_xx + + precfs_yy = precfs[:, :, ~mask] + precfs_yy = precfs_yy[:, :, :, ~mask] + precs_yy = precfs_yy.transpose(3, 2) @ precfs_yy + + precs = precfs.transpose(3, 2) @ precfs + precs_xy = precs[:, :, mask] + precs_xy = precs_xy[:, :, :, ~mask] + + means = mu_x - ( + torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_mixtures, -1, 1) + ).view(1, n_mixtures, -1) + + diags = torch.diagonal(precfs_yy, dim1=2, dim2=3) + sumlogdiag_yy = torch.sum(torch.log(diags), dim=2) + log_prob = mdn.log_prob_mog(y, torch.zeros((1, 1)), mu_y, precs_yy, sumlogdiag_yy) + + # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. + new_mcs = torch.exp(logits + log_prob) + new_mcs = new_mcs / new_mcs.sum() + logits = torch.log(new_mcs) + + sumlogdiag = torch.sum(torch.log(torch.diagonal(precfs_xx, dim1=2, dim2=3)), dim=2) + return logits, means, precfs_xx, sumlogdiag diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 8f804295a..943ec269f 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -12,7 +12,9 @@ from sbi import analysis as analysis from sbi import utils as utils + from sbi.inference import SNPE_A, SNPE_B, SNPE_C, prepare_for_sbi, simulate_for_sbi + from sbi.simulators.linear_gaussian import ( linear_gaussian, samples_true_posterior_linear_gaussian_mvn_prior_different_dims, @@ -26,6 +28,8 @@ get_prob_outside_uniform_prior, ) +from tests.sbiutils_test import conditional_of_mvn + @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize( @@ -69,10 +73,11 @@ def test_c2st_snpe_on_linearGaussian( target_samples = samples_true_posterior_linear_gaussian_uniform_prior( x_o, likelihood_shift, likelihood_cov, prior=prior, num_samples=num_samples ) - + simulator, prior = prepare_for_sbi( lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior ) + inference = snpe_method(prior, show_progress_bars=False) theta, x = simulate_for_sbi( @@ -176,6 +181,7 @@ def test_c2st_snpe_on_linearGaussian_different_dims(set_seed): # type: ignore theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) + inference = inference.append_simulations(theta, x) _ = inference.train(max_num_epochs=10) # Test whether we can stop and resume. _ = inference.train(resume_training=True) @@ -380,6 +386,7 @@ def simulator(theta): net = utils.posterior_nn("maf", hidden_features=20) simulator, prior = prepare_for_sbi(simulator, prior) + inference = snpe_method(prior, density_estimator=net, show_progress_bars=False) # We need a pretty big dataset to properly model the bimodality. @@ -431,6 +438,74 @@ def simulator(theta): max_err = np.max(error) assert max_err < 0.0026 +@pytest.mark.slow +@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) +def test_mdn_conditional_density(snpe_method: type, num_dim: int = 3, cond_dim: int = 1): + """Test whether the conditional density infered from MDN parameters of a + `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and + conditions on the last m values to generate a conditional. + + Gaussian prior used for easier ground truthing of conditional posterior. + + Args: + num_dim: Dimensionality of the MVM. + cond_dim: Dimensionality of the condition. + """ + + assert ( + num_dim > cond_dim + ), "The number of dimensions needs to be greater than that of the condition!" + + x_o = zeros(1, num_dim) + num_samples = 1000 + num_simulations = 2500 + condition = 0.1 * ones(1, num_dim) + + dims = list(range(num_dim)) + dims2sample = dims[-cond_dim:] + dims2condition = dims[:-cond_dim] + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + joint_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + joint_cov = joint_posterior.covariance_matrix + joint_mean = joint_posterior.loc + + conditional_mean, conditional_cov = conditional_of_mvn( + joint_mean, joint_cov, condition[0, dims2condition] + ) + conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) + + conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + simulator, prior = prepare_for_sbi(simulator, prior) + inference = snpe_method(prior, show_progress_bars=False, density_estimator="mdn") + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + _ = inference.append_simulations(theta, x).train(training_batch_size=100) + posterior = inference.build_posterior().set_default_x(x_o) + + conditional_samples_sbi = posterior.sample_conditional( + (num_samples,), condition, dims2sample, x_o + ) + check_c2st( + conditional_samples_sbi, + conditional_samples_gt, + alg="analytic_mdn_conditioning_of_direct_posterior", + ) @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) def test_example_posterior(snpe_method: type): @@ -455,6 +530,7 @@ def test_example_posterior(snpe_method: type): lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior ) inference = snpe_method(prior, show_progress_bars=False, **extra_kwargs) + theta, x = simulate_for_sbi( simulator, prior, 1000, simulation_batch_size=10, num_workers=6 )