Skip to content

Commit

Permalink
Update zluda.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Sep 19, 2024
1 parent d8b7380 commit 04fb8b3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
20 changes: 19 additions & 1 deletion modules/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ctypes
import shutil
import subprocess
import importlib.metadata
from typing import Union, List


Expand Down Expand Up @@ -81,6 +82,17 @@ def get_gfx_version(self) -> Union[str, None]:
return None


def get_version_torch() -> Union[str, None]:
version_ = None
try:
version_ = importlib.metadata.version("torch")
except importlib.metadata.PackageNotFoundError:
return None
if "+rocm" not in version_: # unofficial build, non-rocm torch.
return None
return version_.split("+rocm")[1]


if sys.platform == "win32":
def find() -> Union[str, None]:
hip_path = shutil.which("hipconfig")
Expand Down Expand Up @@ -161,6 +173,8 @@ def load_hsa_runtime() -> None:
try:
# Preload stdc++ library. This will ignore Anaconda stdc++ library.
load_library_global("/lib/x86_64-linux-gnu/libstdc++.so.6")
# Use tcmalloc if possible.
load_library_global("/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4")
except OSError:
pass
# Preload HSA Runtime library.
Expand All @@ -173,10 +187,14 @@ def set_blaslt_enabled(enabled: bool) -> None:
else:
os.environ["TORCH_BLAS_PREFER_HIPBLASLT"] = "0"

is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', None) is not None
def get_blaslt_enabled() -> bool:
return bool(int(os.environ.get("TORCH_BLAS_PREFER_HIPBLASLT", "1")))

is_wsl: bool = os.environ.get('WSL_DISTRO_NAME', 'unknown' if spawn('wslpath -w /') else None) is not None
path = find()
is_installed = False
version = None
version_torch = get_version_torch()
if path is not None:
is_installed = True
version = get_version()
20 changes: 6 additions & 14 deletions modules/zluda_hijacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
import torch
from modules import rocm


_topk = torch.topk
Expand All @@ -10,19 +9,12 @@ def topk(tensor: torch.Tensor, *args, **kwargs):
return torch.return_types.topk((values.to(device), indices.to(device),))


def _join_rocm_home(*paths) -> str:
from torch.utils.cpp_extension import ROCM_HOME
return os.path.join(ROCM_HOME, *paths)
def jit_script(f, *_, **__): # experiment / provide dummy graph
f.graph = torch._C.Graph() # pylint: disable=protected-access
return f


def do_hijack():
torch.version.hip = "5.7"
torch.version.hip = rocm.version
torch.topk = topk
platform = sys.platform
sys.platform = ""
from torch.utils import cpp_extension
sys.platform = platform
cpp_extension.IS_WINDOWS = platform == "win32"
cpp_extension.IS_MACOS = False
cpp_extension.IS_LINUX = platform.startswith('linux')
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
torch.jit.script = jit_script
26 changes: 18 additions & 8 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import ctypes
import shutil
import zipfile
Expand All @@ -23,14 +24,7 @@ def install(zluda_path: os.PathLike) -> None:
if os.path.exists(zluda_path):
return

default_hash = None
if rocm.version == "6.1":
default_hash = 'd7714d84c0c13bbf816eaaac32693e4e75e58a87'
elif rocm.version == "5.7":
default_hash = '11cc5844514f93161e0e74387f04e2c537705a82'
else:
raise RuntimeError(f'Unsupported HIP SDK version: {rocm.version}')
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{os.environ.get("ZLUDA_HASH", default_hash)}/ZLUDA-windows-amd64.zip', '_zluda')
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{os.environ.get("ZLUDA_HASH", "c0804ca624963aab420cb418412b1c7fbae3454b")}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda')
with zipfile.ZipFile('_zluda', 'r') as archive:
infos = archive.infolist()
for info in infos:
Expand All @@ -55,9 +49,25 @@ def make_copy(zluda_path: os.PathLike) -> None:


def load(zluda_path: os.PathLike) -> None:
os.environ["ZLUDA_COMGR_LOG_LEVEL"] = "1"

for v in HIPSDK_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', v))
for v in ZLUDA_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
for v in DLL_MAPPING.values():
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))

def conceal():
import torch # pylint: disable=unused-import
platform = sys.platform
sys.platform = ""
from torch.utils import cpp_extension
sys.platform = platform
cpp_extension.IS_WINDOWS = platform == "win32"
cpp_extension.IS_MACOS = False
cpp_extension.IS_LINUX = platform.startswith('linux')
def _join_rocm_home(*paths) -> str:
return os.path.join(cpp_extension.ROCM_HOME, *paths)
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
rocm.conceal = conceal

0 comments on commit 04fb8b3

Please sign in to comment.