Commit e1283ff5 by nhynes Committed by Tianqi Chen

Add normal distribution to random engines (#1352)

parent 10b7757a
...@@ -55,4 +55,29 @@ def uniform(low, high, size): ...@@ -55,4 +55,29 @@ def uniform(low, high, size):
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
"tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32') "tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32')
def normal(loc, scale, size):
"""Draw samples from a normal distribution.
Return random samples from a normal distribution.
Parameters
----------
loc : float
loc of the distribution.
scale : float
Standard deviation of the distribution.
size : tuple of ints
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
samples are drawn.
Returns
------
out : Tensor
A tensor with specified size and dtype
"""
return _api.extern(size, [], lambda ins, outs: _intrin.call_packed(
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
_init_api("tvm.contrib.random") _init_api("tvm.contrib.random")
...@@ -73,7 +73,32 @@ class RandomEngine { ...@@ -73,7 +73,32 @@ class RandomEngine {
return uniform_dist(rnd_engine_); return uniform_dist(rnd_engine_);
}); });
} else { } else {
LOG(FATAL) << "Do not support random.randint on this device yet"; LOG(FATAL) << "Do not support random.uniform on this device yet";
}
}
/*!
* \brief Fills a tensor with values drawn from Normal(loc, scale**2)
*/
void SampleNormal(DLTensor* data, float loc, float scale) {
CHECK_GT(scale, 0) << "standard deviation must be positive";
CHECK(data->strides == nullptr);
DLDataType dtype = data->dtype;
int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}
CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
if (data->ctx.device_type == kDLCPU) {
std::normal_distribution<float> normal_dist(loc, scale);
std::generate_n(static_cast<float*>(data->data), size, [&] () {
return normal_dist(rnd_engine_);
});
} else {
LOG(FATAL) << "Do not support random.normal on this device yet";
} }
} }
......
...@@ -87,6 +87,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") ...@@ -87,6 +87,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.randint")
}) })
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal();
...@@ -97,5 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") ...@@ -97,5 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform")
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.random.normal")
.set_body([](TVMArgs args, TVMRetValue *ret) {
RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal();
double loc = args[0];
double scale = args[1];
DLTensor* out = args[2];
entry->random_engine.SampleNormal(out, loc, scale);
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -50,6 +50,30 @@ def test_uniform(): ...@@ -50,6 +50,30 @@ def test_uniform():
verify() verify()
def test_normal():
m = 1024
n = 1024
A = random.normal(3, 4, size=(m, n))
s = tvm.create_schedule(A.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.random.normal", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A], target)
a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
f(a)
na = a.asnumpy()
assert abs(np.mean(na) - 3) < 1e-2
assert abs(np.std(na) - 4) < 1e-2
verify()
if __name__ == "__main__": if __name__ == "__main__":
test_randint() test_randint()
test_uniform() test_uniform()
test_normal()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment