Commit a5def36f by Andrew Tulloch Committed by Tianqi Chen

codegen_spirv support Call::reinterpret (#3795)

parent 61d19ccc
...@@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) { ...@@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
} else { } else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
} }
} else if (op->is_intrinsic(Call::reinterpret)) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type),
MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op); return this->CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
......
...@@ -526,12 +526,12 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { ...@@ -526,12 +526,12 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
Value IRBuilder::_OpName(Value a, Value b) { \ Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \ CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \ if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, a.stype, a, b); \ return MakeValue(spv::OpS##_Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \ } else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, a.stype, a, b); \ return MakeValue(spv::OpU##_Op, a.stype, a, b); \
} else { \ } else { \
CHECK(a.stype.type.is_float()); \ CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \ return MakeValue(spv::OpF##_Op, a.stype, a, b); \
} \ } \
} }
...@@ -552,20 +552,18 @@ Value IRBuilder::Mod(Value a, Value b) { ...@@ -552,20 +552,18 @@ Value IRBuilder::Mod(Value a, Value b) {
} }
} }
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \ Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \ CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
t_bool_ = DeclareType(UInt(1)); \ const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
} \
if (a.stype.type.is_int()) { \ if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \ return MakeValue(spv::OpS##_Op, bool_type, a, b); \
} else if (a.stype.type.is_uint()) { \ } else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \ return MakeValue(spv::OpU##_Op, bool_type, a, b); \
} else { \ } else { \
CHECK(a.stype.type.is_float()); \ CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \ } \
} }
...@@ -575,16 +573,15 @@ DEFINE_BUILDER_CMP_OP(GT, GreaterThan); ...@@ -575,16 +573,15 @@ DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ #define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \ Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \ CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \ CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
t_bool_ = DeclareType(UInt(1)); \ const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
} \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \ return MakeValue(spv::OpI##_Op, bool_type, a, b); \
} else { \ } else { \
CHECK(a.stype.type.is_float()); \ CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \ } \
} }
...@@ -593,7 +590,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); ...@@ -593,7 +590,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
Value IRBuilder::Select(Value cond, Value a, Value b) { Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id); CHECK_EQ(a.stype.id, b.stype.id);
CHECK_EQ(cond.stype.type, UInt(1)); CHECK_EQ(cond.stype.type.element_of(), UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b); return MakeValue(spv::OpSelect, a.stype, cond, a, b);
} }
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import re
def test_vector_comparison():
if not tvm.module.enabled("vulkan"):
print("Skipping due to no Vulkan module")
return
target = 'vulkan'
def check_correct_assembly(dtype):
n = (1024,)
A = tvm.placeholder(n, dtype=dtype, name='A')
B = tvm.compute(
A.shape,
lambda i: tvm.expr.Select(
A[i] >= 0, A[i] + tvm.const(1, dtype),
tvm.const(0, dtype)), name='B')
s = tvm.create_schedule(B.op)
(bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
(tx, vx) = s[B].split(tx, factor=4)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[B].vectorize(vx)
f = tvm.build(s, [A, B], target)
# Verify we generate the boolx4 type declaration and the OpSelect
# v4{float,half,int} instruction
assembly = f.imported_modules[0].get_source()
matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
assert len(matches) == 1
matches = re.findall("OpSelect %v4.*", assembly)
assert len(matches) == 1
check_correct_assembly('float32')
check_correct_assembly('int32')
check_correct_assembly('float16')
if __name__ == "__main__":
test_vector_comparison()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment