Unverified Commit c3b89b76 by MORITA Kazutaka Committed by GitHub

[CODEGEN][OPENCL] Explicitly cast min/max operands (#5090)

* [CODEGEN][OPENCL] Explicitly cast min/max operands

* retrigger CI
parent 7ca3212f
......@@ -226,26 +226,6 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
os << "))";
}
void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintType(op->args[2].dtype(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
PrintType(op->true_value.dtype(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
......@@ -259,6 +239,34 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char* opstr,
std::ostream& os,
CodeGenOpenCL* p) {
if (op->dtype.lanes() == 1) {
os << opstr << "((";
p->PrintType(op->a->dtype, os);
os << ")";
p->PrintExpr(op->a, os);
os << ", (";
p->PrintType(op->b->dtype, os);
os << ")";
p->PrintExpr(op->b, os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}
void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) {
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
......
......@@ -55,9 +55,10 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;
private:
// whether enable fp16 and fp64 extension
......
......@@ -94,6 +94,35 @@ def test_opencl_inf_nan():
check_inf_nan(ctx, 1, float('nan'), 'float64')
def test_opencl_max():
def check_max(ctx, n, dtype):
A = te.placeholder((n,), name='A', dtype=dtype)
max_lhs = A[0] + tvm.tir.const(1, dtype=dtype)
max_rhs = tvm.tir.const(0, dtype=dtype)
C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name='C')
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
# Only need to test compiling here
fun(a, c)
if not tvm.runtime.enabled(target):
print("skip because opencl is not enabled..")
return
ctx = tvm.context(target, 0)
check_max(ctx, 1, 'int8')
check_max(ctx, 1, 'uint8')
check_max(ctx, 1, 'int16')
check_max(ctx, 1, 'uint16')
check_max(ctx, 1, 'float32')
check_max(ctx, 1, 'float64')
if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment