Skip to content

Commit

Permalink
ruff auto fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Apr 11, 2024
1 parent 2fdbf25 commit f9405af
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
15 changes: 6 additions & 9 deletions pyknos/mdn/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Taken from http://github.com/conormdurkan/lfi. See there for copyright.
"""

import warnings
from typing import Optional, Tuple

Expand Down Expand Up @@ -161,9 +162,9 @@ def get_mixture_components(
torch.transpose(precision_factors, 2, 3), precision_factors
)
# Add epsilon to diagnonal for numerical stability.
precisions[
..., torch.arange(self._features), torch.arange(self._features)
] += self._epsilon
precisions[..., torch.arange(self._features), torch.arange(self._features)] += (
self._epsilon
)

# The sum of the log diagonal of A is used in the likelihood calculation.
sumlogdiag = torch.sum(torch.log(diagonal), dim=-1)
Expand Down Expand Up @@ -288,9 +289,7 @@ def sample_mog(
# Choose num_samples mixture components per example in the batch.
choices = torch.multinomial(
coefficients, num_samples=num_samples, replacement=True
).view(
-1
) # [batch_size, num_samples]
).view(-1) # [batch_size, num_samples]

# Create dummy index for indexing means and precision factors.
ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples)
Expand Down Expand Up @@ -338,9 +337,7 @@ def _initialize(self) -> None:
torch.exp(torch.tensor([1 - self._epsilon])) - 1
) * torch.ones(
self._num_components * self._features
) + self._epsilon * torch.randn(
self._num_components * self._features
)
) + self._epsilon * torch.randn(self._num_components * self._features)

# Initialize off-diagonal of precision factors to zero.
self._upper_layer.weight.data = self._epsilon * torch.randn(
Expand Down
5 changes: 1 addition & 4 deletions tests/mdn_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
import torch
from torch import Tensor, eye, ones, zeros
from torch import Tensor, eye
import torch.nn as nn
from torch.distributions import MultivariateNormal

from pyknos.mdn.mdn import MultivariateGaussianMDN

Expand All @@ -12,7 +11,6 @@ def linear_gaussian(
likelihood_shift: Tensor,
likelihood_cov: Tensor,
) -> Tensor:

chol_factor = torch.cholesky(likelihood_cov)

return likelihood_shift + theta + torch.mm(chol_factor, torch.randn_like(theta).T).T
Expand All @@ -26,7 +24,6 @@ def linear_gaussian(
def test_mdn_for_diff_dimension_data(
dim: int, device: str, hidden_features: int = 50, num_components: int = 10
) -> None:

if device == "cuda:0" and not torch.cuda.is_available():
pass
else:
Expand Down

0 comments on commit f9405af

Please sign in to comment.