Commit a5def36f by Andrew Tulloch Committed by Tianqi Chen

codegen_spirv support Call::reinterpret (#3795)

parent 61d19ccc
......@@ -283,6 +283,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
} else if (op->is_intrinsic(Call::reinterpret)) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type),
MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -522,17 +522,17 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
} \
}
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
} \
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF##_Op, a.stype, a, b); \
} \
}
DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
......@@ -552,21 +552,19 @@ Value IRBuilder::Mod(Value a, Value b) {
}
}
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, bool_type, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU##_Op, bool_type, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
......@@ -574,18 +572,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI##_Op, bool_type, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \
} \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
......@@ -593,7 +590,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
CHECK_EQ(cond.stype.type, UInt(1));
CHECK_EQ(cond.stype.type.element_of(), UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import re
def test_vector_comparison():
if not tvm.module.enabled("vulkan"):
print("Skipping due to no Vulkan module")
return
target = 'vulkan'
def check_correct_assembly(dtype):
n = (1024,)
A = tvm.placeholder(n, dtype=dtype, name='A')
B = tvm.compute(
A.shape,
lambda i: tvm.expr.Select(
A[i] >= 0, A[i] + tvm.const(1, dtype),
tvm.const(0, dtype)), name='B')
s = tvm.create_schedule(B.op)
(bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
(tx, vx) = s[B].split(tx, factor=4)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[B].vectorize(vx)
f = tvm.build(s, [A, B], target)
# Verify we generate the boolx4 type declaration and the OpSelect
# v4{float,half,int} instruction
assembly = f.imported_modules[0].get_source()
matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
assert len(matches) == 1
matches = re.findall("OpSelect %v4.*", assembly)
assert len(matches) == 1
check_correct_assembly('float32')
check_correct_assembly('int32')
check_correct_assembly('float16')
if __name__ == "__main__":
test_vector_comparison()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment