Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing cunumeric.random.BitGenerator #254

Merged
merged 96 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
d8537be
Implementing BitGenerator - GPU part - creation
fduguet-nv Mar 21, 2022
6f8437d
Plugging destroy
fduguet-nv Mar 23, 2022
e997e8d
skip ahead for GPU BitGenerator random_raw
fduguet-nv Mar 25, 2022
e805210
Refectoring test
fduguet-nv Mar 28, 2022
536c871
Piping done for BitGenerator.random_raw
fduguet-nv Mar 28, 2022
14304e8
LGT-260 -- implemented random_raw for CUDA
fduguet-nv Mar 29, 2022
75fea76
Implementing BitGenerator - GPU part - creation
fduguet-nv Mar 21, 2022
c436346
Plugging destroy
fduguet-nv Mar 23, 2022
bec87f7
skip ahead for GPU BitGenerator random_raw
fduguet-nv Mar 25, 2022
c3b1000
Refectoring test
fduguet-nv Mar 28, 2022
7133cf8
Piping done for BitGenerator.random_raw
fduguet-nv Mar 28, 2022
55c2d4c
LGT-260 -- implemented random_raw for CUDA
fduguet-nv Mar 29, 2022
fdb284b
Updating pre-commit
fduguet-nv Mar 29, 2022
80e2c66
Updating pre-commit
fduguet-nv Mar 29, 2022
3ffbd6c
Adding Multi-GPU support
fduguet-nv Mar 31, 2022
4b7ee28
Merge branch 'nv-legate:branch-22.05' into fduguet
fduguet-nv Mar 31, 2022
c2e71e0
Refactoring
fduguet-nv Apr 1, 2022
601dbbb
refactoring
fduguet-nv Apr 1, 2022
006cd04
CPU support - issues remain on multi-cpu mapper
fduguet-nv Apr 1, 2022
f1ae1be
Merge branch 'fduguet' of https://github.com/fduguet-nv/cunumeric int…
fduguet-nv Apr 1, 2022
e8be760
Attempts to support multi-dimension
fduguet-nv Apr 1, 2022
a6944a9
Finalizing multi-cpu implementation of BitGenerator
fduguet-nv Apr 1, 2022
3bed169
merging
fduguet-nv Apr 1, 2022
4ad2adb
Repro behavior of BitGenerator
fduguet-nv Apr 1, 2022
65d38c4
Merge branch 'nv-legate:branch-22.05' into fduguet
fduguet-nv Apr 8, 2022
5f283f2
Adding random in tests
fduguet-nv Apr 8, 2022
7b8702e
Merge branch 'fduguet' of https://github.com/fduguet-nv/cunumeric int…
fduguet-nv Apr 8, 2022
b7d103a
Fixes from code review - part 1
fduguet-nv Apr 9, 2022
b4891dc
use of logger
fduguet-nv Apr 14, 2022
7fc1643
Enums improvement
fduguet-nv Apr 15, 2022
c405020
Merge branch 'nv-legate:branch-22.05' into fduguet
fduguet-nv Apr 15, 2022
bd504af
Merge branch 'fduguet' of https://github.com/fduguet-nv/cunumeric int…
fduguet-nv Apr 15, 2022
31f44cc
More on code review
fduguet-nv Apr 19, 2022
d36140a
Merge
fduguet-nv Apr 19, 2022
7f3b9e5
Merge branch 'nv-legate:branch-22.05' into fduguet
fduguet-nv Apr 20, 2022
f373e36
code review
fduguet-nv Apr 21, 2022
1e8c615
Removing task in python destructor
fduguet-nv Apr 21, 2022
9f335ed
Merge branch 'nv-legate:branch-22.05' into fduguet
fduguet-nv Apr 21, 2022
f0a9f12
Merge branch 'fduguet' of https://github.com/fduguet-nv/cunumeric int…
fduguet-nv Apr 21, 2022
16b5bf2
Using Shape
fduguet-nv Apr 21, 2022
7b216e4
Lazy destroy also at destroy
fduguet-nv Apr 21, 2022
23ecf09
Using create_buffer for temporary buffer creation
fduguet-nv May 18, 2022
c956669
Merging with branch 22.05
fduguet-nv May 18, 2022
de7522a
Fixing bugs following merge
fduguet-nv May 18, 2022
ce39f7f
Merge remote-tracking branch 'origin/branch-22.05' into fduguet
fduguet-nv May 24, 2022
b1abb4e
Adding alternate lazy init implementation
fduguet-nv May 24, 2022
392c9db
Adding force create and force destroy
fduguet-nv May 24, 2022
57ef74b
Removing constraint in mapper
fduguet-nv May 24, 2022
00c7557
Updating test
fduguet-nv May 25, 2022
4214b1c
Fixing Eager code branch
fduguet-nv May 25, 2022
565f52c
Updating license information and adding integer generator
fduguet-nv May 25, 2022
67ffae0
Merging with branch 22.07
fduguet-nv May 30, 2022
b0b250a
PR review
fduguet-nv Jun 1, 2022
5580280
Removing reference to removed omp file
fduguet-nv Jun 1, 2022
d77163a
Changes from PR comments
fduguet-nv Jun 20, 2022
0a93761
Merge
fduguet-nv Jun 20, 2022
ad58247
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jun 28, 2022
3727c75
Adding generator.random implementation
fduguet-nv Jun 28, 2022
2cda5c4
Adding generator.lognormal implementation
fduguet-nv Jun 28, 2022
60b688c
Adding generator.normal implementation
fduguet-nv Jun 28, 2022
6218e17
Adding generator.poisson implementation
fduguet-nv Jun 28, 2022
9d23034
Refectoring randutil to allow host-only compilation without CUDA enabled
fduguet-nv Jun 29, 2022
f9e928a
Moving distributions to HOST only capable
fduguet-nv Jun 29, 2022
73e9ac1
Minor refactoring to split compilation of further distributions
fduguet-nv Jun 29, 2022
5f3e9c6
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jun 29, 2022
4cd1531
Merging
fduguet-nv Jun 30, 2022
186ca84
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jul 4, 2022
a2787ec
Fixing openmp test failure
fduguet-nv Jul 4, 2022
e87f6db
Adding random.generator.exponential
fduguet-nv Jul 4, 2022
69075cd
Adding random.generator.gumbel
fduguet-nv Jul 4, 2022
666484d
Splitting distributions in several files
fduguet-nv Jul 4, 2022
19e47ca
Actually applying the WAR
fduguet-nv Jul 5, 2022
d5ccf56
Adding random.generator.laplace
fduguet-nv Jul 5, 2022
4bd8cb0
Adding random.generator.logistic
fduguet-nv Jul 5, 2022
c1546e0
Adding random.generator.pareto
fduguet-nv Jul 5, 2022
370e94c
Adding random.generator.power
fduguet-nv Jul 5, 2022
192fb4a
Adding random.generator.rayleigh
fduguet-nv Jul 5, 2022
4ccd7d2
Adding random.generator.standard_cauchy
fduguet-nv Jul 5, 2022
a54c3df
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jul 7, 2022
7fbdf60
Adding random.generatortriangular
fduguet-nv Jul 7, 2022
78a7c87
Adding random.generator.weibul
fduguet-nv Jul 8, 2022
1449e91
Adding random.generator.bytes
fduguet-nv Jul 8, 2022
90135c9
PR review
fduguet-nv Jul 15, 2022
7364d28
PR review
fduguet-nv Jul 15, 2022
cacee28
PR review
fduguet-nv Jul 15, 2022
37df124
PR review
fduguet-nv Jul 15, 2022
f530847
merge
fduguet-nv Jul 15, 2022
2450038
fixing glitch
fduguet-nv Jul 15, 2022
ee2c203
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jul 18, 2022
1a7ff8c
More work on PR
fduguet-nv Jul 20, 2022
a40a94f
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jul 20, 2022
c3c7d7e
Fixing CURANDAPI
fduguet-nv Jul 20, 2022
6dfb548
Moving logger symbol from .cu to .cc
fduguet-nv Jul 21, 2022
185c354
removing destroy
fduguet-nv Jul 25, 2022
997e848
Merge
fduguet-nv Jul 25, 2022
68b3a01
Merge remote-tracking branch 'origin/branch-22.07' into fduguet
fduguet-nv Jul 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class CuNumericOpCode(IntEnum):
BINARY_OP = _cunumeric.CUNUMERIC_BINARY_OP
BINARY_RED = _cunumeric.CUNUMERIC_BINARY_RED
BINCOUNT = _cunumeric.CUNUMERIC_BINCOUNT
BITGENERATOR = _cunumeric.CUNUMERIC_BITGENERATOR
CHOOSE = _cunumeric.CUNUMERIC_CHOOSE
CONTRACT = _cunumeric.CUNUMERIC_CONTRACT
CONVERT = _cunumeric.CUNUMERIC_CONVERT
Expand Down
18 changes: 18 additions & 0 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from .thunk import NumPyThunk
from .utils import get_arg_value_dtype

# from defusedxml import NotSupportedError
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved


def _complex_field_dtype(dtype):
if dtype == np.complex64:
Expand Down Expand Up @@ -1326,6 +1328,22 @@ def nonzero(self):
task.execute()
return results

def bitgenerator_random_raw(self, handle):
task = self.context.create_task(CuNumericOpCode.BITGENERATOR)
task.add_output(self.base)
task.add_scalar_arg(3, ty.int32) # OP_RAND_RAW
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
task.add_scalar_arg(handle, ty.uint32)
# TODO: check if no function does this
totalsize = 1
for sz in self.shape:
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
totalsize = totalsize * sz
task.add_scalar_arg(totalsize, ty.uint64)
# strides
task.add_scalar_arg(self.compute_strides(self.shape), (ty.int64,))

task.add_broadcast(self.base, axes=tuple(range(1, self.ndim)))
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
task.execute()

def random(self, gen_code, args=[]):
task = self.context.create_task(CuNumericOpCode.RAND)

Expand Down
14 changes: 14 additions & 0 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,20 @@ def sort(self, rhs, argsort=False, axis=-1, kind="quicksort", order=None):
else:
self.array = np.sort(rhs.array, axis, kind, order)

def bitgenerator_random_raw(self, handle):
if self.deferred is not None:
print("eager.py - bitgenerator_random_raw - deferred is not None")
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
self.deferred.bitgenerator_random_raw(handle)
else:
print("eager.py - bitgenerator_random_raw - deferred is None")
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
if self.array.size == 1:
self.array.fill(np.random.rand())
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
else:
a = np.random.randint(
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
0, 2**32 - 1, *(self.array.shape), dtype=self.array.dtype
)
self.array[:] = a[:]

def random_uniform(self):
if self.deferred is not None:
self.deferred.random_uniform()
Expand Down
87 changes: 87 additions & 0 deletions cunumeric/random/BitGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2021-2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
from cunumeric.array import ndarray
from cunumeric.runtime import runtime


class BitGenerator:
# see bitgenerator_util.h
OP_CREATE = 1
OP_DESTROY = 2
OP_RAND_RAW = 3
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved

# see bitgenerator_util.h
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT = 0
XORWOW = 1
MRG32K3A = 2
MTGP32 = 3
MT19937 = 4
PHILOX4_32_10 = 5

__slots__ = [
"handle", # handle to the runtime id
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
]

def __init__(self, seed=None, generatorType=DEFAULT):
if type(self) is BitGenerator:
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"BitGenerator is a base class and cannot be instantized"
)
self.handle = runtime.bitgenerator_create(generatorType)
if seed is not None:
runtime.bitgenerator_set_seed(self.handle, seed)

def __del__(self):
runtime.bitgenerator_destroy(self.handle)
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved

# when output is false => skip ahead
def random_raw(self, shape=None, output=True):
if shape is None:
raise NotImplementedError("Empty shape not implemented")
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(shape, tuple):
shape = (shape,)
if output:
res = ndarray(shape, dtype=np.dtype(np.uint32))
res._thunk.bitgenerator_random_raw(self.handle)
return res
else:
runtime.bitgenerator_random_raw(self.handle, shape)


class XORWOW(BitGenerator):
def __init__(self, seed=None):
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(seed, BitGenerator.XORWOW)


class MRG32k3a(BitGenerator):
def __init__(self, seed=None):
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(seed, BitGenerator.MRG32K3A)


class MTGP32(BitGenerator):
def __init__(self, seed=None):
super().__init__(seed, BitGenerator.MTGP32)


class MT19937(BitGenerator):
def __init__(self, seed=None):
super().__init__(seed, BitGenerator.MT19937)


class PHILOX4_32_10(BitGenerator):
def __init__(self, seed=None):
super().__init__(seed, BitGenerator.PHILOX4_32_10)
1 change: 1 addition & 0 deletions cunumeric/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy.random as _nprandom
from cunumeric.random.random import *
from cunumeric.coverage import clone_module
from cunumeric.random.BitGenerator import *
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved

clone_module(_nprandom, globals())

Expand Down
57 changes: 57 additions & 0 deletions cunumeric/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, legate_context):
self.legate_context = legate_context
self.legate_runtime = get_legate_runtime()
self.current_random_epoch = 0
self.current_random_bitgenid = 0
self.destroyed = False
self.api_calls = []

Expand Down Expand Up @@ -218,6 +219,62 @@ def create_scalar(self, array: memoryview, dtype, shape=None, wrap=False):
result = future
return result

def bitgenerator_create(self, generatorType):
task = self.legate_context.create_task(
CuNumericOpCode.BITGENERATOR,
manual=True,
launch_domain=Rect(lo=(0,), hi=(self.num_procs,)),
)
self.current_random_bitgenid = self.current_random_bitgenid + 1
task.add_scalar_arg(1, ty.int32) # OP_CREATE
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
task.add_scalar_arg(self.current_random_bitgenid, ty.uint32)
task.add_scalar_arg(generatorType, ty.uint64)
task.execute()
self.legate_runtime.issue_execution_fence()
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
return self.current_random_bitgenid

def bitgenerator_destroy(self, handle):
self.legate_runtime.issue_execution_fence()
task = self.legate_context.create_task(
CuNumericOpCode.BITGENERATOR,
manual=True,
launch_domain=Rect(lo=(0,), hi=(self.num_procs,)),
)
task.add_scalar_arg(2, ty.int32) # OP_DESTROY
task.add_scalar_arg(handle, ty.uint32)
task.add_scalar_arg(0, ty.uint64)
task.execute()

def bitgenerator_set_seed(self, handle, seed):
if not isinstance(seed, int):
raise NotImplementedError("Non integer seed is not implemented")
task = self.legate_context.create_task(
CuNumericOpCode.BITGENERATOR,
manual=True,
launch_domain=Rect(lo=(0,), hi=(self.num_procs,)),
)
task.add_scalar_arg(4, ty.int32) # OP_SET_SEED
task.add_scalar_arg(handle, ty.uint32)
task.add_scalar_arg(seed, ty.uint64)
task.execute()

def bitgenerator_random_raw(self, handle, size):
# here, no output: we discard generated numbers... - just a skipahead
task = self.legate_context.create_task(
CuNumericOpCode.BITGENERATOR,
manual=True,
launch_domain=Rect(lo=(0,), hi=(self.num_procs,)),
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
)
task.add_scalar_arg(3, ty.int32) # OP_RAND_RAW
task.add_scalar_arg(handle, ty.uint32)
gencount = 1
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
for sz in size:
gencount = gencount * sz
task.add_scalar_arg(gencount, ty.uint64) # size of the output
task.execute()
# for consistent random ordering
self.legate_runtime.issue_execution_fence()
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved

def set_next_random_epoch(self, epoch):
self.current_random_epoch = epoch

Expand Down
4 changes: 4 additions & 0 deletions cunumeric/thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def bincount(self, rhs, weights=None):
def nonzero(self):
...

@abstractmethod
def bitgenerator_random_raw(self, bitgen):
...

magnatelee marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def random_uniform(self):
...
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ LD_FLAGS += -L$(OPENBLAS_PATH)/lib -l$(OPENBLAS_LIBNAME) -Wl,-rpath,$(OPENBLAS_P
LD_FLAGS += -L$(TBLIS_PATH)/lib -ltblis -Wl,-rpath,$(TBLIS_PATH)/lib
ifeq ($(strip $(USE_CUDA)),1)
DEVICE_LD_FLAGS += -lcufft_static
LD_FLAGS += -lcublas -lcusolver -lcufft_static -lculibos
LD_FLAGS += -lcublas -lcusolver -lcufft_static -lculibos -lcurand
LD_FLAGS += -L$(CUTENSOR_PATH)/lib -lcutensor -Wl,-rpath,$(CUTENSOR_PATH)/lib
LD_FLAGS += -L$(NCCL_PATH)/lib -lnccl -Wl,-rpath,$(NCCL_PATH)/lib
endif
Expand Down
3 changes: 3 additions & 0 deletions src/cunumeric.mk
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \
cunumeric/matrix/trsm.cc \
cunumeric/matrix/util.cc \
cunumeric/random/rand.cc \
cunumeric/random/bitgenerator.cc \
cunumeric/search/nonzero.cc \
cunumeric/set/unique.cc \
cunumeric/set/unique_reduce.cc \
Expand Down Expand Up @@ -80,6 +81,7 @@ GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \
cunumeric/matrix/trsm_omp.cc \
cunumeric/matrix/util_omp.cc \
cunumeric/random/rand_omp.cc \
cunumeric/random/bitgenerator_omp.cc \
cunumeric/search/nonzero_omp.cc \
cunumeric/set/unique_omp.cc \
cunumeric/sort/sort_omp.cc \
Expand Down Expand Up @@ -119,6 +121,7 @@ GEN_GPU_SRC += cunumeric/ternary/where.cu \
cunumeric/matrix/trilu.cu \
cunumeric/matrix/trsm.cu \
cunumeric/random/rand.cu \
cunumeric/random/bitgenerator.cu \
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
cunumeric/search/nonzero.cu \
cunumeric/set/unique.cu \
cunumeric/sort/sort.cu \
Expand Down
1 change: 1 addition & 0 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum CuNumericOpCode {
CUNUMERIC_BINARY_OP,
CUNUMERIC_BINARY_RED,
CUNUMERIC_BINCOUNT,
CUNUMERIC_BITGENERATOR,
CUNUMERIC_CHOOSE,
CUNUMERIC_CONTRACT,
CUNUMERIC_CONVERT,
Expand Down
10 changes: 10 additions & 0 deletions src/cunumeric/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ std::vector<StoreMapping> CuNumericMapper::store_mappings(
mappings.back().policy.exact = true;
return std::move(mappings);
}
case CUNUMERIC_BITGENERATOR: {
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
std::vector<StoreMapping> mappings;
auto& outputs = task.outputs();
for (auto& output : outputs) {
mappings.push_back(StoreMapping::default_mapping(output, options.front()));
mappings.back().policy.ordering.c_order();
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
mappings.back().policy.exact = true;
}
return std::move(mappings);
}
default: {
return {};
}
Expand Down
78 changes: 78 additions & 0 deletions src/cunumeric/random/bitgenerator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/random/bitgenerator.h"
#include "cunumeric/random/bitgenerator_template.inl"
#include "cunumeric/random/bitgenerator_util.h"

#include "cunumeric/random/curand_help.h"

#include "cunumeric/random/bitgenerator_curand.inl"

namespace cunumeric {

using namespace Legion;
using namespace legate;

template <>
struct CURANDGeneratorBuilder<VariantKind::CPU> {
static CURANDGenerator* build(BitGeneratorType gentype)
{
curandGenerator_t gen;
CHECK_CURAND(::curandCreateGeneratorHost(&gen, get_curandRngType(gentype)));
CURANDGenerator* cugenptr = new CURANDGenerator();
CURANDGenerator& cugen = *cugenptr;
cugen.gen = gen;
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
cugen.offset = 0;
cugen.type = get_curandRngType(gentype);
cugen.supports_skipahead = supportsSkipAhead(cugen.type);
cugen.dev_buffer_size = cugen.DEFAULT_DEV_BUFFER_SIZE;
cugen.dev_buffer = (uint32_t*)::malloc(cugen.dev_buffer_size * sizeof(uint32_t));
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
return cugenptr;
}

static void destroy(CURANDGenerator* cugenptr)
{
// wait for rand jobs and clean-up resources
std::lock_guard<std::mutex> guard(cugenptr->lock);
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
::free(cugenptr->dev_buffer);
fduguet-nv marked this conversation as resolved.
Show resolved Hide resolved
CHECK_CURAND(::curandDestroyGenerator(cugenptr->gen));
}
};

template <>
std::map<Legion::Processor, std::unique_ptr<generatormap<VariantKind::CPU>>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a data structure you should protect using a lock, as this map is initially empty and will be populated by tasks running on several different processors. There's a way to avoid that lock as well, but I wouldn't mind using a lock in this pull request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried thread local but this fails because tasks on the same processor can be called by different threads. If you have another option without lock, I'll be happy to use it.

BitGeneratorImplBody<VariantKind::CPU>::m_generators = {};

template <>
std::mutex BitGeneratorImplBody<VariantKind::CPU>::lock_generators = {};

/*static*/ void BitGeneratorTask::cpu_variant(TaskContext& context)
{
bitgenerator_template<VariantKind::CPU>(context);
}

namespace // unnamed
{

static void __attribute__((constructor)) register_tasks(void)
{
BitGeneratorTask::register_variants();
}

} // namespace

} // namespace cunumeric
Loading