Skip to content

Commit

Permalink
Additional fix for wheels on windows (include dirs)
Browse files Browse the repository at this point in the history
ghstack-source-id: 5d5e50ca22a2207675407a3ef1339ce69f578d28
Pull Request resolved: https://github.com/fairinternal/xformers/pull/422

__original_commit__ = fairinternal/xformers@56c23ff4c1d503a607ae7fa6931505953af2d19d
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 12, 2023
1 parent 8fd184e commit ad240e2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
+ get_extra_nvcc_flags_for_build_type(),
},
include_dirs=[
Path(flash_root) / "csrc" / "flash_attn",
Path(flash_root) / "csrc" / "flash_attn" / "src",
# Path(flash_root) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
Path(this_dir) / "third_party" / "cutlass" / "include",
p.absolute()
for p in [
Path(flash_root) / "csrc" / "flash_attn",
Path(flash_root) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "third_party" / "cutlass" / "include",
]
],
)
]
Expand Down Expand Up @@ -254,7 +256,7 @@ def get_extensions():
extension(
"xformers._C",
sorted(sources),
include_dirs=include_dirs,
include_dirs=[os.path.abspath(p) for p in include_dirs],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
Expand Down

0 comments on commit ad240e2

Please sign in to comment.