Commit fa709832 by lixiaoquan Committed by MORITA Kazutaka

[CODEGEN][OPENCL] Fix compile error about ternary expression. (#2821)

Code like this can't be built with NV OpenCL, and it needs an explicit type
  converison for ternary expression if return type is uchar.

       uchar i = 0, j = 0;
       uchar t = max((uchar)j, ((i > 0) ? (uchar)1 : (uchar)0));
parent 0f6989f9
......@@ -208,6 +208,25 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))";
}
void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].type(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.type(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
......
......@@ -38,6 +38,8 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
......
import tvm
target = 'opencl'
def test_opencl_ternary_expression():
def check_if_then_else(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.if_then_else(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)
def check_select(ctx, n, dtype):
A = tvm.placeholder((n,), name='A', dtype=dtype)
true_value = tvm.const(1, dtype=dtype)
false_value = tvm.const(3, dtype=dtype)
max_lhs = tvm.const(2, dtype=dtype)
max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value)
C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C')
s = tvm.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)
if not tvm.module.enabled(target):
print("skip because opencl is not enabled..")
return
ctx = tvm.context(target, 0)
check_if_then_else(ctx, 1, 'int8')
check_if_then_else(ctx, 1, 'uint8')
check_if_then_else(ctx, 1, 'int16')
check_if_then_else(ctx, 1, 'uint16')
check_select(ctx, 1, 'int8')
check_select(ctx, 1, 'uint8')
check_select(ctx, 1, 'int16')
check_select(ctx, 1, 'uint16')
if __name__ == "__main__":
test_opencl_ternary_expression()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment