Skip to content

Commit

Permalink
[DirectML] Fix pow_.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Apr 17, 2024
1 parent c113734 commit 8ee05a8
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions modules/dml/hijack/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from modules.sd_hijack_utils import CondFunc

CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'privateuseone')

# https://github.com/microsoft/DirectML/issues/400
CondFunc('torch.Tensor.new', lambda orig, self, *args, **kwargs: orig(self.cpu(), *args, **kwargs).to(self.device), lambda orig, self, *args, **kwargs: torch.dml.is_directml_device(self.device))


_lerp = torch.lerp
def lerp(*args, **kwargs) -> torch.Tensor:
rep = None
Expand Down Expand Up @@ -41,10 +41,11 @@ def lerp(*args, **kwargs) -> torch.Tensor:
return _lerp(*args, **kwargs)
torch.lerp = lerp


# https://github.com/lshqqytiger/stable-diffusion-webui-directml/issues/436
_pow_ = torch.Tensor.pow_
def pow_(self: torch.Tensor, *args, **kwargs):
if self.dtype == torch.float64:
return _pow_(self.cpu(), *args, **kwargs).to(self.device)
return _pow_(self, *args, **kwargs)
torch.Tensor.pow_ = _pow_
torch.Tensor.pow_ = pow_

0 comments on commit 8ee05a8

Please sign in to comment.