Skip to content

Commit

Permalink
PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
fduguet-nv committed Jul 15, 2022
1 parent 90135c9 commit 7364d28
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
28 changes: 14 additions & 14 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,20 +659,6 @@ def bitgenerator_integers(
a = np.random.random_integers(low, high, size=self.array.shape)
self.array[:] = a

def bitgenerator_uniform(
self, handle, generatorType, seed, flags, low, high
) -> None:
if self.deferred is not None:
self.deferred.bitgenerator_uniform(
handle, generatorType, seed, flags, low, high
)
else:
if self.array.size == 1:
self.array.fill(np.random.uniform(low, high))
else:
a = np.random.uniform(low, high, size=self.array.shape)
self.array[:] = a

def bitgenerator_lognormal(
self, handle, generatorType, seed, flags, mean, sigma
) -> None:
Expand Down Expand Up @@ -701,6 +687,20 @@ def bitgenerator_normal(
a = np.random.normal(mean, sigma, size=self.array.shape)
self.array[:] = a

def bitgenerator_uniform(
self, handle, generatorType, seed, flags, low, high
) -> None:
if self.deferred is not None:
self.deferred.bitgenerator_uniform(
handle, generatorType, seed, flags, low, high
)
else:
if self.array.size == 1:
self.array.fill(np.random.uniform(low, high))
else:
a = np.random.uniform(low, high, size=self.array.shape)
self.array[:] = a

def bitgenerator_poisson(
self, handle, generatorType, seed, flags, lam
) -> None:
Expand Down
11 changes: 11 additions & 0 deletions cunumeric/random/bitgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ def normal(self, mean=0.0, sigma=1.0, shape=None, dtype=np.float64):
)
return res

def uniform(self, low=0.0, high=1.0, shape=None, dtype=np.float64):
if shape is None:
shape = (1,)
if not isinstance(shape, tuple):
shape = (shape,)
res = ndarray(shape, dtype=np.dtype(dtype))
res._thunk.bitgenerator_uniform(
self.handle, self.generatorType, self.seed, self.flags, low, high
)
return res

def poisson(self, lam, shape=None):
if shape is None:
shape = (1,)
Expand Down
3 changes: 3 additions & 0 deletions cunumeric/random/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def normal(self, loc=0.0, scale=1.0, size=None, dtype=np.float64):
mean=loc, sigma=scale, shape=size, dtype=dtype
)

def uniform(self, low=0.0, high=1.0, size=None, dtype=np.float64):
return self.bit_generator.uniform(low, high, size, dtype)

def poisson(self, lam=1.0, size=None):
return self.bit_generator.poisson(lam, size)

Expand Down
1 change: 0 additions & 1 deletion src/cunumeric/random/bitgenerator_curand.inl
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,6 @@ struct BitGeneratorImplBody {
std::vector<legate::Store>& args)
{
generator_map_t& genmap = get_generator_map();
// printtid((int)op);
switch (op) {
case BitGeneratorOperation::CREATE: {
genmap.create(generatorID, generatorType, seed, flags);
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/test_random_bitgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ def test_normal_float64(t):
assert_distribution(a, theo_mean, theo_std)


@pytest.mark.parametrize("t", BITGENERATOR_ARGS, ids=str)
def test_uniform_float32(t):
bitgen = t(seed=42)
gen = num.random.Generator(bitgen)
low = 1.414
high = 3.14
a = gen.uniform(low, high, size=(1024 * 1024,), dtype=np.float32)
theo_mean = (low + high) / 2.0
theo_std = (high - low) / np.sqrt(12.0)
assert_distribution(a, theo_mean, theo_std)


@pytest.mark.parametrize("t", BITGENERATOR_ARGS, ids=str)
def test_uniform_float64(t):
bitgen = t(seed=42)
gen = num.random.Generator(bitgen)
low = 1.414
high = 3.14
a = gen.uniform(low, high, size=(1024 * 1024,), dtype=np.float64)
theo_mean = (low + high) / 2.0
theo_std = (high - low) / np.sqrt(12.0)
assert_distribution(a, theo_mean, theo_std)


@pytest.mark.parametrize("t", BITGENERATOR_ARGS, ids=str)
def test_poisson(t):
bitgen = t(seed=42)
Expand Down

0 comments on commit 7364d28

Please sign in to comment.