From f9405afc27a94907daaabb99b3eb0b37632156e0 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Thu, 11 Apr 2024 08:12:20 +0200 Subject: [PATCH] ruff auto fixes. --- pyknos/mdn/mdn.py | 15 ++++++--------- tests/mdn_test.py | 5 +---- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pyknos/mdn/mdn.py b/pyknos/mdn/mdn.py index 38d12e3..04dcb28 100644 --- a/pyknos/mdn/mdn.py +++ b/pyknos/mdn/mdn.py @@ -4,6 +4,7 @@ Taken from http://github.com/conormdurkan/lfi. See there for copyright. """ + import warnings from typing import Optional, Tuple @@ -161,9 +162,9 @@ def get_mixture_components( torch.transpose(precision_factors, 2, 3), precision_factors ) # Add epsilon to diagnonal for numerical stability. - precisions[ - ..., torch.arange(self._features), torch.arange(self._features) - ] += self._epsilon + precisions[..., torch.arange(self._features), torch.arange(self._features)] += ( + self._epsilon + ) # The sum of the log diagonal of A is used in the likelihood calculation. sumlogdiag = torch.sum(torch.log(diagonal), dim=-1) @@ -288,9 +289,7 @@ def sample_mog( # Choose num_samples mixture components per example in the batch. choices = torch.multinomial( coefficients, num_samples=num_samples, replacement=True - ).view( - -1 - ) # [batch_size, num_samples] + ).view(-1) # [batch_size, num_samples] # Create dummy index for indexing means and precision factors. ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples) @@ -338,9 +337,7 @@ def _initialize(self) -> None: torch.exp(torch.tensor([1 - self._epsilon])) - 1 ) * torch.ones( self._num_components * self._features - ) + self._epsilon * torch.randn( - self._num_components * self._features - ) + ) + self._epsilon * torch.randn(self._num_components * self._features) # Initialize off-diagonal of precision factors to zero. self._upper_layer.weight.data = self._epsilon * torch.randn( diff --git a/tests/mdn_test.py b/tests/mdn_test.py index 2e8d2b3..49eb148 100644 --- a/tests/mdn_test.py +++ b/tests/mdn_test.py @@ -1,8 +1,7 @@ import pytest import torch -from torch import Tensor, eye, ones, zeros +from torch import Tensor, eye import torch.nn as nn -from torch.distributions import MultivariateNormal from pyknos.mdn.mdn import MultivariateGaussianMDN @@ -12,7 +11,6 @@ def linear_gaussian( likelihood_shift: Tensor, likelihood_cov: Tensor, ) -> Tensor: - chol_factor = torch.cholesky(likelihood_cov) return likelihood_shift + theta + torch.mm(chol_factor, torch.randn_like(theta).T).T @@ -26,7 +24,6 @@ def linear_gaussian( def test_mdn_for_diff_dimension_data( dim: int, device: str, hidden_features: int = 50, num_components: int = 10 ) -> None: - if device == "cuda:0" and not torch.cuda.is_available(): pass else: