/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file expr_operator.cc */ #include <tvm/tir/expr.h> #include <tvm/tir/op.h> #include <cmath> // Centralized header for constant folders. #include "../../arith/const_fold.h" namespace tvm { using namespace tir; // simple cast that only checks if type matches and cast inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; return tir::CastNode::make(t, value); } PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { return tir::CallNode::make( t, tir::intrinsic::tvm_large_uint_imm, {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { lhs = tir::BroadcastNode::make(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); } else { CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; // Only do very simple type coversion // int->float, DataType::Int(32)->int(64) // require the types to be relatively consistent // This will the reduce amount code generated by operators // and also help user to find potential type conversion problems. if (!lhs.dtype().is_float() && rhs.dtype().is_float()) { // int->float lhs = cast(rhs.dtype(), lhs); } else if (lhs.dtype().is_float() && !rhs.dtype().is_float()) { // int->float rhs = cast(lhs.dtype(), rhs); } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { // promote int to higher bits if (lhs.dtype().bits() < rhs.dtype().bits()) { lhs = cast(rhs.dtype(), lhs); } else { rhs = cast(lhs.dtype(), rhs); } } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || (lhs.dtype().is_uint() && rhs.dtype().is_int())) { int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs); rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs); } else { LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; } } // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace tir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { return IntImm(dtype, std::numeric_limits<int64_t>::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; return IntImm(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { return make_const(dtype, std::numeric_limits<uint64_t>::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast<uint64_t>(dtype.bits())) - 1; return IntImm(dtype, static_cast<int64_t>(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImm(dtype, std::numeric_limits<double>::max()); } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits<float>::max()); } else if (dtype.bits() == 16) { return FloatImm(dtype, 65504.0); } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; return PrimExpr(); } PrimExpr min_value(const DataType& dtype) { using namespace tir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { return IntImm(dtype, std::numeric_limits<int64_t>::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); return IntImm(dtype, val); } } else if (dtype.is_uint()) { return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImm(dtype, std::numeric_limits<double>::lowest()); } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits<float>::lowest()); } else if (dtype.bits() == 16) { return FloatImm(dtype, -65504.0); } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; return PrimExpr(); } namespace tir { template<typename ValueType> inline bool ConstPowerHelper(ValueType val, int *shift) { if (val <= 0) return false; shift[0] = 0; while (val != 0) { if (val & 1) { return (val == 1); } ++shift[0]; val = val >> 1; } return true; } bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as<tir::IntImmNode>()) { return ConstPowerHelper(op->value, shift); } else { return false; } } } // namespace tir PrimExpr cast(const DataType& t, PrimExpr value) { using tir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as<IntImmNode>()) { return make_const(t, op->value); } else if (const FloatImmNode* op = value.as<FloatImmNode>()) { return make_const(t, op->value); } return tir::CastNode::make(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast DataType vtype = t.element_of(); if (value.dtype() != vtype) { if (const IntImmNode* op = value.as<IntImmNode>()) { value = make_const(vtype, op->value); } else if (const FloatImmNode* op = value.as<FloatImmNode>()) { value = make_const(vtype, op->value); } else { value = tir::CastNode::make(vtype, value); } } return tir::BroadcastNode::make(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); return tir::CastNode::make(t, value); } } } PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; return tir::CallNode::make( t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::AddNode>(a, b); if (ret.defined()) return ret; return tir::AddNode::make(a, b); } // negation PrimExpr operator-(PrimExpr a) { using tir::IntImmNode; using tir::FloatImmNode; const IntImmNode* pa = a.as<IntImmNode>(); const FloatImmNode* fa = a.as<FloatImmNode>(); if (pa) return IntImm(a.dtype(), -pa->value); if (fa) return FloatImm(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } PrimExpr operator-(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::SubNode>(a, b); if (ret.defined()) return ret; return tir::SubNode::make(a, b); } PrimExpr operator*(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::MulNode>(a, b); if (ret.defined()) return ret; return tir::MulNode::make(a, b); } PrimExpr div(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::DivNode>(a, b); if (ret.defined()) return ret; return tir::DivNode::make(a, b); } PrimExpr truncdiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; return div(a, b); } PrimExpr truncmod(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::ModNode>(a, b); if (ret.defined()) return ret; return tir::ModNode::make(a, b); } PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); } PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); } PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::FloorDivNode>(a, b); if (ret.defined()) return ret; return tir::FloorDivNode::make(a, b); } PrimExpr floormod(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::FloorModNode>(a, b); if (ret.defined()) return ret; return tir::FloorModNode::make(a, b); } PrimExpr min(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton using arith::is_pos_inf; using arith::is_neg_inf; if (is_pos_inf(a)) return b; if (is_neg_inf(a)) return a; if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::MinNode>(a, b); if (ret.defined()) return ret; return tir::MinNode::make(a, b); } PrimExpr max(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton using arith::is_pos_inf; using arith::is_neg_inf; if (is_pos_inf(a)) return a; if (is_neg_inf(a)) return b; if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::MaxNode>(a, b); if (ret.defined()) return ret; return tir::MaxNode::make(a, b); } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); if (const IntImmNode* op = cond.as<IntImmNode>()) { if (op->value != 0) { return true_value; } else { return false_value; } } return tir::CallNode::make( true_value.dtype(), tir::intrinsic::tvm_if_then_else, {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; return tir::CallNode::make(cond.dtype(), tir::CallNode::likely, { cond }, tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::GTNode>(a, b); if (ret.defined()) return ret; return tir::GTNode::make(a, b); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::GENode>(a, b); if (ret.defined()) return ret; return tir::GENode::make(a, b); } PrimExpr operator<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::LTNode>(a, b); if (ret.defined()) return ret; return tir::LTNode::make(a, b); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::LENode>(a, b); if (ret.defined()) return ret; return tir::LENode::make(a, b); } PrimExpr operator==(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::EQNode>(a, b); if (ret.defined()) return ret; return tir::EQNode::make(a, b); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold<tir::NENode>(a, b); if (ret.defined()) return ret; return tir::NENode::make(a, b); } PrimExpr operator&&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); PrimExpr ret = arith::TryConstFold<tir::AndNode>(a, b); if (ret.defined()) return ret; return tir::AndNode::make(a, b); } PrimExpr operator||(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); PrimExpr ret = arith::TryConstFold<tir::OrNode>(a, b); if (ret.defined()) return ret; return tir::OrNode::make(a, b); } PrimExpr operator!(PrimExpr a) { CHECK(a.dtype().is_bool()); PrimExpr ret = arith::TryConstFold<tir::NotNode>(a); if (ret.defined()) return ret; return tir::NotNode::make(a); } PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } }); return tir::CallNode::make( a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } }); return tir::CallNode::make( a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); return tir::CallNode::make( a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); return tir::CallNode::make( a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); return tir::CallNode::make( a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); return tir::CallNode::make( a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; return tir::CallNode::make( x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { using tir::IntImmNode; const IntImmNode* px = x.as<IntImmNode>(); if (px) { return IntImm(x.dtype(), std::abs(px->value)); } return tir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } return tir::CallNode::make(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { LOG(FATAL) << "Data type " << x.dtype() <<" not supported for absolute op. Skipping absolute op..."; return x; } } PrimExpr isnan(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); } else if (x.dtype().is_float()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) { return make_const(t, std::isnan(fx->value)); } if (x.dtype().bits() == 16) { return tir::CallNode::make(t, tir::CallNode::isnan, {cast(DataType::Float(32, t.lanes()), std::move(x))}, tir::CallNode::PureIntrinsic); } else { return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() <<" not supported for isnan op. Skipping isnan op..."; return x; } } PrimExpr sum(PrimExpr source, Array<IterVar> rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AddNode::make(x, y); PrimExpr identity_element = make_zero(source.dtype()); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr all(PrimExpr source, Array<IterVar> rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AndNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), true); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr any(PrimExpr source, Array<IterVar> rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::OrNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), false); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr max(PrimExpr source, Array<IterVar> rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MaxNode::make(x, y); PrimExpr identity_element = min_value(source.dtype()); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr min(PrimExpr source, Array<IterVar> rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MinNode::make(x, y); PrimExpr identity_element = max_value(source.dtype()); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr prod(PrimExpr source, Array<IterVar> rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MulNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); return tir::CallNode::make(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); } PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); return tir::CallNode::make(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); } PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); return tir::CallNode::make(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); } PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); return tir::CallNode::make(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); } PrimExpr trunc(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as<FloatImmNode>(); if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } } // namespace tvm