diff --git a/modules/dml/hijack/torch.py b/modules/dml/hijack/torch.py index f70479d0da0..ff5ee4df8e5 100644 --- a/modules/dml/hijack/torch.py +++ b/modules/dml/hijack/torch.py @@ -3,10 +3,10 @@ from modules.sd_hijack_utils import CondFunc CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'privateuseone') - # https://github.com/microsoft/DirectML/issues/400 CondFunc('torch.Tensor.new', lambda orig, self, *args, **kwargs: orig(self.cpu(), *args, **kwargs).to(self.device), lambda orig, self, *args, **kwargs: torch.dml.is_directml_device(self.device)) + _lerp = torch.lerp def lerp(*args, **kwargs) -> torch.Tensor: rep = None @@ -41,10 +41,11 @@ def lerp(*args, **kwargs) -> torch.Tensor: return _lerp(*args, **kwargs) torch.lerp = lerp + # https://github.com/lshqqytiger/stable-diffusion-webui-directml/issues/436 _pow_ = torch.Tensor.pow_ def pow_(self: torch.Tensor, *args, **kwargs): if self.dtype == torch.float64: return _pow_(self.cpu(), *args, **kwargs).to(self.device) return _pow_(self, *args, **kwargs) -torch.Tensor.pow_ = _pow_ +torch.Tensor.pow_ = pow_