Unverified Commit f4c5f93b by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Add Node suffix to low-level IR nodes (#4649)

* [REFACTOR][IR] Variable -> VarNode

* [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc.

* [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc.

* [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc.

* [REFACTOR][IR] Add Node suffix to Select/Call/Load/Ramp/Shuffle/Let

* [REFACTOR][IR] Add node suffix to IntImm/UIntImm/FloatImm/StringImm

* [REFACTOR][IR] Add Node suffix to Any, AttrStmt, AssertStmt

* [REFACTOR][IR] Add Node suffix to Store/Provide/Allocate/Free

* [REFACTOR][IR] Add Node suffix to ProducerConsumer

* Fix lint

* style updates, test fixes
parent df02e730
......@@ -564,7 +564,7 @@ IntSet EvalSet(Expr e,
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
......@@ -586,7 +586,7 @@ IntSet EvalSet(Range r,
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
......@@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s,
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
......@@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
*/
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map);
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
......@@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond,
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);
/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
......
......@@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
......@@ -503,7 +503,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
*ptr = val.operator std::string();
} else {
Expr expr = val;
const ir::StringImm* op = expr.as<ir::StringImm>();
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
}
......@@ -519,11 +519,11 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
} else {
Expr expr = val;
CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
......
......@@ -102,7 +102,7 @@ class Var;
* - Let
* - LetStmt
*/
class Variable : public ExprNode {
class VarNode : public ExprNode {
public:
/*!
* \brief The hint to the variable name.
......@@ -118,7 +118,7 @@ class Variable : public ExprNode {
}
static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
/*! \brief a named variable in TVM */
......@@ -139,18 +139,18 @@ class Var : public Expr {
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* operator->() const {
const VarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* get() const {
return static_cast<const Variable*>(data_.get());
const VarNode* get() const {
return static_cast<const VarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
using ContainerType = VarNode;
};
// Backward compatibility, will be removed later.
......@@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual;
class Integer;
/*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode {
class IntImmNode : public ExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
......@@ -174,7 +174,7 @@ class IntImm : public ExprNode {
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
};
/*!
......@@ -206,8 +206,8 @@ class Integer : public Expr {
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(get());
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*!
* \brief convert to int64_t
......@@ -218,7 +218,7 @@ class Integer : public Expr {
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
using ContainerType = IntImmNode;
};
/*! \brief range over one dimension */
......
......@@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) {
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
return &(op->value);
} else {
return nullptr;
......@@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) {
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
} else {
return nullptr;
......@@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
......@@ -617,11 +617,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);
// Implementation details after this
inline bool is_const(const Expr& x) {
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::Broadcast>()) {
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
return true;
}
}
......@@ -629,9 +629,9 @@ inline bool is_const(const Expr& x) {
}
inline bool is_positive_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
} else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
} else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0;
} else {
return false;
......@@ -639,7 +639,7 @@ inline bool is_positive_const(const Expr& a) {
}
inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value < 0;
} else {
return false;
......@@ -647,15 +647,15 @@ inline bool is_negative_const(const Expr& a) {
}
inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
} else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
......@@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) {
inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) {
if (const auto* op = stmt.as<ir::EvaluateNode>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
......@@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) {
template<typename ValueType>
inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
return ir::FloatImm::make(t, static_cast<double>(value));
return ir::FloatImmNode::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
......@@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
return ir::Broadcast::make(
return ir::BroadcastNode::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
......
......@@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v);
* \param vset The variable set.
* \return Whether e uses vset.
*/
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
/*!
* \brief Convert a IR node to be SSA form.
......@@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt);
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map);
const std::unordered_map<const VarNode*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
......@@ -157,7 +157,7 @@ Stmt Substitute(Stmt stmt,
* \return The converted expression.
*/
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map);
const std::unordered_map<const VarNode*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
......
......@@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode {
virtual void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
......@@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
......@@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
const Stage& stage,
......@@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
const Stage& stage,
......@@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
......@@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
......@@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode {
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
......
......@@ -39,7 +39,7 @@
namespace tvm {
namespace relay {
using Any = tvm::ir::Any;
using Any = tvm::ir::AnyNode;
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
......
......@@ -33,7 +33,7 @@ namespace ir {
TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) {
return Variable::make(t, s);
return VarNode::make(t, s);
});
TVM_REGISTER_GLOBAL("make.abs")
......@@ -73,7 +73,7 @@ TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) {
return For::make(loop_var,
return ForNode::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
......@@ -85,9 +85,9 @@ TVM_REGISTER_GLOBAL("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0];
if (args.size() == 3) {
*ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
*ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = Load::make(t, args[1], args[2], args[3]);
*ret = LoadNode::make(t, args[1], args[2], args[3]);
}
});
......@@ -95,14 +95,14 @@ TVM_REGISTER_GLOBAL("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1];
if (args.size() == 3) {
*ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes()));
*ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
*ret = Store::make(args[0], value, args[2], args[3]);
*ret = StoreNode::make(args[0], value, args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("make.Realize")
.set_body_typed(Realize::make);
.set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed([](
......@@ -110,10 +110,10 @@ TVM_REGISTER_GLOBAL("make.Call")
Array<Expr> args, int call_type,
FunctionRef func, int value_index
) {
return Call::make(type,
return CallNode::make(type,
name,
args,
static_cast<Call::CallType>(call_type),
static_cast<CallNode::CallType>(call_type),
func,
value_index);
});
......@@ -122,9 +122,10 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
.set_body_typed(CommReducerNode::make);
// make from two arguments
#define REGISTER_MAKE(Node) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed(Node::make); \
#define REGISTER_MAKE(NodeName) \
TVM_REGISTER_GLOBAL("make."#NodeName) \
.set_body_typed(NodeName ## Node::make); \
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
......@@ -174,7 +175,7 @@ TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){
return Allocate::make(buffer_var, type, extents, condition, body);
return AllocateNode::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
......
......@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("_const")
});
TVM_REGISTER_GLOBAL("_str")
.set_body_typed(ir::StringImm::make);
.set_body_typed(ir::StringImmNode::make);
TVM_REGISTER_GLOBAL("_Array")
......@@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("_MapItems")
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first));
rkvs->data.push_back(ir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
......
......@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() {
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
if (const auto* ptr = expr.as<ir::IntImmNode>()) {
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
......@@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
}
bool Analyzer::CanProve(const Expr& expr) {
if (const auto* ptr = expr.as<ir::UIntImm>()) {
if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) {
if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
return false;
......
......@@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor {
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map)
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce();
......@@ -94,29 +94,29 @@ class BoundDeducer: public ExprVisitor {
}
}
void VisitExpr_(const LT* op) final {
void VisitExpr_(const LTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const LE* op) final {
void VisitExpr_(const LENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const GT* op) final {
void VisitExpr_(const GTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const GE* op) final {
void VisitExpr_(const GENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
void VisitExpr_(const Add* op) final {
void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
this->VisitExpr(left ? op->a : op->b);
}
void VisitExpr_(const Sub* op) final {
void VisitExpr_(const SubNode* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
......@@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor {
this->VisitExpr(left ? op->a : op->b);
}
void VisitExpr_(const Mul* op) final {
void VisitExpr_(const MulNode* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
......@@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor {
CompareOp ReverseOp(CompareOp comp_op);
Expr target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_;
const std::unordered_map<const VarNode*, IntSet>& hint_map_;
const std::unordered_map<const VarNode*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
std::vector<const Object*> path_;
size_t iter_{0};
......@@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) {
if (const LTNode* op = expr_.as<LTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
comp_op = kGreater;
......@@ -245,7 +245,7 @@ void BoundDeducer::Transform() {
expr_ = op->a;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
} else if (const LENode* op = expr_.as<LENode>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
comp_op = kGreater;
......@@ -256,7 +256,7 @@ void BoundDeducer::Transform() {
expr_ = op->a;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
} else if (const GTNode* op = expr_.as<GTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
comp_op = kLess;
......@@ -268,7 +268,7 @@ void BoundDeducer::Transform() {
expr_ = op->a;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
} else if (const GENode* op = expr_.as<GENode>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
comp_op = kLess;
......@@ -279,7 +279,7 @@ void BoundDeducer::Transform() {
expr_ = op->a;
result_ = op->b;
}
} else if (const EQ* op = expr_.as<EQ>()) {
} else if (const EQNode* op = expr_.as<EQNode>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
......@@ -330,8 +330,8 @@ void BoundDeducer::Relax() {
}
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
......@@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e,
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const Variable*, IntSet> hmap;
std::unordered_map<const VarNode*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
}
std::unordered_map<const Variable*, IntSet> rmap;
std::unordered_map<const VarNode*, IntSet> rmap;
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
......
......@@ -450,14 +450,14 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
}
using Rewriter::VisitExpr_;
Expr VisitExpr_(const Add* op) final;
Expr VisitExpr_(const Sub* op) final;
Expr VisitExpr_(const Mul* op) final;
Expr VisitExpr_(const Div* op) final;
Expr VisitExpr_(const Mod* op) final;
Expr VisitExpr_(const FloorDiv* op) final;
Expr VisitExpr_(const FloorMod* op) final;
Expr VisitExpr_(const Reduce* op) final;
Expr VisitExpr_(const AddNode* op) final;
Expr VisitExpr_(const SubNode* op) final;
Expr VisitExpr_(const MulNode* op) final;
Expr VisitExpr_(const DivNode* op) final;
Expr VisitExpr_(const ModNode* op) final;
Expr VisitExpr_(const FloorDivNode* op) final;
Expr VisitExpr_(const FloorModNode* op) final;
Expr VisitExpr_(const ReduceNode* op) final;
private:
/*!
......@@ -553,7 +553,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
}
ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype();
if (const auto* op = expr.as<IntImm>()) {
if (const auto* op = expr.as<IntImmNode>()) {
n->base = op->value;
return SumExpr(n);
} else {
......@@ -562,11 +562,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
}
}
// Simplify the combiner used in reduce.
Expr SimplifyReduceCombiner(const Reduce* op);
Expr SimplifyReduceCombiner(const ReduceNode* op);
};
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Add* op) {
VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -575,13 +575,13 @@ VisitExpr_(const Add* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Add>(a, b);
Expr const_res = TryConstFold<AddNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) {
if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
......@@ -592,7 +592,7 @@ VisitExpr_(const Add* op) {
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Sub* op) {
VisitExpr_(const SubNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -601,13 +601,13 @@ VisitExpr_(const Sub* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Sub>(a, b);
Expr const_res = TryConstFold<SubNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) {
if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(-op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
......@@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) {
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Mul* op) {
VisitExpr_(const MulNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -628,14 +628,14 @@ VisitExpr_(const Mul* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Mul>(a, b);
Expr const_res = TryConstFold<MulNode>(a, b);
if (const_res.defined()) return const_res;
// x * c
if (a.as<IntImm>()) {
if (a.as<IntImmNode>()) {
std::swap(a, b);
}
if (const auto* bconst = b.as<IntImm>()) {
if (const auto* bconst = b.as<IntImmNode>()) {
if (a.as<SumExprNode>()) {
SumExpr ret = Downcast<SumExpr>(std::move(a));
ret.CopyOnWrite()->MulToSelf(bconst->value);
......@@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
return Mul::make(a, b);
return MulNode::make(a, b);
}
}
......@@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Div* op) {
VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -735,7 +735,7 @@ VisitExpr_(const Div* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Div>(a, b);
Expr const_res = TryConstFold<DivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
......@@ -756,7 +756,7 @@ VisitExpr_(const Div* op) {
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
......@@ -782,12 +782,12 @@ VisitExpr_(const Div* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
return Div::make(a, b);
return DivNode::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const FloorDiv* op) {
VisitExpr_(const FloorDivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<FloorDiv>(a, b);
Expr const_res = TryConstFold<FloorDivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
......@@ -813,7 +813,7 @@ VisitExpr_(const FloorDiv* op) {
// continue simplification.
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) {
if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
......@@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
return FloorDiv::make(a, b);
return FloorDivNode::make(a, b);
}
}
......@@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Mod* op) {
VisitExpr_(const ModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<Mod>(a, b);
Expr const_res = TryConstFold<ModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
......@@ -919,7 +919,7 @@ VisitExpr_(const Mod* op) {
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
if (temp.as<IntImmNode>()) {
return truncmod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
......@@ -958,12 +958,12 @@ VisitExpr_(const Mod* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
return Mod::make(a, b);
return ModNode::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const FloorMod* op) {
VisitExpr_(const FloorModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
......@@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) {
Expr b = this->CanonicalMutate(op->b);
// const folding
Expr const_res = TryConstFold<FloorMod>(a, b);
Expr const_res = TryConstFold<FloorModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
......@@ -983,7 +983,7 @@ VisitExpr_(const FloorMod* op) {
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
Expr temp = Normalize(extra);
if (temp.as<IntImm>()) {
if (temp.as<IntImmNode>()) {
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
......@@ -1018,13 +1018,13 @@ VisitExpr_(const FloorMod* op) {
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
return FloorMod::make(a, b);
return FloorModNode::make(a, b);
}
}
// Simplify reduce expression.
Expr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const Reduce* op) {
SimplifyReduceCombiner(const ReduceNode* op) {
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
......@@ -1089,15 +1089,15 @@ SimplifyReduceCombiner(const Reduce* op) {
CommReducer new_combiner =
CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
return Reduce::make(
return ReduceNode::make(
new_combiner, new_source, op->axis, op->condition, new_value_index);
}
Expr CanonicalSimplifier::Impl::
VisitExpr_(const Reduce* op) {
VisitExpr_(const ReduceNode* op) {
// Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<Reduce>();
op = ret.as<ReduceNode>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
if (op->axis.empty()) {
......@@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) {
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return this->VisitExpr(
Select::make(op->condition,
SelectNode::make(op->condition,
op->source[op->value_index],
op->combiner->identity_element[op->value_index]));
}
......
......@@ -77,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
}
template<>
inline Expr Compute<ir::Add>(Expr a, Expr b) {
inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
return a + b;
}
template<>
inline Expr Compute<ir::Sub>(Expr a, Expr b) {
inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
return a - b;
}
template<>
inline Expr Compute<ir::Mul>(Expr a, Expr b) {
inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
return a * b;
}
template<>
inline Expr Compute<ir::Div>(Expr a, Expr b) {
inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
return truncdiv(a, b);
}
template<>
inline Expr Compute<ir::Mod>(Expr a, Expr b) {
inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
return truncmod(a, b);
}
template<>
inline Expr Compute<ir::Max>(Expr a, Expr b) {
inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
return max(a, b);
}
template<>
inline Expr Compute<ir::Min>(Expr a, Expr b) {
inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
return min(a, b);
}
......
......@@ -140,17 +140,17 @@ class ConstIntBoundAnalyzer::Impl :
return res;
}
Entry VisitExpr_(const Cast* op) final {
Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype);
return Intersect(a, b);
}
Entry VisitExpr_(const IntImm* op) final {
Entry VisitExpr_(const IntImmNode* op) final {
return MakeBound(op->value, op->value);
}
Entry VisitExpr_(const UIntImm* op) final {
Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value <= static_cast<uint64_t>(kPosInf)) {
return MakeBound(op->value, op->value);
} else {
......@@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitExpr_(const Add* op) final {
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
......@@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl :
return ret;
}
Entry VisitExpr_(const Sub* op) final {
Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
......@@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl :
return ret;
}
Entry VisitExpr_(const Mul* op) final {
Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return BinaryOpBoundry(a, b, InfAwareMul);
}
Entry VisitExpr_(const Div* op) final {
Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "divide by zero";
......@@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl :
return BinaryOpBoundry(a, b, InfAwareDiv);
}
Entry VisitExpr_(const Mod* op) final {
Entry VisitExpr_(const ModNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
......@@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry VisitExpr_(const FloorDivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
......@@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl :
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
}
Entry VisitExpr_(const FloorMod* op) final {
Entry VisitExpr_(const FloorModNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
......@@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitExpr_(const Min* op) final {
Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
......@@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl :
return ret;
}
Entry VisitExpr_(const Max* op) final {
Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
......@@ -264,25 +264,25 @@ class ConstIntBoundAnalyzer::Impl :
return ret;
}
Entry VisitExpr_(const Select* op) final {
Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
Entry VisitExpr_(const Call* op) final {
Entry VisitExpr_(const CallNode* op) final {
// only special handle >> and & which can be
// used for index calculation.
if (op->is_intrinsic(Call::shift_right)) {
if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op);
} else if (op->is_intrinsic(Call::bitwise_and)) {
} else if (op->is_intrinsic(CallNode::bitwise_and)) {
return VisitBitwiseAnd(op);
} else {
return Everything(op->dtype);
}
}
Entry VisitExpr_(const Variable* op) final {
Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
......@@ -292,13 +292,13 @@ class ConstIntBoundAnalyzer::Impl :
}
}
Entry VisitRightShift(const Call* op) {
Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundry(a, b, InfAwareRightShift);
}
Entry VisitBitwiseAnd(const Call* op) {
Entry VisitBitwiseAnd(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
// handle positive index case.
......@@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl :
return kNegInf;
}
if (y == kPosInf || y == kNegInf) return y;
if (WillOverflow<Add>(x, y, kNegInf, kPosInf)) {
if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) {
if (x > 0) return kPosInf;
return kNegInf;
}
......@@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl :
* \return the result.
*/
static int64_t InfAwareMul(int64_t x, int64_t y) {
if (!WillOverflow<Mul>(x, y, kNegInf, kPosInf)) return x * y;
if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y;
if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf;
return kNegInf;
}
......
......@@ -60,7 +60,7 @@ class LinearEqDetector
return true;
}
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
......@@ -70,7 +70,7 @@ class LinearEqDetector
return ret;
}
LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
......@@ -80,7 +80,7 @@ class LinearEqDetector
return ret;
}
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
......@@ -96,7 +96,7 @@ class LinearEqDetector
ret.coeff = MulCombine(a.base, b.coeff);
return ret;
}
LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
ret.coeff = make_const(op->dtype, 1);
......@@ -152,7 +152,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
base = std::move(ret.base);
}
std::unordered_set<const Variable*> vset;
std::unordered_set<const VarNode*> vset;
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
......@@ -167,11 +167,11 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
// Detect clip condition as min max value
bool DetectClipBound(
const Expr& cond,
std::unordered_map<const Variable*, IntervalEntry>* bmap) {
std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
if (const Variable* v = n.as<Variable>()) {
if (const VarNode* v = n.as<VarNode>()) {
if (bmap->count(v)) {
if (flag == 0) {
var = Downcast<Var>(n);
......@@ -188,16 +188,16 @@ bool DetectClipBound(
if (flag != 1) return false;
// canonical form: exp >= 0
Expr canonical;
if (const LT* op = cond.as<LT>()) {
if (const LTNode* op = cond.as<LTNode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.dtype(), 1);
} else if (const LE* op = cond.as<LE>()) {
} else if (const LENode* op = cond.as<LENode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a;
} else if (const GT* op = cond.as<GT>()) {
} else if (const GTNode* op = cond.as<GTNode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b - make_const(op->a.dtype(), 1);
} else if (const GE* op = cond.as<GE>()) {
} else if (const GENode* op = cond.as<GENode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b;
} else {
......@@ -210,7 +210,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base);
p.min_value = ir::MaxNode::make(p.min_value, -ret.base);
} else {
p.min_value = -ret.base;
}
......@@ -219,7 +219,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base);
p.max_value = ir::MinNode::make(p.max_value, ret.base);
} else {
p.max_value = ret.base;
}
......@@ -243,8 +243,8 @@ void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
// e must be connected by and.
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
std::vector<Expr> splits;
SplitCommExpr<ir::And>(e, &splits);
std::unordered_map<const Variable*, IntervalEntry> rmap;
SplitCommExpr<ir::AndNode>(e, &splits);
std::unordered_map<const VarNode*, IntervalEntry> rmap;
for (Var v : vars) {
rmap[v.get()] = IntervalEntry();
}
......
......@@ -53,15 +53,15 @@ class FuncTouchedDomain final : public StmtExprVisitor {
return ret;
}
void VisitStmt_(const For *op) final {
const Variable* var = op->loop_var.get();
void VisitStmt_(const ForNode *op) final {
const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
void VisitStmt_(const LetStmt* op) final {
void VisitStmt_(const LetStmtNode* op) final {
dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op);
......@@ -69,11 +69,11 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
/* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmt* op) final {
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
......@@ -82,7 +82,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
}
void VisitExpr_(const Call* op) final {
void VisitExpr_(const CallNode* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
......@@ -90,7 +90,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const Provide* op) final {
void VisitStmt_(const ProvideNode* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
......@@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
const Tensor &tensor_;
bool consider_calls_, consider_provides_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const Variable*, IntSet> dom_map_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
......
......@@ -47,7 +47,7 @@ inline bool WillOverflow(int64_t x,
}
template<>
inline bool WillOverflow<ir::Add>(int64_t x,
inline bool WillOverflow<ir::AddNode>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
......@@ -57,7 +57,7 @@ inline bool WillOverflow<ir::Add>(int64_t x,
}
template<>
inline bool WillOverflow<ir::Sub>(int64_t x,
inline bool WillOverflow<ir::SubNode>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
......@@ -67,7 +67,7 @@ inline bool WillOverflow<ir::Sub>(int64_t x,
}
template<>
inline bool WillOverflow<ir::Mul>(int64_t x,
inline bool WillOverflow<ir::MulNode>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
......@@ -84,7 +84,7 @@ inline bool WillOverflow<ir::Mul>(int64_t x,
}
template<>
inline bool WillOverflow<ir::Mod>(int64_t x,
inline bool WillOverflow<ir::ModNode>(int64_t x,
int64_t y,
int64_t min_value,
int64_t max_value) {
......
......@@ -30,14 +30,14 @@ namespace arith {
using namespace ir;
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const For* op) {
VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const LetStmt* op) {
VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
......@@ -57,7 +57,7 @@ VisitStmt_(const LetStmt* op) {
}
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const IfThenElse* op) {
VisitStmt_(const IfThenElseNode* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
......@@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) {
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition)));
analyzer_->rewrite_simplify(NotNode::make(condition)));
else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
......@@ -74,7 +74,7 @@ VisitStmt_(const IfThenElse* op) {
if (else_case.defined()) {
return else_case;
}
return Evaluate::make(0);
return EvaluateNode::make(0);
}
if (condition.same_as(op->condition) &&
......@@ -91,7 +91,7 @@ VisitStmt_(const IfThenElse* op) {
}
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AttrStmt* op) {
VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -106,7 +106,7 @@ VisitStmt_(const AttrStmt* op) {
}
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AssertStmt* op) {
VisitStmt_(const AssertStmtNode* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
......@@ -126,7 +126,7 @@ VisitStmt_(const AssertStmt* op) {
}
Expr IRMutatorWithAnalyzer::
VisitExpr_(const Call* op) {
VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = this->VisitExpr(op->args[0]);
......@@ -137,7 +137,7 @@ VisitExpr_(const Call* op) {
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
......@@ -151,7 +151,7 @@ VisitExpr_(const Call* op) {
false_value.same_as(op->args[2])) {
return GetRef<Expr>(op);
} else {
return Call::make(op->dtype, op->name,
return CallNode::make(op->dtype, op->name,
{cond, true_value, false_value},
op->call_type);
}
......@@ -160,7 +160,7 @@ VisitExpr_(const Call* op) {
}
Expr IRMutatorWithAnalyzer::
VisitExpr_(const Let* op) {
VisitExpr_(const LetNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
......@@ -172,12 +172,12 @@ VisitExpr_(const Let* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let::make(op->var, value, body);
return LetNode::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
VisitExpr_(const Select* op) {
VisitExpr_(const SelectNode* op) {
Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value;
{
......@@ -186,7 +186,7 @@ VisitExpr_(const Select* op) {
}
{
With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond)));
analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = VisitExpr(op->false_value);
}
if (is_zero(cond)) {
......@@ -201,12 +201,12 @@ VisitExpr_(const Select* op) {
false_value.same_as(op->false_value)) {
return GetRef<Expr>(op);
} else {
return Select::make(cond, true_value, false_value);
return SelectNode::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
VisitExpr_(const Reduce* op) {
VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
......
......@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information.
Stmt VisitStmt_(const ir::For* op) override;
Stmt VisitStmt_(const ir::LetStmt* op) override;
Stmt VisitStmt_(const ir::IfThenElse* op) override;
Stmt VisitStmt_(const ir::AttrStmt* op) override;
Stmt VisitStmt_(const ir::AssertStmt* op) override;
Expr VisitExpr_(const ir::Let* op) override;
Expr VisitExpr_(const ir::Select* op) override;
Expr VisitExpr_(const ir::Call* op) override;
Expr VisitExpr_(const ir::Reduce* op) override;
Stmt VisitStmt_(const ir::ForNode* op) override;
Stmt VisitStmt_(const ir::LetStmtNode* op) override;
Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
Expr VisitExpr_(const ir::LetNode* op) override;
Expr VisitExpr_(const ir::SelectNode* op) override;
Expr VisitExpr_(const ir::CallNode* op) override;
Expr VisitExpr_(const ir::ReduceNode* op) override;
protected:
/*! \brief internal analyzer field. */
......
......@@ -38,13 +38,13 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
return analyzer_.Simplify(expr);
}
void VisitStmt_(const For* op) {
void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmt* op) {
void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -57,7 +57,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
}
}
void VisitExpr_(const Reduce* op) {
void VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
......
......@@ -124,15 +124,15 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const Cast* op) final {
Entry VisitExpr_(const CastNode* op) final {
return VisitExpr(op->value);
}
Entry VisitExpr_(const IntImm* op) final {
Entry VisitExpr_(const IntImmNode* op) final {
return Entry(0, op->value);
}
Entry VisitExpr_(const UIntImm* op) final {
Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
return Entry(0, static_cast<int>(op->value));
} else {
......@@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl :
}
}
Entry VisitExpr_(const Add* op) final {
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}
Entry VisitExpr_(const Sub* op) final {
Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}
Entry VisitExpr_(const Mul* op) final {
Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
// Simplification rule, x, y, z are in Z
......@@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const Div* op) final {
Entry VisitExpr_(const DivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, false);
......@@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const FloorDiv* op) final {
Entry VisitExpr_(const FloorDivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
......@@ -204,35 +204,35 @@ class ModularSetAnalyzer::Impl :
return Everything();
}
Entry VisitExpr_(const Min* op) final {
Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const Max* op) final {
Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
Entry VisitExpr_(const Select* op) final {
Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
Entry VisitExpr_(const Call* op) final {
Entry VisitExpr_(const CallNode* op) final {
// only special handle >> which can be
// used for index calculation.
if (op->is_intrinsic(Call::shift_right)) {
if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op);
} else {
return Everything();
}
}
Entry VisitExpr_(const Variable* op) final {
Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
......@@ -242,7 +242,7 @@ class ModularSetAnalyzer::Impl :
}
}
Entry VisitRightShift(const Call* op) {
Entry VisitRightShift(const CallNode* op) {
Entry b = VisitExpr(op->args[1]);
// a c x / c -> a x
if (b.is_const()) {
......
......@@ -283,7 +283,7 @@ class PConstWithTypeLike :
void InitMatch_() const {}
bool Match_(const ObjectRef& node) const {
if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
if (const ir::IntImmNode* ptr = node.as<ir::IntImmNode>()) {
return ptr->value == value_;
} else {
return false;
......@@ -325,30 +325,30 @@ class PConstWithTypeLike :
// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a));
// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(div, ir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
TVM_PATTERN_BINARY_OP(operator+, ir::AddNode);
TVM_PATTERN_BINARY_OP(operator-, ir::SubNode);
TVM_PATTERN_BINARY_OP(operator*, ir::MulNode);
TVM_PATTERN_BINARY_OP(min, ir::MinNode);
TVM_PATTERN_BINARY_OP(max, ir::MaxNode);
TVM_PATTERN_BINARY_OP(div, ir::DivNode);
TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode);
TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode);
// logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GT);
TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
TVM_PATTERN_BINARY_OP(operator<, ir::LT);
TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
TVM_PATTERN_BINARY_OP(operator&&, ir::And);
TVM_PATTERN_BINARY_OP(operator||, ir::Or);
TVM_PATTERN_BINARY_OP(operator>, ir::GTNode);
TVM_PATTERN_BINARY_OP(operator>=, ir::GENode);
TVM_PATTERN_BINARY_OP(operator<, ir::LTNode);
TVM_PATTERN_BINARY_OP(operator<=, ir::LENode);
TVM_PATTERN_BINARY_OP(operator==, ir::EQNode);
TVM_PATTERN_BINARY_OP(operator!=, ir::NENode);
TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode);
TVM_PATTERN_BINARY_OP(operator||, ir::OrNode);
/*!
* \brief Pattern not expression.
......@@ -365,7 +365,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
}
bool Match_(const ObjectRef& node) const {
if (const ir::Not* ptr = node.as<ir::Not>()) {
if (const ir::NotNode* ptr = node.as<ir::NotNode>()) {
if (!value_.Match_(ptr->a)) return false;
return true;
} else {
......@@ -374,7 +374,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
}
Expr Eval() const {
return ir::Not::make(value_.Eval());
return ir::NotNode::make(value_.Eval());
}
private:
......@@ -411,7 +411,7 @@ class PSelectExpr :
}
bool Match_(const ObjectRef& node) const {
if (const ir::Select* ptr = node.as<ir::Select>()) {
if (const ir::SelectNode* ptr = node.as<ir::SelectNode>()) {
if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false;
if (!false_value_.Match_(ptr->false_value)) return false;
......@@ -422,7 +422,7 @@ class PSelectExpr :
}
Expr Eval() const {
return ir::Select::make(
return ir::SelectNode::make(
condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
......@@ -473,7 +473,7 @@ class PCastExpr :
}
bool Match_(const ObjectRef& node) const {
if (const ir::Cast* ptr = node.as<ir::Cast>()) {
if (const ir::CastNode* ptr = node.as<ir::CastNode>()) {
if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false;
return true;
......@@ -483,7 +483,7 @@ class PCastExpr :
}
Expr Eval() const {
return ir::Cast::make(dtype_.Eval(), value_.Eval());
return ir::CastNode::make(dtype_.Eval(), value_.Eval());
}
private:
......@@ -531,7 +531,7 @@ class PRampExpr :
}
bool Match_(const ObjectRef& node) const {
if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
if (const ir::RampNode* ptr = node.as<ir::RampNode>()) {
if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
......@@ -542,7 +542,7 @@ class PRampExpr :
}
Expr Eval() const {
return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
}
private:
......@@ -593,7 +593,7 @@ class PBroadcastExpr :
}
bool Match_(const ObjectRef& node) const {
if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
if (const ir::BroadcastNode* ptr = node.as<ir::BroadcastNode>()) {
if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
......@@ -603,7 +603,7 @@ class PBroadcastExpr :
}
Expr Eval() const {
return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
}
private:
......@@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor {
};
struct PCallExprMatchFunctor {
const ir::Call* call_;
const ir::CallNode* call_;
bool matched_{true};
explicit PCallExprMatchFunctor(const ir::Call* call)
explicit PCallExprMatchFunctor(const ir::CallNode* call)
: call_(call) {}
template<typename T>
......@@ -705,7 +705,7 @@ class PCallExpr :
}
bool Match_(const ObjectRef& node) const {
if (const ir::Call* ptr = node.as<ir::Call>()) {
if (const ir::CallNode* ptr = node.as<ir::CallNode>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false;
detail::PCallExprMatchFunctor fmatch(ptr);
......@@ -730,8 +730,8 @@ class PCallExpr :
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
return ir::Call::make(args[0].dtype(), kName, args, \
ir::Call::PureIntrinsic); \
return ir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
}; \
......@@ -751,8 +751,8 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
return ir::Call::make(args[0].dtype(), kName, args, \
ir::Call::PureIntrinsic); \
return ir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
}; \
......@@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
// if_then_else
struct PIfThenElseOp {
static Expr Eval(Array<Expr> args) {
return ir::Call::make(
return ir::CallNode::make(
args[1].dtype(), kName, args,
ir::Call::PureIntrinsic);
ir::CallNode::PureIntrinsic);
}
static constexpr const char* kName = "tvm_if_then_else";
};
......
......@@ -50,29 +50,29 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
: IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override_info);
Expr VisitExpr_(const Add* op) override;
Expr VisitExpr_(const Sub* op) override;
Expr VisitExpr_(const Mul* op) override;
Expr VisitExpr_(const Div* op) override;
Expr VisitExpr_(const Mod* op) override;
Expr VisitExpr_(const FloorDiv* op) override;
Expr VisitExpr_(const FloorMod* op) override;
Expr VisitExpr_(const Min* op) override;
Expr VisitExpr_(const Max* op) override;
Expr VisitExpr_(const EQ* op) override;
Expr VisitExpr_(const NE* op) override;
Expr VisitExpr_(const LT* op) override;
Expr VisitExpr_(const LE* op) override;
Expr VisitExpr_(const GT* op) override;
Expr VisitExpr_(const GE* op) override;
Expr VisitExpr_(const And* op) override;
Expr VisitExpr_(const Or* op) override;
Expr VisitExpr_(const Not* op) override;
Expr VisitExpr_(const Select* op) override;
Expr VisitExpr_(const Call* op) override;
Expr VisitExpr_(const Variable* op) override;
Expr VisitExpr_(const Cast* op) override;
Expr VisitExpr_(const Let* op) override;
Expr VisitExpr_(const AddNode* op) override;
Expr VisitExpr_(const SubNode* op) override;
Expr VisitExpr_(const MulNode* op) override;
Expr VisitExpr_(const DivNode* op) override;
Expr VisitExpr_(const ModNode* op) override;
Expr VisitExpr_(const FloorDivNode* op) override;
Expr VisitExpr_(const FloorModNode* op) override;
Expr VisitExpr_(const MinNode* op) override;
Expr VisitExpr_(const MaxNode* op) override;
Expr VisitExpr_(const EQNode* op) override;
Expr VisitExpr_(const NENode* op) override;
Expr VisitExpr_(const LTNode* op) override;
Expr VisitExpr_(const LENode* op) override;
Expr VisitExpr_(const GTNode* op) override;
Expr VisitExpr_(const GENode* op) override;
Expr VisitExpr_(const AndNode* op) override;
Expr VisitExpr_(const OrNode* op) override;
Expr VisitExpr_(const NotNode* op) override;
Expr VisitExpr_(const SelectNode* op) override;
Expr VisitExpr_(const CallNode* op) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const CastNode* op) override;
Expr VisitExpr_(const LetNode* op) override;
std::function<void()> EnterConstraint(const Expr& constraint);
......
......@@ -50,14 +50,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return operator()(std::move(stmt));
}
Stmt VisitStmt_(const For* op) final {
Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op);
}
Stmt VisitStmt_(const LetStmt* op) {
Stmt VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
......@@ -78,13 +78,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
// eliminate useless stores
Stmt VisitStmt_(const Store* op) final {
Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<Store>();
if (const Load* load = op->value.as<Load>()) {
op = stmt.as<StoreNode>();
if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
return Evaluate::make(0);
return EvaluateNode::make(0);
}
}
return GetRef<Stmt>(op);
......
......@@ -29,8 +29,8 @@ namespace tvm {
namespace autotvm {
// for loop
void FeatureVisitor::VisitStmt_(const For* op) {
const auto *extent = op->extent.as<IntImm>();
void FeatureVisitor::VisitStmt_(const ForNode* op) {
const auto *extent = op->extent.as<IntImmNode>();
int64_t loop_extent = -1;
if (extent != nullptr)
loop_extent = extent->value;
......@@ -57,11 +57,11 @@ void FeatureVisitor::VisitStmt_(const For* op) {
}
// parallel axis, virtual thread
void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>();
const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
std::string name = var.get()->name_hint;
......@@ -95,13 +95,13 @@ void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
}
// memory access
void FeatureVisitor::VisitExpr_(const Load* op) {
void FeatureVisitor::VisitExpr_(const LoadNode* op) {
EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitExpr_(op);
ExitMem_();
}
void FeatureVisitor::VisitStmt_(const Store* op) {
void FeatureVisitor::VisitStmt_(const StoreNode* op) {
EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitStmt_(op);
ExitMem_();
......
......@@ -51,12 +51,12 @@ enum AnnotationType {
class FeatureVisitor : public StmtExprVisitor {
public:
// for loop
void VisitStmt_(const For* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
// memory access
void VisitExpr_(const Load* op) final;
void VisitStmt_(const Store* op) final;
void VisitExpr_(const LoadNode* op) final;
void VisitStmt_(const StoreNode* op) final;
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
......
......@@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor {
this->VisitExpr(expr);
}
void VisitExpr_(const Variable* op) final {
void VisitExpr_(const VarNode* op) final {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
......@@ -60,16 +60,16 @@ class IndexParser: public ExprVisitor {
}
}
void VisitExpr_(const Mul* op) final {
if (op->a.as<Variable>()) {
if (const auto stride = op->b.as<IntImm>()) {
void VisitExpr_(const MulNode* op) final {
if (op->a.as<VarNode>()) {
if (const auto stride = op->b.as<IntImmNode>()) {
next_stride_ = stride->value;
}
}
ExprVisitor::VisitExpr_(op);
}
std::unordered_map<const Variable*, TouchPattern> pattern_map;
std::unordered_map<const VarNode*, TouchPattern> pattern_map;
private:
int64_t next_stride_ = 1;
......@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
Array<Expr> attr{std::string("_attr_"),
FloatImm::make(DataType::Float(32), trans(fea.length)),
IntImm::make(DataType::Int(32), fea.nest_level),
FloatImm::make(DataType::Float(32), trans(fea.topdown_product)),
FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)),
FloatImmNode::make(DataType::Float(32), trans(fea.length)),
IntImmNode::make(DataType::Int(32), fea.nest_level),
FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
......@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
// arithmetic
feature_row.push_back(Array<Expr>{std::string("_arith_"),
FloatImm::make(DataType::Float(32), trans(fea.add_ct)),
FloatImm::make(DataType::Float(32), trans(fea.mul_ct)),
FloatImm::make(DataType::Float(32), trans(fea.div_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
});
// touch map
......@@ -281,13 +281,14 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(Array<Expr>{k,
FloatImm::make(DataType::Float(32), trans(v.stride)),
FloatImm::make(DataType::Float(32), trans(v.mod)),
FloatImm::make(DataType::Float(32), trans(v.count)),
FloatImm::make(DataType::Float(32), trans(v.reuse)),
FloatImm::make(DataType::Float(32), trans(v.thread_count)),
FloatImm::make(DataType::Float(32), trans(v.thread_reuse)),
feature_row.push_back(
Array<Expr>{k,
FloatImmNode::make(DataType::Float(32), trans(v.stride)),
FloatImmNode::make(DataType::Float(32), trans(v.mod)),
FloatImmNode::make(DataType::Float(32), trans(v.count)),
FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
});
}
......
......@@ -92,31 +92,31 @@ class TouchExtractor : public FeatureVisitor {
}
// arithmetic stats
void VisitExpr_(const Add* op) final {
void VisitExpr_(const AddNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Sub* op) final {
void VisitExpr_(const SubNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Mul* op) final {
void VisitExpr_(const MulNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Div* op) final {
void VisitExpr_(const DivNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
}
void VisitExpr_(const Mod* op) final {
void VisitExpr_(const ModNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
......
......@@ -65,39 +65,39 @@ Target CreateTarget(const std::string& target_name,
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));
t->options_array.push_back(ir::StringImmNode::make(item));
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(ir::StringImm::make(lib_item));
t->libs_array.push_back(ir::StringImmNode::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name));
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item));
t->keys_array.push_back(ir::StringImmNode::make(key_item));
}
}
}
if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(t->device_name));
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
t->keys_array.push_back(ir::StringImm::make("cpu"));
t->keys_array.push_back(ir::StringImmNode::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->keys_array.push_back(ir::StringImmNode::make("cuda"));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
......@@ -107,8 +107,8 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
......@@ -119,20 +119,20 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImm::make("sdaccel"));
t->keys_array.push_back(ir::StringImm::make("hls"));
t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImm::make("aocl"));
t->keys_array.push_back(ir::StringImm::make("hls"));
t->keys_array.push_back(ir::StringImmNode::make("aocl"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl"));
t->keys_array.push_back(ir::StringImmNode::make("opengl"));
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
......@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("_TargetFromString")
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
result.push_back(expr.as<ir::StringImm>()->value);
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
......@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
result.push_back(expr.as<ir::StringImm>()->value);
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
......@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
result.insert(expr.as<ir::StringImm>()->value);
result.insert(expr.as<ir::StringImmNode>()->value);
}
return result;
}
......@@ -348,7 +348,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<Variable>()) {
if (it.as<VarNode>()) {
has_any = true;
break;
}
......@@ -860,7 +860,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::ir::StringImm>()->value);
tags_vector.push_back(tag.as<tvm::ir::StringImmNode>()->value);
}
generic_func
......
......@@ -102,46 +102,46 @@ class CodeGenC :
*/
virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
......@@ -154,15 +154,15 @@ class CodeGenC :
*/
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, DataType op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base);
virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
virtual void PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
......@@ -180,28 +180,28 @@ class CodeGenC :
DataType t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
DataType t, const Variable* buffer, Expr index);
DataType t, const VarNode* buffer, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch(const Variable* buf_var, DataType t) const;
bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType(const Variable* buf_var, DataType t);
void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
void PrintSSAAssign(
const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, DataType> handle_data_type_;
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
......@@ -209,7 +209,7 @@ class CodeGenC :
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_;
std::unordered_set<const VarNode*> volatile_buf_;
};
} // namespace codegen
......
......@@ -142,7 +142,7 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
......@@ -194,11 +194,11 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
this->stream << "}\n";
}
void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
std::string stack_name = GetUniqueName("stack");
const std::string& type = op->args[0].as<StringImm>()->value;
const IntImm* num = op->args[1].as<IntImm>();
const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
size_t unit = sizeof(TVMValue);
......@@ -218,10 +218,10 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
const StringImm* s = op->args[0].as<StringImm>();
const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImm>()->value;
int64_t end = op->args[4].as<IntImm>()->value;
int64_t begin = op->args[3].as<IntImmNode>()->value;
int64_t end = op->args[4].as<IntImmNode>()->value;
int64_t num_args = end - begin;
CHECK_GE(num_args, 0);
std::string func_name = s->value;
......@@ -237,14 +237,14 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
}
}
void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
PrintIndent();
stream << "TVMAPISetLastError(\"" << op->message.as<StringImm>()->value << "\");\n";
stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
......@@ -254,11 +254,11 @@ void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
this->PrintStmt(op->body);
}
void CodeGenCHost::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenCHost::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
......
......@@ -42,14 +42,14 @@ class CodeGenCHost final : public CodeGenC {
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
void VisitExpr_(const Min *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Max *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const AssertStmt *op) final; // NOLINT(*)
void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
private:
std::string module_name_;
......
......@@ -93,7 +93,7 @@ std::string CodeGenCUDA::Finish() {
return CodeGenC::Finish();
}
void CodeGenCUDA::VisitStmt_(const ir::For* op) {
void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) {
CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) {
PrintIndent();
......@@ -265,8 +265,8 @@ void CodeGenCUDA::PrintVecElemStore(
}
}
void CodeGenCUDA::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared") {
......@@ -314,7 +314,7 @@ void CodeGenCUDA::PrintStorageScope(
}
}
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
......@@ -348,7 +348,7 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImm *str = op->args[7].as<StringImm>()) {
if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
......@@ -369,20 +369,20 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
}
}
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::fragment_shape) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* shape_str = op->value.as<StringImm>();
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* layout_str = op->value.as<StringImm>();
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
}
CodeGenC::VisitStmt_(op);
}
void CodeGenCUDA::VisitStmt_(const Allocate* op) {
void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
......@@ -397,7 +397,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>();
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
......@@ -425,9 +425,9 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
this->PrintStmt(op->body);
}
void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
const CallNode* call = op->value.as<CallNode>();
if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
......@@ -442,7 +442,7 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
}
}
void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
os << "((make_int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
......@@ -452,7 +452,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
os << "))";
}
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);
......@@ -474,7 +474,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
os << ')';
}
void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
......@@ -492,7 +492,7 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
os << ')';
}
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64: case 32: {
std::ostringstream temp;
......@@ -523,12 +523,12 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
}
void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
const Variable* variable, std::ostream &os) {
const VarNode* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
......@@ -550,7 +550,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
}
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
const Variable* variable, int32_t size) {
const VarNode* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
......
......@@ -43,8 +43,8 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const Call* op) final;
void VisitStmt_(const ir::ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, DataType t,
......@@ -56,14 +56,14 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final;
void VisitExpr_(const Call *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final;
void VisitStmt_(const Allocate *op) final;
void VisitStmt_(const AttrStmt *op) final;
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
void VisitExpr_(const CallNode *op, std::ostream& os) final;
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
private:
// Whether global barrier is needed.
......@@ -81,13 +81,13 @@ class CodeGenCUDA final : public CodeGenC {
// whether need mma.h
bool need_mma_h_{false};
std::unordered_map<const Variable*, std::string> fragment_shapes;
std::unordered_map<const Variable*, std::string> fragment_layouts;
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(
const std::string& scope, DataType t, const Variable* variable, std::ostream& os);
const std::string& scope, DataType t, const VarNode* variable, std::ostream& os);
int32_t GetWmmaFragmentSize(
const std::string &scope, const Variable* variable, int32_t size);
const std::string &scope, const VarNode* variable, int32_t size);
};
} // namespace codegen
......
......@@ -196,8 +196,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
void CodeGenMetal::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
void CodeGenMetal::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
......@@ -234,7 +234,7 @@ void CodeGenMetal::PrintStorageScope(
}
}
void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->dtype, os);
os << "(";
......@@ -245,8 +245,8 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI
os << ')';
}
void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::reinterpret)) {
void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(CallNode::reinterpret)) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->dtype, os);
......
......@@ -40,7 +40,7 @@ class CodeGenMetal final : public CodeGenC {
void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element
......@@ -50,10 +50,10 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore(
const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
private:
int thread_index_bits_{32};
......
......@@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
......@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
PrintExpr(base, os);
}
std::string CodeGenOpenCL::GetVecLoad(
DataType t, const Variable* buffer, Expr base) {
DataType t, const VarNode* buffer, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
......@@ -168,7 +168,7 @@ std::string CodeGenOpenCL::GetVecLoad(
return os.str();
}
void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value) {
this->PrintIndent();
......@@ -177,8 +177,8 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
stream << ");\n";
}
void CodeGenOpenCL::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
......@@ -215,7 +215,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType
return os.str();
}
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
......@@ -227,7 +227,7 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))";
}
void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
......@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
......@@ -247,7 +247,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
......
......@@ -42,23 +42,23 @@ class CodeGenOpenCL final : public CodeGenC {
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(DataType t, const Variable* buffer,
std::string GetVecLoad(DataType t, const VarNode* buffer,
Expr base) final;
void PrintVecStore(const Variable* buffer,
void PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const Variable* buffer, DataType t,
void PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os); // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
......
......@@ -188,13 +188,13 @@ void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
this->stream << "}\n";
}
void CodeGenOpenGL::VisitStmt_(const Store* op) {
void CodeGenOpenGL::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Store statement not supported in OpenGL."
<< " Texture store should be a Call statement.";
}
// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
std::ostringstream os;
os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
PrintExpr(index, os);
......@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
// Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r
std::string CodeGenOpenGL::GetBufferRef(
DataType t, const Variable* buffer, Expr index) {
DataType t, const VarNode* buffer, Expr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
......@@ -242,34 +242,34 @@ void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) {
// Codegen for immediate values
void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
void CodeGenOpenGL::VisitExpr_(const StringImmNode*, std::ostream& os) {
LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
}
void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
auto call = op->value.as<Call>();
if (call == nullptr || call->name != Call::glsl_texture_store) {
void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
auto call = op->value.as<CallNode>();
if (call == nullptr || call->name != CallNode::glsl_texture_store) {
// Fallback to normal logic.
CodeGenC::VisitStmt_(op);
}
CHECK_EQ(call->args.size(), 2);
auto buffer = call->args[0].as<Variable>();
auto buffer = call->args[0].as<VarNode>();
auto value = call->args[1];
// Doesn't support store to vector.
......
......@@ -43,24 +43,24 @@ class CodeGenOpenGL final : public CodeGenC {
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const Store* op) final;
std::string TexelFetch(const Variable* buffer, Expr index);
std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final;
void VisitStmt_(const StoreNode* op) final;
std::string TexelFetch(const VarNode* buffer, Expr index);
std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values
void VisitExpr_(const IntImm* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*)
// Match glsl_texture_store Call.
void VisitStmt_(const Evaluate* op) final; // NOLINT(*)
void VisitStmt_(const EvaluateNode* op) final; // NOLINT(*)
private:
const Variable* output_{nullptr};
std::unordered_set<const Variable*> inputs_;
const Variable* output_iter_var_{nullptr};
const VarNode* output_{nullptr};
std::unordered_set<const VarNode*> inputs_;
const VarNode* output_iter_var_{nullptr};
std::unordered_map<std::string, runtime::OpenGLShader> shaders_;
std::string thread_extent_var_;
};
......
......@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
return e.vid;
}
std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
std::string CodeGenSourceBase::AllocVarID(const VarNode* v) {
CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
......@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
return vid;
}
std::string CodeGenSourceBase::GetVarID(const Variable* v) const {
std::string CodeGenSourceBase::GetVarID(const VarNode* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
......
......@@ -66,13 +66,13 @@ class CodeGenSourceBase {
* \param v The variable.
* \return the variable name.
*/
std::string AllocVarID(const Variable* v);
std::string AllocVarID(const VarNode* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
std::string GetVarID(const Variable* v) const;
std::string GetVarID(const VarNode* v) const;
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
......@@ -110,7 +110,7 @@ class CodeGenSourceBase {
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
std::unordered_map<const VarNode*, std::string> var_idmap_;
private:
/*! \brief assignment map of ssa */
......
......@@ -98,7 +98,7 @@ inline void PrintBinaryExpr(const T* op,
os << ')';
}
void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
......@@ -112,7 +112,7 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(
PrintBinaryExpr(op, opstr, os, this);
}
void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
......
......@@ -38,8 +38,8 @@ class CodeGenVivadoHLS final : public CodeGenC {
void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const Min *op, std::ostream& os);
void VisitExpr_(const Max *op, std::ostream& os);
void VisitExpr_(const MinNode *op, std::ostream& os);
void VisitExpr_(const MaxNode *op, std::ostream& os);
};
} // namespace codegen
......
......@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
......@@ -67,7 +67,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
......
......@@ -61,12 +61,12 @@ struct Direct {
template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
const Call* call = e.as<Call>();
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::string name = T()(call->dtype, call->name);
if (name.length() != 0) {
*rv = Call::make(
call->dtype, name, call->args, Call::PureExtern);
*rv = CallNode::make(
call->dtype, name, call->args, CallNode::PureExtern);
} else {
*rv = e;
}
......
......@@ -70,7 +70,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
function_->addFnAttr("amdgpu-flat-work-group-size", attr.str());
}
void VisitStmt_(const Allocate* op) final {
void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
......@@ -153,8 +153,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
return builder_->CreateCall(f, {});
}
llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
llvm::Value* CreateStorageSync(const CallNode* op) final {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
return nullptr;
} else if (sync == "shared") {
......@@ -234,7 +234,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
for (auto &bitcode : bitcode_files) {
std::string path = bitcode.as<StringImm>()->value;
std::string path = bitcode.as<StringImmNode>()->value;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
......
......@@ -39,25 +39,25 @@ class CodeGenARM final : public CodeGenCPU {
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
}
llvm::Value* CreateIntrinsic(const Call* op) override;
llvm::Value* CreateIntrinsic(const CallNode* op) override;
private:
Expr ARMPopcount(const Call* op);
Expr ARMPopcount(const CallNode* op);
};
llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
op->args[0].as<UIntImmNode>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>());
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}
Expr CodeGenARM::ARMPopcount(const Call *call) {
Expr CodeGenARM::ARMPopcount(const CallNode *call) {
using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
......@@ -68,10 +68,10 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt_args.push_back(e);
return ir::Call::make(call->dtype, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
}
// Popcount lowering rule:
......@@ -90,40 +90,44 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const Call* c0 = input8.as<Call>();
const CallNode* c0 = input8.as<CallNode>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
Expr vcnt8 = ir::CallNode::make(
uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
Expr vcnt16 = ir::CallNode::make(
uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
Expr vcnt32 = ir::CallNode::make(
uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->dtype, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
return ir::CallNode::make(
call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
......
......@@ -319,7 +319,7 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(
}
}
llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) {
llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
......@@ -417,7 +417,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
return end_block;
}
void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
......@@ -436,12 +436,12 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value,
op->value.as<StringImmNode>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
......@@ -497,7 +497,7 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* nu
void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields,
std::unordered_map<const Variable*, llvm::Value*>* vmap) {
std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
......@@ -528,7 +528,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
......@@ -594,7 +594,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f);
......@@ -673,7 +673,7 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
std::string func_name = args[0].as<StringImm>()->value;
std::string func_name = args[0].as<StringImmNode>()->value;
llvm::Value *handle = GetPackedFuncHandle(func_name);
// call the function
int64_t nargs = end - begin;
......@@ -701,24 +701,24 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
return end_block;
}
llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) {
llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) {
CHECK_EQ(op->args.size(), 5U);
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
op->args[3].as<IntImm>()->value,
op->args[4].as<IntImm>()->value);
op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
return rvalue;
}
llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) {
llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 6U);
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
BasicBlock *end_block = MakeCallPacked(
op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImm>()->value,
op->args[4].as<IntImm>()->value);
op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value);
// Get traced value.
llvm::Value *traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value.
......@@ -786,7 +786,7 @@ void CodeGenCPU::AddStartupFunction() {
}
}
llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) {
......@@ -798,7 +798,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
op->dtype, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
......@@ -809,7 +809,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>()->value;
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].dtype(), MakeValue(op->args[0]),
......@@ -823,7 +823,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value;
const std::string& type = op->args[0].as<StringImmNode>()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* {
const int64_t* pval = as_const_int(op->args[1]);
CHECK(pval) << "require stack alloca to contain constant value";
......@@ -846,13 +846,13 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
}
}
void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
if (op->message.as<StringImm>()) {
os << ", " << op->message.as<StringImm>()->value;
if (op->message.as<StringImmNode>()) {
os << ", " << op->message.as<StringImmNode>()->value;
}
llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create(
......@@ -869,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
CodeGenLLVM::VisitStmt_(op);
}
void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) {
......@@ -893,7 +893,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == ir::attr::pragma_import_llvm) {
const StringImm* value = op->value.as<StringImm>();
const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr);
this->HandleImport(value->value);
this->VisitStmt(op->body);
......@@ -906,7 +906,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
}
}
void CodeGenCPU::VisitStmt_(const For* op) {
void CodeGenCPU::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial ||
op->for_type == ForType::Unrolled) {
......@@ -914,7 +914,7 @@ void CodeGenCPU::VisitStmt_(const For* op) {
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For::make(
ForNode::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
......@@ -936,8 +936,8 @@ void CodeGenCPU::VisitStmt_(const For* op) {
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
Expr begin = MinNode::make(task_id * step, op->extent);
Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
......
......@@ -45,11 +45,11 @@ class CodeGenCPU : public CodeGenLLVM {
void AddFunction(const LoweredFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const For* op) override;
llvm::Value* CreateIntrinsic(const Call* op) override;
llvm::Value* CreateCallExtern(const Call* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const ForNode* op) override;
llvm::Value* CreateIntrinsic(const CallNode* op) override;
llvm::Value* CreateCallExtern(const CallNode* op) override;
protected:
void AddStartupFunction() final;
......@@ -99,22 +99,22 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap);
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const Call* op);
llvm::Value* CreateCallPacked(const CallNode* op);
// Create trace call into tvm packed function.
llvm::Value* CreateCallTracePacked(const Call *op);
llvm::Value* CreateCallTracePacked(const CallNode *op);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create a new compute scope.
void CreateComputeScope(const AttrStmt* op);
void CreateComputeScope(const AttrStmtNode* op);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
......
......@@ -103,46 +103,46 @@ class CodeGenLLVM :
return llvm::ConstantInt::getSigned(t_int32_, value);
}
// override codegen
llvm::Value* VisitExpr_(const Variable* op) override;
llvm::Value* VisitExpr_(const Cast* op) override;
llvm::Value* VisitExpr_(const IntImm* op) override;
llvm::Value* VisitExpr_(const UIntImm* op) override;
llvm::Value* VisitExpr_(const FloatImm* op) override;
llvm::Value* VisitExpr_(const StringImm* op) override;
llvm::Value* VisitExpr_(const Add* op) override;
llvm::Value* VisitExpr_(const Sub* op) override;
llvm::Value* VisitExpr_(const Mul* op) override;
llvm::Value* VisitExpr_(const Div* op) override;
llvm::Value* VisitExpr_(const Mod* op) override;
llvm::Value* VisitExpr_(const Min* op) override;
llvm::Value* VisitExpr_(const Max* op) override;
llvm::Value* VisitExpr_(const LT* op) override;
llvm::Value* VisitExpr_(const LE* op) override;
llvm::Value* VisitExpr_(const GT* op) override;
llvm::Value* VisitExpr_(const GE* op) override;
llvm::Value* VisitExpr_(const EQ* op) override;
llvm::Value* VisitExpr_(const NE* op) override;
llvm::Value* VisitExpr_(const And* op) override;
llvm::Value* VisitExpr_(const Or* op) override;
llvm::Value* VisitExpr_(const Not* op) override;
llvm::Value* VisitExpr_(const Select* op) override;
llvm::Value* VisitExpr_(const Let* op) override;
llvm::Value* VisitExpr_(const Load* op) override;
llvm::Value* VisitExpr_(const Call* op) override;
llvm::Value* VisitExpr_(const Ramp* op) override;
llvm::Value* VisitExpr_(const Shuffle* op) override;
llvm::Value* VisitExpr_(const Broadcast* op) override;
llvm::Value* VisitExpr_(const VarNode* op) override;
llvm::Value* VisitExpr_(const CastNode* op) override;
llvm::Value* VisitExpr_(const IntImmNode* op) override;
llvm::Value* VisitExpr_(const UIntImmNode* op) override;
llvm::Value* VisitExpr_(const FloatImmNode* op) override;
llvm::Value* VisitExpr_(const StringImmNode* op) override;
llvm::Value* VisitExpr_(const AddNode* op) override;
llvm::Value* VisitExpr_(const SubNode* op) override;
llvm::Value* VisitExpr_(const MulNode* op) override;
llvm::Value* VisitExpr_(const DivNode* op) override;
llvm::Value* VisitExpr_(const ModNode* op) override;
llvm::Value* VisitExpr_(const MinNode* op) override;
llvm::Value* VisitExpr_(const MaxNode* op) override;
llvm::Value* VisitExpr_(const LTNode* op) override;
llvm::Value* VisitExpr_(const LENode* op) override;
llvm::Value* VisitExpr_(const GTNode* op) override;
llvm::Value* VisitExpr_(const GENode* op) override;
llvm::Value* VisitExpr_(const EQNode* op) override;
llvm::Value* VisitExpr_(const NENode* op) override;
llvm::Value* VisitExpr_(const AndNode* op) override;
llvm::Value* VisitExpr_(const OrNode* op) override;
llvm::Value* VisitExpr_(const NotNode* op) override;
llvm::Value* VisitExpr_(const SelectNode* op) override;
llvm::Value* VisitExpr_(const LetNode* op) override;
llvm::Value* VisitExpr_(const LoadNode* op) override;
llvm::Value* VisitExpr_(const CallNode* op) override;
llvm::Value* VisitExpr_(const RampNode* op) override;
llvm::Value* VisitExpr_(const ShuffleNode* op) override;
llvm::Value* VisitExpr_(const BroadcastNode* op) override;
// stmt
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
......@@ -173,13 +173,13 @@ class CodeGenLLVM :
return res;
}
// create intrinstic given call
virtual llvm::Value* CreateIntrinsic(const Call* op);
virtual llvm::Value* CreateIntrinsic(const CallNode* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
virtual llvm::Value* CreateCallExtern(const CallNode* op);
// Get the corresponding thread index
virtual llvm::Value* GetThreadIndex(const IterVar& iv);
// Get the corresponding thread index
virtual llvm::Value* CreateStorageSync(const Call* op);
virtual llvm::Value* CreateStorageSync(const CallNode* op);
// apply optimization on the module.
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
......@@ -211,19 +211,19 @@ class CodeGenLLVM :
void InitFuncState();
// Get alignment given index.
void GetAlignment(
DataType t, const Variable* buf_var, const Expr& index,
DataType t, const VarNode* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// handle module import
void HandleImport(const std::string& code);
// cast operatpr
llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* GetVarValue(const VarNode* v) const;
llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
......@@ -245,7 +245,7 @@ class CodeGenLLVM :
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type);
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
......@@ -280,9 +280,9 @@ class CodeGenLLVM :
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// Whether current function is restricted
......@@ -290,9 +290,9 @@ class CodeGenLLVM :
// The analyzer information
std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const Variable*> volatile_buf_;
std::unordered_set<const VarNode*> volatile_buf_;
/*! \brief Helper struct for debug infos. */
struct DebugInfo {
std::unique_ptr<llvm::DIBuilder> di_builder_;
......
......@@ -46,7 +46,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
llvm::ValueAsMetadata::get(ConstInt32(1)) }));
}
void VisitStmt_(const Allocate* op) final {
void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
......@@ -129,8 +129,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
return builder_->CreateCall(f, {});
}
llvm::Value* CreateStorageSync(const Call* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value;
llvm::Value* CreateStorageSync(const CallNode* op) final {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
......
......@@ -65,14 +65,14 @@ bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature)
class CodeGenX86_64 final : public CodeGenCPU {
public:
llvm::Value* VisitExpr_(const Cast* op) override;
llvm::Value* VisitExpr_(const CastNode* op) override;
private:
llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
const std::vector<llvm::Value*>& args);
};
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
// LLVM does not automatically generate the correct instruction sequences for
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
// vcvtph2ps), so we explicitly generate them ourselves.
......@@ -90,22 +90,23 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())),
{
MakeValue(ir::Call::make(
DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
ir::Call::PureIntrinsic)),
MakeValue(ir::CallNode::make(
DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
ir::CallNode::PureIntrinsic)),
MakeValue(
ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)),
ir::BroadcastNode::make(
ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
});
}
if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
{MakeValue(ir::Call::make(
DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
ir::Call::PureIntrinsic))});
{MakeValue(ir::CallNode::make(
DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
ir::CallNode::PureIntrinsic))});
}
}
......
......@@ -64,21 +64,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
const Expr& x = call->args[0];
Expr one = make_const(x.dtype(), 1);
Expr two = make_const(x.dtype(), 2);
Expr neg_two = make_const(x.dtype(), -2);
Expr exp_neg2x = ir::Call::make(
x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
Expr exp_pos2x = ir::Call::make(
x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic);
Expr exp_neg2x = ir::CallNode::make(
x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
Expr exp_pos2x = ir::CallNode::make(
x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::Select::make(
*rv = ir::SelectNode::make(
x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});
......
......@@ -39,34 +39,34 @@ namespace codegen {
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
*rv = ir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic);
}
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic);
*rv = ir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic);
}
} // namespace codegen
......
......@@ -35,14 +35,14 @@ namespace codegen {
inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
const Call* call = e.as<Call>();
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name;
if (call->dtype.bits() == 32) intrinsic_name << "f";
*rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern);
*rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
CallNode::PureExtern);
}
namespace llvm {
......
......@@ -35,12 +35,12 @@ namespace codegen {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
const Call* call = e.as<Call>();
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::ostringstream intrinsic_name;
intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
*rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern);
*rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
CallNode::PureExtern);
}
namespace llvm {
......
......@@ -62,45 +62,45 @@ class CodeGenSPIRV:
return VisitExpr(e);
}
// override codegen
spirv::Value VisitExpr_(const Variable* op) override;
spirv::Value VisitExpr_(const Cast* op) override;
spirv::Value VisitExpr_(const IntImm* op) override;
spirv::Value VisitExpr_(const UIntImm* op) override;
spirv::Value VisitExpr_(const FloatImm* op) override;
spirv::Value VisitExpr_(const StringImm* op) override;
spirv::Value VisitExpr_(const Add* op) override;
spirv::Value VisitExpr_(const Sub* op) override;
spirv::Value VisitExpr_(const Mul* op) override;
spirv::Value VisitExpr_(const Div* op) override;
spirv::Value VisitExpr_(const Mod* op) override;
spirv::Value VisitExpr_(const Min* op) override;
spirv::Value VisitExpr_(const Max* op) override;
spirv::Value VisitExpr_(const LT* op) override;
spirv::Value VisitExpr_(const LE* op) override;
spirv::Value VisitExpr_(const GT* op) override;
spirv::Value VisitExpr_(const GE* op) override;
spirv::Value VisitExpr_(const EQ* op) override;
spirv::Value VisitExpr_(const NE* op) override;
spirv::Value VisitExpr_(const And* op) override;
spirv::Value VisitExpr_(const Or* op) override;
spirv::Value VisitExpr_(const Not* op) override;
spirv::Value VisitExpr_(const Select* op) override;
spirv::Value VisitExpr_(const Let* op) override;
spirv::Value VisitExpr_(const Call* op) override;
spirv::Value VisitExpr_(const Ramp* op) override;
spirv::Value VisitExpr_(const Broadcast* op) override;
spirv::Value VisitExpr_(const Load* op) override;
spirv::Value VisitExpr_(const VarNode* op) override;
spirv::Value VisitExpr_(const CastNode* op) override;
spirv::Value VisitExpr_(const IntImmNode* op) override;
spirv::Value VisitExpr_(const UIntImmNode* op) override;
spirv::Value VisitExpr_(const FloatImmNode* op) override;
spirv::Value VisitExpr_(const StringImmNode* op) override;
spirv::Value VisitExpr_(const AddNode* op) override;
spirv::Value VisitExpr_(const SubNode* op) override;
spirv::Value VisitExpr_(const MulNode* op) override;
spirv::Value VisitExpr_(const DivNode* op) override;
spirv::Value VisitExpr_(const ModNode* op) override;
spirv::Value VisitExpr_(const MinNode* op) override;
spirv::Value VisitExpr_(const MaxNode* op) override;
spirv::Value VisitExpr_(const LTNode* op) override;
spirv::Value VisitExpr_(const LENode* op) override;
spirv::Value VisitExpr_(const GTNode* op) override;
spirv::Value VisitExpr_(const GENode* op) override;
spirv::Value VisitExpr_(const EQNode* op) override;
spirv::Value VisitExpr_(const NENode* op) override;
spirv::Value VisitExpr_(const AndNode* op) override;
spirv::Value VisitExpr_(const OrNode* op) override;
spirv::Value VisitExpr_(const NotNode* op) override;
spirv::Value VisitExpr_(const SelectNode* op) override;
spirv::Value VisitExpr_(const LetNode* op) override;
spirv::Value VisitExpr_(const CallNode* op) override;
spirv::Value VisitExpr_(const RampNode* op) override;
spirv::Value VisitExpr_(const BroadcastNode* op) override;
spirv::Value VisitExpr_(const LoadNode* op) override;
// stmt
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
......@@ -129,7 +129,7 @@ class CodeGenSPIRV:
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
spirv::Value CreateStorageSync(const Call* op);
spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f);
// The builder
......@@ -139,9 +139,9 @@ class CodeGenSPIRV:
// Likely branch
uint32_t weight_likely_branch_{128};
// the storage scope of allocation
std::unordered_map<const Variable*, StorageInfo> storage_info_;
std::unordered_map<const VarNode*, StorageInfo> storage_info_;
// The definition of local variable.
std::unordered_map<const Variable*, spirv::Value> var_map_;
std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;
};
......
......@@ -35,17 +35,17 @@ using namespace runtime;
template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
*rv = ir::CallNode::make(
call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
......
......@@ -96,13 +96,13 @@ class CodeGenStackVM
* \param v The variable.
* \return the heap index of the var.
*/
int AllocVarID(const Variable* v);
int AllocVarID(const VarNode* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the heap index of the var.
*/
int GetVarID(const Variable* v) const;
int GetVarID(const VarNode* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
......@@ -111,52 +111,52 @@ class CodeGenStackVM
void PushCast(DataType dst, DataType src);
// overloadable functions
// expression
void VisitExpr_(const Variable* op) final;
void VisitExpr_(const Load* op) final;
void VisitExpr_(const Let* op) final;
void VisitExpr_(const Call* op) final;
void VisitExpr_(const Add* op) final;
void VisitExpr_(const Sub* op) final;
void VisitExpr_(const Mul* op) final;
void VisitExpr_(const Div* op) final;
void VisitExpr_(const Mod* op) final;
void VisitExpr_(const Min* op) final;
void VisitExpr_(const Max* op) final;
void VisitExpr_(const EQ* op) final;
void VisitExpr_(const NE* op) final;
void VisitExpr_(const LT* op) final;
void VisitExpr_(const LE* op) final;
void VisitExpr_(const GT* op) final;
void VisitExpr_(const GE* op) final;
void VisitExpr_(const And* op) final;
void VisitExpr_(const Or* op) final;
void VisitExpr_(const Cast* op) final;
void VisitExpr_(const Not* op) final;
void VisitExpr_(const Select* op) final;
void VisitExpr_(const Ramp* op) final;
void VisitExpr_(const Broadcast* op) final;
void VisitExpr_(const IntImm* op) final;
void VisitExpr_(const UIntImm* op) final;
void VisitExpr_(const FloatImm* op) final;
void VisitExpr_(const StringImm* op) final;
void VisitExpr_(const VarNode* op) final;
void VisitExpr_(const LoadNode* op) final;
void VisitExpr_(const LetNode* op) final;
void VisitExpr_(const CallNode* op) final;
void VisitExpr_(const AddNode* op) final;
void VisitExpr_(const SubNode* op) final;
void VisitExpr_(const MulNode* op) final;
void VisitExpr_(const DivNode* op) final;
void VisitExpr_(const ModNode* op) final;
void VisitExpr_(const MinNode* op) final;
void VisitExpr_(const MaxNode* op) final;
void VisitExpr_(const EQNode* op) final;
void VisitExpr_(const NENode* op) final;
void VisitExpr_(const LTNode* op) final;
void VisitExpr_(const LENode* op) final;
void VisitExpr_(const GTNode* op) final;
void VisitExpr_(const GENode* op) final;
void VisitExpr_(const AndNode* op) final;
void VisitExpr_(const OrNode* op) final;
void VisitExpr_(const CastNode* op) final;
void VisitExpr_(const NotNode* op) final;
void VisitExpr_(const SelectNode* op) final;
void VisitExpr_(const RampNode* op) final;
void VisitExpr_(const BroadcastNode* op) final;
void VisitExpr_(const IntImmNode* op) final;
void VisitExpr_(const UIntImmNode* op) final;
void VisitExpr_(const FloatImmNode* op) final;
void VisitExpr_(const StringImmNode* op) final;
// statment
void VisitStmt_(const LetStmt* op) final;
void VisitStmt_(const Store* op) final;
void VisitStmt_(const For* op) final;
void VisitStmt_(const IfThenElse* op) final;
void VisitStmt_(const Allocate* op) final;
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const LetStmtNode* op) final;
void VisitStmt_(const StoreNode* op) final;
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const IfThenElseNode* op) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumer* op) final;
void VisitStmt_(const ProducerConsumerNode* op) final;
private:
bool debug_{false};
/*! \brief The vm to be generated */
StackVM vm_;
/*! \brief id of each variable */
std::unordered_map<const Variable*, int> var_idmap_;
std::unordered_map<const VarNode*, int> var_idmap_;
/*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */
......
......@@ -90,49 +90,49 @@ class CodeGenHybrid :
return os.str();
}
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* \brief Print Type represetnation of type t.
* \param t The type representation.
......@@ -154,7 +154,7 @@ class CodeGenHybrid :
* Values are the corresponding IDs.*/
std::map<std::pair<const Object *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_;
std::map<const VarNode *, std::string> binds_;
/*!
* \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix.
......@@ -166,7 +166,7 @@ class CodeGenHybrid :
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
std::string GetVarID(const Variable *v);
std::string GetVarID(const VarNode *v);
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
......
......@@ -90,29 +90,29 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot
return lhs == other.get();
}
bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImm>()) {
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<UIntImm>()) {
bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<UIntImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImm>()) {
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImm>()) {
bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
}
return false;
......@@ -151,34 +151,34 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Not>()) {
TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<NotNode>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Cast>()) {
bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CastNode>()) {
if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value);
} else {
......@@ -186,8 +186,8 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
}
}
bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Call>()) {
bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<CallNode>()) {
return
lhs->name == rhs->name &&
lhs->dtype == rhs->dtype &&
......@@ -198,8 +198,8 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
}
}
bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Select>()) {
bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<SelectNode>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
......@@ -220,19 +220,19 @@ size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
}
}
size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
return std::hash<int64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
return std::hash<uint64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
return std::hash<std::string>()(op->value);
}
......@@ -265,31 +265,31 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);
size_t AttrsHashHandler::VisitAttr_(const Not* op) {
static size_t key = std::hash<std::string>()(Not::_type_key);
TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
static size_t key = std::hash<std::string>()(NotNode::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key);
size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
static size_t key = std::hash<std::string>()(CastNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->dtype));
......@@ -297,8 +297,8 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Call* op) {
static size_t key = std::hash<std::string>()(Call::_type_key);
size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
static size_t key = std::hash<std::string>()(CallNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
......@@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) {
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Select* op) {
static size_t key = std::hash<std::string>()(Select::_type_key);
size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
static size_t key = std::hash<std::string>()(SelectNode::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
......
......@@ -31,8 +31,8 @@
namespace tvm {
// TODO(tqchen): change to floormod/div
using IndexMod = ir::FloorMod;
using IndexDiv = ir::FloorDiv;
using IndexMod = ir::FloorModNode;
using IndexDiv = ir::FloorDivNode;
Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) {
......@@ -65,7 +65,7 @@ inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
while (!split_buffer.empty()) {
const Expr* top_ele = split_buffer.top();
split_buffer.pop();
auto expr_add_match = top_ele->as<Add>();
auto expr_add_match = top_ele->as<AddNode>();
if (expr_add_match) {
split_buffer.push(&expr_add_match->b);
split_buffer.push(&expr_add_match->a);
......@@ -88,13 +88,13 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
const Expr &mod_l_expr,
const Expr &mod_r_expr) {
using namespace ir;
const Mul* mult_ptr = mult_expr.as<Mul>();
const MulNode* mult_ptr = mult_expr.as<MulNode>();
if (!mult_ptr) return std::make_pair(false, Expr());
Expr mult_outer = mult_ptr->b;
const Expr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier
while (true) {
mult_ptr = inner->as<Mul>();
mult_ptr = inner->as<MulNode>();
if (mult_ptr) {
inner = &(mult_ptr->a);
mult_outer = mult_ptr->b * mult_outer;
......@@ -113,8 +113,8 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<Mul>();
auto inner_add_ptr = search_ptr->as<Add>();
auto inner_mult_ptr = search_ptr->as<MulNode>();
auto inner_add_ptr = search_ptr->as<AddNode>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
return std::make_pair(false, Expr());
} else if (inner_div_ptr) {
......@@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
*has_mod = false;
for (const Expr* ele : eles) {
auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<Mul>();
auto mult_ptr = ele->as<MulNode>();
if (mod_ptr) {
*has_mod = true;
mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
......@@ -252,7 +252,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
if (n->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImm>();
auto is_int = index[0].as<IntImmNode>();
CHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
......@@ -285,7 +285,7 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype)
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
if (dtype.lanes() != 1) {
return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
} else {
return offset;
}
......@@ -299,13 +299,13 @@ Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (dtype == DataType::Bool()) {
return ir::Cast::make(
return ir::CastNode::make(
DataType::Bool(),
ir::Load::make(
ir::LoadNode::make(
DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
const_true()));
} else {
return ir::Load::make(
return ir::LoadNode::make(
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
......@@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (value.dtype() == DataType::Bool()) {
return ir::Store::make(n->data,
ir::Cast::make(DataType::Int(8), value),
return ir::StoreNode::make(n->data,
ir::CastNode::make(DataType::Int(8), value),
BufferOffset(n, begin, DataType::Int(8)),
const_true());
} else {
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}
......@@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
......@@ -405,8 +405,8 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
Array<Expr> acc_args{
e_dtype, self->data, elem_offset,
extent, make_const(DataType::Int(32), access_mask)};
return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
return ir::CallNode::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic);
}
Buffer BufferNode::make(Var data,
......
......@@ -72,7 +72,7 @@ Layout::Layout(const Array<IterVar>& axes) {
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
if (const auto* factor = axis->dom->extent.as<IntImm>()) {
if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
......@@ -186,7 +186,7 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!this->defined()) return -1;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImm>();
const auto* factor = itvar->dom->extent.as<IntImmNode>();
CHECK(factor);
return factor->value;
}
......@@ -251,7 +251,7 @@ inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
const Array<IterVar>& src_axis,
const Array<Expr>& transform_rule) {
Array<Expr> result;
std::unordered_map<const Variable*, Expr> bind_map;
std::unordered_map<const VarNode*, Expr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
......@@ -287,18 +287,18 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const Variable*, Expr> bind_map;
std::unordered_map<const VarNode*, Expr> bind_map;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (orig_shape.as<ir::Any>()) {
if (orig_shape.as<ir::AnyNode>()) {
symbolic_var_set.insert(i);
}
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImm>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
const auto* orig_shape_const = orig_shape.as<IntImmNode>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
if (orig_shape_const) {
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
......@@ -322,7 +322,7 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
result.push_back(axis->dom->extent);
} else {
if (symbolic_var_set.count(i)) {
result.push_back(ir::Any::make());
result.push_back(ir::AnyNode::make());
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
......
......@@ -30,19 +30,19 @@
namespace tvm {
Expr::Expr(int32_t value)
: Expr(IntImm::make(DataType::Int(32), value)) {}
: Expr(IntImmNode::make(DataType::Int(32), value)) {}
Expr::Expr(float value)
: Expr(ir::FloatImm::make(DataType::Float(32), value)) {}
: Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
Expr::Expr(std::string str)
: Expr(ir::StringImm::make(str)) {}
: Expr(ir::StringImmNode::make(str)) {}
Var::Var(std::string name_hint, DataType t)
: Var(Variable::make(t, name_hint)) {}
: Var(VarNode::make(t, name_hint)) {}
Var Variable::make(DataType t, std::string name_hint) {
ObjectPtr<Variable> node = make_object<Variable>();
Var VarNode::make(DataType t, std::string name_hint) {
ObjectPtr<VarNode> node = make_object<VarNode>();
node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
......@@ -54,10 +54,10 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) {
}
Integer IntImm::make(DataType t, int64_t value) {
Integer IntImmNode::make(DataType t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
ObjectPtr<IntImm> node = make_object<IntImm>();
ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = t;
node->value = value;
return Integer(node);
......@@ -98,8 +98,8 @@ Var var(std::string name_hint, DataType t) {
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImm*>(node.get());
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment