Commit a5acca92 by Josh Fromm Committed by Tianqi Chen

[Vulkan] Added conversion from bool to float. (#3513)

* Added bool to float conversion support to spirv ir builder.

* Added unittest for vulkan bool conversion.

* Typo fix.
parent 7fb9557b
......@@ -462,6 +462,9 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0));
} else if (to.is_float()) {
return MakeValue(spv::OpConvertUToF, dst_type,
Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0)));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
......
......@@ -683,7 +683,7 @@ void VulkanWorkspace::Init() {
try {
instance_ = CreateInstance();
context_ = GetContext(instance_);
LOG(INFO) << "Initialzie Vulkan with " << context_.size() << " devices..";
LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
for (size_t i = 0; i < context_.size(); ++i) {
LOG(INFO) << "vulkan(" << i
<< ")=\'" << context_[i].phy_device_prop.deviceName
......
......@@ -24,7 +24,8 @@ def test_cmp_load_store():
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D")
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i),
A(*i) > 1).astype('float32'), name="D")
def check_llvm():
......@@ -43,7 +44,7 @@ def test_cmp_load_store():
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -61,7 +62,7 @@ def test_cmp_load_store():
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))
check_llvm()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment