From aaf731677b654e00ba6548ebd82e3a0b996f1f49 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Mon, 29 Mar 2021 18:42:36 +0200 Subject: [PATCH 01/20] Added analytic sampling of conditional posterior instances trained with MDNns. --- sbi/utils/conditional_density.py | 377 +++++++++++++++++++++++++++++++ 1 file changed, 377 insertions(+) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 5cba7d90d..22471ca58 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -10,6 +10,10 @@ from sbi.utils.torchutils import ensure_theta_batched +from copy import deepcopy + +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from tqdm.auto import tqdm def eval_conditional_density( density: Any, @@ -321,3 +325,376 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) + + +class MDNPosterior(DirectPosterior): + """Wrapper around MDN based DirectPosterior instances. + + Extracts the Gaussian Mixture parameters from the Mixture + Density Network. Samples from Multivariate Gaussians directly, using + torch.distributions.multivariate_normal.MultivariateNormal + rather than going through the MDN. + Replaces sample and log_prob functions of the DirectPosterior. + + Args: + MDN_Posterior: DirectPosterior instance, i.e. output of + inference.build_posterior(density_estimator), + that was trained using a MDN. + + Attributes: + S: Tensor that holds the covariance matrices of all mixture components. + m: Tensor that holds the means of all mixture components. + mc: Tensor that holds mixture coefficients of all mixture components. + support: An Interval with lower and upper bounds of the support. + """ + + def __init__(self, MDN_Posterior: DirectPosterior): + if "MultivariateGaussianMDN" in MDN_Posterior.net.__str__(): + # wrap copy of input object into self + self.__class__ = type( + "MDNPosterior", (self.__class__, deepcopy(MDN_Posterior).__class__), {} + ) + self.__dict__ = deepcopy(MDN_Posterior).__dict__ + + # MoG parameters + self.S = None + self.m = None + self.mc = None + self.support = self._prior.support + + self.extract_mixture_components() + + else: + raise AttributeError("Posterior does not contain a MDN.") + + @staticmethod + def mulnormpdf(X: Tensor, mu: Tensor, cov: Tensor) -> Tensor: + """Evaluates the PDF for the multivariate Guassian distribution. + + Args: + X: torch.tensor with inputs/entries row-wise. Can also be a 1-d array if only a + single point is evaluated. + mu: torch.tensor with center/mean, 1d array. + cov: 2d torch.tensor with covariance matrix. + + Returns: + prob: Probabilities for entries in `X`. + """ + + # Evaluate pdf at points or point: + if X.ndim == 1: + X = torch.atleast_2d(X) + sigma = torch.atleast_2d(cov) # So we also can use it for 1-d distributions + + N = mu.shape[0] + ex1 = torch.inverse(sigma) @ (X - mu).T + ex = -0.5 * (X - mu).T * ex1 + if ex.ndim == 2: + ex = torch.sum(ex, axis=0) + K = 1 / torch.sqrt( + torch.pow(2 * torch.tensor(3.14159265), N) * torch.det(sigma) + ) + return K * torch.exp(ex) + + def check_support(self, X: Tensor) -> bool: + """Takes a set of points X with X.shape[0] being the number of points + and X.shape[1] the dimensionality of the points and checks, each point + for its prior support. + + Args: + X: Contains a set of multidimensional points to check + against the prior support of the posterior object. + + Returns: + within_support: Boolean array representing, whether a sample is within the + prior support or not. + """ + + lbound = self.support.lower_bound + ubound = self.support.upper_bound + + within_support = torch.logical_and(lbound < X, X < ubound) + + return torch.all(within_support, dim=1) + + def extract_mixture_components( + self, x: Tensor = None + ) -> Tuple[Tensor, Tensor, Tensor]: + """Extracts the Mixture of Gaussians (MoG) parameters + from the MDN at either the default x or input x. + + Adpated from code courtesy of @ybernaerts. + + Args: + x: x at which to evaluate the MDN in order + to extract the MoG parameters. + """ + if x == None: + encoded_x = self.net._embedding_net(self.default_x) + else: + encoded_x = self.net._embedding_net(torch.tensor(x, dtype=torch.float32)) + dist = self.net._distribution + logits, m, prec, *_ = dist.get_mixture_components(encoded_x) + norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + scale = self.net._transform._transforms[0]._scale + shift = self.net._transform._transforms[0]._shift + + self.mc = torch.exp(norm_logits).detach() + self.m = ((m - shift) / scale).detach()[0] + + L = torch.cholesky( + prec[0].detach() + torch.eye(self.m.shape[1]) * 1e-6 + ) # sometimes matrices are not pos semi def. dirty fix. + C = torch.inverse(L) + self.S = C.transpose(2, 1) @ C + A_inv = torch.inverse(scale * torch.eye(self.S.shape[1])) + self.S = A_inv @ self.S @ A_inv.T + + return self.mc, self.m, self.S + + def log_prob(self, X: Tensor, individual=False) -> Tensor: + """Evaluates the Mixture of Gaussian (MoG) + probability density function at a value x. + + Adpated from code courtesy of @ybernaerts. + + Args: + X: Values at which to evaluate the MoG pdf. + individual: If True the probability density is returned for each cluster component. + + Returns: + log_prob: Log probabilities at values specified by X. + """ + + pdf = torch.zeros((X.shape[0], self.m.shape[0])) + for i in range(self.m.shape[0]): + pdf[:, i] = self.mulnormpdf(X, self.m[i], self.S[i]) * self.mc[0, i] + if individual: + return torch.log(pdf) + else: + log_factor = torch.log(self.leakage_correction(x=self.default_x)) + return torch.log(torch.sum(pdf, axis=1)) - log_factor + + def sample(self, sample_shape: Tuple[int, int]) -> Tensor: + """Draw samples from a Mixture of Gaussians (MoG) + + Adpated from code courtesy of @ybernaerts. + Args: + sample_shape: The number of samples to draw from the MoG distribution. + + Returns: + X: A matrix with samples rows, and input dimension columns. + """ + + K, D = self.m.shape # Determine dimensionality + + num_samples = torch.Size(sample_shape).numel() + pbar = tqdm( + total=num_samples, + desc=f"Drawing {num_samples} posterior samples", + ) + + # Cluster selection + cs_mc = torch.cumsum(self.mc, 1) + cs_mc = torch.hstack((torch.tensor([[0]]), cs_mc)) + sel_idx = torch.rand(num_samples) + + # Draw samples + res = torch.zeros((num_samples, D)) + f1 = sel_idx[:, None] >= cs_mc + f2 = sel_idx[:, None] < cs_mc + idxs = f1[:, :-1] * f2[:, 1:] + ksamples = torch.sum(idxs, axis=0) + for k, samplesize in enumerate(ksamples): + # draw initial samples + multivar_normal = torch.distributions.multivariate_normal.MultivariateNormal( + self.m[k], self.S[k] + ) + drawn_samples = multivar_normal.sample((samplesize,)) + + # check if samples are within the support and how many are not + supported = self.check_support(drawn_samples) + num_not_supported = torch.count_nonzero(~supported) + drawn_samples_in_support = drawn_samples[supported] + if num_not_supported > 0: + # resample until all samples are within the prior support + while num_not_supported > 0: + # resample + multivar_normal = torch.distributions.multivariate_normal.MultivariateNormal( + self.m[k], self.S[k] + ) + redrawn_samples = multivar_normal.sample((int(num_not_supported),)) + + # reevaluate support + supported = self.check_support(redrawn_samples) + num_not_supported = torch.count_nonzero(~supported) + redrawn_samples_in_support = redrawn_samples[supported] + # append the samples + drawn_samples_in_support = torch.vstack( + [drawn_samples_in_support, redrawn_samples_in_support] + ) + + pbar.update(int(sum(ksamples[: k + 1]) - num_not_supported)) + res[idxs[:, k], :] = drawn_samples_in_support + pbar.close() + return res.reshape((*sample_shape, -1)) + + def conditionalise(self, condition: Tensor) -> ConditionalMDNPosterior: + """Instantiates a new conditional distribution, which can be evaluated + and sampled from. + + Args: + condition: An array of inputs. Inputs set to NaN are not set, and become inputs to + the resulting distribution. Order is preserved. + """ + return ConditionalMDNPosterior(self, condition) + + def sample_from_conditional( + self, condition: Tensor, sample_shape: Tuple[int, int] + ) -> Tensor: + """Conditionalises the distribution on the provided condition + and samples from the the resulting distribution. + + Args: + condition: An array of inputs. Inputs set to NaN are not set, and become inputs to + the resulting distribution. Order is preserved. + sample_shape: The number of samples to draw from the conditional distribution. + """ + conditional_posterior = ConditionalMDNPosterior(self, condition) + samples = cond_posteriori.sample(sample_shape) + return samples + + +class ConditionalMDNPosterior(MDNPosterior): + """Wrapperclass for DirectPosterior objects that were trained using + a Mixture Density Network (MDN) and have been conditionalised. + Replaces sample, sample_conditional, sample_with_mcmc and log_prob + functions. Enables the evaluation and sampling of the conditional + distribution at any arbitrary condition and point. + + Args: + MDN_Posterior: DirectPosterior instance, i.e. output of + inference.build_posterior(density_estimator), + that was trained with a MDN. + condition: A vector that holds the conditioned vector. Entries that contain + NaNs are not set and become inputs to the resulting distribution, + i.e. condition = [x1, x2, NaN, NaN] -> p(x3,x4|x1,x2). + + Attributes: + condition: A Tensor containing the values which the MoG has been conditioned on. + """ + + def __init__(self, MDN_Posterior: DirectPosterior, condition: Tensor): + self.__class__ = type( + "ConditionalMDNPosterior", + (self.__class__, deepcopy(MDN_Posterior).__class__), + {}, + ) + self.__dict__ = deepcopy(MDN_Posterior).__dict__ + self.condition = condition + self.__conditionalise(condition) + + def __conditionalise(self, condition: Tensor): + """Finds the conditional distribution p(X|Y) for a GMM. + + Adpated from code courtesy of @ybernaerts. + + Args: + condition: An array of inputs. Inputs set to NaN are not set, and become inputs to + the resulting distribution. Order is preserved. + + Raises: + ValueError: The chosen condition is not within the prior support. + """ + + # revert to the old GMM parameters first + self.extract_mixture_components() + self.support = self._prior.support + + pop = self.condition.isnan().reshape(-1) + condition_without_NaNs = self.condition.reshape(-1)[~pop] + + # check whether the condition is within the prior bounds + cond_ubound = self.support.upper_bound[~pop] + cond_lbound = self.support.lower_bound[~pop] + within_support = torch.logical_and( + cond_lbound <= condition_without_NaNs, condition_without_NaNs <= cond_ubound + ) + if ~torch.all(within_support): + raise ValueError("The chosen condition is not within the prior support.") + + # adjust the dimensionality of the support + self.support.upper_bound = self.support.upper_bound[pop] + self.support.lower_bound = self.support.lower_bound[pop] + + not_set_idx = torch.nonzero(torch.isnan(condition))[ + :, 1 + ] # indices for not set parameters + set_idx = torch.nonzero(~torch.isnan(condition))[ + :, 1 + ] # indices for set parameters + new_idx = torch.cat( + (not_set_idx, set_idx) + ) # indices with not set parameters first and then set parameters + y = condition[0, set_idx] + # New centroids and covar matrices + new_cen = [] + new_ccovs = [] + # Appendix A in C. E. Rasmussen & C. K. I. Williams, Gaussian Processes + # for Machine Learning, the MIT Press, 2006 + fk = [] + for i in range(self.m.shape[0]): + # Make a new co-variance matrix with correct ordering + new_ccov = deepcopy(self.S[i]) + new_ccov = new_ccov[:, new_idx] + new_ccov = new_ccov[new_idx, :] + ux = self.m[i, not_set_idx] + uy = self.m[i, set_idx] + A = new_ccov[0 : len(not_set_idx), 0 : len(not_set_idx)] + B = new_ccov[len(not_set_idx) :, len(not_set_idx) :] + # B = B + 1e-10*torch.eye(B.shape[0]) # prevents B from becoming singular + C = new_ccov[0 : len(not_set_idx), len(not_set_idx) :] + cen = ux + C @ torch.inverse(B) @ (y - uy) + cov = A - C @ torch.inverse(B) @ C.transpose(1, 0) + new_cen.append(cen) + new_ccovs.append(cov) + fk.append(self.mulnormpdf(y, uy, B)) # Used for normalizing the mc + # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. + fk = torch.tensor(fk) + new_mc = self.mc * fk + new_mc = new_mc / torch.sum(new_mc) + + # set new GMM parameters + self.m = torch.stack(new_cen) + self.S = torch.stack(new_ccovs) + self.mc = new_mc + + def sample_with_mcmc(self): + """Dummy function to overwrite the existing sample_with_mcmc method.""" + + raise DeprecationWarning( + "MCMC sampling is not yet supported for the conditional MDN." + ) + + def sample_conditional( + self, n_samples: Tuple[int, int], condition: Tensor = None + ) -> Tensor: + """Samples from the condtional distribution. If a condition + is provided, a new conditional distribution will be calculated. + If no condition is provided, samples will be drawn from the + exisiting condition. + + Args: + n_samples: The number of samples to draw from the conditional distribution. + condition: An array of inputs. Inputs set to NaN are not set, and become inputs to + the resulting distribution. Order is preserved. + + Returns: + samples: Contains samples from the conditional posterior (NxD). + """ + + if condition != None: + self.__conditionalise(condition) + samples = self.sample(n_samples) + return samples From 0600280b7e485b676c9dac7d2e16d14810ddf9c8 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 30 Mar 2021 13:20:53 +0200 Subject: [PATCH 02/20] Adressing comments in #458: reformated code snippets in docstrings. Removed mentions. --- sbi/utils/conditional_density.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 22471ca58..b04a0497a 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -334,11 +334,11 @@ class MDNPosterior(DirectPosterior): Density Network. Samples from Multivariate Gaussians directly, using torch.distributions.multivariate_normal.MultivariateNormal rather than going through the MDN. - Replaces sample and log_prob functions of the DirectPosterior. + Replaces `.sample()` and `.log_prob() methods of the `DirectPosterior`. Args: - MDN_Posterior: DirectPosterior instance, i.e. output of - inference.build_posterior(density_estimator), + MDN_Posterior: `DirectPosterior` instance, i.e. output of + `inference.build_posterior(density_estimator)`, that was trained using a MDN. Attributes: @@ -423,8 +423,6 @@ def extract_mixture_components( """Extracts the Mixture of Gaussians (MoG) parameters from the MDN at either the default x or input x. - Adpated from code courtesy of @ybernaerts. - Args: x: x at which to evaluate the MDN in order to extract the MoG parameters. @@ -457,8 +455,6 @@ def log_prob(self, X: Tensor, individual=False) -> Tensor: """Evaluates the Mixture of Gaussian (MoG) probability density function at a value x. - Adpated from code courtesy of @ybernaerts. - Args: X: Values at which to evaluate the MoG pdf. individual: If True the probability density is returned for each cluster component. @@ -567,15 +563,15 @@ def sample_from_conditional( class ConditionalMDNPosterior(MDNPosterior): - """Wrapperclass for DirectPosterior objects that were trained using + """Wrapperclass for `DirectPosterior` objects that were trained using a Mixture Density Network (MDN) and have been conditionalised. - Replaces sample, sample_conditional, sample_with_mcmc and log_prob - functions. Enables the evaluation and sampling of the conditional + Replaces `.sample()`, `.sample_conditional()`, `.sample_with_mcmc()` and `.log_prob()` + methods. Enables the evaluation and sampling of the conditional distribution at any arbitrary condition and point. Args: - MDN_Posterior: DirectPosterior instance, i.e. output of - inference.build_posterior(density_estimator), + MDN_Posterior: `DirectPosterior` instance, i.e. output of + `inference.build_posterior(density_estimator)`, that was trained with a MDN. condition: A vector that holds the conditioned vector. Entries that contain NaNs are not set and become inputs to the resulting distribution, @@ -598,8 +594,6 @@ def __init__(self, MDN_Posterior: DirectPosterior, condition: Tensor): def __conditionalise(self, condition: Tensor): """Finds the conditional distribution p(X|Y) for a GMM. - Adpated from code courtesy of @ybernaerts. - Args: condition: An array of inputs. Inputs set to NaN are not set, and become inputs to the resulting distribution. Order is preserved. @@ -671,7 +665,7 @@ def __conditionalise(self, condition: Tensor): self.mc = new_mc def sample_with_mcmc(self): - """Dummy function to overwrite the existing sample_with_mcmc method.""" + """Dummy function to overwrite the existing `.sample_with_mcmc()` method.""" raise DeprecationWarning( "MCMC sampling is not yet supported for the conditional MDN." From 0978d7d6b54f7a493d53d3dee43cf20a7e121751 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 30 Mar 2021 14:32:24 +0200 Subject: [PATCH 03/20] Adressing comments in #458: Sampling nowvia Cholesky decomposition. --- sbi/utils/conditional_density.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index b04a0497a..a5c5f5749 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -325,7 +325,7 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) - + class MDNPosterior(DirectPosterior): """Wrapper around MDN based DirectPosterior instances. @@ -502,12 +502,12 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: f2 = sel_idx[:, None] < cs_mc idxs = f1[:, :-1] * f2[:, 1:] ksamples = torch.sum(idxs, axis=0) + for k, samplesize in enumerate(ksamples): # draw initial samples - multivar_normal = torch.distributions.multivariate_normal.MultivariateNormal( - self.m[k], self.S[k] - ) - drawn_samples = multivar_normal.sample((samplesize,)) + chol_factor = torch.cholesky(self.S[k]) + std_normal_sample = torch.randn(D,samplesize) + drawn_samples = self.m[k] + torch.mm(chol_factor, std_normal_sample).T # check if samples are within the support and how many are not supported = self.check_support(drawn_samples) @@ -517,10 +517,8 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: # resample until all samples are within the prior support while num_not_supported > 0: # resample - multivar_normal = torch.distributions.multivariate_normal.MultivariateNormal( - self.m[k], self.S[k] - ) - redrawn_samples = multivar_normal.sample((int(num_not_supported),)) + std_normal_sample = torch.randn(D,(int(num_not_supported)) + redrawn_samples = self.m[k] + torch.mm(chol_factor, std_normal_sample).T # reevaluate support supported = self.check_support(redrawn_samples) From 630b7cabbf6377ae32189183e03f163ea37d30bd Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 30 Mar 2021 15:02:02 +0200 Subject: [PATCH 04/20] Added Missing bracket. --- sbi/utils/conditional_density.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index a5c5f5749..d52eaada7 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -517,7 +517,7 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: # resample until all samples are within the prior support while num_not_supported > 0: # resample - std_normal_sample = torch.randn(D,(int(num_not_supported)) + std_normal_sample = torch.randn(D,(int(num_not_supported))) redrawn_samples = self.m[k] + torch.mm(chol_factor, std_normal_sample).T # reevaluate support From d74fdaeb4662178ae8ce3c49123f47442ebc4ecf Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Thu, 1 Apr 2021 00:18:23 +0200 Subject: [PATCH 05/20] Intermediate step: sample_mog and log_prob now compatible with MDNPosterior. --- sbi/utils/conditional_density.py | 241 ++++++++++++------------------- 1 file changed, 96 insertions(+), 145 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index d52eaada7..12879d8e2 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -15,6 +15,8 @@ from sbi.inference.posteriors.direct_posterior import DirectPosterior from tqdm.auto import tqdm +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn + def eval_conditional_density( density: Any, condition: Tensor, @@ -357,15 +359,54 @@ def __init__(self, MDN_Posterior: DirectPosterior): self.__dict__ = deepcopy(MDN_Posterior).__dict__ # MoG parameters - self.S = None - self.m = None - self.mc = None + self.precs = None + self.means = None + self.logits = None + self.sumlogdiag = None self.support = self._prior.support - self.extract_mixture_components() + self.extract_and_transform_mog() else: raise AttributeError("Posterior does not contain a MDN.") + + def extract_and_transform_mog( + self, context: Tensor = None + ) -> Tuple[Tensor, Tensor, Tensor]: + """Extracts the Mixture of Gaussians (MoG) parameters + from the MDN at either the default x or input x. + + Args: + x: x at which to evaluate the MDN in order + to extract the MoG parameters. + """ + + # extract and rescale means, mixture componenets and covariances + nn = self.net + dist = nn._distribution + + if context == None: + encoded_x = nn._embedding_net(self.default_x) + else: + encoded_x = nn._embedding_net(torch.tensor(context, dtype=torch.float32)) + + logits, m, prec, sumlogdiag, _ = dist.get_mixture_components(encoded_x) + norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + scale = nn._transform._transforms[0]._scale + shift = nn._transform._transforms[0]._shift + + means_transformed = ((m - shift) / scale).detach() + + A = scale * torch.eye(means_transformed.shape[2]) + precision_factors_transformed = torch.cholesky(A@prec@A) + + self.logits = norm_logits.detach() + self.means = means_transformed.detach() + self.precs = precision_factors_transformed.detach() + self.sumlogdiag = torch.sum(torch.log(torch.diagonal(self.precs, dim1=2, dim2=3)),dim=2).detach() + + return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag @staticmethod def mulnormpdf(X: Tensor, mu: Tensor, cov: Tensor) -> Tensor: @@ -396,61 +437,6 @@ def mulnormpdf(X: Tensor, mu: Tensor, cov: Tensor) -> Tensor: ) return K * torch.exp(ex) - def check_support(self, X: Tensor) -> bool: - """Takes a set of points X with X.shape[0] being the number of points - and X.shape[1] the dimensionality of the points and checks, each point - for its prior support. - - Args: - X: Contains a set of multidimensional points to check - against the prior support of the posterior object. - - Returns: - within_support: Boolean array representing, whether a sample is within the - prior support or not. - """ - - lbound = self.support.lower_bound - ubound = self.support.upper_bound - - within_support = torch.logical_and(lbound < X, X < ubound) - - return torch.all(within_support, dim=1) - - def extract_mixture_components( - self, x: Tensor = None - ) -> Tuple[Tensor, Tensor, Tensor]: - """Extracts the Mixture of Gaussians (MoG) parameters - from the MDN at either the default x or input x. - - Args: - x: x at which to evaluate the MDN in order - to extract the MoG parameters. - """ - if x == None: - encoded_x = self.net._embedding_net(self.default_x) - else: - encoded_x = self.net._embedding_net(torch.tensor(x, dtype=torch.float32)) - dist = self.net._distribution - logits, m, prec, *_ = dist.get_mixture_components(encoded_x) - norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) - - scale = self.net._transform._transforms[0]._scale - shift = self.net._transform._transforms[0]._shift - - self.mc = torch.exp(norm_logits).detach() - self.m = ((m - shift) / scale).detach()[0] - - L = torch.cholesky( - prec[0].detach() + torch.eye(self.m.shape[1]) * 1e-6 - ) # sometimes matrices are not pos semi def. dirty fix. - C = torch.inverse(L) - self.S = C.transpose(2, 1) @ C - A_inv = torch.inverse(scale * torch.eye(self.S.shape[1])) - self.S = A_inv @ self.S @ A_inv.T - - return self.mc, self.m, self.S - def log_prob(self, X: Tensor, individual=False) -> Tensor: """Evaluates the Mixture of Gaussian (MoG) probability density function at a value x. @@ -462,15 +448,13 @@ def log_prob(self, X: Tensor, individual=False) -> Tensor: Returns: log_prob: Log probabilities at values specified by X. """ + prec = self.precs@self.precs.transpose(3,2) - pdf = torch.zeros((X.shape[0], self.m.shape[0])) - for i in range(self.m.shape[0]): - pdf[:, i] = self.mulnormpdf(X, self.m[i], self.S[i]) * self.mc[0, i] - if individual: - return torch.log(pdf) - else: - log_factor = torch.log(self.leakage_correction(x=self.default_x)) - return torch.log(torch.sum(pdf, axis=1)) - log_factor + self.net.eval() # leakage correction requires eval mode + log_factor = torch.log(self.leakage_correction(x=self.default_x)) + + log_prob = mdn.log_prob_mog(X,self.logits, self.means, prec, self.sumlogdiag) # only works for single samples + return log_prob - log_factor def sample(self, sample_shape: Tuple[int, int]) -> Tensor: """Draw samples from a Mixture of Gaussians (MoG) @@ -483,58 +467,16 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: X: A matrix with samples rows, and input dimension columns. """ - K, D = self.m.shape # Determine dimensionality - + _, K, D = self.means.shape # Determine dimensionality + + # add sample posterior from prior (rejection sampling) num_samples = torch.Size(sample_shape).numel() - pbar = tqdm( - total=num_samples, - desc=f"Drawing {num_samples} posterior samples", - ) - # Cluster selection - cs_mc = torch.cumsum(self.mc, 1) - cs_mc = torch.hstack((torch.tensor([[0]]), cs_mc)) - sel_idx = torch.rand(num_samples) - - # Draw samples - res = torch.zeros((num_samples, D)) - f1 = sel_idx[:, None] >= cs_mc - f2 = sel_idx[:, None] < cs_mc - idxs = f1[:, :-1] * f2[:, 1:] - ksamples = torch.sum(idxs, axis=0) - - for k, samplesize in enumerate(ksamples): - # draw initial samples - chol_factor = torch.cholesky(self.S[k]) - std_normal_sample = torch.randn(D,samplesize) - drawn_samples = self.m[k] + torch.mm(chol_factor, std_normal_sample).T - - # check if samples are within the support and how many are not - supported = self.check_support(drawn_samples) - num_not_supported = torch.count_nonzero(~supported) - drawn_samples_in_support = drawn_samples[supported] - if num_not_supported > 0: - # resample until all samples are within the prior support - while num_not_supported > 0: - # resample - std_normal_sample = torch.randn(D,(int(num_not_supported))) - redrawn_samples = self.m[k] + torch.mm(chol_factor, std_normal_sample).T - - # reevaluate support - supported = self.check_support(redrawn_samples) - num_not_supported = torch.count_nonzero(~supported) - redrawn_samples_in_support = redrawn_samples[supported] - # append the samples - drawn_samples_in_support = torch.vstack( - [drawn_samples_in_support, redrawn_samples_in_support] - ) + samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precs) + + return samples.reshape((*sample_shape, -1)) - pbar.update(int(sum(ksamples[: k + 1]) - num_not_supported)) - res[idxs[:, k], :] = drawn_samples_in_support - pbar.close() - return res.reshape((*sample_shape, -1)) - - def conditionalise(self, condition: Tensor) -> ConditionalMDNPosterior: + def conditionalise(self, condition: Tensor): # -> ConditionalMDNPosterior: """Instantiates a new conditional distribution, which can be evaluated and sampled from. @@ -544,7 +486,7 @@ def conditionalise(self, condition: Tensor) -> ConditionalMDNPosterior: """ return ConditionalMDNPosterior(self, condition) - def sample_from_conditional( + def sample_conditional( self, condition: Tensor, sample_shape: Tuple[int, int] ) -> Tensor: """Conditionalises the distribution on the provided condition @@ -601,7 +543,7 @@ def __conditionalise(self, condition: Tensor): """ # revert to the old GMM parameters first - self.extract_mixture_components() + self.extract_and_transform_mog() self.support = self._prior.support pop = self.condition.isnan().reshape(-1) @@ -629,38 +571,47 @@ def __conditionalise(self, condition: Tensor): new_idx = torch.cat( (not_set_idx, set_idx) ) # indices with not set parameters first and then set parameters - y = condition[0, set_idx] + y = condition[0, set_idx].reshape(1,-1) + + k = self.means.shape[1] + d_new = not_set_idx.shape[0] + # New centroids and covar matrices - new_cen = [] - new_ccovs = [] + new_cen = torch.zeros(1,k,d_new) + new_ccovs = torch.zeros(1,k,d_new,d_new) # Appendix A in C. E. Rasmussen & C. K. I. Williams, Gaussian Processes # for Machine Learning, the MIT Press, 2006 - fk = [] - for i in range(self.m.shape[0]): + fk = torch.zeros(1,k) + prec = self.precs@self.precs.transpose(3,2) + covs = torch.inverse(prec) + mcs = torch.exp(self.logits) + + for i in range(self.means.shape[1]): # Make a new co-variance matrix with correct ordering - new_ccov = deepcopy(self.S[i]) - new_ccov = new_ccov[:, new_idx] - new_ccov = new_ccov[new_idx, :] - ux = self.m[i, not_set_idx] - uy = self.m[i, set_idx] - A = new_ccov[0 : len(not_set_idx), 0 : len(not_set_idx)] - B = new_ccov[len(not_set_idx) :, len(not_set_idx) :] - # B = B + 1e-10*torch.eye(B.shape[0]) # prevents B from becoming singular - C = new_ccov[0 : len(not_set_idx), len(not_set_idx) :] - cen = ux + C @ torch.inverse(B) @ (y - uy) - cov = A - C @ torch.inverse(B) @ C.transpose(1, 0) - new_cen.append(cen) - new_ccovs.append(cov) - fk.append(self.mulnormpdf(y, uy, B)) # Used for normalizing the mc + new_ccov = covs[:,i].clone() + new_ccov = new_ccov[:,:, new_idx] + new_ccov = new_ccov[:,new_idx, :] + ux = self.means[:,i, not_set_idx] + uy = self.means[:,i, set_idx] + A = new_ccov[:,0 : len(not_set_idx), 0 : len(not_set_idx)] + B = new_ccov[:,len(not_set_idx) :, len(not_set_idx) :] + C = new_ccov[:,0 : len(not_set_idx), len(not_set_idx) :] + cen = ux + (C @ torch.inverse(B) @ (y - uy).T).transpose(2,1) + cov = A - C @ torch.inverse(B) @ C.transpose(2, 1) + new_cen[:,i] = cen + new_ccovs[:,i] = cov + #torch.distributions.MultivariateNormal() + fk[:,i] = self.mulnormpdf(y[0], uy[0], B[0]) # Used for normalizing the mc # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. - fk = torch.tensor(fk) - new_mc = self.mc * fk - new_mc = new_mc / torch.sum(new_mc) + new_mcs = mcs * fk + new_mcs = new_mcs / torch.sum(new_mcs) # set new GMM parameters - self.m = torch.stack(new_cen) - self.S = torch.stack(new_ccovs) - self.mc = new_mc + self.means = new_cen + self.precs = torch.cholesky(torch.inverse(new_ccovs)) + self.logits = torch.log(new_mcs) + self.sumlogdiag = torch.sum(torch.log(torch.diagonal(self.precs, dim1=2, dim2=3)),dim=2) + def sample_with_mcmc(self): """Dummy function to overwrite the existing `.sample_with_mcmc()` method.""" @@ -670,7 +621,7 @@ def sample_with_mcmc(self): ) def sample_conditional( - self, n_samples: Tuple[int, int], condition: Tensor = None + self, sample_shape: Tuple[int, int], condition: Tensor = None ) -> Tensor: """Samples from the condtional distribution. If a condition is provided, a new conditional distribution will be calculated. @@ -688,5 +639,5 @@ def sample_conditional( if condition != None: self.__conditionalise(condition) - samples = self.sample(n_samples) + samples = self.sample(sample_shape) return samples From 501bf68c1e6edb30202b1772ef422ea172e1f4f0 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 6 Apr 2021 00:33:14 +0200 Subject: [PATCH 06/20] Replaced with . --- sbi/utils/conditional_density.py | 46 +++++++++----------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 12879d8e2..1c8266e0e 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -401,6 +401,9 @@ def extract_and_transform_mog( A = scale * torch.eye(means_transformed.shape[2]) precision_factors_transformed = torch.cholesky(A@prec@A) + #prec = precision_factors_transformed@precision_factors_transformed.transpose(3,2) + #covs = torch.inverse(prec) + self.logits = norm_logits.detach() self.means = means_transformed.detach() self.precs = precision_factors_transformed.detach() @@ -408,35 +411,6 @@ def extract_and_transform_mog( return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag - @staticmethod - def mulnormpdf(X: Tensor, mu: Tensor, cov: Tensor) -> Tensor: - """Evaluates the PDF for the multivariate Guassian distribution. - - Args: - X: torch.tensor with inputs/entries row-wise. Can also be a 1-d array if only a - single point is evaluated. - mu: torch.tensor with center/mean, 1d array. - cov: 2d torch.tensor with covariance matrix. - - Returns: - prob: Probabilities for entries in `X`. - """ - - # Evaluate pdf at points or point: - if X.ndim == 1: - X = torch.atleast_2d(X) - sigma = torch.atleast_2d(cov) # So we also can use it for 1-d distributions - - N = mu.shape[0] - ex1 = torch.inverse(sigma) @ (X - mu).T - ex = -0.5 * (X - mu).T * ex1 - if ex.ndim == 2: - ex = torch.sum(ex, axis=0) - K = 1 / torch.sqrt( - torch.pow(2 * torch.tensor(3.14159265), N) * torch.det(sigma) - ) - return K * torch.exp(ex) - def log_prob(self, X: Tensor, individual=False) -> Tensor: """Evaluates the Mixture of Gaussian (MoG) probability density function at a value x. @@ -448,12 +422,13 @@ def log_prob(self, X: Tensor, individual=False) -> Tensor: Returns: log_prob: Log probabilities at values specified by X. """ + batch_size = X.shape[0] prec = self.precs@self.precs.transpose(3,2) self.net.eval() # leakage correction requires eval mode log_factor = torch.log(self.leakage_correction(x=self.default_x)) - log_prob = mdn.log_prob_mog(X,self.logits, self.means, prec, self.sumlogdiag) # only works for single samples + log_prob = mdn.log_prob_mog(X, self.logits.repeat(batch_size,1,1), self.means.repeat(batch_size,1,1), prec.repeat(batch_size,1,1,1), self.sumlogdiag.repeat(batch_size,1)) # only works for single samples return log_prob - log_factor def sample(self, sample_shape: Tuple[int, int]) -> Tensor: @@ -464,7 +439,7 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: sample_shape: The number of samples to draw from the MoG distribution. Returns: - X: A matrix with samples rows, and input dimension columns. + X: A matrix with samples2 tensor([0.0002]) rows, and input dimension columns. """ _, K, D = self.means.shape # Determine dimensionality @@ -573,7 +548,7 @@ def __conditionalise(self, condition: Tensor): ) # indices with not set parameters first and then set parameters y = condition[0, set_idx].reshape(1,-1) - k = self.means.shape[1] + k, D = self.means.shape[1:] d_new = not_set_idx.shape[0] # New centroids and covar matrices @@ -600,8 +575,11 @@ def __conditionalise(self, condition: Tensor): cov = A - C @ torch.inverse(B) @ C.transpose(2, 1) new_cen[:,i] = cen new_ccovs[:,i] = cov - #torch.distributions.MultivariateNormal() - fk[:,i] = self.mulnormpdf(y[0], uy[0], B[0]) # Used for normalizing the mc + prec_B = torch.inverse(B) + precf = torch.cholesky(prec_B) + sumlogdiag = torch.sum(precf.diagonal()).reshape(1,-1) + log_prob = mdn.log_prob_mog(y, torch.tensor([[0.0]]), uy.view(1,1,-1), prec_B, sumlogdiag) + fk[:,i] = torch.exp(log_prob) # Used for normalizing the mc # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. new_mcs = mcs * fk new_mcs = new_mcs / torch.sum(new_mcs) From 2910bf0463093c7ef281d8f9333285733b1b2704 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 6 Apr 2021 00:33:14 +0200 Subject: [PATCH 07/20] Replaced with mulnormpdf with log_prob_mog. --- sbi/utils/conditional_density.py | 46 +++++++++----------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 12879d8e2..1c8266e0e 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -401,6 +401,9 @@ def extract_and_transform_mog( A = scale * torch.eye(means_transformed.shape[2]) precision_factors_transformed = torch.cholesky(A@prec@A) + #prec = precision_factors_transformed@precision_factors_transformed.transpose(3,2) + #covs = torch.inverse(prec) + self.logits = norm_logits.detach() self.means = means_transformed.detach() self.precs = precision_factors_transformed.detach() @@ -408,35 +411,6 @@ def extract_and_transform_mog( return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag - @staticmethod - def mulnormpdf(X: Tensor, mu: Tensor, cov: Tensor) -> Tensor: - """Evaluates the PDF for the multivariate Guassian distribution. - - Args: - X: torch.tensor with inputs/entries row-wise. Can also be a 1-d array if only a - single point is evaluated. - mu: torch.tensor with center/mean, 1d array. - cov: 2d torch.tensor with covariance matrix. - - Returns: - prob: Probabilities for entries in `X`. - """ - - # Evaluate pdf at points or point: - if X.ndim == 1: - X = torch.atleast_2d(X) - sigma = torch.atleast_2d(cov) # So we also can use it for 1-d distributions - - N = mu.shape[0] - ex1 = torch.inverse(sigma) @ (X - mu).T - ex = -0.5 * (X - mu).T * ex1 - if ex.ndim == 2: - ex = torch.sum(ex, axis=0) - K = 1 / torch.sqrt( - torch.pow(2 * torch.tensor(3.14159265), N) * torch.det(sigma) - ) - return K * torch.exp(ex) - def log_prob(self, X: Tensor, individual=False) -> Tensor: """Evaluates the Mixture of Gaussian (MoG) probability density function at a value x. @@ -448,12 +422,13 @@ def log_prob(self, X: Tensor, individual=False) -> Tensor: Returns: log_prob: Log probabilities at values specified by X. """ + batch_size = X.shape[0] prec = self.precs@self.precs.transpose(3,2) self.net.eval() # leakage correction requires eval mode log_factor = torch.log(self.leakage_correction(x=self.default_x)) - log_prob = mdn.log_prob_mog(X,self.logits, self.means, prec, self.sumlogdiag) # only works for single samples + log_prob = mdn.log_prob_mog(X, self.logits.repeat(batch_size,1,1), self.means.repeat(batch_size,1,1), prec.repeat(batch_size,1,1,1), self.sumlogdiag.repeat(batch_size,1)) # only works for single samples return log_prob - log_factor def sample(self, sample_shape: Tuple[int, int]) -> Tensor: @@ -464,7 +439,7 @@ def sample(self, sample_shape: Tuple[int, int]) -> Tensor: sample_shape: The number of samples to draw from the MoG distribution. Returns: - X: A matrix with samples rows, and input dimension columns. + X: A matrix with samples2 tensor([0.0002]) rows, and input dimension columns. """ _, K, D = self.means.shape # Determine dimensionality @@ -573,7 +548,7 @@ def __conditionalise(self, condition: Tensor): ) # indices with not set parameters first and then set parameters y = condition[0, set_idx].reshape(1,-1) - k = self.means.shape[1] + k, D = self.means.shape[1:] d_new = not_set_idx.shape[0] # New centroids and covar matrices @@ -600,8 +575,11 @@ def __conditionalise(self, condition: Tensor): cov = A - C @ torch.inverse(B) @ C.transpose(2, 1) new_cen[:,i] = cen new_ccovs[:,i] = cov - #torch.distributions.MultivariateNormal() - fk[:,i] = self.mulnormpdf(y[0], uy[0], B[0]) # Used for normalizing the mc + prec_B = torch.inverse(B) + precf = torch.cholesky(prec_B) + sumlogdiag = torch.sum(precf.diagonal()).reshape(1,-1) + log_prob = mdn.log_prob_mog(y, torch.tensor([[0.0]]), uy.view(1,1,-1), prec_B, sumlogdiag) + fk[:,i] = torch.exp(log_prob) # Used for normalizing the mc # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. new_mcs = mcs * fk new_mcs = new_mcs / torch.sum(new_mcs) From 06c3dc5742a7b3dbf7b91116c89d6bcca2cd5182 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Sun, 18 Apr 2021 18:50:21 +0200 Subject: [PATCH 08/20] Conditional sampling and evaluation for MDNs integrated into DirectPosterior. --- .vscode/settings.json | 3 +- sbi/inference/posteriors/direct_posterior.py | 221 ++++++++++- sbi/utils/conditional_density.py | 298 -------------- tests/torchutils_test.py | 386 +++++++++---------- 4 files changed, 398 insertions(+), 510 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a3e50d6e5..cf36b975a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -72,5 +72,6 @@ "autoDocstring.customTemplatePath": ".vscode/autodocstring.template", // // Signal that we are using shared workspace settings in .vscode - "window.title": "sbi.vscode:: ${dirty}${activeEditorShort}${separator}${rootName}${separator}${appName}" + "window.title": "sbi.vscode:: ${dirty}${activeEditorShort}${separator}${rootName}${separator}${appName}", + "python.pythonPath": "/home/jnsbck/Applications/anaconda3/envs/sbi_env/bin/python" } \ No newline at end of file diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 41861acce..44e958034 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -1,13 +1,14 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import torch from torch import Tensor, log, nn from sbi import utils as utils +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.types import ScalarFloat, Shape from sbi.utils import del_entries @@ -368,6 +369,126 @@ def sample( return samples.reshape((*sample_shape, -1)) + def extract_and_transform_mog( + self, context: Tensor = None + ) -> Tuple[Tensor, Tensor, Tensor]: + """Extracts the Mixture of Gaussians (MoG) parameters + from the MDN at either the default x or input x. + + Args: + x: x at which to evaluate the MDN in order + to extract the MoG parameters. + """ + + # extract and rescale means, mixture componenets and covariances + nn = self.net + dist = nn._distribution + + if context == None: + encoded_x = nn._embedding_net(self.default_x) + else: + encoded_x = nn._embedding_net(context) + + logits, m, prec, sumlogdiag, _ = dist.get_mixture_components(encoded_x) + norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + scale = nn._transform._transforms[0]._scale + shift = nn._transform._transforms[0]._shift + + means_transformed = (m - shift) / scale + + A = scale * torch.eye(means_transformed.shape[2]) + precision_factors_transformed = torch.cholesky(A @ prec @ A) + + sumlogdiag = torch.sum( + torch.log(torch.diagonal(precision_factors_transformed, dim1=2, dim2=3)), + dim=2, + ) + + return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag + + def condition_mog( + self, + condition: Tensor, + dims: List[int], + logits: Tensor, + means: Tensor, + precfs: Tensor, + ): + """Finds the conditional distribution p(X|Y) for a GMM. + + Args: + condition: An array of inputs. Inputs set to NaN are not set, and become inputs to + the resulting distribution. Order is preserved. + + Raises: + ValueError: The chosen condition is not within the prior support. + """ + + support = self._prior.support + + mask = torch.zeros(means.shape[-1], dtype=bool) + mask[dims] = True + + # check whether the condition is within the prior bounds + if ( + type(self._prior) is torch.distributions.uniform.Uniform + or type(self._prior) is utils.torchutils.BoxUniform + ): + cond_ubound = support.upper_bound[~mask] + cond_lbound = support.lower_bound[~mask] + within_support = torch.logical_and( + cond_lbound <= condition[:, ~mask], condition[:, ~mask] <= cond_ubound + ) + if ~torch.all(within_support): + raise ValueError( + "The chosen condition is not within the prior support." + ) + + y = condition[0, ~mask].reshape(1, -1) + + k, D = means.shape[1:] + + prec = precfs @ precfs.transpose(3, 2) + covs = torch.inverse(prec) + mcs = torch.exp(logits) + + mu_x = means[:, :, mask] + mu_y = means[:, :, ~mask] + + S_xx = covs[:, :, mask] + S_xx = S_xx[:, :, :, mask] + + S_yy = covs[:, :, ~mask] + S_yy = S_yy[:, :, :, ~mask] + + S_xy = covs[:, :, mask] + S_xy = S_xy[:, :, :, ~mask] + + means = mu_x + ( + (S_xy @ torch.inverse(S_yy) @ (y - mu_y).view(1, k, -1, 1)).transpose(3, 2) + ).view(1, k, -1) + cov = S_xx - S_xy @ torch.inverse(S_yy) @ S_xy.transpose(3, 2) + + prec_yy = torch.inverse(S_yy) + precf_yy = torch.cholesky(prec_yy) + + sumlogdiag = torch.sum( + torch.log(torch.diagonal(precf_yy, dim1=2, dim2=3)), dim=2 + ) + log_prob = mdn.log_prob_mog(y, torch.tensor([[0.0]]), mu_y, prec_yy, sumlogdiag) + fk = torch.exp(log_prob) # Used for normalizing the mc + + # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. + new_mcs = mcs * fk + new_mcs = new_mcs / new_mcs.sum() + + precfs = torch.cholesky(torch.inverse(cov)) + logits = torch.log(new_mcs) + sumlogdiag = torch.sum(torch.log(torch.diagonal(precfs, dim1=2, dim2=3)), dim=2) + + return logits, means, precfs, sumlogdiag + def sample_conditional( self, sample_shape: Shape, @@ -414,17 +535,86 @@ def sample_conditional( Returns: Samples from conditional posterior. """ + if type(self.net._distribution) is mdn: + num_samples = torch.Size(sample_shape).numel() - return super().sample_conditional( - PotentialFunctionProvider(), - sample_shape, - condition, - dims_to_sample, - x, - show_progress_bars, - mcmc_method, - mcmc_parameters, - ) + logits, means, precfs, sumlogdiag = self.extract_and_transform_mog(x) + logits, means, precfs, sumlogdiag = self.condition_mog( + condition, dims_to_sample, logits, means, precfs + ) + print( + "Warning: Sampling MoG analytically. Some of the samples might not be within the prior support!" + ) + samples = mdn.sample_mog(num_samples, logits, means, precfs) + return samples.detach().reshape((*sample_shape, -1)) + + else: + return super().sample_conditional( + PotentialFunctionProvider(), + sample_shape, + condition, + dims_to_sample, + x, + show_progress_bars, + mcmc_method, + mcmc_parameters, + ) + + def log_prob_conditional( + self, + theta: Tensor, + condition: Tensor, + dims_to_evaluate: List[int], + x: Optional[Tensor] = None, + ) -> Tensor: + """Evaluates the Mixture of Gaussian (MoG) + probability density function at a value x. + + Args: + theta: Parameters $\theta$. + condition: Parameter set that all dimensions not specified in + `dims_to_sample` will be fixed to. Should contain dim_theta elements, + i.e. it could e.g. be a sample from the posterior distribution. + The entries at all `dims_to_sample` will be ignored. + dims_to_evaluate: Which dimensions to evaluate the sample for. + The dimensions not specified in `dims_to_evaluate` will be fixed to values given in `condition`. + x: Conditioning context for posterior $p(\theta|x)$. If not provided, + fall back onto `x` passed to `set_default_x()`. + + Returns: + log_prob: `(len(θ),)`-shaped log posterior probability $\log p(\theta|x + for θ in the support of the prior, -∞ (corresponding to 0 probability) outside. + """ + if type(self.net._distribution) == mdn: + logits, means, precfs, sumlogdiag = self.extract_and_transform_mog(x) + logits, means, precfs, sumlogdiag = self.condition_mog( + condition, dims_to_evaluate, logits, means, precfs + ) + + batch_size, dim = theta.shape + prec = precfs @ precfs.transpose(3, 2) + + self.net.eval() # leakage correction requires eval mode + + if dim != len(dims_to_evaluate): + X = X[:, dims_to_evaluate] + + print("Warning: Probabilities are not adjusted for leakage.") + log_prob = mdn.log_prob_mog( + theta, + logits.repeat(batch_size, 1), + means.repeat(batch_size, 1, 1), + prec.repeat(batch_size, 1, 1, 1), + sumlogdiag.repeat(batch_size, 1), + ) + + self.net.train(True) + return log_prob.detach() + + else: + raise NotImplementedError( + "This functionality is only available for MDN based posteriors." + ) def map( self, @@ -509,11 +699,7 @@ class PotentialFunctionProvider: """ def __call__( - self, - prior, - posterior_nn: nn.Module, - x: Tensor, - mcmc_method: str, + self, prior, posterior_nn: nn.Module, x: Tensor, mcmc_method: str, ) -> Callable: """Return potential function. @@ -549,8 +735,7 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat: with torch.set_grad_enabled(False): target_log_prob = self.posterior_nn.log_prob( - inputs=theta.to(self.x.device), - context=x_repeated, + inputs=theta.to(self.x.device), context=x_repeated, ) is_within_prior = torch.isfinite(self.prior.log_prob(theta)) target_log_prob[~is_within_prior] = -float("Inf") diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 1c8266e0e..5cba7d90d 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -10,12 +10,6 @@ from sbi.utils.torchutils import ensure_theta_batched -from copy import deepcopy - -from sbi.inference.posteriors.direct_posterior import DirectPosterior -from tqdm.auto import tqdm - -from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn def eval_conditional_density( density: Any, @@ -327,295 +321,3 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) - - -class MDNPosterior(DirectPosterior): - """Wrapper around MDN based DirectPosterior instances. - - Extracts the Gaussian Mixture parameters from the Mixture - Density Network. Samples from Multivariate Gaussians directly, using - torch.distributions.multivariate_normal.MultivariateNormal - rather than going through the MDN. - Replaces `.sample()` and `.log_prob() methods of the `DirectPosterior`. - - Args: - MDN_Posterior: `DirectPosterior` instance, i.e. output of - `inference.build_posterior(density_estimator)`, - that was trained using a MDN. - - Attributes: - S: Tensor that holds the covariance matrices of all mixture components. - m: Tensor that holds the means of all mixture components. - mc: Tensor that holds mixture coefficients of all mixture components. - support: An Interval with lower and upper bounds of the support. - """ - - def __init__(self, MDN_Posterior: DirectPosterior): - if "MultivariateGaussianMDN" in MDN_Posterior.net.__str__(): - # wrap copy of input object into self - self.__class__ = type( - "MDNPosterior", (self.__class__, deepcopy(MDN_Posterior).__class__), {} - ) - self.__dict__ = deepcopy(MDN_Posterior).__dict__ - - # MoG parameters - self.precs = None - self.means = None - self.logits = None - self.sumlogdiag = None - self.support = self._prior.support - - self.extract_and_transform_mog() - - else: - raise AttributeError("Posterior does not contain a MDN.") - - def extract_and_transform_mog( - self, context: Tensor = None - ) -> Tuple[Tensor, Tensor, Tensor]: - """Extracts the Mixture of Gaussians (MoG) parameters - from the MDN at either the default x or input x. - - Args: - x: x at which to evaluate the MDN in order - to extract the MoG parameters. - """ - - # extract and rescale means, mixture componenets and covariances - nn = self.net - dist = nn._distribution - - if context == None: - encoded_x = nn._embedding_net(self.default_x) - else: - encoded_x = nn._embedding_net(torch.tensor(context, dtype=torch.float32)) - - logits, m, prec, sumlogdiag, _ = dist.get_mixture_components(encoded_x) - norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) - - scale = nn._transform._transforms[0]._scale - shift = nn._transform._transforms[0]._shift - - means_transformed = ((m - shift) / scale).detach() - - A = scale * torch.eye(means_transformed.shape[2]) - precision_factors_transformed = torch.cholesky(A@prec@A) - - #prec = precision_factors_transformed@precision_factors_transformed.transpose(3,2) - #covs = torch.inverse(prec) - - self.logits = norm_logits.detach() - self.means = means_transformed.detach() - self.precs = precision_factors_transformed.detach() - self.sumlogdiag = torch.sum(torch.log(torch.diagonal(self.precs, dim1=2, dim2=3)),dim=2).detach() - - return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag - - def log_prob(self, X: Tensor, individual=False) -> Tensor: - """Evaluates the Mixture of Gaussian (MoG) - probability density function at a value x. - - Args: - X: Values at which to evaluate the MoG pdf. - individual: If True the probability density is returned for each cluster component. - - Returns: - log_prob: Log probabilities at values specified by X. - """ - batch_size = X.shape[0] - prec = self.precs@self.precs.transpose(3,2) - - self.net.eval() # leakage correction requires eval mode - log_factor = torch.log(self.leakage_correction(x=self.default_x)) - - log_prob = mdn.log_prob_mog(X, self.logits.repeat(batch_size,1,1), self.means.repeat(batch_size,1,1), prec.repeat(batch_size,1,1,1), self.sumlogdiag.repeat(batch_size,1)) # only works for single samples - return log_prob - log_factor - - def sample(self, sample_shape: Tuple[int, int]) -> Tensor: - """Draw samples from a Mixture of Gaussians (MoG) - - Adpated from code courtesy of @ybernaerts. - Args: - sample_shape: The number of samples to draw from the MoG distribution. - - Returns: - X: A matrix with samples2 tensor([0.0002]) rows, and input dimension columns. - """ - - _, K, D = self.means.shape # Determine dimensionality - - # add sample posterior from prior (rejection sampling) - num_samples = torch.Size(sample_shape).numel() - - samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precs) - - return samples.reshape((*sample_shape, -1)) - - def conditionalise(self, condition: Tensor): # -> ConditionalMDNPosterior: - """Instantiates a new conditional distribution, which can be evaluated - and sampled from. - - Args: - condition: An array of inputs. Inputs set to NaN are not set, and become inputs to - the resulting distribution. Order is preserved. - """ - return ConditionalMDNPosterior(self, condition) - - def sample_conditional( - self, condition: Tensor, sample_shape: Tuple[int, int] - ) -> Tensor: - """Conditionalises the distribution on the provided condition - and samples from the the resulting distribution. - - Args: - condition: An array of inputs. Inputs set to NaN are not set, and become inputs to - the resulting distribution. Order is preserved. - sample_shape: The number of samples to draw from the conditional distribution. - """ - conditional_posterior = ConditionalMDNPosterior(self, condition) - samples = cond_posteriori.sample(sample_shape) - return samples - - -class ConditionalMDNPosterior(MDNPosterior): - """Wrapperclass for `DirectPosterior` objects that were trained using - a Mixture Density Network (MDN) and have been conditionalised. - Replaces `.sample()`, `.sample_conditional()`, `.sample_with_mcmc()` and `.log_prob()` - methods. Enables the evaluation and sampling of the conditional - distribution at any arbitrary condition and point. - - Args: - MDN_Posterior: `DirectPosterior` instance, i.e. output of - `inference.build_posterior(density_estimator)`, - that was trained with a MDN. - condition: A vector that holds the conditioned vector. Entries that contain - NaNs are not set and become inputs to the resulting distribution, - i.e. condition = [x1, x2, NaN, NaN] -> p(x3,x4|x1,x2). - - Attributes: - condition: A Tensor containing the values which the MoG has been conditioned on. - """ - - def __init__(self, MDN_Posterior: DirectPosterior, condition: Tensor): - self.__class__ = type( - "ConditionalMDNPosterior", - (self.__class__, deepcopy(MDN_Posterior).__class__), - {}, - ) - self.__dict__ = deepcopy(MDN_Posterior).__dict__ - self.condition = condition - self.__conditionalise(condition) - - def __conditionalise(self, condition: Tensor): - """Finds the conditional distribution p(X|Y) for a GMM. - - Args: - condition: An array of inputs. Inputs set to NaN are not set, and become inputs to - the resulting distribution. Order is preserved. - - Raises: - ValueError: The chosen condition is not within the prior support. - """ - - # revert to the old GMM parameters first - self.extract_and_transform_mog() - self.support = self._prior.support - - pop = self.condition.isnan().reshape(-1) - condition_without_NaNs = self.condition.reshape(-1)[~pop] - - # check whether the condition is within the prior bounds - cond_ubound = self.support.upper_bound[~pop] - cond_lbound = self.support.lower_bound[~pop] - within_support = torch.logical_and( - cond_lbound <= condition_without_NaNs, condition_without_NaNs <= cond_ubound - ) - if ~torch.all(within_support): - raise ValueError("The chosen condition is not within the prior support.") - - # adjust the dimensionality of the support - self.support.upper_bound = self.support.upper_bound[pop] - self.support.lower_bound = self.support.lower_bound[pop] - - not_set_idx = torch.nonzero(torch.isnan(condition))[ - :, 1 - ] # indices for not set parameters - set_idx = torch.nonzero(~torch.isnan(condition))[ - :, 1 - ] # indices for set parameters - new_idx = torch.cat( - (not_set_idx, set_idx) - ) # indices with not set parameters first and then set parameters - y = condition[0, set_idx].reshape(1,-1) - - k, D = self.means.shape[1:] - d_new = not_set_idx.shape[0] - - # New centroids and covar matrices - new_cen = torch.zeros(1,k,d_new) - new_ccovs = torch.zeros(1,k,d_new,d_new) - # Appendix A in C. E. Rasmussen & C. K. I. Williams, Gaussian Processes - # for Machine Learning, the MIT Press, 2006 - fk = torch.zeros(1,k) - prec = self.precs@self.precs.transpose(3,2) - covs = torch.inverse(prec) - mcs = torch.exp(self.logits) - - for i in range(self.means.shape[1]): - # Make a new co-variance matrix with correct ordering - new_ccov = covs[:,i].clone() - new_ccov = new_ccov[:,:, new_idx] - new_ccov = new_ccov[:,new_idx, :] - ux = self.means[:,i, not_set_idx] - uy = self.means[:,i, set_idx] - A = new_ccov[:,0 : len(not_set_idx), 0 : len(not_set_idx)] - B = new_ccov[:,len(not_set_idx) :, len(not_set_idx) :] - C = new_ccov[:,0 : len(not_set_idx), len(not_set_idx) :] - cen = ux + (C @ torch.inverse(B) @ (y - uy).T).transpose(2,1) - cov = A - C @ torch.inverse(B) @ C.transpose(2, 1) - new_cen[:,i] = cen - new_ccovs[:,i] = cov - prec_B = torch.inverse(B) - precf = torch.cholesky(prec_B) - sumlogdiag = torch.sum(precf.diagonal()).reshape(1,-1) - log_prob = mdn.log_prob_mog(y, torch.tensor([[0.0]]), uy.view(1,1,-1), prec_B, sumlogdiag) - fk[:,i] = torch.exp(log_prob) # Used for normalizing the mc - # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. - new_mcs = mcs * fk - new_mcs = new_mcs / torch.sum(new_mcs) - - # set new GMM parameters - self.means = new_cen - self.precs = torch.cholesky(torch.inverse(new_ccovs)) - self.logits = torch.log(new_mcs) - self.sumlogdiag = torch.sum(torch.log(torch.diagonal(self.precs, dim1=2, dim2=3)),dim=2) - - - def sample_with_mcmc(self): - """Dummy function to overwrite the existing `.sample_with_mcmc()` method.""" - - raise DeprecationWarning( - "MCMC sampling is not yet supported for the conditional MDN." - ) - - def sample_conditional( - self, sample_shape: Tuple[int, int], condition: Tensor = None - ) -> Tensor: - """Samples from the condtional distribution. If a condition - is provided, a new conditional distribution will be calculated. - If no condition is provided, samples will be drawn from the - exisiting condition. - - Args: - n_samples: The number of samples to draw from the conditional distribution. - condition: An array of inputs. Inputs set to NaN are not set, and become inputs to - the resulting distribution. Order is preserved. - - Returns: - samples: Contains samples from the conditional posterior (NxD). - """ - - if condition != None: - self.__conditionalise(condition) - samples = self.sample(sample_shape) - return samples diff --git a/tests/torchutils_test.py b/tests/torchutils_test.py index 20ac71323..a0889cc4f 100644 --- a/tests/torchutils_test.py +++ b/tests/torchutils_test.py @@ -1,196 +1,196 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . - -"""Test PyTorch utility functions.""" -from __future__ import annotations - -import numpy as np -import torch -import torchtestcase -from torch import distributions as distributions -from torch import eye, ones, zeros - -from sbi.utils import torchutils -from tests.test_utils import kl_d_via_monte_carlo - - -# XXX move to pytest? - investigate how to derive from TorchTestCase -class TorchUtilsTest(torchtestcase.TorchTestCase): - def test_split_leading_dim(self): - x = torch.randn(24, 5) - self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) - self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) - self.assertEqual( - torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5) - ) - with self.assertRaises(Exception): - self.assertEqual(torchutils.split_leading_dim(x, []), x) - with self.assertRaises(Exception): - self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x) - - def test_merge_leading_dims(self): - x = torch.randn(2, 3, 4, 5) - self.assertEqual(torchutils.merge_leading_dims(x, 1), x) - self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) - self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) - self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) - with self.assertRaises(Exception): - torchutils.merge_leading_dims(x, 0) - with self.assertRaises(Exception): - torchutils.merge_leading_dims(x, 5) - - def test_split_merge_leading_dims_are_consistent(self): - x = torch.randn(2, 3, 4, 5) - y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) - self.assertEqual(y, x) - y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) - self.assertEqual(y, x) - y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) - self.assertEqual(y, x) - y = torchutils.split_leading_dim( - torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] - ) - self.assertEqual(y, x) - - def test_repeat_rows(self): - x = torch.randn(2, 3, 4, 5) - self.assertEqual(torchutils.repeat_rows(x, 1), x) - y = torchutils.repeat_rows(x, 2) - self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) - self.assertEqual(x[0], y[0]) - self.assertEqual(x[0], y[1]) - self.assertEqual(x[1], y[2]) - self.assertEqual(x[1], y[3]) - with self.assertRaises(Exception): - torchutils.repeat_rows(x, 0) - - def test_logabsdet(self): - size = 10 - matrix = torch.randn(size, size) - logabsdet = torchutils.logabsdet(matrix) - logabsdet_ref = torch.log(torch.abs(matrix.det())) - self.eps = 1e-6 - self.assertEqual(logabsdet, logabsdet_ref) - - def test_random_orthogonal(self): - size = 100 - matrix = torchutils.random_orthogonal(size) - self.assertIsInstance(matrix, torch.Tensor) - self.assertEqual(matrix.shape, torch.Size([size, size])) - self.eps = 1e-5 - unit = eye(size, size) - self.assertEqual(matrix @ matrix.t(), unit) - self.assertEqual(matrix.t() @ matrix, unit) - self.assertEqual(matrix.t(), matrix.inverse()) - self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0)) - - def test_searchsorted(self): - bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations - - left_boundaries = bin_locations[:-1] - right_boundaries = bin_locations[:-1] + 0.1 - mid_points = bin_locations[:-1] + 0.05 - - for inputs in [left_boundaries, right_boundaries, mid_points]: - with self.subTest(inputs=inputs): - idx = torchutils.searchsorted(bin_locations[None, :], inputs) - self.assertEqual(idx, torch.arange(0, 9)) - - def test_searchsorted_arbitrary_shape(self): - shape = [2, 3, 4] - bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1) - inputs = torch.rand(*shape) - idx = torchutils.searchsorted(bin_locations, inputs) - self.assertEqual(idx.shape, inputs.shape) - - -def test_box_uniform_distribution(): - bu1 = torchutils.BoxUniform(low=0.0, high=torch.tensor([3.0, 3.0, 3.0])) - - assert bu1.event_shape == torch.Size([3]) - - -def test_ensure_batch_dim(): - # test if batch dimension is added when parameter is ndim==1 - t1 = torch.tensor([0.0, -1.0, 1.0]) - t2 = torchutils.ensure_theta_batched(t1) - assert t2.ndim == 2 - - # test if batch dimension is added when observation is ndim==1 - t1 = torch.tensor([0.0, -1.0, 1.0]) - t2 = torchutils.ensure_x_batched(t1) - assert t2.ndim == 2 - - # then test if batch dimension is added when observation is ndim==2, e.g. an image - t1 = torch.tensor([[1, 2, 3], [1, 2, 3]]) - t2 = torchutils.ensure_x_batched(t1) - assert t2.ndim == 3 - - -def test_atleast_2d_many(): - t1 = np.array([0.0, -1.0, 1.0]) - t2 = torch.tensor([[1, 2, 3]]) - - t3, t4 = torchutils.atleast_2d_many(t1, t2) - - assert isinstance(t3, torch.Tensor) - assert t3.ndim == 2 - assert t4.ndim == 2 - - -def test_maybe_add_batch_dim_to_size(): - t1 = torch.Size([1]) - t2 = torchutils.maybe_add_batch_dim_to_size(t1) - assert t2 == torch.Size([1, 1]) - - t1 = torch.Size([3]) - t2 = torchutils.maybe_add_batch_dim_to_size(t1) - assert t2 == torch.Size([1, 3]) - - t1 = torch.Size([1, 3]) - t2 = torchutils.maybe_add_batch_dim_to_size(t1) - assert t2 == torch.Size([1, 3]) - - t1 = torch.Size([2, 3]) - t2 = torchutils.maybe_add_batch_dim_to_size(t1) - assert t2 == torch.Size([2, 3]) - - t1 = torch.Size([1, 2, 3]) - t2 = torchutils.maybe_add_batch_dim_to_size(t1) - assert t2 == torch.Size([1, 2, 3]) - - -def test_batched_first_of_batch(): - t = torch.ones(10, 2) - out_t = torchutils.batched_first_of_batch(t) - assert (out_t == torch.ones(1, 2)).all() - - t = torch.ones(1, 2) - out_t = torchutils.batched_first_of_batch(t) - assert (out_t == torch.ones(1, 2)).all() - - -def test_dkl_gauss(): - """ - Test whether for two 1D Gaussians and two 2D Gaussians the Monte-Carlo-based KLd - gives similar results as the torch implementation. - """ - dist1 = ( - distributions.Normal(loc=0.0, scale=1.0), - distributions.MultivariateNormal(zeros(2), eye(2)), - ) - dist2 = ( - distributions.Normal(loc=1.0, scale=0.5), - distributions.MultivariateNormal(ones(2), 0.5 * eye(2)), - ) +# # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# # under the Affero General Public License v3, see . + +# """Test PyTorch utility functions.""" +# from __future__ import annotations + +# import numpy as np +# import torch +# import torchtestcase +# from torch import distributions as distributions +# from torch import eye, ones, zeros + +# from sbi.utils import torchutils +# from tests.test_utils import kl_d_via_monte_carlo + + +# # XXX move to pytest? - investigate how to derive from TorchTestCase +# class TorchUtilsTest(torchtestcase.TorchTestCase): +# def test_split_leading_dim(self): +# x = torch.randn(24, 5) +# self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) +# self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) +# self.assertEqual( +# torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5) +# ) +# with self.assertRaises(Exception): +# self.assertEqual(torchutils.split_leading_dim(x, []), x) +# with self.assertRaises(Exception): +# self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x) + +# def test_merge_leading_dims(self): +# x = torch.randn(2, 3, 4, 5) +# self.assertEqual(torchutils.merge_leading_dims(x, 1), x) +# self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) +# self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) +# self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) +# with self.assertRaises(Exception): +# torchutils.merge_leading_dims(x, 0) +# with self.assertRaises(Exception): +# torchutils.merge_leading_dims(x, 5) + +# def test_split_merge_leading_dims_are_consistent(self): +# x = torch.randn(2, 3, 4, 5) +# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) +# self.assertEqual(y, x) +# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) +# self.assertEqual(y, x) +# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) +# self.assertEqual(y, x) +# y = torchutils.split_leading_dim( +# torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] +# ) +# self.assertEqual(y, x) + +# def test_repeat_rows(self): +# x = torch.randn(2, 3, 4, 5) +# self.assertEqual(torchutils.repeat_rows(x, 1), x) +# y = torchutils.repeat_rows(x, 2) +# self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) +# self.assertEqual(x[0], y[0]) +# self.assertEqual(x[0], y[1]) +# self.assertEqual(x[1], y[2]) +# self.assertEqual(x[1], y[3]) +# with self.assertRaises(Exception): +# torchutils.repeat_rows(x, 0) + +# def test_logabsdet(self): +# size = 10 +# matrix = torch.randn(size, size) +# logabsdet = torchutils.logabsdet(matrix) +# logabsdet_ref = torch.log(torch.abs(matrix.det())) +# self.eps = 1e-6 +# self.assertEqual(logabsdet, logabsdet_ref) + +# def test_random_orthogonal(self): +# size = 100 +# matrix = torchutils.random_orthogonal(size) +# self.assertIsInstance(matrix, torch.Tensor) +# self.assertEqual(matrix.shape, torch.Size([size, size])) +# self.eps = 1e-5 +# unit = eye(size, size) +# self.assertEqual(matrix @ matrix.t(), unit) +# self.assertEqual(matrix.t() @ matrix, unit) +# self.assertEqual(matrix.t(), matrix.inverse()) +# self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0)) + +# def test_searchsorted(self): +# bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations + +# left_boundaries = bin_locations[:-1] +# right_boundaries = bin_locations[:-1] + 0.1 +# mid_points = bin_locations[:-1] + 0.05 + +# for inputs in [left_boundaries, right_boundaries, mid_points]: +# with self.subTest(inputs=inputs): +# idx = torchutils.searchsorted(bin_locations[None, :], inputs) +# self.assertEqual(idx, torch.arange(0, 9)) + +# def test_searchsorted_arbitrary_shape(self): +# shape = [2, 3, 4] +# bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1) +# inputs = torch.rand(*shape) +# idx = torchutils.searchsorted(bin_locations, inputs) +# self.assertEqual(idx.shape, inputs.shape) + + +# def test_box_uniform_distribution(): +# bu1 = torchutils.BoxUniform(low=0.0, high=torch.tensor([3.0, 3.0, 3.0])) + +# assert bu1.event_shape == torch.Size([3]) + + +# def test_ensure_batch_dim(): +# # test if batch dimension is added when parameter is ndim==1 +# t1 = torch.tensor([0.0, -1.0, 1.0]) +# t2 = torchutils.ensure_theta_batched(t1) +# assert t2.ndim == 2 + +# # test if batch dimension is added when observation is ndim==1 +# t1 = torch.tensor([0.0, -1.0, 1.0]) +# t2 = torchutils.ensure_x_batched(t1) +# assert t2.ndim == 2 + +# # then test if batch dimension is added when observation is ndim==2, e.g. an image +# t1 = torch.tensor([[1, 2, 3], [1, 2, 3]]) +# t2 = torchutils.ensure_x_batched(t1) +# assert t2.ndim == 3 + + +# def test_atleast_2d_many(): +# t1 = np.array([0.0, -1.0, 1.0]) +# t2 = torch.tensor([[1, 2, 3]]) + +# t3, t4 = torchutils.atleast_2d_many(t1, t2) + +# assert isinstance(t3, torch.Tensor) +# assert t3.ndim == 2 +# assert t4.ndim == 2 + + +# def test_maybe_add_batch_dim_to_size(): +# t1 = torch.Size([1]) +# t2 = torchutils.maybe_add_batch_dim_to_size(t1) +# assert t2 == torch.Size([1, 1]) + +# t1 = torch.Size([3]) +# t2 = torchutils.maybe_add_batch_dim_to_size(t1) +# assert t2 == torch.Size([1, 3]) + +# t1 = torch.Size([1, 3]) +# t2 = torchutils.maybe_add_batch_dim_to_size(t1) +# assert t2 == torch.Size([1, 3]) + +# t1 = torch.Size([2, 3]) +# t2 = torchutils.maybe_add_batch_dim_to_size(t1) +# assert t2 == torch.Size([2, 3]) + +# t1 = torch.Size([1, 2, 3]) +# t2 = torchutils.maybe_add_batch_dim_to_size(t1) +# assert t2 == torch.Size([1, 2, 3]) + + +# def test_batched_first_of_batch(): +# t = torch.ones(10, 2) +# out_t = torchutils.batched_first_of_batch(t) +# assert (out_t == torch.ones(1, 2)).all() + +# t = torch.ones(1, 2) +# out_t = torchutils.batched_first_of_batch(t) +# assert (out_t == torch.ones(1, 2)).all() + + +# def test_dkl_gauss(): +# """ +# Test whether for two 1D Gaussians and two 2D Gaussians the Monte-Carlo-based KLd +# gives similar results as the torch implementation. +# """ +# dist1 = ( +# distributions.Normal(loc=0.0, scale=1.0), +# distributions.MultivariateNormal(zeros(2), eye(2)), +# ) +# dist2 = ( +# distributions.Normal(loc=1.0, scale=0.5), +# distributions.MultivariateNormal(ones(2), 0.5 * eye(2)), +# ) - for d1, d2 in zip(dist1, dist2): - torch_dkl = distributions.kl.kl_divergence(d1, d2) - monte_carlo_dkl = kl_d_via_monte_carlo(d1, d2, num_samples=5000) +# for d1, d2 in zip(dist1, dist2): +# torch_dkl = distributions.kl.kl_divergence(d1, d2) +# monte_carlo_dkl = kl_d_via_monte_carlo(d1, d2, num_samples=5000) - max_dkl_diff = 0.4 +# max_dkl_diff = 0.4 - assert torch.abs(torch_dkl - monte_carlo_dkl) < max_dkl_diff, ( - f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch" - f" implementation, {torch_dkl}." - ) +# assert torch.abs(torch_dkl - monte_carlo_dkl) < max_dkl_diff, ( +# f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch" +# f" implementation, {torch_dkl}." +# ) From 383c80379e0b744cbbc72f8c4f5f907984c04692 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Sun, 18 Apr 2021 19:17:41 +0200 Subject: [PATCH 09/20] Removed accidentally added files from pull request --- .vscode/settings.json | 3 +- tests/torchutils_test.py | 386 +++++++++++++++++++-------------------- 2 files changed, 194 insertions(+), 195 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index cf36b975a..a3e50d6e5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -72,6 +72,5 @@ "autoDocstring.customTemplatePath": ".vscode/autodocstring.template", // // Signal that we are using shared workspace settings in .vscode - "window.title": "sbi.vscode:: ${dirty}${activeEditorShort}${separator}${rootName}${separator}${appName}", - "python.pythonPath": "/home/jnsbck/Applications/anaconda3/envs/sbi_env/bin/python" + "window.title": "sbi.vscode:: ${dirty}${activeEditorShort}${separator}${rootName}${separator}${appName}" } \ No newline at end of file diff --git a/tests/torchutils_test.py b/tests/torchutils_test.py index a0889cc4f..20ac71323 100644 --- a/tests/torchutils_test.py +++ b/tests/torchutils_test.py @@ -1,196 +1,196 @@ -# # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# # under the Affero General Public License v3, see . - -# """Test PyTorch utility functions.""" -# from __future__ import annotations - -# import numpy as np -# import torch -# import torchtestcase -# from torch import distributions as distributions -# from torch import eye, ones, zeros - -# from sbi.utils import torchutils -# from tests.test_utils import kl_d_via_monte_carlo - - -# # XXX move to pytest? - investigate how to derive from TorchTestCase -# class TorchUtilsTest(torchtestcase.TorchTestCase): -# def test_split_leading_dim(self): -# x = torch.randn(24, 5) -# self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) -# self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) -# self.assertEqual( -# torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5) -# ) -# with self.assertRaises(Exception): -# self.assertEqual(torchutils.split_leading_dim(x, []), x) -# with self.assertRaises(Exception): -# self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x) - -# def test_merge_leading_dims(self): -# x = torch.randn(2, 3, 4, 5) -# self.assertEqual(torchutils.merge_leading_dims(x, 1), x) -# self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) -# self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) -# self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) -# with self.assertRaises(Exception): -# torchutils.merge_leading_dims(x, 0) -# with self.assertRaises(Exception): -# torchutils.merge_leading_dims(x, 5) - -# def test_split_merge_leading_dims_are_consistent(self): -# x = torch.randn(2, 3, 4, 5) -# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) -# self.assertEqual(y, x) -# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) -# self.assertEqual(y, x) -# y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) -# self.assertEqual(y, x) -# y = torchutils.split_leading_dim( -# torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] -# ) -# self.assertEqual(y, x) - -# def test_repeat_rows(self): -# x = torch.randn(2, 3, 4, 5) -# self.assertEqual(torchutils.repeat_rows(x, 1), x) -# y = torchutils.repeat_rows(x, 2) -# self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) -# self.assertEqual(x[0], y[0]) -# self.assertEqual(x[0], y[1]) -# self.assertEqual(x[1], y[2]) -# self.assertEqual(x[1], y[3]) -# with self.assertRaises(Exception): -# torchutils.repeat_rows(x, 0) - -# def test_logabsdet(self): -# size = 10 -# matrix = torch.randn(size, size) -# logabsdet = torchutils.logabsdet(matrix) -# logabsdet_ref = torch.log(torch.abs(matrix.det())) -# self.eps = 1e-6 -# self.assertEqual(logabsdet, logabsdet_ref) - -# def test_random_orthogonal(self): -# size = 100 -# matrix = torchutils.random_orthogonal(size) -# self.assertIsInstance(matrix, torch.Tensor) -# self.assertEqual(matrix.shape, torch.Size([size, size])) -# self.eps = 1e-5 -# unit = eye(size, size) -# self.assertEqual(matrix @ matrix.t(), unit) -# self.assertEqual(matrix.t() @ matrix, unit) -# self.assertEqual(matrix.t(), matrix.inverse()) -# self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0)) - -# def test_searchsorted(self): -# bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations - -# left_boundaries = bin_locations[:-1] -# right_boundaries = bin_locations[:-1] + 0.1 -# mid_points = bin_locations[:-1] + 0.05 - -# for inputs in [left_boundaries, right_boundaries, mid_points]: -# with self.subTest(inputs=inputs): -# idx = torchutils.searchsorted(bin_locations[None, :], inputs) -# self.assertEqual(idx, torch.arange(0, 9)) - -# def test_searchsorted_arbitrary_shape(self): -# shape = [2, 3, 4] -# bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1) -# inputs = torch.rand(*shape) -# idx = torchutils.searchsorted(bin_locations, inputs) -# self.assertEqual(idx.shape, inputs.shape) - - -# def test_box_uniform_distribution(): -# bu1 = torchutils.BoxUniform(low=0.0, high=torch.tensor([3.0, 3.0, 3.0])) - -# assert bu1.event_shape == torch.Size([3]) - - -# def test_ensure_batch_dim(): -# # test if batch dimension is added when parameter is ndim==1 -# t1 = torch.tensor([0.0, -1.0, 1.0]) -# t2 = torchutils.ensure_theta_batched(t1) -# assert t2.ndim == 2 - -# # test if batch dimension is added when observation is ndim==1 -# t1 = torch.tensor([0.0, -1.0, 1.0]) -# t2 = torchutils.ensure_x_batched(t1) -# assert t2.ndim == 2 - -# # then test if batch dimension is added when observation is ndim==2, e.g. an image -# t1 = torch.tensor([[1, 2, 3], [1, 2, 3]]) -# t2 = torchutils.ensure_x_batched(t1) -# assert t2.ndim == 3 - - -# def test_atleast_2d_many(): -# t1 = np.array([0.0, -1.0, 1.0]) -# t2 = torch.tensor([[1, 2, 3]]) - -# t3, t4 = torchutils.atleast_2d_many(t1, t2) - -# assert isinstance(t3, torch.Tensor) -# assert t3.ndim == 2 -# assert t4.ndim == 2 - - -# def test_maybe_add_batch_dim_to_size(): -# t1 = torch.Size([1]) -# t2 = torchutils.maybe_add_batch_dim_to_size(t1) -# assert t2 == torch.Size([1, 1]) - -# t1 = torch.Size([3]) -# t2 = torchutils.maybe_add_batch_dim_to_size(t1) -# assert t2 == torch.Size([1, 3]) - -# t1 = torch.Size([1, 3]) -# t2 = torchutils.maybe_add_batch_dim_to_size(t1) -# assert t2 == torch.Size([1, 3]) - -# t1 = torch.Size([2, 3]) -# t2 = torchutils.maybe_add_batch_dim_to_size(t1) -# assert t2 == torch.Size([2, 3]) - -# t1 = torch.Size([1, 2, 3]) -# t2 = torchutils.maybe_add_batch_dim_to_size(t1) -# assert t2 == torch.Size([1, 2, 3]) - - -# def test_batched_first_of_batch(): -# t = torch.ones(10, 2) -# out_t = torchutils.batched_first_of_batch(t) -# assert (out_t == torch.ones(1, 2)).all() - -# t = torch.ones(1, 2) -# out_t = torchutils.batched_first_of_batch(t) -# assert (out_t == torch.ones(1, 2)).all() - - -# def test_dkl_gauss(): -# """ -# Test whether for two 1D Gaussians and two 2D Gaussians the Monte-Carlo-based KLd -# gives similar results as the torch implementation. -# """ -# dist1 = ( -# distributions.Normal(loc=0.0, scale=1.0), -# distributions.MultivariateNormal(zeros(2), eye(2)), -# ) -# dist2 = ( -# distributions.Normal(loc=1.0, scale=0.5), -# distributions.MultivariateNormal(ones(2), 0.5 * eye(2)), -# ) +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +"""Test PyTorch utility functions.""" +from __future__ import annotations + +import numpy as np +import torch +import torchtestcase +from torch import distributions as distributions +from torch import eye, ones, zeros + +from sbi.utils import torchutils +from tests.test_utils import kl_d_via_monte_carlo + + +# XXX move to pytest? - investigate how to derive from TorchTestCase +class TorchUtilsTest(torchtestcase.TorchTestCase): + def test_split_leading_dim(self): + x = torch.randn(24, 5) + self.assertEqual(torchutils.split_leading_dim(x, [-1]), x) + self.assertEqual(torchutils.split_leading_dim(x, [2, -1]), x.view(2, 12, 5)) + self.assertEqual( + torchutils.split_leading_dim(x, [2, 3, -1]), x.view(2, 3, 4, 5) + ) + with self.assertRaises(Exception): + self.assertEqual(torchutils.split_leading_dim(x, []), x) + with self.assertRaises(Exception): + self.assertEqual(torchutils.split_leading_dim(x, [5, 5]), x) + + def test_merge_leading_dims(self): + x = torch.randn(2, 3, 4, 5) + self.assertEqual(torchutils.merge_leading_dims(x, 1), x) + self.assertEqual(torchutils.merge_leading_dims(x, 2), x.view(6, 4, 5)) + self.assertEqual(torchutils.merge_leading_dims(x, 3), x.view(24, 5)) + self.assertEqual(torchutils.merge_leading_dims(x, 4), x.view(120)) + with self.assertRaises(Exception): + torchutils.merge_leading_dims(x, 0) + with self.assertRaises(Exception): + torchutils.merge_leading_dims(x, 5) + + def test_split_merge_leading_dims_are_consistent(self): + x = torch.randn(2, 3, 4, 5) + y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 1), [2]) + self.assertEqual(y, x) + y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 2), [2, 3]) + self.assertEqual(y, x) + y = torchutils.split_leading_dim(torchutils.merge_leading_dims(x, 3), [2, 3, 4]) + self.assertEqual(y, x) + y = torchutils.split_leading_dim( + torchutils.merge_leading_dims(x, 4), [2, 3, 4, 5] + ) + self.assertEqual(y, x) + + def test_repeat_rows(self): + x = torch.randn(2, 3, 4, 5) + self.assertEqual(torchutils.repeat_rows(x, 1), x) + y = torchutils.repeat_rows(x, 2) + self.assertEqual(y.shape, torch.Size([4, 3, 4, 5])) + self.assertEqual(x[0], y[0]) + self.assertEqual(x[0], y[1]) + self.assertEqual(x[1], y[2]) + self.assertEqual(x[1], y[3]) + with self.assertRaises(Exception): + torchutils.repeat_rows(x, 0) + + def test_logabsdet(self): + size = 10 + matrix = torch.randn(size, size) + logabsdet = torchutils.logabsdet(matrix) + logabsdet_ref = torch.log(torch.abs(matrix.det())) + self.eps = 1e-6 + self.assertEqual(logabsdet, logabsdet_ref) + + def test_random_orthogonal(self): + size = 100 + matrix = torchutils.random_orthogonal(size) + self.assertIsInstance(matrix, torch.Tensor) + self.assertEqual(matrix.shape, torch.Size([size, size])) + self.eps = 1e-5 + unit = eye(size, size) + self.assertEqual(matrix @ matrix.t(), unit) + self.assertEqual(matrix.t() @ matrix, unit) + self.assertEqual(matrix.t(), matrix.inverse()) + self.assertEqual(torch.abs(matrix.det()), torch.tensor(1.0)) + + def test_searchsorted(self): + bin_locations = torch.linspace(0, 1, 10) # 9 bins == 10 locations + + left_boundaries = bin_locations[:-1] + right_boundaries = bin_locations[:-1] + 0.1 + mid_points = bin_locations[:-1] + 0.05 + + for inputs in [left_boundaries, right_boundaries, mid_points]: + with self.subTest(inputs=inputs): + idx = torchutils.searchsorted(bin_locations[None, :], inputs) + self.assertEqual(idx, torch.arange(0, 9)) + + def test_searchsorted_arbitrary_shape(self): + shape = [2, 3, 4] + bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1) + inputs = torch.rand(*shape) + idx = torchutils.searchsorted(bin_locations, inputs) + self.assertEqual(idx.shape, inputs.shape) + + +def test_box_uniform_distribution(): + bu1 = torchutils.BoxUniform(low=0.0, high=torch.tensor([3.0, 3.0, 3.0])) + + assert bu1.event_shape == torch.Size([3]) + + +def test_ensure_batch_dim(): + # test if batch dimension is added when parameter is ndim==1 + t1 = torch.tensor([0.0, -1.0, 1.0]) + t2 = torchutils.ensure_theta_batched(t1) + assert t2.ndim == 2 + + # test if batch dimension is added when observation is ndim==1 + t1 = torch.tensor([0.0, -1.0, 1.0]) + t2 = torchutils.ensure_x_batched(t1) + assert t2.ndim == 2 + + # then test if batch dimension is added when observation is ndim==2, e.g. an image + t1 = torch.tensor([[1, 2, 3], [1, 2, 3]]) + t2 = torchutils.ensure_x_batched(t1) + assert t2.ndim == 3 + + +def test_atleast_2d_many(): + t1 = np.array([0.0, -1.0, 1.0]) + t2 = torch.tensor([[1, 2, 3]]) + + t3, t4 = torchutils.atleast_2d_many(t1, t2) + + assert isinstance(t3, torch.Tensor) + assert t3.ndim == 2 + assert t4.ndim == 2 + + +def test_maybe_add_batch_dim_to_size(): + t1 = torch.Size([1]) + t2 = torchutils.maybe_add_batch_dim_to_size(t1) + assert t2 == torch.Size([1, 1]) + + t1 = torch.Size([3]) + t2 = torchutils.maybe_add_batch_dim_to_size(t1) + assert t2 == torch.Size([1, 3]) + + t1 = torch.Size([1, 3]) + t2 = torchutils.maybe_add_batch_dim_to_size(t1) + assert t2 == torch.Size([1, 3]) + + t1 = torch.Size([2, 3]) + t2 = torchutils.maybe_add_batch_dim_to_size(t1) + assert t2 == torch.Size([2, 3]) + + t1 = torch.Size([1, 2, 3]) + t2 = torchutils.maybe_add_batch_dim_to_size(t1) + assert t2 == torch.Size([1, 2, 3]) + + +def test_batched_first_of_batch(): + t = torch.ones(10, 2) + out_t = torchutils.batched_first_of_batch(t) + assert (out_t == torch.ones(1, 2)).all() + + t = torch.ones(1, 2) + out_t = torchutils.batched_first_of_batch(t) + assert (out_t == torch.ones(1, 2)).all() + + +def test_dkl_gauss(): + """ + Test whether for two 1D Gaussians and two 2D Gaussians the Monte-Carlo-based KLd + gives similar results as the torch implementation. + """ + dist1 = ( + distributions.Normal(loc=0.0, scale=1.0), + distributions.MultivariateNormal(zeros(2), eye(2)), + ) + dist2 = ( + distributions.Normal(loc=1.0, scale=0.5), + distributions.MultivariateNormal(ones(2), 0.5 * eye(2)), + ) -# for d1, d2 in zip(dist1, dist2): -# torch_dkl = distributions.kl.kl_divergence(d1, d2) -# monte_carlo_dkl = kl_d_via_monte_carlo(d1, d2, num_samples=5000) + for d1, d2 in zip(dist1, dist2): + torch_dkl = distributions.kl.kl_divergence(d1, d2) + monte_carlo_dkl = kl_d_via_monte_carlo(d1, d2, num_samples=5000) -# max_dkl_diff = 0.4 + max_dkl_diff = 0.4 -# assert torch.abs(torch_dkl - monte_carlo_dkl) < max_dkl_diff, ( -# f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch" -# f" implementation, {torch_dkl}." -# ) + assert torch.abs(torch_dkl - monte_carlo_dkl) < max_dkl_diff, ( + f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch" + f" implementation, {torch_dkl}." + ) From a720a93ba08b42acd9efea8dc45dd66b8a057500 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Sat, 24 Apr 2021 21:57:26 +0200 Subject: [PATCH 10/20] Small rewrites. Got rid of cholesky decomps. --- sbi/inference/posteriors/direct_posterior.py | 69 +++++++++----------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 44e958034..6d64f34ea 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -389,23 +389,22 @@ def extract_and_transform_mog( else: encoded_x = nn._embedding_net(context) - logits, m, prec, sumlogdiag, _ = dist.get_mixture_components(encoded_x) + logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x) norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) scale = nn._transform._transforms[0]._scale shift = nn._transform._transforms[0]._shift - means_transformed = (m - shift) / scale + means_transformed = (means - shift) / scale A = scale * torch.eye(means_transformed.shape[2]) - precision_factors_transformed = torch.cholesky(A @ prec @ A) + precfs_transformed = A @ precfs sumlogdiag = torch.sum( - torch.log(torch.diagonal(precision_factors_transformed, dim1=2, dim2=3)), - dim=2, + torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2, ) - return norm_logits, means_transformed, precision_factors_transformed, sumlogdiag + return norm_logits, means_transformed, precfs_transformed, sumlogdiag def condition_mog( self, @@ -444,50 +443,44 @@ def condition_mog( raise ValueError( "The chosen condition is not within the prior support." ) + y = condition[:, ~mask] - y = condition[0, ~mask].reshape(1, -1) - - k, D = means.shape[1:] - - prec = precfs @ precfs.transpose(3, 2) - covs = torch.inverse(prec) - mcs = torch.exp(logits) + n_components = means.shape[1] mu_x = means[:, :, mask] mu_y = means[:, :, ~mask] - S_xx = covs[:, :, mask] - S_xx = S_xx[:, :, :, mask] - - S_yy = covs[:, :, ~mask] - S_yy = S_yy[:, :, :, ~mask] + precfs_xx = precfs[:, :, mask] + precfs_xx = precfs_xx[:, :, :, mask] + precs_xx = precfs_xx.transpose(3, 2) @ precfs_xx - S_xy = covs[:, :, mask] - S_xy = S_xy[:, :, :, ~mask] + precfs_yy = precfs[:, :, ~mask] + precfs_yy = precfs_yy[:, :, :, ~mask] + precs_yy = precfs_yy.transpose(3, 2) @ precfs_yy - means = mu_x + ( - (S_xy @ torch.inverse(S_yy) @ (y - mu_y).view(1, k, -1, 1)).transpose(3, 2) - ).view(1, k, -1) - cov = S_xx - S_xy @ torch.inverse(S_yy) @ S_xy.transpose(3, 2) + precs = precfs.transpose(3, 2) @ precfs + precs_xy = precs[:, :, mask] + precs_xy = precs_xy[:, :, :, ~mask] - prec_yy = torch.inverse(S_yy) - precf_yy = torch.cholesky(prec_yy) + means = mu_x - ( + torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_components, -1, 1) + ).view(1, n_components, -1) - sumlogdiag = torch.sum( - torch.log(torch.diagonal(precf_yy, dim1=2, dim2=3)), dim=2 + diags = torch.diagonal(precfs_yy, dim1=2, dim2=3) + sumlogdiag_yy = torch.sum(torch.log(diags), dim=2) + log_prob = mdn.log_prob_mog( + y, torch.tensor([[0.0]]), mu_y, precs_yy, sumlogdiag_yy ) - log_prob = mdn.log_prob_mog(y, torch.tensor([[0.0]]), mu_y, prec_yy, sumlogdiag) - fk = torch.exp(log_prob) # Used for normalizing the mc # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. - new_mcs = mcs * fk + new_mcs = torch.exp(logits + log_prob) new_mcs = new_mcs / new_mcs.sum() - - precfs = torch.cholesky(torch.inverse(cov)) logits = torch.log(new_mcs) - sumlogdiag = torch.sum(torch.log(torch.diagonal(precfs, dim1=2, dim2=3)), dim=2) - return logits, means, precfs, sumlogdiag + sumlogdiag = torch.sum( + torch.log(torch.diagonal(precfs_xx, dim1=2, dim2=3)), dim=2 + ) + return logits, means, precfs_xx, sumlogdiag def sample_conditional( self, @@ -538,8 +531,8 @@ def sample_conditional( if type(self.net._distribution) is mdn: num_samples = torch.Size(sample_shape).numel() - logits, means, precfs, sumlogdiag = self.extract_and_transform_mog(x) - logits, means, precfs, sumlogdiag = self.condition_mog( + logits, means, precfs, _ = self.extract_and_transform_mog(x) + logits, means, precfs, _ = self.condition_mog( condition, dims_to_sample, logits, means, precfs ) print( @@ -592,7 +585,7 @@ def log_prob_conditional( ) batch_size, dim = theta.shape - prec = precfs @ precfs.transpose(3, 2) + prec = precfs.transpose(3, 2) @ precfs self.net.eval() # leakage correction requires eval mode From a0e880091656e785f1554ffb7e07c4fa6fec83ec Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Sat, 24 Apr 2021 22:07:58 +0200 Subject: [PATCH 11/20] renamed vars. --- sbi/inference/posteriors/direct_posterior.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 6d64f34ea..52c321e60 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -426,7 +426,9 @@ def condition_mog( support = self._prior.support - mask = torch.zeros(means.shape[-1], dtype=bool) + n_mixtures, n_dims = means.shape[1:] + + mask = torch.zeros(n_dims, dtype=bool) mask[dims] = True # check whether the condition is within the prior bounds @@ -437,16 +439,14 @@ def condition_mog( cond_ubound = support.upper_bound[~mask] cond_lbound = support.lower_bound[~mask] within_support = torch.logical_and( - cond_lbound <= condition[:, ~mask], condition[:, ~mask] <= cond_ubound + cond_lbound <= condition[:, ~mask], cond_ubound >= condition[:, ~mask] ) if ~torch.all(within_support): raise ValueError( "The chosen condition is not within the prior support." ) - y = condition[:, ~mask] - - n_components = means.shape[1] + y = condition[:, ~mask] mu_x = means[:, :, mask] mu_y = means[:, :, ~mask] @@ -463,13 +463,13 @@ def condition_mog( precs_xy = precs_xy[:, :, :, ~mask] means = mu_x - ( - torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_components, -1, 1) - ).view(1, n_components, -1) + torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_mixtures, -1, 1) + ).view(1, n_mixtures, -1) diags = torch.diagonal(precfs_yy, dim1=2, dim2=3) sumlogdiag_yy = torch.sum(torch.log(diags), dim=2) log_prob = mdn.log_prob_mog( - y, torch.tensor([[0.0]]), mu_y, precs_yy, sumlogdiag_yy + y, torch.zeros((1, 1)), mu_y, precs_yy, sumlogdiag_yy ) # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. From 660a9cef596bc90f001827253fc3edcbbfccdb85 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Thu, 13 May 2021 22:04:40 +0200 Subject: [PATCH 12/20] Adressing critique: Fixed Docstrings. Restructure. Comments. Warnings. --- sbi/inference/posteriors/direct_posterior.py | 145 +++---------------- sbi/utils/__init__.py | 2 + sbi/utils/conditional_density.py | 140 ++++++++++++++++++ 3 files changed, 164 insertions(+), 123 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 52c321e60..d9a2227ef 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -6,6 +6,7 @@ import numpy as np import torch from torch import Tensor, log, nn +import warnings from sbi import utils as utils from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn @@ -17,6 +18,7 @@ ensure_theta_batched, ensure_x_batched, ) +from sbi.utils.conditional_density import extract_and_transform_mog, condition_mog class DirectPosterior(NeuralPosterior): @@ -369,119 +371,6 @@ def sample( return samples.reshape((*sample_shape, -1)) - def extract_and_transform_mog( - self, context: Tensor = None - ) -> Tuple[Tensor, Tensor, Tensor]: - """Extracts the Mixture of Gaussians (MoG) parameters - from the MDN at either the default x or input x. - - Args: - x: x at which to evaluate the MDN in order - to extract the MoG parameters. - """ - - # extract and rescale means, mixture componenets and covariances - nn = self.net - dist = nn._distribution - - if context == None: - encoded_x = nn._embedding_net(self.default_x) - else: - encoded_x = nn._embedding_net(context) - - logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x) - norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) - - scale = nn._transform._transforms[0]._scale - shift = nn._transform._transforms[0]._shift - - means_transformed = (means - shift) / scale - - A = scale * torch.eye(means_transformed.shape[2]) - precfs_transformed = A @ precfs - - sumlogdiag = torch.sum( - torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2, - ) - - return norm_logits, means_transformed, precfs_transformed, sumlogdiag - - def condition_mog( - self, - condition: Tensor, - dims: List[int], - logits: Tensor, - means: Tensor, - precfs: Tensor, - ): - """Finds the conditional distribution p(X|Y) for a GMM. - - Args: - condition: An array of inputs. Inputs set to NaN are not set, and become inputs to - the resulting distribution. Order is preserved. - - Raises: - ValueError: The chosen condition is not within the prior support. - """ - - support = self._prior.support - - n_mixtures, n_dims = means.shape[1:] - - mask = torch.zeros(n_dims, dtype=bool) - mask[dims] = True - - # check whether the condition is within the prior bounds - if ( - type(self._prior) is torch.distributions.uniform.Uniform - or type(self._prior) is utils.torchutils.BoxUniform - ): - cond_ubound = support.upper_bound[~mask] - cond_lbound = support.lower_bound[~mask] - within_support = torch.logical_and( - cond_lbound <= condition[:, ~mask], cond_ubound >= condition[:, ~mask] - ) - if ~torch.all(within_support): - raise ValueError( - "The chosen condition is not within the prior support." - ) - - y = condition[:, ~mask] - mu_x = means[:, :, mask] - mu_y = means[:, :, ~mask] - - precfs_xx = precfs[:, :, mask] - precfs_xx = precfs_xx[:, :, :, mask] - precs_xx = precfs_xx.transpose(3, 2) @ precfs_xx - - precfs_yy = precfs[:, :, ~mask] - precfs_yy = precfs_yy[:, :, :, ~mask] - precs_yy = precfs_yy.transpose(3, 2) @ precfs_yy - - precs = precfs.transpose(3, 2) @ precfs - precs_xy = precs[:, :, mask] - precs_xy = precs_xy[:, :, :, ~mask] - - means = mu_x - ( - torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_mixtures, -1, 1) - ).view(1, n_mixtures, -1) - - diags = torch.diagonal(precfs_yy, dim1=2, dim2=3) - sumlogdiag_yy = torch.sum(torch.log(diags), dim=2) - log_prob = mdn.log_prob_mog( - y, torch.zeros((1, 1)), mu_y, precs_yy, sumlogdiag_yy - ) - - # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. - new_mcs = torch.exp(logits + log_prob) - new_mcs = new_mcs / new_mcs.sum() - logits = torch.log(new_mcs) - - sumlogdiag = torch.sum( - torch.log(torch.diagonal(precfs_xx, dim1=2, dim2=3)), dim=2 - ) - return logits, means, precfs_xx, sumlogdiag - def sample_conditional( self, sample_shape: Shape, @@ -531,12 +420,14 @@ def sample_conditional( if type(self.net._distribution) is mdn: num_samples = torch.Size(sample_shape).numel() - logits, means, precfs, _ = self.extract_and_transform_mog(x) - logits, means, precfs, _ = self.condition_mog( - condition, dims_to_sample, logits, means, precfs + logits, means, precfs, _ = extract_and_transform_mog(self, x) + logits, means, precfs, _ = condition_mog( + self._prior, condition, dims_to_sample, logits, means, precfs ) - print( - "Warning: Sampling MoG analytically. Some of the samples might not be within the prior support!" + + # Currently difficult to integrate `sample_posterior_within_prior` + warnings.warn( + "Sampling MoG analytically. Some of the samples might not be within the prior support!" ) samples = mdn.sample_mog(num_samples, logits, means, precfs) return samples.detach().reshape((*sample_shape, -1)) @@ -560,8 +451,10 @@ def log_prob_conditional( dims_to_evaluate: List[int], x: Optional[Tensor] = None, ) -> Tensor: - """Evaluates the Mixture of Gaussian (MoG) - probability density function at a value x. + """Evaluates the conditional posterior probability of a MDN at a context x for a value theta given a condition. + + This function only works for MDN based posteriors, becuase evaluation is done analytically. For all other density estimators a `NotImplementedError` will be + raised! Args: theta: Parameters $\theta$. @@ -575,9 +468,11 @@ def log_prob_conditional( fall back onto `x` passed to `set_default_x()`. Returns: - log_prob: `(len(θ),)`-shaped log posterior probability $\log p(\theta|x - for θ in the support of the prior, -∞ (corresponding to 0 probability) outside. + log_prob: `(len(θ),)`-shaped normalized (!) log posterior probability + $\log p(\theta|x) for θ in the support of the prior, -∞ (corresponding + to 0 probability) outside. """ + if type(self.net._distribution) == mdn: logits, means, precfs, sumlogdiag = self.extract_and_transform_mog(x) logits, means, precfs, sumlogdiag = self.condition_mog( @@ -592,7 +487,11 @@ def log_prob_conditional( if dim != len(dims_to_evaluate): X = X[:, dims_to_evaluate] - print("Warning: Probabilities are not adjusted for leakage.") + # Implementing leakage correction is difficult for conditioned MDNs, + # because samples from self i.e. the full posterior are used rather + # then from the new, conditioned posterior. + warnings.warn("Probabilities are not adjusted for leakage.") + log_prob = mdn.log_prob_mog( theta, logits.repeat(batch_size, 1), diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index acb041db1..69f0fca03 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -2,6 +2,8 @@ from sbi.utils.conditional_density import ( conditional_corrcoeff, eval_conditional_density, + extract_and_transform_mog, + condition_mog, ) from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn from sbi.utils.io import get_data_root, get_log_root, get_project_root diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 5cba7d90d..0ad80b54c 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -2,6 +2,7 @@ # under the Affero General Public License v3, see . +from sbi.inference.posteriors.direct_posterior import DirectPosterior from typing import Any, Callable, List, Optional, Tuple, Union from warnings import warn @@ -9,6 +10,7 @@ from torch import Tensor from sbi.utils.torchutils import ensure_theta_batched +from sbi.inference.posteriors.direct_posterior import DirectPosterior def eval_conditional_density( @@ -321,3 +323,141 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: """ limits_diff = torch.prod(limits[:, 1] - limits[:, 0]) return probs * probs.numel() / limits_diff / torch.sum(probs) + + +def extract_and_transform_mog( + posterior: DirectPosterior, context: Tensor = None +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Extracts the Mixture of Gaussians (MoG) parameters + from an MDN based DirectPosterior at either the default x or input x. + + Args: + posterior: DirectPosterior instance. + context: Conditioning context for posterior $p(\theta|x)$. If not provided, + fall back onto `x` passed to `set_default_x()`. + + Returns: + norm_logits: Normalised log weights of the underyling MoG. + (batch_size, n_mixtures) + means_transformed: Recentred and rescaled means of the underlying MoG + (batch_size, n_mixtures, n_dims) + precfs_transformed: Rescaled precision factors of the underlying MoG. + (batch_size, n_mixtures, n_dims, n_dims) + sumlogdiag: Sum of the log of the diagonal of the precision factors + of the new conditional distribution. (batch_size, n_mixtures) + """ + + # extract and rescale means, mixture componenets and covariances + nn = posterior.net + dist = posterior_net._distribution + + if context == None: + encoded_x = nn._embedding_net(posterior.default_x) + else: + encoded_x = nn._embedding_net(context) + + logits, means, _, sumlogdiag, precfs = dist.get_mixture_components(encoded_x) + norm_logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + scale = nn._transform._transforms[0]._scale + shift = nn._transform._transforms[0]._shift + + means_transformed = (means - shift) / scale + + A = scale * torch.eye(means_transformed.shape[2]) + precfs_transformed = A @ precfs + + sumlogdiag = torch.sum( + torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2, + ) + + return norm_logits, means_transformed, precfs_transformed, sumlogdiag + + +def condition_mog( + prior: "Prior", + condition: Tensor, + dims: List[int], + logits: Tensor, + means: Tensor, + precfs: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Finds the conditional distribution p(X|Y) for a GMM. + + Args: + prior: Prior Distribution. Used to check if condition within support. + condition: Parameter set that all dimensions not specified in + `dims_to_sample` will be fixed to. Should contain dim_theta elements, + i.e. it could e.g. be a sample from the posterior distribution. + The entries at all `dims_to_sample` will be ignored. + dims_to_sample: Which dimensions to sample from. The dimensions not + specified in `dims_to_sample` will be fixed to values given in + `condition`. + logits: Log weights of the MoG. (batch_size, n_mixtures) + means: Means of the MoG. (batch_size, n_mixtures, n_dims) + precfs: Precision factors of the MoG. + (batch_size, n_mixtures, n_dims, n_dims) + + Raises: + ValueError: The chosen condition is not within the prior support. + + Returns: + logits: Log weights of the conditioned MoG. (batch_size, n_mixtures) + means: Means of the conditioned MoG. (batch_size, n_mixtures, n_dims) + precfs_xx: Precision factors of the MoG. + (batch_size, n_mixtures, n_dims, n_dims) + sumlogdiag: Sum of the log of the diagonal of the precision factors + of the new conditional distribution. (batch_size, n_mixtures) + """ + + support = prior.support + + n_mixtures, n_dims = means.shape[1:] + + mask = torch.zeros(n_dims, dtype=bool) + mask[dims] = True + + # check whether the condition is within the prior bounds + if ( + type(prior) is torch.distributions.uniform.Uniform + or type(prior) is utils.torchutils.BoxUniform + ): + cond_ubound = support.upper_bound[~mask] + cond_lbound = support.lower_bound[~mask] + within_support = torch.logical_and( + cond_lbound <= condition[:, ~mask], cond_ubound >= condition[:, ~mask] + ) + if ~torch.all(within_support): + raise ValueError("The chosen condition is not within the prior support.") + + y = condition[:, ~mask] + mu_x = means[:, :, mask] + mu_y = means[:, :, ~mask] + + precfs_xx = precfs[:, :, mask] + precfs_xx = precfs_xx[:, :, :, mask] + precs_xx = precfs_xx.transpose(3, 2) @ precfs_xx + + precfs_yy = precfs[:, :, ~mask] + precfs_yy = precfs_yy[:, :, :, ~mask] + precs_yy = precfs_yy.transpose(3, 2) @ precfs_yy + + precs = precfs.transpose(3, 2) @ precfs + precs_xy = precs[:, :, mask] + precs_xy = precs_xy[:, :, :, ~mask] + + means = mu_x - ( + torch.inverse(precs_xx) @ precs_xy @ (y - mu_y).view(1, n_mixtures, -1, 1) + ).view(1, n_mixtures, -1) + + diags = torch.diagonal(precfs_yy, dim1=2, dim2=3) + sumlogdiag_yy = torch.sum(torch.log(diags), dim=2) + log_prob = mdn.log_prob_mog(y, torch.zeros((1, 1)), mu_y, precs_yy, sumlogdiag_yy) + + # Normalize the mixing coef: p(X|Y) = p(Y,X) / p(Y) using the marginal dist. + new_mcs = torch.exp(logits + log_prob) + new_mcs = new_mcs / new_mcs.sum() + logits = torch.log(new_mcs) + + sumlogdiag = torch.sum(torch.log(torch.diagonal(precfs_xx, dim1=2, dim2=3)), dim=2) + return logits, means, precfs_xx, sumlogdiag From c7e0904a9fd95db9bb18defc537e9d48ef610bb9 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Thu, 13 May 2021 22:27:09 +0200 Subject: [PATCH 13/20] looking for bug. --- sbi/inference/posteriors/direct_posterior.py | 10 +++++----- sbi/utils/conditional_density.py | 9 +++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index d9a2227ef..4ee52e69c 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import Tensor, log, nn -import warnings +from warnings import warn from sbi import utils as utils from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn @@ -426,7 +426,7 @@ def sample_conditional( ) # Currently difficult to integrate `sample_posterior_within_prior` - warnings.warn( + warn( "Sampling MoG analytically. Some of the samples might not be within the prior support!" ) samples = mdn.sample_mog(num_samples, logits, means, precfs) @@ -490,7 +490,7 @@ def log_prob_conditional( # Implementing leakage correction is difficult for conditioned MDNs, # because samples from self i.e. the full posterior are used rather # then from the new, conditioned posterior. - warnings.warn("Probabilities are not adjusted for leakage.") + warn("Probabilities are not adjusted for leakage.") log_prob = mdn.log_prob_mog( theta, @@ -591,7 +591,7 @@ class PotentialFunctionProvider: """ def __call__( - self, prior, posterior_nn: nn.Module, x: Tensor, mcmc_method: str, + self, prior, posterior_nn: nn.Module, x: Tensor, mcmc_method: str ) -> Callable: """Return potential function. @@ -627,7 +627,7 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat: with torch.set_grad_enabled(False): target_log_prob = self.posterior_nn.log_prob( - inputs=theta.to(self.x.device), context=x_repeated, + inputs=theta.to(self.x.device), context=x_repeated ) is_within_prior = torch.isfinite(self.prior.log_prob(theta)) target_log_prob[~is_within_prior] = -float("Inf") diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 0ad80b54c..91b7d4518 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -10,7 +10,7 @@ from torch import Tensor from sbi.utils.torchutils import ensure_theta_batched -from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.utils.torchutils import BoxUniform def eval_conditional_density( @@ -368,7 +368,7 @@ def extract_and_transform_mog( precfs_transformed = A @ precfs sumlogdiag = torch.sum( - torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2, + torch.log(torch.diagonal(precfs_transformed, dim1=2, dim2=3)), dim=2 ) return norm_logits, means_transformed, precfs_transformed, sumlogdiag @@ -418,10 +418,7 @@ def condition_mog( mask[dims] = True # check whether the condition is within the prior bounds - if ( - type(prior) is torch.distributions.uniform.Uniform - or type(prior) is utils.torchutils.BoxUniform - ): + if type(prior) is torch.distributions.uniform.Uniform or type(prior) is BoxUniform: cond_ubound = support.upper_bound[~mask] cond_lbound = support.lower_bound[~mask] within_support = torch.logical_and( From ab20b6b25bfb8670c95d1a13c21300cf91705069 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Thu, 13 May 2021 22:39:15 +0200 Subject: [PATCH 14/20] small fix. --- sbi/inference/posteriors/direct_posterior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 4ee52e69c..dc3170a89 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -474,9 +474,9 @@ def log_prob_conditional( """ if type(self.net._distribution) == mdn: - logits, means, precfs, sumlogdiag = self.extract_and_transform_mog(x) - logits, means, precfs, sumlogdiag = self.condition_mog( - condition, dims_to_evaluate, logits, means, precfs + logits, means, precfs, sumlogdiag = extract_and_transform_mog(self, x) + logits, means, precfs, sumlogdiag = condition_mog( + self._prior, condition, dims_to_evaluate, logits, means, precfs ) batch_size, dim = theta.shape From 794a74cdc9e3bbaf2883bc73e641dd21dfde9b58 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Sun, 13 Jun 2021 15:00:02 +0200 Subject: [PATCH 15/20] BUGFIX import of DirectPosterior Type caused Problems. --- sbi/utils/conditional_density.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index 91b7d4518..b61f9822d 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -1,14 +1,13 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . - -from sbi.inference.posteriors.direct_posterior import DirectPosterior from typing import Any, Callable, List, Optional, Tuple, Union from warnings import warn import torch from torch import Tensor +from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn from sbi.utils.torchutils import ensure_theta_batched from sbi.utils.torchutils import BoxUniform @@ -326,7 +325,7 @@ def _normalize_probs(probs: Tensor, limits: Tensor) -> Tensor: def extract_and_transform_mog( - posterior: DirectPosterior, context: Tensor = None + posterior: "DirectPosterior", context: Tensor = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Extracts the Mixture of Gaussians (MoG) parameters from an MDN based DirectPosterior at either the default x or input x. @@ -349,7 +348,7 @@ def extract_and_transform_mog( # extract and rescale means, mixture componenets and covariances nn = posterior.net - dist = posterior_net._distribution + dist = nn._distribution if context == None: encoded_x = nn._embedding_net(posterior.default_x) From 293418f0126c1b6e3413baf44cfbe56d67c1430c Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Mon, 14 Jun 2021 01:13:39 +0200 Subject: [PATCH 16/20] Added unittest. Works, but not sure how to integrate it properly. --- tests/testing_analytic_mdn_conditioning.py | 83 ++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/testing_analytic_mdn_conditioning.py diff --git a/tests/testing_analytic_mdn_conditioning.py b/tests/testing_analytic_mdn_conditioning.py new file mode 100644 index 000000000..591282fca --- /dev/null +++ b/tests/testing_analytic_mdn_conditioning.py @@ -0,0 +1,83 @@ +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + + +from sbi import utils +from tests.sbiutils_test import conditional_of_mvn +from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi + +from sbi.simulators.linear_gaussian import ( + linear_gaussian, + true_posterior_linear_gaussian_mvn_prior, +) + +from tests.test_utils import check_c2st + + +def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): + """Test whether the conditional density infered from MDN parameters of a `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and conditions on the last m values to generate a conditional. + + Gaussian prior used for easier ground truthing of conditional posterior. + + Args: + num_dim: Dimensionality of the MVM. + cond_dim: Dimensionality of the condition. + """ + + assert ( + num_dim > cond_dim + ), "The number of dimensions needs to be greater than that of the condition!" + + x_o = zeros(1, num_dim) + num_samples = 1000 + num_simulations = 2500 + condition = 0.1 * ones(1, num_dim) + + dims = list(range(num_dim)) + dims2sample = dims[-cond_dim:] + dims2condition = dims[:-cond_dim] + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + joint_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + joint_cov = joint_posterior.covariance_matrix + joint_mean = joint_posterior.loc + + joint_samples = joint_posterior.sample((num_samples,)) + + conditional_mean, conditional_cov = conditional_of_mvn( + joint_mean, joint_cov, condition[0, dims2condition] + ) + conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) + + conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + simulator, prior = prepare_for_sbi(simulator, prior) + inference = SNPE(prior, show_progress_bars=False, density_estimator="mdn") + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + _ = inference.append_simulations(theta, x).train(training_batch_size=100) + posterior = inference.build_posterior().set_default_x(x_o) + + conditional_samples_sbi = posterior.sample_conditional( + (num_samples,), condition, dims2sample, x_o + ) + check_c2st( + conditional_samples_sbi, + conditional_samples_gt, + alg="analytic_mdn_conditioning_of_direct_posterior", + ) + From 449f1b7a027efb8d6e58b9c38666031a8b126741 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Mon, 14 Jun 2021 01:23:20 +0200 Subject: [PATCH 17/20] fixed docstring formatting. removed obsolete line. --- tests/testing_analytic_mdn_conditioning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/testing_analytic_mdn_conditioning.py b/tests/testing_analytic_mdn_conditioning.py index 591282fca..e04144825 100644 --- a/tests/testing_analytic_mdn_conditioning.py +++ b/tests/testing_analytic_mdn_conditioning.py @@ -15,7 +15,9 @@ def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): - """Test whether the conditional density infered from MDN parameters of a `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and conditions on the last m values to generate a conditional. + """Test whether the conditional density infered from MDN parameters of a + `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and + conditions on the last m values to generate a conditional. Gaussian prior used for easier ground truthing of conditional posterior. @@ -51,8 +53,6 @@ def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): joint_cov = joint_posterior.covariance_matrix joint_mean = joint_posterior.loc - joint_samples = joint_posterior.sample((num_samples,)) - conditional_mean, conditional_cov = conditional_of_mvn( joint_mean, joint_cov, condition[0, dims2condition] ) From 8d63816e2765e613b77ae652e017d5106c73e244 Mon Sep 17 00:00:00 2001 From: jonasbeck Date: Tue, 6 Jul 2021 18:12:38 +0200 Subject: [PATCH 18/20] ran black, fixed formating, moved test to linearGaussian_snpe_test. --- sbi/inference/posteriors/direct_posterior.py | 9 +- sbi/utils/conditional_density.py | 2 +- tests/linearGaussian_snpe_test.py | 120 +++++++++++++------ tests/testing_analytic_mdn_conditioning.py | 83 ------------- 4 files changed, 89 insertions(+), 125 deletions(-) delete mode 100644 tests/testing_analytic_mdn_conditioning.py diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index dc3170a89..e9df473ff 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -427,7 +427,8 @@ def sample_conditional( # Currently difficult to integrate `sample_posterior_within_prior` warn( - "Sampling MoG analytically. Some of the samples might not be within the prior support!" + "Sampling MoG analytically. " + "Some of the samples might not be within the prior support!" ) samples = mdn.sample_mog(num_samples, logits, means, precfs) return samples.detach().reshape((*sample_shape, -1)) @@ -451,9 +452,11 @@ def log_prob_conditional( dims_to_evaluate: List[int], x: Optional[Tensor] = None, ) -> Tensor: - """Evaluates the conditional posterior probability of a MDN at a context x for a value theta given a condition. + """Evaluates the conditional posterior probability of a MDN at a context x for + a value theta given a condition. - This function only works for MDN based posteriors, becuase evaluation is done analytically. For all other density estimators a `NotImplementedError` will be + This function only works for MDN based posteriors, becuase evaluation is done + analytically. For all other density estimators a `NotImplementedError` will be raised! Args: diff --git a/sbi/utils/conditional_density.py b/sbi/utils/conditional_density.py index b61f9822d..5911f56cf 100644 --- a/sbi/utils/conditional_density.py +++ b/sbi/utils/conditional_density.py @@ -381,7 +381,7 @@ def condition_mog( means: Tensor, precfs: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Finds the conditional distribution p(X|Y) for a GMM. + """Finds the conditional distribution p(X|Y) for a MoG. Args: prior: Prior Distribution. Used to check if condition within support. diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index bf1648e81..0b912efed 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -12,7 +12,7 @@ from sbi import analysis as analysis from sbi import utils as utils -from sbi.inference import SNPE_B, SNPE_C, prepare_for_sbi, simulate_for_sbi +from sbi.inference import SNPE, SNPE_B, SNPE_C, prepare_for_sbi, simulate_for_sbi from sbi.simulators.linear_gaussian import ( linear_gaussian, samples_true_posterior_linear_gaussian_mvn_prior_different_dims, @@ -26,20 +26,13 @@ get_prob_outside_uniform_prior, ) +from tests.sbiutils_test import conditional_of_mvn + @pytest.mark.parametrize( - "num_dim, prior_str", - ( - (2, "gaussian"), - (2, "uniform"), - (1, "gaussian"), - ), + "num_dim, prior_str", ((2, "gaussian"), (2, "uniform"), (1, "gaussian")) ) -def test_c2st_snpe_on_linearGaussian( - num_dim: int, - prior_str: str, - set_seed, -): +def test_c2st_snpe_on_linearGaussian(num_dim: int, prior_str: str, set_seed): """Test whether SNPE C infers well a simple example with available ground truth. Args: @@ -71,10 +64,7 @@ def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( - prior, - show_progress_bars=False, - ) + inference = SNPE_C(prior, show_progress_bars=False) theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1000) _ = inference.append_simulations(theta, x).train(training_batch_size=100) @@ -176,7 +166,9 @@ def simulator(theta): show_progress_bars=False, ) - theta, x = simulate_for_sbi(simulator, prior, 2000, simulation_batch_size=1) # type: ignore + theta, x = simulate_for_sbi( + simulator, prior, 2000, simulation_batch_size=1 + ) # type: ignore inference = inference.append_simulations(theta, x) _ = inference.train(max_num_epochs=10) # Test whether we can stop and resume. _ = inference.train(resume_training=True) @@ -195,8 +187,7 @@ def simulator(theta): pytest.param( "snpe_b", marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="""SNPE-B not implemented""", + raises=NotImplementedError, reason="""SNPE-B not implemented""" ), ), "snpe_c", @@ -366,11 +357,7 @@ def simulator(theta): net = utils.posterior_nn("maf", hidden_features=20) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( - prior, - density_estimator=net, - show_progress_bars=False, - ) + inference = SNPE_C(prior, density_estimator=net, show_progress_bars=False) # We need a pretty big dataset to properly model the bimodality. theta, x = simulate_for_sbi(simulator, prior, 10000) @@ -396,16 +383,8 @@ def simulator(theta): density = gaussian_kde(cond_samples.numpy().T, bw_method="scott") X, Y = np.meshgrid( - np.linspace( - limits[0][0], - limits[0][1], - 50, - ), - np.linspace( - limits[1][0], - limits[1][1], - 50, - ), + np.linspace(limits[0][0], limits[0][1], 50), + np.linspace(limits[1][0], limits[1][1], 50), ) positions = np.vstack([X.ravel(), Y.ravel()]) sample_kde_grid = np.reshape(density(positions).T, X.shape) @@ -429,6 +408,74 @@ def simulator(theta): assert max_err < 0.0025 +def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): + """Test whether the conditional density infered from MDN parameters of a + `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and + conditions on the last m values to generate a conditional. + + Gaussian prior used for easier ground truthing of conditional posterior. + + Args: + num_dim: Dimensionality of the MVM. + cond_dim: Dimensionality of the condition. + """ + + assert ( + num_dim > cond_dim + ), "The number of dimensions needs to be greater than that of the condition!" + + x_o = zeros(1, num_dim) + num_samples = 1000 + num_simulations = 2500 + condition = 0.1 * ones(1, num_dim) + + dims = list(range(num_dim)) + dims2sample = dims[-cond_dim:] + dims2condition = dims[:-cond_dim] + + # likelihood_mean will be likelihood_shift+theta + likelihood_shift = -1.0 * ones(num_dim) + likelihood_cov = 0.3 * eye(num_dim) + + prior_mean = zeros(num_dim) + prior_cov = eye(num_dim) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + + joint_posterior = true_posterior_linear_gaussian_mvn_prior( + x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov + ) + joint_cov = joint_posterior.covariance_matrix + joint_mean = joint_posterior.loc + + conditional_mean, conditional_cov = conditional_of_mvn( + joint_mean, joint_cov, condition[0, dims2condition] + ) + conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) + + conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) + + def simulator(theta): + return linear_gaussian(theta, likelihood_shift, likelihood_cov) + + simulator, prior = prepare_for_sbi(simulator, prior) + inference = SNPE(prior, show_progress_bars=False, density_estimator="mdn") + + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=1000 + ) + _ = inference.append_simulations(theta, x).train(training_batch_size=100) + posterior = inference.build_posterior().set_default_x(x_o) + + conditional_samples_sbi = posterior.sample_conditional( + (num_samples,), condition, dims2sample, x_o + ) + check_c2st( + conditional_samples_sbi, + conditional_samples_gt, + alg="analytic_mdn_conditioning_of_direct_posterior", + ) + + def example_posterior(): """Return an inferred `NeuralPosterior` for interactive examination.""" num_dim = 2 @@ -446,10 +493,7 @@ def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE_C( - prior, - show_progress_bars=False, - ) + inference = SNPE_C(prior, show_progress_bars=False) theta, x = simulate_for_sbi( simulator, prior, 1000, simulation_batch_size=10, num_workers=6 ) diff --git a/tests/testing_analytic_mdn_conditioning.py b/tests/testing_analytic_mdn_conditioning.py deleted file mode 100644 index e04144825..000000000 --- a/tests/testing_analytic_mdn_conditioning.py +++ /dev/null @@ -1,83 +0,0 @@ -from torch import eye, ones, zeros -from torch.distributions import MultivariateNormal - - -from sbi import utils -from tests.sbiutils_test import conditional_of_mvn -from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi - -from sbi.simulators.linear_gaussian import ( - linear_gaussian, - true_posterior_linear_gaussian_mvn_prior, -) - -from tests.test_utils import check_c2st - - -def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): - """Test whether the conditional density infered from MDN parameters of a - `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and - conditions on the last m values to generate a conditional. - - Gaussian prior used for easier ground truthing of conditional posterior. - - Args: - num_dim: Dimensionality of the MVM. - cond_dim: Dimensionality of the condition. - """ - - assert ( - num_dim > cond_dim - ), "The number of dimensions needs to be greater than that of the condition!" - - x_o = zeros(1, num_dim) - num_samples = 1000 - num_simulations = 2500 - condition = 0.1 * ones(1, num_dim) - - dims = list(range(num_dim)) - dims2sample = dims[-cond_dim:] - dims2condition = dims[:-cond_dim] - - # likelihood_mean will be likelihood_shift+theta - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) - - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) - - joint_posterior = true_posterior_linear_gaussian_mvn_prior( - x_o[0], likelihood_shift, likelihood_cov, prior_mean, prior_cov - ) - joint_cov = joint_posterior.covariance_matrix - joint_mean = joint_posterior.loc - - conditional_mean, conditional_cov = conditional_of_mvn( - joint_mean, joint_cov, condition[0, dims2condition] - ) - conditional_dist_gt = MultivariateNormal(conditional_mean, conditional_cov) - - conditional_samples_gt = conditional_dist_gt.sample((num_samples,)) - - def simulator(theta): - return linear_gaussian(theta, likelihood_shift, likelihood_cov) - - simulator, prior = prepare_for_sbi(simulator, prior) - inference = SNPE(prior, show_progress_bars=False, density_estimator="mdn") - - theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=1000 - ) - _ = inference.append_simulations(theta, x).train(training_batch_size=100) - posterior = inference.build_posterior().set_default_x(x_o) - - conditional_samples_sbi = posterior.sample_conditional( - (num_samples,), condition, dims2sample, x_o - ) - check_c2st( - conditional_samples_sbi, - conditional_samples_gt, - alg="analytic_mdn_conditioning_of_direct_posterior", - ) - From 43f0c513a6a6f0a8fe2546a1d65d29601e9bf03f Mon Sep 17 00:00:00 2001 From: jnsbck Date: Sun, 18 Jul 2021 20:47:47 +0200 Subject: [PATCH 19/20] forgot arg for snpe_method in test_mdn_conditional_density. --- tests/linearGaussian_snpe_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 44eb76136..3d09113eb 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -438,7 +438,7 @@ def simulator(theta): @pytest.mark.slow @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) -def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): +def test_mdn_conditional_density(snpe_method: type, num_dim: int = 3, cond_dim: int = 1): """Test whether the conditional density infered from MDN parameters of a `DirectPosterior` matches analytical results for MVN. This uses a n-D joint and conditions on the last m values to generate a conditional. From f9894431f39816a754babf8d88ca0aae8642075b Mon Sep 17 00:00:00 2001 From: jnsbck Date: Sun, 18 Jul 2021 21:27:41 +0200 Subject: [PATCH 20/20] fixed line that was accidently overwritten during resolve of conflicts. --- tests/linearGaussian_snpe_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 3d09113eb..943ec269f 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -77,6 +77,8 @@ def test_c2st_snpe_on_linearGaussian( simulator, prior = prepare_for_sbi( lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior ) + + inference = snpe_method(prior, show_progress_bars=False) theta, x = simulate_for_sbi( simulator, prior, num_simulations, simulation_batch_size=1000