Unverified Commit f4c5f93b by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Add Node suffix to low-level IR nodes (#4649)

* [REFACTOR][IR] Variable -> VarNode

* [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc.

* [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc.

* [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc.

* [REFACTOR][IR] Add Node suffix to Select/Call/Load/Ramp/Shuffle/Let

* [REFACTOR][IR] Add node suffix to IntImm/UIntImm/FloatImm/StringImm

* [REFACTOR][IR] Add Node suffix to Any, AttrStmt, AssertStmt

* [REFACTOR][IR] Add Node suffix to Store/Provide/Allocate/Free

* [REFACTOR][IR] Add Node suffix to ProducerConsumer

* Fix lint

* style updates, test fixes
parent df02e730
...@@ -564,7 +564,7 @@ IntSet EvalSet(Expr e, ...@@ -564,7 +564,7 @@ IntSet EvalSet(Expr e,
* \return An integer set that can cover all the possible values of e. * \return An integer set that can cover all the possible values of e.
*/ */
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! /*!
* \brief Find an symbolic integer set that contains is union over * \brief Find an symbolic integer set that contains is union over
...@@ -586,7 +586,7 @@ IntSet EvalSet(Range r, ...@@ -586,7 +586,7 @@ IntSet EvalSet(Range r,
* \return An integer set that can cover all the possible values. * \return An integer set that can cover all the possible values.
*/ */
IntSet EvalSet(IntSet s, IntSet EvalSet(IntSet s,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! /*!
* \brief Same as EvalSet, but takes unordered_map * \brief Same as EvalSet, but takes unordered_map
* *
...@@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s, ...@@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s,
* \return An integer set that can cover all the possible values of e. * \return An integer set that can cover all the possible values of e.
*/ */
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */ /*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>; using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
...@@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>; ...@@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
*/ */
ExprIntSetMap EvalSetForEachSubExpr( ExprIntSetMap EvalSetForEachSubExpr(
Expr e, Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map); const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! /*!
* \brief Create an union set of all sets * \brief Create an union set of all sets
...@@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond, ...@@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond,
* \return An integer set that always satisfies the condition. * \return An integer set that always satisfies the condition.
*/ */
IntSet DeduceBound(Expr v, Expr cond, IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map, const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map); const std::unordered_map<const VarNode*, IntSet>& relax_map);
/*! /*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement. * \brief Infer a regular domain that covers all the calls or provides within the given statement.
......
...@@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { ...@@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
} else { } else {
Expr expr = val; Expr expr = val;
CHECK(expr.defined()); CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) { if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value); *ptr = static_cast<T>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value); *ptr = static_cast<T>(op->value);
} else { } else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
...@@ -503,7 +503,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { ...@@ -503,7 +503,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
*ptr = val.operator std::string(); *ptr = val.operator std::string();
} else { } else {
Expr expr = val; Expr expr = val;
const ir::StringImm* op = expr.as<ir::StringImm>(); const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr); CHECK(op != nullptr);
*ptr = op->value; *ptr = op->value;
} }
...@@ -519,11 +519,11 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) { ...@@ -519,11 +519,11 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
} else { } else {
Expr expr = val; Expr expr = val;
CHECK(expr.defined()); CHECK(expr.defined());
if (const ir::IntImm* op = expr.as<ir::IntImm>()) { if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value); *ptr = static_cast<double>(op->value);
} else if (const ir::IntImm* op = expr.as<ir::IntImm>()) { } else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value); *ptr = static_cast<double>(op->value);
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) { } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value); *ptr = static_cast<double>(op->value);
} else { } else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
......
...@@ -102,7 +102,7 @@ class Var; ...@@ -102,7 +102,7 @@ class Var;
* - Let * - Let
* - LetStmt * - LetStmt
*/ */
class Variable : public ExprNode { class VarNode : public ExprNode {
public: public:
/*! /*!
* \brief The hint to the variable name. * \brief The hint to the variable name.
...@@ -118,7 +118,7 @@ class Variable : public ExprNode { ...@@ -118,7 +118,7 @@ class Variable : public ExprNode {
} }
static constexpr const char* _type_key = "Variable"; static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
}; };
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
...@@ -139,18 +139,18 @@ class Var : public Expr { ...@@ -139,18 +139,18 @@ class Var : public Expr {
* \brief Get pointer to the internal value. * \brief Get pointer to the internal value.
* \return the corresponding Variable. * \return the corresponding Variable.
*/ */
const Variable* operator->() const { const VarNode* operator->() const {
return get(); return get();
} }
/*! /*!
* \brief Get pointer to the internal value. * \brief Get pointer to the internal value.
* \return the corresponding Variable. * \return the corresponding Variable.
*/ */
const Variable* get() const { const VarNode* get() const {
return static_cast<const Variable*>(data_.get()); return static_cast<const VarNode*>(data_.get());
} }
/*! \brief type indicate the container type */ /*! \brief type indicate the container type */
using ContainerType = Variable; using ContainerType = VarNode;
}; };
// Backward compatibility, will be removed later. // Backward compatibility, will be removed later.
...@@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual; ...@@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual;
class Integer; class Integer;
/*! \brief ExprNode: constant integer. */ /*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode { class IntImmNode : public ExprNode {
public: public:
/*! \brief the Internal value. */ /*! \brief the Internal value. */
int64_t value; int64_t value;
...@@ -174,7 +174,7 @@ class IntImm : public ExprNode { ...@@ -174,7 +174,7 @@ class IntImm : public ExprNode {
TVM_DLL static Integer make(DataType t, int64_t value); TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm"; static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
}; };
/*! /*!
...@@ -206,8 +206,8 @@ class Integer : public Expr { ...@@ -206,8 +206,8 @@ class Integer : public Expr {
* \brief Get pointer to the internal value. * \brief Get pointer to the internal value.
* \return the content of the integer. * \return the content of the integer.
*/ */
const IntImm* operator->() const { const IntImmNode* operator->() const {
return static_cast<const IntImm*>(get()); return static_cast<const IntImmNode*>(get());
} }
/*! /*!
* \brief convert to int64_t * \brief convert to int64_t
...@@ -218,7 +218,7 @@ class Integer : public Expr { ...@@ -218,7 +218,7 @@ class Integer : public Expr {
return (*this)->value; return (*this)->value;
} }
/*! \brief type indicate the container type */ /*! \brief type indicate the container type */
using ContainerType = IntImm; using ContainerType = IntImmNode;
}; };
/*! \brief range over one dimension */ /*! \brief range over one dimension */
......
...@@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) { ...@@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) {
*/ */
inline const int64_t* as_const_int(const Expr& x) { inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr; if (!x.defined()) return nullptr;
if (const ir::IntImm* op = x.as<ir::IntImm>()) { if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
return &(op->value); return &(op->value);
} else { } else {
return nullptr; return nullptr;
...@@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) { ...@@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) {
*/ */
inline const uint64_t* as_const_uint(const Expr& x) { inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr; if (!x.defined()) return nullptr;
if (const ir::UIntImm* op = x.as<ir::UIntImm>()) { if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value); return &(op->value);
} else { } else {
return nullptr; return nullptr;
...@@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x); ...@@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators // Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \ #define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \ inline Expr OpName(Expr x) { \
return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \ return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
} \ } \
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
...@@ -617,11 +617,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); ...@@ -617,11 +617,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);
// Implementation details after this // Implementation details after this
inline bool is_const(const Expr& x) { inline bool is_const(const Expr& x) {
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) { if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
return true; return true;
} else if (const auto* op = x.as<ir::Broadcast>()) { } else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value; const Expr& val = op->value;
if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) { if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
return true; return true;
} }
} }
...@@ -629,9 +629,9 @@ inline bool is_const(const Expr& x) { ...@@ -629,9 +629,9 @@ inline bool is_const(const Expr& x) {
} }
inline bool is_positive_const(const Expr& a) { inline bool is_positive_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) { if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0; return op->value > 0;
} else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) { } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0; return op->value > 0;
} else { } else {
return false; return false;
...@@ -639,7 +639,7 @@ inline bool is_positive_const(const Expr& a) { ...@@ -639,7 +639,7 @@ inline bool is_positive_const(const Expr& a) {
} }
inline bool is_negative_const(const Expr& a) { inline bool is_negative_const(const Expr& a) {
if (const ir::IntImm* op = a.as<ir::IntImm>()) { if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value < 0; return op->value < 0;
} else { } else {
return false; return false;
...@@ -647,15 +647,15 @@ inline bool is_negative_const(const Expr& a) { ...@@ -647,15 +647,15 @@ inline bool is_negative_const(const Expr& a) {
} }
inline bool is_const_int(const Expr& x, int64_t value) { inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) { if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value; return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) { } else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value); return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::Broadcast>()) { } else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value; const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) { if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value; return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) { } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value); return opv->value == static_cast<uint64_t>(value);
} }
} }
...@@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) { ...@@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) {
inline bool is_no_op(const Stmt& stmt) { inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true; if (!stmt.defined()) return true;
if (const auto* op = stmt.as<ir::Evaluate>()) { if (const auto* op = stmt.as<ir::EvaluateNode>()) {
return is_const(op->value); return is_const(op->value);
} }
if (const auto* op = stmt.as<ir::SeqStmtNode>()) { if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
...@@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) { ...@@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) {
template<typename ValueType> template<typename ValueType>
inline Expr MakeConstScalar(DataType t, ValueType value) { inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value)); if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value)); if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value)); if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the // For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format // datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype. // specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough? // TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
return ir::FloatImm::make(t, static_cast<double>(value)); return ir::FloatImmNode::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t; LOG(FATAL) << "cannot make const for type " << t;
return Expr(); return Expr();
} }
...@@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) { ...@@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) { if (t.lanes() == 1) {
return MakeConstScalar(t, value); return MakeConstScalar(t, value);
} else { } else {
return ir::Broadcast::make( return ir::BroadcastNode::make(
MakeConstScalar(t.element_of(), value), t.lanes()); MakeConstScalar(t.element_of(), value), t.lanes());
} }
} }
......
...@@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v); ...@@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v);
* \param vset The variable set. * \param vset The variable set.
* \return Whether e uses vset. * \return Whether e uses vset.
*/ */
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset); bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
/*! /*!
* \brief Convert a IR node to be SSA form. * \brief Convert a IR node to be SSA form.
...@@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt); ...@@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt);
* \return The converted form. * \return The converted form.
*/ */
Stmt Substitute(Stmt stmt, Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map); const std::unordered_map<const VarNode*, Expr>& value_map);
/*! /*!
* \brief Substitute the var specified in key->var to be value. * \brief Substitute the var specified in key->var to be value.
...@@ -157,7 +157,7 @@ Stmt Substitute(Stmt stmt, ...@@ -157,7 +157,7 @@ Stmt Substitute(Stmt stmt,
* \return The converted expression. * \return The converted expression.
*/ */
Expr Substitute(Expr expr, Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map); const std::unordered_map<const VarNode*, Expr>& value_map);
/*! /*!
* \brief Substitute the var specified in key->var to be value. * \brief Substitute the var specified in key->var to be value.
......
...@@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode { ...@@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode {
virtual void PropBoundToInputs( virtual void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0; std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*! /*!
* \brief Gather the bound from output tensor. * \brief Gather the bound from output tensor.
...@@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
...@@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ...@@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
...@@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { ...@@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide( Stmt BuildProvide(
const Stage& stage, const Stage& stage,
...@@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode { ...@@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
...@@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode { ...@@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
...@@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode { ...@@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode {
void PropBoundToInputs( void PropBoundToInputs(
const Operation& self, const Operation& self,
arith::Analyzer* analyzer, arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet>& dom_map, const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound( void GatherBound(
const Operation& self, const Operation& self,
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using Any = tvm::ir::Any; using Any = tvm::ir::AnyNode;
using Kind = TypeKind; using Kind = TypeKind;
using Type = tvm::Type; using Type = tvm::Type;
using TypeNode = tvm::TypeNode; using TypeNode = tvm::TypeNode;
......
...@@ -33,7 +33,7 @@ namespace ir { ...@@ -33,7 +33,7 @@ namespace ir {
TVM_REGISTER_GLOBAL("_Var") TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return Variable::make(t, s); return VarNode::make(t, s);
}); });
TVM_REGISTER_GLOBAL("make.abs") TVM_REGISTER_GLOBAL("make.abs")
...@@ -73,7 +73,7 @@ TVM_REGISTER_GLOBAL("make.For") ...@@ -73,7 +73,7 @@ TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([]( .set_body_typed([](
VarExpr loop_var, Expr min, Expr extent, VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) { int for_type, int device_api, Stmt body) {
return For::make(loop_var, return ForNode::make(loop_var,
min, min,
extent, extent,
static_cast<ForType>(for_type), static_cast<ForType>(for_type),
...@@ -85,9 +85,9 @@ TVM_REGISTER_GLOBAL("make.Load") ...@@ -85,9 +85,9 @@ TVM_REGISTER_GLOBAL("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0]; DataType t = args[0];
if (args.size() == 3) { if (args.size() == 3) {
*ret = Load::make(t, args[1], args[2], const_true(t.lanes())); *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
} else { } else {
*ret = Load::make(t, args[1], args[2], args[3]); *ret = LoadNode::make(t, args[1], args[2], args[3]);
} }
}); });
...@@ -95,14 +95,14 @@ TVM_REGISTER_GLOBAL("make.Store") ...@@ -95,14 +95,14 @@ TVM_REGISTER_GLOBAL("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1]; Expr value = args[1];
if (args.size() == 3) { if (args.size() == 3) {
*ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes())); *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else { } else {
*ret = Store::make(args[0], value, args[2], args[3]); *ret = StoreNode::make(args[0], value, args[2], args[3]);
} }
}); });
TVM_REGISTER_GLOBAL("make.Realize") TVM_REGISTER_GLOBAL("make.Realize")
.set_body_typed(Realize::make); .set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("make.Call") TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed([]( .set_body_typed([](
...@@ -110,10 +110,10 @@ TVM_REGISTER_GLOBAL("make.Call") ...@@ -110,10 +110,10 @@ TVM_REGISTER_GLOBAL("make.Call")
Array<Expr> args, int call_type, Array<Expr> args, int call_type,
FunctionRef func, int value_index FunctionRef func, int value_index
) { ) {
return Call::make(type, return CallNode::make(type,
name, name,
args, args,
static_cast<Call::CallType>(call_type), static_cast<CallNode::CallType>(call_type),
func, func,
value_index); value_index);
}); });
...@@ -122,9 +122,10 @@ TVM_REGISTER_GLOBAL("make.CommReducer") ...@@ -122,9 +122,10 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
.set_body_typed(CommReducerNode::make); .set_body_typed(CommReducerNode::make);
// make from two arguments // make from two arguments
#define REGISTER_MAKE(Node) \ #define REGISTER_MAKE(NodeName) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("make."#NodeName) \
.set_body_typed(Node::make); \ .set_body_typed(NodeName ## Node::make); \
REGISTER_MAKE(Reduce); REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt); REGISTER_MAKE(AttrStmt);
...@@ -174,7 +175,7 @@ TVM_REGISTER_GLOBAL("make.Allocate") ...@@ -174,7 +175,7 @@ TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([]( .set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){ ){
return Allocate::make(buffer_var, type, extents, condition, body); return AllocateNode::make(buffer_var, type, extents, condition, body);
}); });
// operator overloading, smarter than make // operator overloading, smarter than make
......
...@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("_const") ...@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("_const")
}); });
TVM_REGISTER_GLOBAL("_str") TVM_REGISTER_GLOBAL("_str")
.set_body_typed(ir::StringImm::make); .set_body_typed(ir::StringImmNode::make);
TVM_REGISTER_GLOBAL("_Array") TVM_REGISTER_GLOBAL("_Array")
...@@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("_MapItems") ...@@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("_MapItems")
auto* n = static_cast<const StrMapNode*>(ptr); auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>(); auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(ir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = Array<ObjectRef>(rkvs); *ret = Array<ObjectRef>(rkvs);
......
...@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() { ...@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() {
} }
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) { if (const auto* ptr = expr.as<ir::IntImmNode>()) {
return ptr->value >= lower_bound; return ptr->value >= lower_bound;
} }
auto bd = this->const_int_bound(this->rewrite_simplify(expr)); auto bd = this->const_int_bound(this->rewrite_simplify(expr));
...@@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { ...@@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
} }
bool Analyzer::CanProve(const Expr& expr) { bool Analyzer::CanProve(const Expr& expr) {
if (const auto* ptr = expr.as<ir::UIntImm>()) { if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
return ptr->value != 0; return ptr->value != 0;
} }
auto res = this->rewrite_simplify(expr); auto res = this->rewrite_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) { if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0; return ptr->value != 0;
} }
res = this->canonical_simplify(expr); res = this->canonical_simplify(expr);
if (const auto* ptr = res.as<ir::UIntImm>()) { if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0; return ptr->value != 0;
} }
return false; return false;
......
...@@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor { ...@@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor {
friend class BoundDeduceInputChecker; friend class BoundDeduceInputChecker;
friend class Converter; friend class Converter;
BoundDeducer(Expr target, Expr expr, BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& hint_map, const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) const std::unordered_map<const VarNode*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce(); void Deduce();
...@@ -94,29 +94,29 @@ class BoundDeducer: public ExprVisitor { ...@@ -94,29 +94,29 @@ class BoundDeducer: public ExprVisitor {
} }
} }
void VisitExpr_(const LT* op) final { void VisitExpr_(const LTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void VisitExpr_(const LE* op) final { void VisitExpr_(const LENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void VisitExpr_(const GT* op) final { void VisitExpr_(const GTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void VisitExpr_(const GE* op) final { void VisitExpr_(const GENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; LOG(FATAL) << "unable to deduce due to multiple comparison operator";
} }
void VisitExpr_(const Add* op) final { void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a; result_ -= left ? op->b : op->a;
this->VisitExpr(left ? op->a : op->b); this->VisitExpr(left ? op->a : op->b);
} }
void VisitExpr_(const Sub* op) final { void VisitExpr_(const SubNode* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
if (left) { if (left) {
result_ += op->b; result_ += op->b;
...@@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor { ...@@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor {
this->VisitExpr(left ? op->a : op->b); this->VisitExpr(left ? op->a : op->b);
} }
void VisitExpr_(const Mul* op) final { void VisitExpr_(const MulNode* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a; Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b; Expr target_var = left ? op->a : op->b;
...@@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor { ...@@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor {
CompareOp ReverseOp(CompareOp comp_op); CompareOp ReverseOp(CompareOp comp_op);
Expr target_; Expr target_;
Expr expr_; Expr expr_;
const std::unordered_map<const Variable*, IntSet>& hint_map_; const std::unordered_map<const VarNode*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_; const std::unordered_map<const VarNode*, IntSet>& relax_map_;
ExprIntSetMap expr_map_; ExprIntSetMap expr_map_;
std::vector<const Object*> path_; std::vector<const Object*> path_;
size_t iter_{0}; size_t iter_{0};
...@@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { ...@@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
void BoundDeducer::Transform() { void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_ // We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) { if (const LTNode* op = expr_.as<LTNode>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1 // a < b -> b >= a + 1
comp_op = kGreater; comp_op = kGreater;
...@@ -245,7 +245,7 @@ void BoundDeducer::Transform() { ...@@ -245,7 +245,7 @@ void BoundDeducer::Transform() {
expr_ = op->a; expr_ = op->a;
result_ = op->b - 1; result_ = op->b - 1;
} }
} else if (const LE* op = expr_.as<LE>()) { } else if (const LENode* op = expr_.as<LENode>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a // a <= b -> b >= a
comp_op = kGreater; comp_op = kGreater;
...@@ -256,7 +256,7 @@ void BoundDeducer::Transform() { ...@@ -256,7 +256,7 @@ void BoundDeducer::Transform() {
expr_ = op->a; expr_ = op->a;
result_ = op->b; result_ = op->b;
} }
} else if (const GT* op = expr_.as<GT>()) { } else if (const GTNode* op = expr_.as<GTNode>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1 // a > b -> b <= a - 1
comp_op = kLess; comp_op = kLess;
...@@ -268,7 +268,7 @@ void BoundDeducer::Transform() { ...@@ -268,7 +268,7 @@ void BoundDeducer::Transform() {
expr_ = op->a; expr_ = op->a;
result_ = op->b + 1; result_ = op->b + 1;
} }
} else if (const GE* op = expr_.as<GE>()) { } else if (const GENode* op = expr_.as<GENode>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a // a >= b -> b <= a
comp_op = kLess; comp_op = kLess;
...@@ -279,7 +279,7 @@ void BoundDeducer::Transform() { ...@@ -279,7 +279,7 @@ void BoundDeducer::Transform() {
expr_ = op->a; expr_ = op->a;
result_ = op->b; result_ = op->b;
} }
} else if (const EQ* op = expr_.as<EQ>()) { } else if (const EQNode* op = expr_.as<EQNode>()) {
comp_op = kEqual; comp_op = kEqual;
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b // if the b == a -> a == b
...@@ -330,8 +330,8 @@ void BoundDeducer::Relax() { ...@@ -330,8 +330,8 @@ void BoundDeducer::Relax() {
} }
IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map, const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) { const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map); BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce(); d.Deduce();
if (!d.success_) return IntSet::nothing(); if (!d.success_) return IntSet::nothing();
...@@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e, ...@@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e,
IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map, const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) { const Map<Var, IntSet>& relax_map) {
std::unordered_map<const Variable*, IntSet> hmap; std::unordered_map<const VarNode*, IntSet> hmap;
for (auto kv : hint_map) { for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second; hmap[kv.first.get()] = kv.second;
} }
std::unordered_map<const Variable*, IntSet> rmap; std::unordered_map<const VarNode*, IntSet> rmap;
for (auto kv : relax_map) { for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second; rmap[kv.first.get()] = kv.second;
} }
......
...@@ -450,14 +450,14 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -450,14 +450,14 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
} }
using Rewriter::VisitExpr_; using Rewriter::VisitExpr_;
Expr VisitExpr_(const Add* op) final; Expr VisitExpr_(const AddNode* op) final;
Expr VisitExpr_(const Sub* op) final; Expr VisitExpr_(const SubNode* op) final;
Expr VisitExpr_(const Mul* op) final; Expr VisitExpr_(const MulNode* op) final;
Expr VisitExpr_(const Div* op) final; Expr VisitExpr_(const DivNode* op) final;
Expr VisitExpr_(const Mod* op) final; Expr VisitExpr_(const ModNode* op) final;
Expr VisitExpr_(const FloorDiv* op) final; Expr VisitExpr_(const FloorDivNode* op) final;
Expr VisitExpr_(const FloorMod* op) final; Expr VisitExpr_(const FloorModNode* op) final;
Expr VisitExpr_(const Reduce* op) final; Expr VisitExpr_(const ReduceNode* op) final;
private: private:
/*! /*!
...@@ -553,7 +553,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -553,7 +553,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
} }
ObjectPtr<SumExprNode> n = make_object<SumExprNode>(); ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype(); n->dtype = expr.dtype();
if (const auto* op = expr.as<IntImm>()) { if (const auto* op = expr.as<IntImmNode>()) {
n->base = op->value; n->base = op->value;
return SumExpr(n); return SumExpr(n);
} else { } else {
...@@ -562,11 +562,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { ...@@ -562,11 +562,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
} }
} }
// Simplify the combiner used in reduce. // Simplify the combiner used in reduce.
Expr SimplifyReduceCombiner(const Reduce* op); Expr SimplifyReduceCombiner(const ReduceNode* op);
}; };
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Add* op) { VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -575,13 +575,13 @@ VisitExpr_(const Add* op) { ...@@ -575,13 +575,13 @@ VisitExpr_(const Add* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<Add>(a, b); Expr const_res = TryConstFold<AddNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
// canonical form simplification. // canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a)); SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) { if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(op->value); ret.CopyOnWrite()->AddToSelf(op->value);
} else if (const auto* op = b.as<SumExprNode>()) { } else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1); ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
...@@ -592,7 +592,7 @@ VisitExpr_(const Add* op) { ...@@ -592,7 +592,7 @@ VisitExpr_(const Add* op) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Sub* op) { VisitExpr_(const SubNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -601,13 +601,13 @@ VisitExpr_(const Sub* op) { ...@@ -601,13 +601,13 @@ VisitExpr_(const Sub* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<Sub>(a, b); Expr const_res = TryConstFold<SubNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
// canonical form simplification. // canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a)); SumExpr ret = ToSumExpr(std::move(a));
if (const auto* op = b.as<IntImm>()) { if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(-op->value); ret.CopyOnWrite()->AddToSelf(-op->value);
} else if (const auto* op = b.as<SumExprNode>()) { } else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1); ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
...@@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) { ...@@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) {
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Mul* op) { VisitExpr_(const MulNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -628,14 +628,14 @@ VisitExpr_(const Mul* op) { ...@@ -628,14 +628,14 @@ VisitExpr_(const Mul* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<Mul>(a, b); Expr const_res = TryConstFold<MulNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
// x * c // x * c
if (a.as<IntImm>()) { if (a.as<IntImmNode>()) {
std::swap(a, b); std::swap(a, b);
} }
if (const auto* bconst = b.as<IntImm>()) { if (const auto* bconst = b.as<IntImmNode>()) {
if (a.as<SumExprNode>()) { if (a.as<SumExprNode>()) {
SumExpr ret = Downcast<SumExpr>(std::move(a)); SumExpr ret = Downcast<SumExpr>(std::move(a));
ret.CopyOnWrite()->MulToSelf(bconst->value); ret.CopyOnWrite()->MulToSelf(bconst->value);
...@@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) { ...@@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) {
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Mul::make(a, b); return MulNode::make(a, b);
} }
} }
...@@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { ...@@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Div* op) { VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -735,7 +735,7 @@ VisitExpr_(const Div* op) { ...@@ -735,7 +735,7 @@ VisitExpr_(const Div* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<Div>(a, b); Expr const_res = TryConstFold<DivNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
PVar<Integer> c1; PVar<Integer> c1;
// x / c1 // x / c1
...@@ -756,7 +756,7 @@ VisitExpr_(const Div* op) { ...@@ -756,7 +756,7 @@ VisitExpr_(const Div* op) {
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval); lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) { if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else { } else {
// if 0 <= extra < cval, it means the extra can be eliminated. // if 0 <= extra < cval, it means the extra can be eliminated.
...@@ -782,12 +782,12 @@ VisitExpr_(const Div* op) { ...@@ -782,12 +782,12 @@ VisitExpr_(const Div* op) {
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Div::make(a, b); return DivNode::make(a, b);
} }
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const FloorDiv* op) { VisitExpr_(const FloorDivNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) { ...@@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<FloorDiv>(a, b); Expr const_res = TryConstFold<FloorDivNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
PVar<Integer> c1; PVar<Integer> c1;
// x / c1 // x / c1
...@@ -813,7 +813,7 @@ VisitExpr_(const FloorDiv* op) { ...@@ -813,7 +813,7 @@ VisitExpr_(const FloorDiv* op) {
// continue simplification. // continue simplification.
lhs.CopyOnWrite()->DivideBy(cval); lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImm>()) { if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else { } else {
// if 0 <= extra < cval, it means the extra can be eliminated. // if 0 <= extra < cval, it means the extra can be eliminated.
...@@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) { ...@@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) {
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FloorDiv::make(a, b); return FloorDivNode::make(a, b);
} }
} }
...@@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { ...@@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Mod* op) { VisitExpr_(const ModNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) { ...@@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<Mod>(a, b); Expr const_res = TryConstFold<ModNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
PVar<Integer> c1; PVar<Integer> c1;
...@@ -919,7 +919,7 @@ VisitExpr_(const Mod* op) { ...@@ -919,7 +919,7 @@ VisitExpr_(const Mod* op) {
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (temp.as<IntImm>()) { if (temp.as<IntImmNode>()) {
return truncmod(temp, c1.Eval()); return truncmod(temp, c1.Eval());
} else { } else {
// If temp < cval && temp >=0 then can remove the mod. // If temp < cval && temp >=0 then can remove the mod.
...@@ -958,12 +958,12 @@ VisitExpr_(const Mod* op) { ...@@ -958,12 +958,12 @@ VisitExpr_(const Mod* op) {
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Mod::make(a, b); return ModNode::make(a, b);
} }
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const FloorMod* op) { VisitExpr_(const FloorModNode* op) {
if (!IsIndexType(op->dtype)) { if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op); return Rewriter::VisitExpr_(op);
} }
...@@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) { ...@@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) {
Expr b = this->CanonicalMutate(op->b); Expr b = this->CanonicalMutate(op->b);
// const folding // const folding
Expr const_res = TryConstFold<FloorMod>(a, b); Expr const_res = TryConstFold<FloorModNode>(a, b);
if (const_res.defined()) return const_res; if (const_res.defined()) return const_res;
PVar<Integer> c1; PVar<Integer> c1;
...@@ -983,7 +983,7 @@ VisitExpr_(const FloorMod* op) { ...@@ -983,7 +983,7 @@ VisitExpr_(const FloorMod* op) {
SumExpr lhs, extra; SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra); SeparateDivisibleParts(psum, cval, &lhs, &extra);
Expr temp = Normalize(extra); Expr temp = Normalize(extra);
if (temp.as<IntImm>()) { if (temp.as<IntImmNode>()) {
return floormod(temp, c1.Eval()); return floormod(temp, c1.Eval());
} else { } else {
// If temp < cval && temp >=0 then can remove the mod. // If temp < cval && temp >=0 then can remove the mod.
...@@ -1018,13 +1018,13 @@ VisitExpr_(const FloorMod* op) { ...@@ -1018,13 +1018,13 @@ VisitExpr_(const FloorMod* op) {
if (op->a.same_as(a) && op->b.same_as(b)) { if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FloorMod::make(a, b); return FloorModNode::make(a, b);
} }
} }
// Simplify reduce expression. // Simplify reduce expression.
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const Reduce* op) { SimplifyReduceCombiner(const ReduceNode* op) {
// First simplify the results // First simplify the results
Array<Expr> simplified_result; Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) { for (const auto& res : op->combiner->result) {
...@@ -1089,15 +1089,15 @@ SimplifyReduceCombiner(const Reduce* op) { ...@@ -1089,15 +1089,15 @@ SimplifyReduceCombiner(const Reduce* op) {
CommReducer new_combiner = CommReducer new_combiner =
CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
return Reduce::make( return ReduceNode::make(
new_combiner, new_source, op->axis, op->condition, new_value_index); new_combiner, new_source, op->axis, op->condition, new_value_index);
} }
Expr CanonicalSimplifier::Impl:: Expr CanonicalSimplifier::Impl::
VisitExpr_(const Reduce* op) { VisitExpr_(const ReduceNode* op) {
// Recursively call simplification when necessary. // Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::VisitExpr_(op); Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<Reduce>(); op = ret.as<ReduceNode>();
// already been simplified by const reduction axis removal // already been simplified by const reduction axis removal
if (op == nullptr) return ret; if (op == nullptr) return ret;
if (op->axis.empty()) { if (op->axis.empty()) {
...@@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) { ...@@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) {
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify. // instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return this->VisitExpr( return this->VisitExpr(
Select::make(op->condition, SelectNode::make(op->condition,
op->source[op->value_index], op->source[op->value_index],
op->combiner->identity_element[op->value_index])); op->combiner->identity_element[op->value_index]));
} }
......
...@@ -77,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) { ...@@ -77,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
} }
template<> template<>
inline Expr Compute<ir::Add>(Expr a, Expr b) { inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
return a + b; return a + b;
} }
template<> template<>
inline Expr Compute<ir::Sub>(Expr a, Expr b) { inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
return a - b; return a - b;
} }
template<> template<>
inline Expr Compute<ir::Mul>(Expr a, Expr b) { inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
return a * b; return a * b;
} }
template<> template<>
inline Expr Compute<ir::Div>(Expr a, Expr b) { inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
return truncdiv(a, b); return truncdiv(a, b);
} }
template<> template<>
inline Expr Compute<ir::Mod>(Expr a, Expr b) { inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
return truncmod(a, b); return truncmod(a, b);
} }
template<> template<>
inline Expr Compute<ir::Max>(Expr a, Expr b) { inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
return max(a, b); return max(a, b);
} }
template<> template<>
inline Expr Compute<ir::Min>(Expr a, Expr b) { inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
return min(a, b); return min(a, b);
} }
......
...@@ -140,17 +140,17 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -140,17 +140,17 @@ class ConstIntBoundAnalyzer::Impl :
return res; return res;
} }
Entry VisitExpr_(const Cast* op) final { Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value); Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype); Entry b = Everything(op->dtype);
return Intersect(a, b); return Intersect(a, b);
} }
Entry VisitExpr_(const IntImm* op) final { Entry VisitExpr_(const IntImmNode* op) final {
return MakeBound(op->value, op->value); return MakeBound(op->value, op->value);
} }
Entry VisitExpr_(const UIntImm* op) final { Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value <= static_cast<uint64_t>(kPosInf)) { if (op->value <= static_cast<uint64_t>(kPosInf)) {
return MakeBound(op->value, op->value); return MakeBound(op->value, op->value);
} else { } else {
...@@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const Add* op) final { Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; Entry ret;
...@@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl :
return ret; return ret;
} }
Entry VisitExpr_(const Sub* op) final { Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; Entry ret;
...@@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl :
return ret; return ret;
} }
Entry VisitExpr_(const Mul* op) final { Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
return BinaryOpBoundry(a, b, InfAwareMul); return BinaryOpBoundry(a, b, InfAwareMul);
} }
Entry VisitExpr_(const Div* op) final { Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "divide by zero"; CHECK(!b.is_const(0)) << "divide by zero";
...@@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl :
return BinaryOpBoundry(a, b, InfAwareDiv); return BinaryOpBoundry(a, b, InfAwareDiv);
} }
Entry VisitExpr_(const Mod* op) final { Entry VisitExpr_(const ModNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
if (b.min_value > 0) { if (b.min_value > 0) {
...@@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const FloorDiv* op) final { Entry VisitExpr_(const FloorDivNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero"; CHECK(!b.is_const(0)) << "floordiv by zero";
...@@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl :
return BinaryOpBoundry(a, b, InfAwareFloorDiv); return BinaryOpBoundry(a, b, InfAwareFloorDiv);
} }
Entry VisitExpr_(const FloorMod* op) final { Entry VisitExpr_(const FloorModNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
if (b.min_value > 0) { if (b.min_value > 0) {
...@@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const Min* op) final { Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; Entry ret;
...@@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl :
return ret; return ret;
} }
Entry VisitExpr_(const Max* op) final { Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
Entry ret; Entry ret;
...@@ -264,25 +264,25 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -264,25 +264,25 @@ class ConstIntBoundAnalyzer::Impl :
return ret; return ret;
} }
Entry VisitExpr_(const Select* op) final { Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value); Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value); Entry b = VisitExpr(op->false_value);
return Union(a, b); return Union(a, b);
} }
Entry VisitExpr_(const Call* op) final { Entry VisitExpr_(const CallNode* op) final {
// only special handle >> and & which can be // only special handle >> and & which can be
// used for index calculation. // used for index calculation.
if (op->is_intrinsic(Call::shift_right)) { if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op); return VisitRightShift(op);
} else if (op->is_intrinsic(Call::bitwise_and)) { } else if (op->is_intrinsic(CallNode::bitwise_and)) {
return VisitBitwiseAnd(op); return VisitBitwiseAnd(op);
} else { } else {
return Everything(op->dtype); return Everything(op->dtype);
} }
} }
Entry VisitExpr_(const Variable* op) final { Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op); Var v = GetRef<Var>(op);
auto it = var_map_.find(v); auto it = var_map_.find(v);
if (it != var_map_.end()) { if (it != var_map_.end()) {
...@@ -292,13 +292,13 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -292,13 +292,13 @@ class ConstIntBoundAnalyzer::Impl :
} }
} }
Entry VisitRightShift(const Call* op) { Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]); Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]); Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundry(a, b, InfAwareRightShift); return BinaryOpBoundry(a, b, InfAwareRightShift);
} }
Entry VisitBitwiseAnd(const Call* op) { Entry VisitBitwiseAnd(const CallNode* op) {
Entry a = VisitExpr(op->args[0]); Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]); Entry b = VisitExpr(op->args[1]);
// handle positive index case. // handle positive index case.
...@@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl :
return kNegInf; return kNegInf;
} }
if (y == kPosInf || y == kNegInf) return y; if (y == kPosInf || y == kNegInf) return y;
if (WillOverflow<Add>(x, y, kNegInf, kPosInf)) { if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) {
if (x > 0) return kPosInf; if (x > 0) return kPosInf;
return kNegInf; return kNegInf;
} }
...@@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl :
* \return the result. * \return the result.
*/ */
static int64_t InfAwareMul(int64_t x, int64_t y) { static int64_t InfAwareMul(int64_t x, int64_t y) {
if (!WillOverflow<Mul>(x, y, kNegInf, kPosInf)) return x * y; if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y;
if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf;
return kNegInf; return kNegInf;
} }
......
...@@ -60,7 +60,7 @@ class LinearEqDetector ...@@ -60,7 +60,7 @@ class LinearEqDetector
return true; return true;
} }
LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final { LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry(); if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b); LinearEqEntry b = VisitExpr(op->b, op->b);
...@@ -70,7 +70,7 @@ class LinearEqDetector ...@@ -70,7 +70,7 @@ class LinearEqDetector
return ret; return ret;
} }
LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final { LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry(); if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b); LinearEqEntry b = VisitExpr(op->b, op->b);
...@@ -80,7 +80,7 @@ class LinearEqDetector ...@@ -80,7 +80,7 @@ class LinearEqDetector
return ret; return ret;
} }
LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final { LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry(); if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b); LinearEqEntry b = VisitExpr(op->b, op->b);
...@@ -96,7 +96,7 @@ class LinearEqDetector ...@@ -96,7 +96,7 @@ class LinearEqDetector
ret.coeff = MulCombine(a.base, b.coeff); ret.coeff = MulCombine(a.base, b.coeff);
return ret; return ret;
} }
LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final { LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
LinearEqEntry ret; LinearEqEntry ret;
if (op == var_.get()) { if (op == var_.get()) {
ret.coeff = make_const(op->dtype, 1); ret.coeff = make_const(op->dtype, 1);
...@@ -152,7 +152,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { ...@@ -152,7 +152,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
base = std::move(ret.base); base = std::move(ret.base);
} }
std::unordered_set<const Variable*> vset; std::unordered_set<const VarNode*> vset;
for (size_t i = vars.size(); i > 1; --i) { for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get()); vset.insert(vars[i - 1].get());
// The previous coeff contains the variable // The previous coeff contains the variable
...@@ -167,11 +167,11 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { ...@@ -167,11 +167,11 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
// Detect clip condition as min max value // Detect clip condition as min max value
bool DetectClipBound( bool DetectClipBound(
const Expr& cond, const Expr& cond,
std::unordered_map<const Variable*, IntervalEntry>* bmap) { std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
int flag = 0; int flag = 0;
Var var; Var var;
auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
if (const Variable* v = n.as<Variable>()) { if (const VarNode* v = n.as<VarNode>()) {
if (bmap->count(v)) { if (bmap->count(v)) {
if (flag == 0) { if (flag == 0) {
var = Downcast<Var>(n); var = Downcast<Var>(n);
...@@ -188,16 +188,16 @@ bool DetectClipBound( ...@@ -188,16 +188,16 @@ bool DetectClipBound(
if (flag != 1) return false; if (flag != 1) return false;
// canonical form: exp >= 0 // canonical form: exp >= 0
Expr canonical; Expr canonical;
if (const LT* op = cond.as<LT>()) { if (const LTNode* op = cond.as<LTNode>()) {
if (!op->a.dtype().is_int()) return false; if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.dtype(), 1); canonical = op->b - op->a - make_const(op->a.dtype(), 1);
} else if (const LE* op = cond.as<LE>()) { } else if (const LENode* op = cond.as<LENode>()) {
if (!op->a.dtype().is_int()) return false; if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a; canonical = op->b - op->a;
} else if (const GT* op = cond.as<GT>()) { } else if (const GTNode* op = cond.as<GTNode>()) {
if (!op->a.dtype().is_int()) return false; if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b - make_const(op->a.dtype(), 1); canonical = op->a - op->b - make_const(op->a.dtype(), 1);
} else if (const GE* op = cond.as<GE>()) { } else if (const GENode* op = cond.as<GENode>()) {
if (!op->a.dtype().is_int()) return false; if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b; canonical = op->a - op->b;
} else { } else {
...@@ -210,7 +210,7 @@ bool DetectClipBound( ...@@ -210,7 +210,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, 1)) { if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift // var + shift >=0 -> var >= -shift
if (p.min_value.defined()) { if (p.min_value.defined()) {
p.min_value = ir::Max::make(p.min_value, -ret.base); p.min_value = ir::MaxNode::make(p.min_value, -ret.base);
} else { } else {
p.min_value = -ret.base; p.min_value = -ret.base;
} }
...@@ -219,7 +219,7 @@ bool DetectClipBound( ...@@ -219,7 +219,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, -1)) { if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift // -var + shift >=0 -> var <= shift
if (p.max_value.defined()) { if (p.max_value.defined()) {
p.max_value = ir::Min::make(p.max_value, ret.base); p.max_value = ir::MinNode::make(p.max_value, ret.base);
} else { } else {
p.max_value = ret.base; p.max_value = ret.base;
} }
...@@ -243,8 +243,8 @@ void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) { ...@@ -243,8 +243,8 @@ void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
// e must be connected by and. // e must be connected by and.
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) { Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
std::vector<Expr> splits; std::vector<Expr> splits;
SplitCommExpr<ir::And>(e, &splits); SplitCommExpr<ir::AndNode>(e, &splits);
std::unordered_map<const Variable*, IntervalEntry> rmap; std::unordered_map<const VarNode*, IntervalEntry> rmap;
for (Var v : vars) { for (Var v : vars) {
rmap[v.get()] = IntervalEntry(); rmap[v.get()] = IntervalEntry();
} }
......
...@@ -53,15 +53,15 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -53,15 +53,15 @@ class FuncTouchedDomain final : public StmtExprVisitor {
return ret; return ret;
} }
void VisitStmt_(const For *op) final { void VisitStmt_(const ForNode *op) final {
const Variable* var = op->loop_var.get(); const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::range( dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var); dom_map_.erase(var);
} }
void VisitStmt_(const LetStmt* op) final { void VisitStmt_(const LetStmtNode* op) final {
dom_map_[op->var.get()] = dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_); arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
...@@ -69,11 +69,11 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -69,11 +69,11 @@ class FuncTouchedDomain final : public StmtExprVisitor {
} }
/* TODO: Thread extent unitest not generated.*/ /* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmt* op) final { void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>(); const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis); CHECK(thread_axis);
const Variable* var = thread_axis->var.get(); const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var); dom_map_.erase(var);
...@@ -82,7 +82,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -82,7 +82,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
} }
} }
void VisitExpr_(const Call* op) final { void VisitExpr_(const CallNode* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func) if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) { && tensor_->value_index == op->value_index) {
Touch(op->args); Touch(op->args);
...@@ -90,7 +90,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -90,7 +90,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const Provide* op) final { void VisitStmt_(const ProvideNode* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func) if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) { && tensor_->value_index == op->value_index) {
Touch(op->args); Touch(op->args);
...@@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
const Tensor &tensor_; const Tensor &tensor_;
bool consider_calls_, consider_provides_; bool consider_calls_, consider_provides_;
std::vector<std::vector<IntSet> > bounds_; std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const Variable*, IntSet> dom_map_; std::unordered_map<const VarNode*, IntSet> dom_map_;
}; };
Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) { Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
......
...@@ -47,30 +47,30 @@ inline bool WillOverflow(int64_t x, ...@@ -47,30 +47,30 @@ inline bool WillOverflow(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::Add>(int64_t x, inline bool WillOverflow<ir::AddNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true; if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true;
return false; return false;
} }
template<> template<>
inline bool WillOverflow<ir::Sub>(int64_t x, inline bool WillOverflow<ir::SubNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true; if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true;
return false; return false;
} }
template<> template<>
inline bool WillOverflow<ir::Mul>(int64_t x, inline bool WillOverflow<ir::MulNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
if (y == 0) return false; if (y == 0) return false;
if (y > 0) { if (y > 0) {
if (x < min_value / y) return true; if (x < min_value / y) return true;
...@@ -84,10 +84,10 @@ inline bool WillOverflow<ir::Mul>(int64_t x, ...@@ -84,10 +84,10 @@ inline bool WillOverflow<ir::Mul>(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::Mod>(int64_t x, inline bool WillOverflow<ir::ModNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
return y == 0; return y == 0;
} }
......
...@@ -30,14 +30,14 @@ namespace arith { ...@@ -30,14 +30,14 @@ namespace arith {
using namespace ir; using namespace ir;
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const For* op) { VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var, analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const LetStmt* op) { VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
...@@ -57,7 +57,7 @@ VisitStmt_(const LetStmt* op) { ...@@ -57,7 +57,7 @@ VisitStmt_(const LetStmt* op) {
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const IfThenElse* op) { VisitStmt_(const IfThenElseNode* op) {
Expr condition = this->VisitExpr(op->condition); Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case; Stmt then_case, else_case;
{ {
...@@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) { ...@@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) {
} }
if (op->else_case.defined()) { if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_, With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not::make(condition))); analyzer_->rewrite_simplify(NotNode::make(condition)));
else_case = this->VisitStmt(op->else_case); else_case = this->VisitStmt(op->else_case);
} }
if (is_one(condition)) return then_case; if (is_one(condition)) return then_case;
...@@ -74,7 +74,7 @@ VisitStmt_(const IfThenElse* op) { ...@@ -74,7 +74,7 @@ VisitStmt_(const IfThenElse* op) {
if (else_case.defined()) { if (else_case.defined()) {
return else_case; return else_case;
} }
return Evaluate::make(0); return EvaluateNode::make(0);
} }
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
...@@ -91,7 +91,7 @@ VisitStmt_(const IfThenElse* op) { ...@@ -91,7 +91,7 @@ VisitStmt_(const IfThenElse* op) {
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AttrStmt* op) { VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -106,7 +106,7 @@ VisitStmt_(const AttrStmt* op) { ...@@ -106,7 +106,7 @@ VisitStmt_(const AttrStmt* op) {
} }
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AssertStmt* op) { VisitStmt_(const AssertStmtNode* op) {
Expr condition = this->VisitExpr(op->condition); Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message); Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition); With<ConstraintContext> ctx(analyzer_, condition);
...@@ -126,7 +126,7 @@ VisitStmt_(const AssertStmt* op) { ...@@ -126,7 +126,7 @@ VisitStmt_(const AssertStmt* op) {
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
VisitExpr_(const Call* op) { VisitExpr_(const CallNode* op) {
// add condition context to if_then_else // add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = this->VisitExpr(op->args[0]); Expr cond = this->VisitExpr(op->args[0]);
...@@ -137,7 +137,7 @@ VisitExpr_(const Call* op) { ...@@ -137,7 +137,7 @@ VisitExpr_(const Call* op) {
} }
{ {
With<ConstraintContext> constraint(analyzer_, With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond))); analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = this->VisitExpr(op->args[2]); false_value = this->VisitExpr(op->args[2]);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
...@@ -151,7 +151,7 @@ VisitExpr_(const Call* op) { ...@@ -151,7 +151,7 @@ VisitExpr_(const Call* op) {
false_value.same_as(op->args[2])) { false_value.same_as(op->args[2])) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Call::make(op->dtype, op->name, return CallNode::make(op->dtype, op->name,
{cond, true_value, false_value}, {cond, true_value, false_value},
op->call_type); op->call_type);
} }
...@@ -160,7 +160,7 @@ VisitExpr_(const Call* op) { ...@@ -160,7 +160,7 @@ VisitExpr_(const Call* op) {
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
VisitExpr_(const Let* op) { VisitExpr_(const LetNode* op) {
Expr value = this->VisitExpr(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
...@@ -172,12 +172,12 @@ VisitExpr_(const Let* op) { ...@@ -172,12 +172,12 @@ VisitExpr_(const Let* op) {
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Let::make(op->var, value, body); return LetNode::make(op->var, value, body);
} }
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
VisitExpr_(const Select* op) { VisitExpr_(const SelectNode* op) {
Expr cond = this->VisitExpr(op->condition); Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value; Expr true_value, false_value;
{ {
...@@ -186,7 +186,7 @@ VisitExpr_(const Select* op) { ...@@ -186,7 +186,7 @@ VisitExpr_(const Select* op) {
} }
{ {
With<ConstraintContext> constraint(analyzer_, With<ConstraintContext> constraint(analyzer_,
analyzer_->rewrite_simplify(Not::make(cond))); analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = VisitExpr(op->false_value); false_value = VisitExpr(op->false_value);
} }
if (is_zero(cond)) { if (is_zero(cond)) {
...@@ -201,12 +201,12 @@ VisitExpr_(const Select* op) { ...@@ -201,12 +201,12 @@ VisitExpr_(const Select* op) {
false_value.same_as(op->false_value)) { false_value.same_as(op->false_value)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return Select::make(cond, true_value, false_value); return SelectNode::make(cond, true_value, false_value);
} }
} }
Expr IRMutatorWithAnalyzer:: Expr IRMutatorWithAnalyzer::
VisitExpr_(const Reduce* op) { VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification. // Setup the domain information before simplification.
for (const IterVar& iv : op->axis) { for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom); analyzer_->Bind(iv->var, iv->dom);
......
...@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { ...@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
using StmtExprMutator::VisitExpr_; using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information. // override functions that need to populate the context information.
Stmt VisitStmt_(const ir::For* op) override; Stmt VisitStmt_(const ir::ForNode* op) override;
Stmt VisitStmt_(const ir::LetStmt* op) override; Stmt VisitStmt_(const ir::LetStmtNode* op) override;
Stmt VisitStmt_(const ir::IfThenElse* op) override; Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
Stmt VisitStmt_(const ir::AttrStmt* op) override; Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
Stmt VisitStmt_(const ir::AssertStmt* op) override; Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
Expr VisitExpr_(const ir::Let* op) override; Expr VisitExpr_(const ir::LetNode* op) override;
Expr VisitExpr_(const ir::Select* op) override; Expr VisitExpr_(const ir::SelectNode* op) override;
Expr VisitExpr_(const ir::Call* op) override; Expr VisitExpr_(const ir::CallNode* op) override;
Expr VisitExpr_(const ir::Reduce* op) override; Expr VisitExpr_(const ir::ReduceNode* op) override;
protected: protected:
/*! \brief internal analyzer field. */ /*! \brief internal analyzer field. */
......
...@@ -38,13 +38,13 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { ...@@ -38,13 +38,13 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
return analyzer_.Simplify(expr); return analyzer_.Simplify(expr);
} }
void VisitStmt_(const For* op) { void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var, analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent)); Range::make_by_min_extent(op->min, op->extent));
return StmtExprVisitor::VisitStmt_(op); return StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmt* op) { void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -57,7 +57,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { ...@@ -57,7 +57,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
} }
} }
void VisitExpr_(const Reduce* op) { void VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification. // Setup the domain information before simplification.
for (const IterVar& iv : op->axis) { for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom); analyzer_.Bind(iv->var, iv->dom);
......
...@@ -124,15 +124,15 @@ class ModularSetAnalyzer::Impl : ...@@ -124,15 +124,15 @@ class ModularSetAnalyzer::Impl :
return Everything(); return Everything();
} }
Entry VisitExpr_(const Cast* op) final { Entry VisitExpr_(const CastNode* op) final {
return VisitExpr(op->value); return VisitExpr(op->value);
} }
Entry VisitExpr_(const IntImm* op) final { Entry VisitExpr_(const IntImmNode* op) final {
return Entry(0, op->value); return Entry(0, op->value);
} }
Entry VisitExpr_(const UIntImm* op) final { Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) { if (op->value < std::numeric_limits<int64_t>::max()) {
return Entry(0, static_cast<int>(op->value)); return Entry(0, static_cast<int>(op->value));
} else { } else {
...@@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl : ...@@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl :
} }
} }
Entry VisitExpr_(const Add* op) final { Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base); return Entry(coeff, a.base + b.base);
} }
Entry VisitExpr_(const Sub* op) final { Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base); return Entry(coeff, a.base - b.base);
} }
Entry VisitExpr_(const Mul* op) final { Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
// Simplification rule, x, y, z are in Z // Simplification rule, x, y, z are in Z
...@@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl : ...@@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl :
return Everything(); return Everything();
} }
Entry VisitExpr_(const Div* op) final { Entry VisitExpr_(const DivNode* op) final {
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
if (b.is_const()) { if (b.is_const()) {
return DivByConst(op->a, b.base, false); return DivByConst(op->a, b.base, false);
...@@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl : ...@@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl :
return Everything(); return Everything();
} }
Entry VisitExpr_(const FloorDiv* op) final { Entry VisitExpr_(const FloorDivNode* op) final {
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
if (b.is_const()) { if (b.is_const()) {
return DivByConst(op->a, b.base, true); return DivByConst(op->a, b.base, true);
...@@ -204,35 +204,35 @@ class ModularSetAnalyzer::Impl : ...@@ -204,35 +204,35 @@ class ModularSetAnalyzer::Impl :
return Everything(); return Everything();
} }
Entry VisitExpr_(const Min* op) final { Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
return Union(a, b); return Union(a, b);
} }
Entry VisitExpr_(const Max* op) final { Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a); Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b); Entry b = VisitExpr(op->b);
return Union(a, b); return Union(a, b);
} }
Entry VisitExpr_(const Select* op) final { Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value); Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value); Entry b = VisitExpr(op->false_value);
return Union(a, b); return Union(a, b);
} }
Entry VisitExpr_(const Call* op) final { Entry VisitExpr_(const CallNode* op) final {
// only special handle >> which can be // only special handle >> which can be
// used for index calculation. // used for index calculation.
if (op->is_intrinsic(Call::shift_right)) { if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op); return VisitRightShift(op);
} else { } else {
return Everything(); return Everything();
} }
} }
Entry VisitExpr_(const Variable* op) final { Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op); Var v = GetRef<Var>(op);
auto it = var_map_.find(v); auto it = var_map_.find(v);
if (it != var_map_.end()) { if (it != var_map_.end()) {
...@@ -242,7 +242,7 @@ class ModularSetAnalyzer::Impl : ...@@ -242,7 +242,7 @@ class ModularSetAnalyzer::Impl :
} }
} }
Entry VisitRightShift(const Call* op) { Entry VisitRightShift(const CallNode* op) {
Entry b = VisitExpr(op->args[1]); Entry b = VisitExpr(op->args[1]);
// a c x / c -> a x // a c x / c -> a x
if (b.is_const()) { if (b.is_const()) {
......
...@@ -50,29 +50,29 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { ...@@ -50,29 +50,29 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
: IRMutatorWithAnalyzer(parent) {} : IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override_info); void Update(const Var& var, const Expr& info, bool override_info);
Expr VisitExpr_(const Add* op) override; Expr VisitExpr_(const AddNode* op) override;
Expr VisitExpr_(const Sub* op) override; Expr VisitExpr_(const SubNode* op) override;
Expr VisitExpr_(const Mul* op) override; Expr VisitExpr_(const MulNode* op) override;
Expr VisitExpr_(const Div* op) override; Expr VisitExpr_(const DivNode* op) override;
Expr VisitExpr_(const Mod* op) override; Expr VisitExpr_(const ModNode* op) override;
Expr VisitExpr_(const FloorDiv* op) override; Expr VisitExpr_(const FloorDivNode* op) override;
Expr VisitExpr_(const FloorMod* op) override; Expr VisitExpr_(const FloorModNode* op) override;
Expr VisitExpr_(const Min* op) override; Expr VisitExpr_(const MinNode* op) override;
Expr VisitExpr_(const Max* op) override; Expr VisitExpr_(const MaxNode* op) override;
Expr VisitExpr_(const EQ* op) override; Expr VisitExpr_(const EQNode* op) override;
Expr VisitExpr_(const NE* op) override; Expr VisitExpr_(const NENode* op) override;
Expr VisitExpr_(const LT* op) override; Expr VisitExpr_(const LTNode* op) override;
Expr VisitExpr_(const LE* op) override; Expr VisitExpr_(const LENode* op) override;
Expr VisitExpr_(const GT* op) override; Expr VisitExpr_(const GTNode* op) override;
Expr VisitExpr_(const GE* op) override; Expr VisitExpr_(const GENode* op) override;
Expr VisitExpr_(const And* op) override; Expr VisitExpr_(const AndNode* op) override;
Expr VisitExpr_(const Or* op) override; Expr VisitExpr_(const OrNode* op) override;
Expr VisitExpr_(const Not* op) override; Expr VisitExpr_(const NotNode* op) override;
Expr VisitExpr_(const Select* op) override; Expr VisitExpr_(const SelectNode* op) override;
Expr VisitExpr_(const Call* op) override; Expr VisitExpr_(const CallNode* op) override;
Expr VisitExpr_(const Variable* op) override; Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const Cast* op) override; Expr VisitExpr_(const CastNode* op) override;
Expr VisitExpr_(const Let* op) override; Expr VisitExpr_(const LetNode* op) override;
std::function<void()> EnterConstraint(const Expr& constraint); std::function<void()> EnterConstraint(const Expr& constraint);
......
...@@ -50,14 +50,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -50,14 +50,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return operator()(std::move(stmt)); return operator()(std::move(stmt));
} }
Stmt VisitStmt_(const For* op) final { Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op); return Parent::VisitStmt_(op);
} }
Stmt VisitStmt_(const LetStmt* op) { Stmt VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value); Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding // it is fine to discard the let binding
...@@ -78,13 +78,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -78,13 +78,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} }
// eliminate useless stores // eliminate useless stores
Stmt VisitStmt_(const Store* op) final { Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = Parent::VisitStmt_(op); Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<Store>(); op = stmt.as<StoreNode>();
if (const Load* load = op->value.as<Load>()) { if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) && if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) { Equal(load->index, op->index)) {
return Evaluate::make(0); return EvaluateNode::make(0);
} }
} }
return GetRef<Stmt>(op); return GetRef<Stmt>(op);
......
...@@ -29,8 +29,8 @@ namespace tvm { ...@@ -29,8 +29,8 @@ namespace tvm {
namespace autotvm { namespace autotvm {
// for loop // for loop
void FeatureVisitor::VisitStmt_(const For* op) { void FeatureVisitor::VisitStmt_(const ForNode* op) {
const auto *extent = op->extent.as<IntImm>(); const auto *extent = op->extent.as<IntImmNode>();
int64_t loop_extent = -1; int64_t loop_extent = -1;
if (extent != nullptr) if (extent != nullptr)
loop_extent = extent->value; loop_extent = extent->value;
...@@ -57,11 +57,11 @@ void FeatureVisitor::VisitStmt_(const For* op) { ...@@ -57,11 +57,11 @@ void FeatureVisitor::VisitStmt_(const For* op) {
} }
// parallel axis, virtual thread // parallel axis, virtual thread
void FeatureVisitor::VisitStmt_(const AttrStmt* op) { void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var; VarExpr var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImm>(); const auto *extent = op->value.as<IntImmNode>();
CHECK(extent); CHECK(extent);
std::string name = var.get()->name_hint; std::string name = var.get()->name_hint;
...@@ -95,13 +95,13 @@ void FeatureVisitor::VisitStmt_(const AttrStmt* op) { ...@@ -95,13 +95,13 @@ void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
} }
// memory access // memory access
void FeatureVisitor::VisitExpr_(const Load* op) { void FeatureVisitor::VisitExpr_(const LoadNode* op) {
EnterMem_(op->buffer_var, op->index); EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
ExitMem_(); ExitMem_();
} }
void FeatureVisitor::VisitStmt_(const Store* op) { void FeatureVisitor::VisitStmt_(const StoreNode* op) {
EnterMem_(op->buffer_var, op->index); EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
ExitMem_(); ExitMem_();
......
...@@ -51,12 +51,12 @@ enum AnnotationType { ...@@ -51,12 +51,12 @@ enum AnnotationType {
class FeatureVisitor : public StmtExprVisitor { class FeatureVisitor : public StmtExprVisitor {
public: public:
// for loop // for loop
void VisitStmt_(const For* op) final; void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const AttrStmt* op) final; void VisitStmt_(const AttrStmtNode* op) final;
// memory access // memory access
void VisitExpr_(const Load* op) final; void VisitExpr_(const LoadNode* op) final;
void VisitStmt_(const Store* op) final; void VisitStmt_(const StoreNode* op) final;
using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitExpr_;
......
...@@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor { ...@@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor {
this->VisitExpr(expr); this->VisitExpr(expr);
} }
void VisitExpr_(const Variable* op) final { void VisitExpr_(const VarNode* op) final {
// TODO(lmzheng): handle more index types (multiple occurrence) // TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) { if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern(); pattern_map[op] = TouchPattern();
...@@ -60,16 +60,16 @@ class IndexParser: public ExprVisitor { ...@@ -60,16 +60,16 @@ class IndexParser: public ExprVisitor {
} }
} }
void VisitExpr_(const Mul* op) final { void VisitExpr_(const MulNode* op) final {
if (op->a.as<Variable>()) { if (op->a.as<VarNode>()) {
if (const auto stride = op->b.as<IntImm>()) { if (const auto stride = op->b.as<IntImmNode>()) {
next_stride_ = stride->value; next_stride_ = stride->value;
} }
} }
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
std::unordered_map<const Variable*, TouchPattern> pattern_map; std::unordered_map<const VarNode*, TouchPattern> pattern_map;
private: private:
int64_t next_stride_ = 1; int64_t next_stride_ = 1;
...@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re ...@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var}); feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
Array<Expr> attr{std::string("_attr_"), Array<Expr> attr{std::string("_attr_"),
FloatImm::make(DataType::Float(32), trans(fea.length)), FloatImmNode::make(DataType::Float(32), trans(fea.length)),
IntImm::make(DataType::Int(32), fea.nest_level), IntImmNode::make(DataType::Int(32), fea.nest_level),
FloatImm::make(DataType::Float(32), trans(fea.topdown_product)), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)), FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
}; };
// one hot annotation // one hot annotation
for (int i = 0; i < kNum; i++) { for (int i = 0; i < kNum; i++) {
...@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re ...@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
// arithmetic // arithmetic
feature_row.push_back(Array<Expr>{std::string("_arith_"), feature_row.push_back(Array<Expr>{std::string("_arith_"),
FloatImm::make(DataType::Float(32), trans(fea.add_ct)), FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
FloatImm::make(DataType::Float(32), trans(fea.mul_ct)), FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
FloatImm::make(DataType::Float(32), trans(fea.div_ct)), FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
}); });
// touch map // touch map
...@@ -281,14 +281,15 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re ...@@ -281,14 +281,15 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
std::sort(bufs.begin(), bufs.end()); std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) { for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k]; TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(Array<Expr>{k, feature_row.push_back(
FloatImm::make(DataType::Float(32), trans(v.stride)), Array<Expr>{k,
FloatImm::make(DataType::Float(32), trans(v.mod)), FloatImmNode::make(DataType::Float(32), trans(v.stride)),
FloatImm::make(DataType::Float(32), trans(v.count)), FloatImmNode::make(DataType::Float(32), trans(v.mod)),
FloatImm::make(DataType::Float(32), trans(v.reuse)), FloatImmNode::make(DataType::Float(32), trans(v.count)),
FloatImm::make(DataType::Float(32), trans(v.thread_count)), FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
FloatImm::make(DataType::Float(32), trans(v.thread_reuse)), FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
}); FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
});
} }
ret_feature->push_back(feature_row); ret_feature->push_back(feature_row);
......
...@@ -92,31 +92,31 @@ class TouchExtractor : public FeatureVisitor { ...@@ -92,31 +92,31 @@ class TouchExtractor : public FeatureVisitor {
} }
// arithmetic stats // arithmetic stats
void VisitExpr_(const Add* op) final { void VisitExpr_(const AddNode* op) final {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++; itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op); FeatureVisitor::VisitExpr_(op);
} }
void VisitExpr_(const Sub* op) final { void VisitExpr_(const SubNode* op) final {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++; itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op); FeatureVisitor::VisitExpr_(op);
} }
void VisitExpr_(const Mul* op) final { void VisitExpr_(const MulNode* op) final {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++; itervar_map[itervar_stack_.back()].mul_ct++;
FeatureVisitor::VisitExpr_(op); FeatureVisitor::VisitExpr_(op);
} }
void VisitExpr_(const Div* op) final { void VisitExpr_(const DivNode* op) final {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++; itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op); FeatureVisitor::VisitExpr_(op);
} }
void VisitExpr_(const Mod* op) final { void VisitExpr_(const ModNode* op) final {
if (op->dtype.is_float()) if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++; itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op); FeatureVisitor::VisitExpr_(op);
......
...@@ -65,39 +65,39 @@ Target CreateTarget(const std::string& target_name, ...@@ -65,39 +65,39 @@ Target CreateTarget(const std::string& target_name,
std::string device_flag = "-device="; std::string device_flag = "-device=";
std::string keys_flag = "-keys="; std::string keys_flag = "-keys=";
for (auto& item : options) { for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item)); t->options_array.push_back(ir::StringImmNode::make(item));
if (item.find(libs_flag) == 0) { if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length())); std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item; std::string lib_item;
while (std::getline(ss, lib_item, ',')) { while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(ir::StringImm::make(lib_item)); t->libs_array.push_back(ir::StringImmNode::make(lib_item));
} }
} else if (item.find(device_flag) == 0) { } else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length()); t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name)); t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} else if (item.find(keys_flag) == 0) { } else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length())); std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item; std::string key_item;
while (std::getline(ss, key_item, ',')) { while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item)); t->keys_array.push_back(ir::StringImmNode::make(key_item));
} }
} }
} }
if (t->device_name.length() > 0) { if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(t->device_name)); t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} }
t->device_type = kDLCPU; t->device_type = kDLCPU;
t->thread_warp_size = 1; t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") { if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev; t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") { } else if (target_name == "c" || target_name == "llvm") {
t->keys_array.push_back(ir::StringImm::make("cpu")); t->keys_array.push_back(ir::StringImmNode::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") { } else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU; t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda")); t->keys_array.push_back(ir::StringImmNode::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 1024; t->max_num_threads = 1024;
t->thread_warp_size = 32; t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") { } else if (target_name == "rocm" || target_name == "opencl") {
...@@ -107,8 +107,8 @@ Target CreateTarget(const std::string& target_name, ...@@ -107,8 +107,8 @@ Target CreateTarget(const std::string& target_name,
} else { } else {
t->device_type = kDLROCM; t->device_type = kDLROCM;
} }
t->keys_array.push_back(ir::StringImm::make(target_name)); t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
if (t->device_name == "intel_graphics") { if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16; t->thread_warp_size = 16;
...@@ -119,20 +119,20 @@ Target CreateTarget(const std::string& target_name, ...@@ -119,20 +119,20 @@ Target CreateTarget(const std::string& target_name,
} else { } else {
t->device_type = kDLVulkan; t->device_type = kDLVulkan;
} }
t->keys_array.push_back(ir::StringImm::make(target_name)); t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
} else if (target_name == "sdaccel") { } else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL; t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImm::make("sdaccel")); t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
t->keys_array.push_back(ir::StringImm::make("hls")); t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") { } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL; t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImm::make("aocl")); t->keys_array.push_back(ir::StringImmNode::make("aocl"));
t->keys_array.push_back(ir::StringImm::make("hls")); t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "opengl") { } else if (target_name == "opengl") {
t->device_type = kOpenGL; t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl")); t->keys_array.push_back(ir::StringImmNode::make("opengl"));
} else if (target_name == "stackvm") { } else if (target_name == "stackvm") {
t->device_type = kDLCPU; t->device_type = kDLCPU;
} else if (target_name == "ext_dev") { } else if (target_name == "ext_dev") {
...@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("_TargetFromString") ...@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("_TargetFromString")
std::vector<std::string> TargetNode::keys() const { std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result; std::vector<std::string> result;
for (auto& expr : keys_array) { for (auto& expr : keys_array) {
result.push_back(expr.as<ir::StringImm>()->value); result.push_back(expr.as<ir::StringImmNode>()->value);
} }
return result; return result;
} }
...@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const { ...@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> TargetNode::options() const { std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result; std::vector<std::string> result;
for (auto& expr : options_array) { for (auto& expr : options_array) {
result.push_back(expr.as<ir::StringImm>()->value); result.push_back(expr.as<ir::StringImmNode>()->value);
} }
return result; return result;
} }
...@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const { ...@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
std::unordered_set<std::string> TargetNode::libs() const { std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
for (auto& expr : libs_array) { for (auto& expr : libs_array) {
result.insert(expr.as<ir::StringImm>()->value); result.insert(expr.as<ir::StringImmNode>()->value);
} }
return result; return result;
} }
...@@ -348,7 +348,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape, ...@@ -348,7 +348,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
bool has_any = false; bool has_any = false;
if (!compact) { if (!compact) {
for (const auto& it : shape) { for (const auto& it : shape) {
if (it.as<Variable>()) { if (it.as<VarNode>()) {
has_any = true; has_any = true;
break; break;
} }
...@@ -860,7 +860,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") ...@@ -860,7 +860,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
std::vector<std::string> tags_vector; std::vector<std::string> tags_vector;
for (auto& tag : tags) { for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::ir::StringImm>()->value); tags_vector.push_back(tag.as<tvm::ir::StringImmNode>()->value);
} }
generic_func generic_func
......
...@@ -102,46 +102,46 @@ class CodeGenC : ...@@ -102,46 +102,46 @@ class CodeGenC :
*/ */
virtual void InitFuncState(LoweredFunc f); virtual void InitFuncState(LoweredFunc f);
// expression // expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment // statment
void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const Store* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumerNode* op) override;
/*! /*!
* Print Type represetnation of type t. * Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
...@@ -154,15 +154,15 @@ class CodeGenC : ...@@ -154,15 +154,15 @@ class CodeGenC :
*/ */
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const Call* op); // NOLINT(*) virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
// Binary vector op. // Binary vector op.
virtual void PrintVecBinaryOp( virtual void PrintVecBinaryOp(
const std::string&op, DataType op_type, const std::string&op, DataType op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load // print vector load
virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base); virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
// print vector store // print vector store
virtual void PrintVecStore(const Variable* buffer, virtual void PrintVecStore(const VarNode* buffer,
DataType t, Expr base, DataType t, Expr base,
const std::string& value); // NOLINT(*) const std::string& value); // NOLINT(*)
// print load of single element // print load of single element
...@@ -180,28 +180,28 @@ class CodeGenC : ...@@ -180,28 +180,28 @@ class CodeGenC :
DataType t, const Expr& buffer, const Expr& index, int kind); DataType t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
virtual std::string GetBufferRef( virtual std::string GetBufferRef(
DataType t, const Variable* buffer, Expr index); DataType t, const VarNode* buffer, Expr index);
/*! /*!
* \brief If buffer is allocated as type t. * \brief If buffer is allocated as type t.
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
* \param t The type to be checked. * \param t The type to be checked.
*/ */
bool HandleTypeMatch(const Variable* buf_var, DataType t) const; bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
/*! /*!
* \brief Register the data type of buf_var * \brief Register the data type of buf_var
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
* \param t The type to be checked. * \param t The type to be checked.
*/ */
void RegisterHandleType(const Variable* buf_var, DataType t); void RegisterHandleType(const VarNode* buf_var, DataType t);
// override // override
void PrintSSAAssign( void PrintSSAAssign(
const std::string& target, const std::string& src, DataType t) final; const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */ /*! \brief restrict keyword */
std::string restrict_keyword_{""}; std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */ /*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, DataType> handle_data_type_; std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */ /*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique(); void ReserveKeywordsAsUnique();
...@@ -209,7 +209,7 @@ class CodeGenC : ...@@ -209,7 +209,7 @@ class CodeGenC :
/*! \brief whether to print in SSA form */ /*! \brief whether to print in SSA form */
bool print_ssa_form_{false}; bool print_ssa_form_{false};
/*! \brief set of volatile buf access */ /*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_; std::unordered_set<const VarNode*> volatile_buf_;
}; };
} // namespace codegen } // namespace codegen
......
...@@ -142,7 +142,7 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -142,7 +142,7 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to C type"; LOG(FATAL) << "Cannot convert type " << t << " to C type";
} }
void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << "(("; os << "((";
PrintType(op->dtype, os); PrintType(op->dtype, os);
...@@ -194,11 +194,11 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar ...@@ -194,11 +194,11 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
this->stream << "}\n"; this->stream << "}\n";
} }
void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
std::string stack_name = GetUniqueName("stack"); std::string stack_name = GetUniqueName("stack");
const std::string& type = op->args[0].as<StringImm>()->value; const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImm* num = op->args[1].as<IntImm>(); const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr); CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant"); static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
size_t unit = sizeof(TVMValue); size_t unit = sizeof(TVMValue);
...@@ -218,10 +218,10 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -218,10 +218,10 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
os << stack_name; os << stack_name;
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
const StringImm* s = op->args[0].as<StringImm>(); const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImm>()->value; int64_t begin = op->args[3].as<IntImmNode>()->value;
int64_t end = op->args[4].as<IntImm>()->value; int64_t end = op->args[4].as<IntImmNode>()->value;
int64_t num_args = end - begin; int64_t num_args = end - begin;
CHECK_GE(num_args, 0); CHECK_GE(num_args, 0);
std::string func_name = s->value; std::string func_name = s->value;
...@@ -237,14 +237,14 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -237,14 +237,14 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
} }
} }
void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*) void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
if (emit_asserts_) { if (emit_asserts_) {
std::string cond = PrintExpr(op->condition); std::string cond = PrintExpr(op->condition);
PrintIndent(); PrintIndent();
stream << "if (!(" << cond << ")) {\n"; stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope(); int assert_if_scope = this->BeginScope();
PrintIndent(); PrintIndent();
stream << "TVMAPISetLastError(\"" << op->message.as<StringImm>()->value << "\");\n"; stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
PrintIndent(); PrintIndent();
stream << "return -1;\n"; stream << "return -1;\n";
this->EndScope(assert_if_scope); this->EndScope(assert_if_scope);
...@@ -254,11 +254,11 @@ void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*) ...@@ -254,11 +254,11 @@ void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenCHost::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os); PrintTernaryCondExpr(op, "<", os);
} }
void CodeGenCHost::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os); PrintTernaryCondExpr(op, ">", os);
} }
......
...@@ -42,14 +42,14 @@ class CodeGenCHost final : public CodeGenC { ...@@ -42,14 +42,14 @@ class CodeGenCHost final : public CodeGenC {
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// overload visitor functions // overload visitor functions
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call *op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the // overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations // standard library implementations
void VisitExpr_(const Min *op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Max *op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const AssertStmt *op) final; // NOLINT(*) void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
private: private:
std::string module_name_; std::string module_name_;
......
...@@ -93,7 +93,7 @@ std::string CodeGenCUDA::Finish() { ...@@ -93,7 +93,7 @@ std::string CodeGenCUDA::Finish() {
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
void CodeGenCUDA::VisitStmt_(const ir::For* op) { void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) {
CHECK(is_const_int(op->min, 0)); CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) { if (op->for_type == ir::ForType::Unrolled) {
PrintIndent(); PrintIndent();
...@@ -265,8 +265,8 @@ void CodeGenCUDA::PrintVecElemStore( ...@@ -265,8 +265,8 @@ void CodeGenCUDA::PrintVecElemStore(
} }
} }
void CodeGenCUDA::PrintStorageSync(const Call* op) { void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImm>()->value; const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
// DO nothing. // DO nothing.
} else if (sync == "shared") { } else if (sync == "shared") {
...@@ -314,7 +314,7 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -314,7 +314,7 @@ void CodeGenCUDA::PrintStorageScope(
} }
} }
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true; need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U); CHECK_EQ(op->args.size(), 6U);
...@@ -348,7 +348,7 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { ...@@ -348,7 +348,7 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
this->PrintExpr(op->args[4], os); this->PrintExpr(op->args[4], os);
os << "], "; os << "], ";
this->PrintExpr(op->args[6], os); this->PrintExpr(op->args[6], os);
if (const StringImm *str = op->args[7].as<StringImm>()) { if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value; os << ", nvcuda::wmma::mem_" << str->value;
} else { } else {
LOG(FATAL) << "Invalid parameters"; LOG(FATAL) << "Invalid parameters";
...@@ -369,20 +369,20 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { ...@@ -369,20 +369,20 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
} }
} }
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::fragment_shape) { if (op->attr_key == attr::fragment_shape) {
const Variable* buffer = op->node.as<Variable>(); const VarNode* buffer = op->node.as<VarNode>();
const StringImm* shape_str = op->value.as<StringImm>(); const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value; fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) { } else if (op->attr_key == attr::fragment_layout) {
const Variable* buffer = op->node.as<Variable>(); const VarNode* buffer = op->node.as<VarNode>();
const StringImm* layout_str = op->value.as<StringImm>(); const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value; fragment_layouts[buffer] = layout_str->value;
} }
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenCUDA::VisitStmt_(const Allocate* op) { void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
...@@ -397,7 +397,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { ...@@ -397,7 +397,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now"; << "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>(); const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer); std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) { if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
...@@ -425,9 +425,9 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { ...@@ -425,9 +425,9 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenCUDA::VisitStmt_(const Evaluate *op) { void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) {
if (is_const(op->value)) return; if (is_const(op->value)) return;
const Call* call = op->value.as<Call>(); const CallNode* call = op->value.as<CallNode>();
if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
PrintIndent(); PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
...@@ -442,7 +442,7 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) { ...@@ -442,7 +442,7 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
} }
} }
void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
os << "((make_int" << op->lanes << ")("; os << "((make_int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) { for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
...@@ -452,7 +452,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) { ...@@ -452,7 +452,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
os << "))"; os << "))";
} }
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) { if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4 // make_int8x4
const int64_t *p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
...@@ -474,7 +474,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN ...@@ -474,7 +474,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
os << ')'; os << ')';
} }
void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) { void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
std::vector<std::string> to_shuffle(op->vectors.size()); std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) { for (int i = 0, e = op->vectors.size(); i < e; ++i) {
CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
...@@ -492,7 +492,7 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) { ...@@ -492,7 +492,7 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
os << ')'; os << ')';
} }
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
case 64: case 32: { case 64: case 32: {
std::ostringstream temp; std::ostringstream temp;
...@@ -523,12 +523,12 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { / ...@@ -523,12 +523,12 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
} }
void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this); PrintConst(op, os, this);
} }
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
const Variable* variable, std::ostream &os) { const VarNode* variable, std::ostream &os) {
std::stringstream type; std::stringstream type;
PrintType(t, type); PrintType(t, type);
std::string shape_str = fragment_shapes[variable]; std::string shape_str = fragment_shapes[variable];
...@@ -550,7 +550,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, ...@@ -550,7 +550,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
} }
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
const Variable* variable, int32_t size) { const VarNode* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable]; std::string shape_str = fragment_shapes[variable];
size_t m, n, k; size_t m, n, k;
size_t last_pos = 0, pos = 0; size_t last_pos = 0, pos = 0;
......
...@@ -43,8 +43,8 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -43,8 +43,8 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
} }
// override behavior // override behavior
void VisitStmt_(const ir::For* op) final; void VisitStmt_(const ir::ForNode* op) final;
void PrintStorageSync(const Call* op) final; void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp( void PrintVecBinaryOp(
const std::string&op, DataType t, const std::string&op, DataType t,
...@@ -56,14 +56,14 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -56,14 +56,14 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, DataType t, int i, const std::string& value) final; const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
void VisitExpr_(const Call *op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream& os) final;
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const Allocate *op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmt *op) final; void VisitStmt_(const AttrStmtNode *op) final;
private: private:
// Whether global barrier is needed. // Whether global barrier is needed.
...@@ -81,13 +81,13 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -81,13 +81,13 @@ class CodeGenCUDA final : public CodeGenC {
// whether need mma.h // whether need mma.h
bool need_mma_h_{false}; bool need_mma_h_{false};
std::unordered_map<const Variable*, std::string> fragment_shapes; std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const Variable*, std::string> fragment_layouts; std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope( void PrintWmmaScope(
const std::string& scope, DataType t, const Variable* variable, std::ostream& os); const std::string& scope, DataType t, const VarNode* variable, std::ostream& os);
int32_t GetWmmaFragmentSize( int32_t GetWmmaFragmentSize(
const std::string &scope, const Variable* variable, int32_t size); const std::string &scope, const VarNode* variable, int32_t size);
}; };
} // namespace codegen } // namespace codegen
......
...@@ -196,8 +196,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -196,8 +196,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
} }
void CodeGenMetal::PrintStorageSync(const Call* op) { void CodeGenMetal::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImm>()->value; const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
this->PrintIndent(); this->PrintIndent();
this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n"; this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
...@@ -234,7 +234,7 @@ void CodeGenMetal::PrintStorageScope( ...@@ -234,7 +234,7 @@ void CodeGenMetal::PrintStorageScope(
} }
} }
void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << "("; os << "(";
...@@ -245,8 +245,8 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI ...@@ -245,8 +245,8 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI
os << ')'; os << ')';
} }
void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::reinterpret)) { if (op->is_intrinsic(CallNode::reinterpret)) {
// generate as_type<TYPE>(ARG) // generate as_type<TYPE>(ARG)
os << "(as_type<"; os << "(as_type<";
this->PrintType(op->dtype, os); this->PrintType(op->dtype, os);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -40,7 +40,7 @@ class CodeGenMetal final : public CodeGenC { ...@@ -40,7 +40,7 @@ class CodeGenMetal final : public CodeGenC {
void PrintArgUnionDecl(); void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element // print load of single element
...@@ -50,10 +50,10 @@ class CodeGenMetal final : public CodeGenC { ...@@ -50,10 +50,10 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore( void PrintVecElemStore(
const std::string& vec, DataType t, int i, const std::string& value) final; const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
private: private:
int thread_index_bits_{32}; int thread_index_bits_{32};
......
...@@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
} }
void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os) { // NOLINT(*) Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) { if (!HandleTypeMatch(buffer, t.element_of())) {
os << '('; os << '(';
...@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, ...@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
PrintExpr(base, os); PrintExpr(base, os);
} }
std::string CodeGenOpenCL::GetVecLoad( std::string CodeGenOpenCL::GetVecLoad(
DataType t, const Variable* buffer, Expr base) { DataType t, const VarNode* buffer, Expr base) {
std::ostringstream os; std::ostringstream os;
os << "vload" << t.lanes() << "(0, "; os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os); PrintVecAddr(buffer, t, base, os);
...@@ -168,7 +168,7 @@ std::string CodeGenOpenCL::GetVecLoad( ...@@ -168,7 +168,7 @@ std::string CodeGenOpenCL::GetVecLoad(
return os.str(); return os.str();
} }
void CodeGenOpenCL::PrintVecStore(const Variable* buffer, void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
DataType t, Expr base, DataType t, Expr base,
const std::string& value) { const std::string& value) {
this->PrintIndent(); this->PrintIndent();
...@@ -177,8 +177,8 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer, ...@@ -177,8 +177,8 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
stream << ");\n"; stream << ");\n";
} }
void CodeGenOpenCL::PrintStorageSync(const Call* op) { void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImm>()->value; const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
this->PrintIndent(); this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
...@@ -215,7 +215,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType ...@@ -215,7 +215,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType
return os.str(); return os.str();
} }
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << "(("; os << "((";
PrintType(op->dtype, os); PrintType(op->dtype, os);
...@@ -227,7 +227,7 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL ...@@ -227,7 +227,7 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
os << "))"; os << "))";
} }
void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions, /* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */ * add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
...@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions, /* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */ * add a cast */
os << "("; os << "(";
...@@ -247,7 +247,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT( ...@@ -247,7 +247,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) { if (std::isinf(op->value)) {
if (op->value < 0) { if (op->value < 0) {
os << "-"; os << "-";
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -42,23 +42,23 @@ class CodeGenOpenCL final : public CodeGenC { ...@@ -42,23 +42,23 @@ class CodeGenOpenCL final : public CodeGenC {
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(DataType t, const Variable* buffer, std::string GetVecLoad(DataType t, const VarNode* buffer,
Expr base) final; Expr base) final;
void PrintVecStore(const Variable* buffer, void PrintVecStore(const VarNode* buffer,
DataType t, Expr base, DataType t, Expr base,
const std::string& value) final; // NOLINT(*) const std::string& value) final; // NOLINT(*)
// the address of load/store // the address of load/store
void PrintVecAddr(const Variable* buffer, DataType t, void PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os); // NOLINT(*) Expr base, std::ostream& os); // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
private: private:
// whether enable fp16 and fp64 extension // whether enable fp16 and fp64 extension
......
...@@ -188,13 +188,13 @@ void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) { ...@@ -188,13 +188,13 @@ void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
this->stream << "}\n"; this->stream << "}\n";
} }
void CodeGenOpenGL::VisitStmt_(const Store* op) { void CodeGenOpenGL::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Store statement not supported in OpenGL." LOG(FATAL) << "Store statement not supported in OpenGL."
<< " Texture store should be a Call statement."; << " Texture store should be a Call statement.";
} }
// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
std::ostringstream os; std::ostringstream os;
os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int("; os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
PrintExpr(index, os); PrintExpr(index, os);
...@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { ...@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
// Print a reference expression to a buffer. // Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r // Format: texelFetch(buffer, index, 0).r
std::string CodeGenOpenGL::GetBufferRef( std::string CodeGenOpenGL::GetBufferRef(
DataType t, const Variable* buffer, Expr index) { DataType t, const VarNode* buffer, Expr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
...@@ -242,34 +242,34 @@ void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) { ...@@ -242,34 +242,34 @@ void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) {
// Codegen for immediate values // Codegen for immediate values
void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) { void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints."; CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) { void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) { void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) { void CodeGenOpenGL::VisitExpr_(const StringImmNode*, std::ostream& os) {
LOG(FATAL) << "GLSL 3.0 doesn't support strings."; LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
} }
void CodeGenOpenGL::VisitStmt_(const Evaluate* op) { void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
auto call = op->value.as<Call>(); auto call = op->value.as<CallNode>();
if (call == nullptr || call->name != Call::glsl_texture_store) { if (call == nullptr || call->name != CallNode::glsl_texture_store) {
// Fallback to normal logic. // Fallback to normal logic.
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
CHECK_EQ(call->args.size(), 2); CHECK_EQ(call->args.size(), 2);
auto buffer = call->args[0].as<Variable>(); auto buffer = call->args[0].as<VarNode>();
auto value = call->args[1]; auto value = call->args[1];
// Doesn't support store to vector. // Doesn't support store to vector.
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -43,24 +43,24 @@ class CodeGenOpenGL final : public CodeGenC { ...@@ -43,24 +43,24 @@ class CodeGenOpenGL final : public CodeGenC {
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const Store* op) final; void VisitStmt_(const StoreNode* op) final;
std::string TexelFetch(const Variable* buffer, Expr index); std::string TexelFetch(const VarNode* buffer, Expr index);
std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final; std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values // Codegen for immediate values
void VisitExpr_(const IntImm* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*)
// Match glsl_texture_store Call. // Match glsl_texture_store Call.
void VisitStmt_(const Evaluate* op) final; // NOLINT(*) void VisitStmt_(const EvaluateNode* op) final; // NOLINT(*)
private: private:
const Variable* output_{nullptr}; const VarNode* output_{nullptr};
std::unordered_set<const Variable*> inputs_; std::unordered_set<const VarNode*> inputs_;
const Variable* output_iter_var_{nullptr}; const VarNode* output_iter_var_{nullptr};
std::unordered_map<std::string, runtime::OpenGLShader> shaders_; std::unordered_map<std::string, runtime::OpenGLShader> shaders_;
std::string thread_extent_var_; std::string thread_extent_var_;
}; };
......
...@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { ...@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
return e.vid; return e.vid;
} }
std::string CodeGenSourceBase::AllocVarID(const Variable* v) { std::string CodeGenSourceBase::AllocVarID(const VarNode* v) {
CHECK(!var_idmap_.count(v)) CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint; << "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint; std::string key = v->name_hint;
...@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const Variable* v) { ...@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
return vid; return vid;
} }
std::string CodeGenSourceBase::GetVarID(const Variable* v) const { std::string CodeGenSourceBase::GetVarID(const VarNode* v) const {
auto it = var_idmap_.find(v); auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end()) CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint; << "Find undefined Variable " << v->name_hint;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -66,13 +66,13 @@ class CodeGenSourceBase { ...@@ -66,13 +66,13 @@ class CodeGenSourceBase {
* \param v The variable. * \param v The variable.
* \return the variable name. * \return the variable name.
*/ */
std::string AllocVarID(const Variable* v); std::string AllocVarID(const VarNode* v);
/*! /*!
* \brief Get a variable name. * \brief Get a variable name.
* \param v The variable. * \param v The variable.
* \return the variable name. * \return the variable name.
*/ */
std::string GetVarID(const Variable* v) const; std::string GetVarID(const VarNode* v) const;
/*! /*!
* \brief Get the SSA ID corresponds to src * \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment * If necessary, generate new assignment
...@@ -110,7 +110,7 @@ class CodeGenSourceBase { ...@@ -110,7 +110,7 @@ class CodeGenSourceBase {
/*! \brief the stream to be printed */ /*! \brief the stream to be printed */
std::ostringstream stream; std::ostringstream stream;
/*! \brief name of each variable */ /*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_; std::unordered_map<const VarNode*, std::string> var_idmap_;
private: private:
/*! \brief assignment map of ssa */ /*! \brief assignment map of ssa */
......
...@@ -98,7 +98,7 @@ inline void PrintBinaryExpr(const T* op, ...@@ -98,7 +98,7 @@ inline void PrintBinaryExpr(const T* op,
os << ')'; os << ')';
} }
void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min"; const char *opstr = "std::min";
if (op->dtype.is_float()) { if (op->dtype.is_float()) {
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
...@@ -112,7 +112,7 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT( ...@@ -112,7 +112,7 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(
PrintBinaryExpr(op, opstr, os, this); PrintBinaryExpr(op, opstr, os, this);
} }
void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max"; const char *opstr = "std::max";
if (op->dtype.is_float()) { if (op->dtype.is_float()) {
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -38,8 +38,8 @@ class CodeGenVivadoHLS final : public CodeGenC { ...@@ -38,8 +38,8 @@ class CodeGenVivadoHLS final : public CodeGenC {
void PrintType(DataType t, std::ostream& os); void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f); void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const Min *op, std::ostream& os); void VisitExpr_(const MinNode *op, std::ostream& os);
void VisitExpr_(const Max *op, std::ostream& os); void VisitExpr_(const MaxNode *op, std::ostream& os);
}; };
} // namespace codegen } // namespace codegen
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") ...@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){ .set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0]; Expr e = args[0];
const Call* call = e.as<Call>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1); auto one = make_const(call->args[0].dtype(), 1);
...@@ -67,7 +67,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") ...@@ -67,7 +67,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){ .set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0]; Expr e = args[0];
const Call* call = e.as<Call>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1); auto one = make_const(call->args[0].dtype(), 1);
......
...@@ -61,12 +61,12 @@ struct Direct { ...@@ -61,12 +61,12 @@ struct Direct {
template<typename T> template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0]; Expr e = args[0];
const Call* call = e.as<Call>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
std::string name = T()(call->dtype, call->name); std::string name = T()(call->dtype, call->name);
if (name.length() != 0) { if (name.length() != 0) {
*rv = Call::make( *rv = CallNode::make(
call->dtype, name, call->args, Call::PureExtern); call->dtype, name, call->args, CallNode::PureExtern);
} else { } else {
*rv = e; *rv = e;
} }
......
...@@ -70,7 +70,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -70,7 +70,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
function_->addFnAttr("amdgpu-flat-work-group-size", attr.str()); function_->addFnAttr("amdgpu-flat-work-group-size", attr.str());
} }
void VisitStmt_(const Allocate* op) final { void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
...@@ -153,8 +153,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -153,8 +153,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
return builder_->CreateCall(f, {}); return builder_->CreateCall(f, {});
} }
llvm::Value* CreateStorageSync(const Call* op) final { llvm::Value* CreateStorageSync(const CallNode* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value; const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
return nullptr; return nullptr;
} else if (sync == "shared") { } else if (sync == "shared") {
...@@ -234,7 +234,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -234,7 +234,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
Array<Expr> bitcode_files = (*find_rocm_bitcodes)(); Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
for (auto &bitcode : bitcode_files) { for (auto &bitcode : bitcode_files) {
std::string path = bitcode.as<StringImm>()->value; std::string path = bitcode.as<StringImmNode>()->value;
llvm::SMDiagnostic err; llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx); std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) { if (mlib.get() == nullptr) {
......
...@@ -39,25 +39,25 @@ class CodeGenARM final : public CodeGenCPU { ...@@ -39,25 +39,25 @@ class CodeGenARM final : public CodeGenCPU {
native_vector_bits_ = 16 * 8; native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm); CodeGenCPU::InitTarget(tm);
} }
llvm::Value* CreateIntrinsic(const Call* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override;
private: private:
Expr ARMPopcount(const Call* op); Expr ARMPopcount(const CallNode* op);
}; };
llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) { if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>( llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value); op->args[0].as<UIntImmNode>()->value);
if (id == ::llvm::Intrinsic::ctpop) { if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op); Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>()); return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
} }
} }
return CodeGenCPU::CreateIntrinsic(op); return CodeGenCPU::CreateIntrinsic(op);
} }
Expr CodeGenARM::ARMPopcount(const Call *call) { Expr CodeGenARM::ARMPopcount(const CallNode *call) {
using namespace ir; using namespace ir;
const Expr& e = call->args[2]; const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
...@@ -68,10 +68,10 @@ Expr CodeGenARM::ARMPopcount(const Call *call) { ...@@ -68,10 +68,10 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
if (!call->dtype.is_vector() || call->dtype.bits() == 8 || if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) { (total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args; Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt_args.push_back(e); vcnt_args.push_back(e);
return ir::Call::make(call->dtype, "llvm_intrin", vcnt_args, Call::PureIntrinsic); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
} }
// Popcount lowering rule: // Popcount lowering rule:
...@@ -90,40 +90,44 @@ Expr CodeGenARM::ARMPopcount(const Call *call) { ...@@ -90,40 +90,44 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
// Interpret input as vector of 8bit values // Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e); Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit // Popcount 8bit->8bit
const Call* c0 = input8.as<Call>(); const CallNode* c0 = input8.as<CallNode>();
CHECK(c0 != nullptr); CHECK(c0 != nullptr);
Array<Expr> vcnt8_args; Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt8_args.push_back(input8); vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic); Expr vcnt8 = ir::CallNode::make(
uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit // Accumulation 8->16bit
Array<Expr> vcnt16_args; Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8); vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic); Expr vcnt16 = ir::CallNode::make(
uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) { if (call->dtype.bits() == 16) {
return vcnt16; return vcnt16;
} }
// Accumulation 16->32bit // Accumulation 16->32bit
Array<Expr> vcnt32_args; Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16); vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic); Expr vcnt32 = ir::CallNode::make(
uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) { if (call->dtype.bits() == 32) {
return vcnt32; return vcnt32;
} }
// Accumulation 32->64bit // Accumulation 32->64bit
Array<Expr> vcnt64_args; Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1)); vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32); vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->dtype, "llvm_intrin", vcnt64_args, Call::PureIntrinsic); return ir::CallNode::make(
call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
......
...@@ -319,7 +319,7 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( ...@@ -319,7 +319,7 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(
} }
} }
llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) { llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
std::vector<llvm::Value*> arg_values(op->args.size()); std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]); arg_values[i] = MakeValue(op->args[i]);
...@@ -417,7 +417,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { ...@@ -417,7 +417,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
return end_block; return end_block;
} }
void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
// There are two reasons why we create another function for compute_scope // There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined) // - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs. // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
...@@ -436,12 +436,12 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { ...@@ -436,12 +436,12 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
llvm::Function* fcompute = llvm::Function* fcompute =
llvm::Function::Create(ftype, llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
op->value.as<StringImm>()->value, op->value.as<StringImmNode>()->value,
module_.get()); module_.get());
BasicBlock* compute_call_end = CheckCallSuccess( BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values)); builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction. // setup compute fuinction.
std::unordered_map<const Variable*, llvm::Value*> new_vmap; std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
size_t idx = 0; size_t idx = 0;
for (auto it = fcompute->arg_begin(); for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) { it != fcompute->arg_end(); ++it, ++idx) {
...@@ -497,7 +497,7 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* nu ...@@ -497,7 +497,7 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* nu
void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields, const Array<Var>& vfields,
std::unordered_map<const Variable*, llvm::Value*>* vmap) { std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) { for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] = (*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP( builder_->CreateLoad(builder_->CreateInBoundsGEP(
...@@ -528,7 +528,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { ...@@ -528,7 +528,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
llvm::Value* penv = &(*it++); llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context. // setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap; std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap); UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env // setup parallel env
ParallelEnv par_env; ParallelEnv par_env;
...@@ -594,7 +594,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod ...@@ -594,7 +594,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
auto it = f->arg_begin(); auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context. // setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap; std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap); UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr); CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f); std::swap(function_, f);
...@@ -673,7 +673,7 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue, ...@@ -673,7 +673,7 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type, llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end) { const int64_t begin, const int64_t end) {
using llvm::BasicBlock; using llvm::BasicBlock;
std::string func_name = args[0].as<StringImm>()->value; std::string func_name = args[0].as<StringImmNode>()->value;
llvm::Value *handle = GetPackedFuncHandle(func_name); llvm::Value *handle = GetPackedFuncHandle(func_name);
// call the function // call the function
int64_t nargs = end - begin; int64_t nargs = end - begin;
...@@ -701,24 +701,24 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue, ...@@ -701,24 +701,24 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
return end_block; return end_block;
} }
llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) { llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) {
CHECK_EQ(op->args.size(), 5U); CHECK_EQ(op->args.size(), 5U);
llvm::Value *rvalue = nullptr; llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr; llvm::Value *ret_tcode = nullptr;
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
op->args[3].as<IntImm>()->value, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImm>()->value); op->args[4].as<IntImmNode>()->value);
return rvalue; return rvalue;
} }
llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) { llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
using llvm::BasicBlock; using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 6U); CHECK_EQ(op->args.size(), 6U);
llvm::Value *rvalue = nullptr; llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr; llvm::Value *ret_tcode = nullptr;
BasicBlock *end_block = MakeCallPacked( BasicBlock *end_block = MakeCallPacked(
op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImm>()->value, op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImm>()->value); op->args[4].as<IntImmNode>()->value);
// Get traced value. // Get traced value.
llvm::Value *traced_value = MakeValue(op->args[5]); llvm::Value *traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value. // The update_block handles case when we need to update the return value.
...@@ -786,7 +786,7 @@ void CodeGenCPU::AddStartupFunction() { ...@@ -786,7 +786,7 @@ void CodeGenCPU::AddStartupFunction() {
} }
} }
llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op); return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) { } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) {
...@@ -798,7 +798,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { ...@@ -798,7 +798,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
return ConstInt32(-1); return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value; int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref = this->CreateStructRefPtr( llvm::Value* ref = this->CreateStructRefPtr(
op->dtype, MakeValue(op->args[0]), op->dtype, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind); MakeValue(op->args[1]), kind);
...@@ -809,7 +809,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { ...@@ -809,7 +809,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
} }
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) { } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U); CHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImm>()->value; int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]); llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr( llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].dtype(), MakeValue(op->args[0]), op->args[3].dtype(), MakeValue(op->args[0]),
...@@ -823,7 +823,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { ...@@ -823,7 +823,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
return ConstInt32(0); return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U); CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value; const std::string& type = op->args[0].as<StringImmNode>()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* { return WithFunctionEntry([&]() -> llvm::AllocaInst* {
const int64_t* pval = as_const_int(op->args[1]); const int64_t* pval = as_const_int(op->args[1]);
CHECK(pval) << "require stack alloca to contain constant value"; CHECK(pval) << "require stack alloca to contain constant value";
...@@ -846,13 +846,13 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { ...@@ -846,13 +846,13 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
} }
} }
void CodeGenCPU::VisitStmt_(const AssertStmt* op) { void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
using llvm::BasicBlock; using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition); llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os; std::ostringstream os;
os << "Assert fail: " << op->condition; os << "Assert fail: " << op->condition;
if (op->message.as<StringImm>()) { if (op->message.as<StringImmNode>()) {
os << ", " << op->message.as<StringImm>()->value; os << ", " << op->message.as<StringImmNode>()->value;
} }
llvm::Value* msg = GetConstString(os.str()); llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create( BasicBlock* fail_block = BasicBlock::Create(
...@@ -869,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmt* op) { ...@@ -869,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
CodeGenLLVM::VisitStmt_(op); CodeGenLLVM::VisitStmt_(op);
} }
void CodeGenCPU::VisitStmt_(const AttrStmt* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::coproc_uop_scope) { if (op->attr_key == ir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImm>()->value, op->body); this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) { } else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op); this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) { } else if (attr::IsPragmaKey(op->attr_key)) {
...@@ -893,7 +893,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) { ...@@ -893,7 +893,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
RuntimeTVMParallelBarrier(), RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv}); {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == ir::attr::pragma_import_llvm) { } else if (op->attr_key == ir::attr::pragma_import_llvm) {
const StringImm* value = op->value.as<StringImm>(); const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr); CHECK(value != nullptr);
this->HandleImport(value->value); this->HandleImport(value->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
...@@ -906,7 +906,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) { ...@@ -906,7 +906,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
} }
} }
void CodeGenCPU::VisitStmt_(const For* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial || if (op->for_type == ForType::Serial ||
op->for_type == ForType::Unrolled) { op->for_type == ForType::Unrolled) {
...@@ -914,7 +914,7 @@ void CodeGenCPU::VisitStmt_(const For* op) { ...@@ -914,7 +914,7 @@ void CodeGenCPU::VisitStmt_(const For* op) {
} else if (op->for_type == ForType::Parallel) { } else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) { if (parallel_env_.penv == nullptr) {
CreateParallelLaunch( CreateParallelLaunch(
For::make( ForNode::make(
op->loop_var, op->min, op->extent, op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0); op->for_type, op->device_api, op->body), 0);
} else { } else {
...@@ -936,8 +936,8 @@ void CodeGenCPU::VisitStmt_(const For* op) { ...@@ -936,8 +936,8 @@ void CodeGenCPU::VisitStmt_(const For* op) {
op->body); op->body);
} else { } else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task; Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
Expr begin = Min::make(task_id * step, op->extent); Expr begin = MinNode::make(task_id * step, op->extent);
Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent); Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin), CreateSerialFor(MakeValue(begin),
MakeValue(end), MakeValue(end),
ConstInt32(1), ConstInt32(1),
......
...@@ -45,11 +45,11 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -45,11 +45,11 @@ class CodeGenCPU : public CodeGenLLVM {
void AddFunction(const LoweredFunc& f) override; void AddFunction(const LoweredFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override; void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override; std::unique_ptr<llvm::Module> Finish() override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const ForNode* op) override;
llvm::Value* CreateIntrinsic(const Call* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override;
llvm::Value* CreateCallExtern(const Call* op) override; llvm::Value* CreateCallExtern(const CallNode* op) override;
protected: protected:
void AddStartupFunction() final; void AddStartupFunction() final;
...@@ -99,22 +99,22 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -99,22 +99,22 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata, void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields, const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap); std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call. // Make packed call.
llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args, llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
llvm::Value **rvalue, llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type, llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end); const int64_t begin, const int64_t end);
// create call into tvm packed function. // create call into tvm packed function.
llvm::Value* CreateCallPacked(const Call* op); llvm::Value* CreateCallPacked(const CallNode* op);
// Create trace call into tvm packed function. // Create trace call into tvm packed function.
llvm::Value* CreateCallTracePacked(const Call *op); llvm::Value* CreateCallTracePacked(const CallNode *op);
// Create static initialization // Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body); void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch // Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task); void CreateParallelLaunch(const Stmt& body, int num_task);
// Create a new compute scope. // Create a new compute scope.
void CreateComputeScope(const AttrStmt* op); void CreateComputeScope(const AttrStmtNode* op);
// Check if the call to packed function is successful // Check if the call to packed function is successful
// if not directly finalize function and pass on return code. // if not directly finalize function and pass on return code.
// return the end block after the check // return the end block after the check
......
...@@ -103,46 +103,46 @@ class CodeGenLLVM : ...@@ -103,46 +103,46 @@ class CodeGenLLVM :
return llvm::ConstantInt::getSigned(t_int32_, value); return llvm::ConstantInt::getSigned(t_int32_, value);
} }
// override codegen // override codegen
llvm::Value* VisitExpr_(const Variable* op) override; llvm::Value* VisitExpr_(const VarNode* op) override;
llvm::Value* VisitExpr_(const Cast* op) override; llvm::Value* VisitExpr_(const CastNode* op) override;
llvm::Value* VisitExpr_(const IntImm* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override;
llvm::Value* VisitExpr_(const UIntImm* op) override; llvm::Value* VisitExpr_(const UIntImmNode* op) override;
llvm::Value* VisitExpr_(const FloatImm* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override;
llvm::Value* VisitExpr_(const StringImm* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override;
llvm::Value* VisitExpr_(const Add* op) override; llvm::Value* VisitExpr_(const AddNode* op) override;
llvm::Value* VisitExpr_(const Sub* op) override; llvm::Value* VisitExpr_(const SubNode* op) override;
llvm::Value* VisitExpr_(const Mul* op) override; llvm::Value* VisitExpr_(const MulNode* op) override;
llvm::Value* VisitExpr_(const Div* op) override; llvm::Value* VisitExpr_(const DivNode* op) override;
llvm::Value* VisitExpr_(const Mod* op) override; llvm::Value* VisitExpr_(const ModNode* op) override;
llvm::Value* VisitExpr_(const Min* op) override; llvm::Value* VisitExpr_(const MinNode* op) override;
llvm::Value* VisitExpr_(const Max* op) override; llvm::Value* VisitExpr_(const MaxNode* op) override;
llvm::Value* VisitExpr_(const LT* op) override; llvm::Value* VisitExpr_(const LTNode* op) override;
llvm::Value* VisitExpr_(const LE* op) override; llvm::Value* VisitExpr_(const LENode* op) override;
llvm::Value* VisitExpr_(const GT* op) override; llvm::Value* VisitExpr_(const GTNode* op) override;
llvm::Value* VisitExpr_(const GE* op) override; llvm::Value* VisitExpr_(const GENode* op) override;
llvm::Value* VisitExpr_(const EQ* op) override; llvm::Value* VisitExpr_(const EQNode* op) override;
llvm::Value* VisitExpr_(const NE* op) override; llvm::Value* VisitExpr_(const NENode* op) override;
llvm::Value* VisitExpr_(const And* op) override; llvm::Value* VisitExpr_(const AndNode* op) override;
llvm::Value* VisitExpr_(const Or* op) override; llvm::Value* VisitExpr_(const OrNode* op) override;
llvm::Value* VisitExpr_(const Not* op) override; llvm::Value* VisitExpr_(const NotNode* op) override;
llvm::Value* VisitExpr_(const Select* op) override; llvm::Value* VisitExpr_(const SelectNode* op) override;
llvm::Value* VisitExpr_(const Let* op) override; llvm::Value* VisitExpr_(const LetNode* op) override;
llvm::Value* VisitExpr_(const Load* op) override; llvm::Value* VisitExpr_(const LoadNode* op) override;
llvm::Value* VisitExpr_(const Call* op) override; llvm::Value* VisitExpr_(const CallNode* op) override;
llvm::Value* VisitExpr_(const Ramp* op) override; llvm::Value* VisitExpr_(const RampNode* op) override;
llvm::Value* VisitExpr_(const Shuffle* op) override; llvm::Value* VisitExpr_(const ShuffleNode* op) override;
llvm::Value* VisitExpr_(const Broadcast* op) override; llvm::Value* VisitExpr_(const BroadcastNode* op) override;
// stmt // stmt
void VisitStmt_(const Store* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumerNode* op) override;
protected: protected:
/*! \brief The storage information */ /*! \brief The storage information */
...@@ -173,13 +173,13 @@ class CodeGenLLVM : ...@@ -173,13 +173,13 @@ class CodeGenLLVM :
return res; return res;
} }
// create intrinstic given call // create intrinstic given call
virtual llvm::Value* CreateIntrinsic(const Call* op); virtual llvm::Value* CreateIntrinsic(const CallNode* op);
// create extern function call // create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op); virtual llvm::Value* CreateCallExtern(const CallNode* op);
// Get the corresponding thread index // Get the corresponding thread index
virtual llvm::Value* GetThreadIndex(const IterVar& iv); virtual llvm::Value* GetThreadIndex(const IterVar& iv);
// Get the corresponding thread index // Get the corresponding thread index
virtual llvm::Value* CreateStorageSync(const Call* op); virtual llvm::Value* CreateStorageSync(const CallNode* op);
// apply optimization on the module. // apply optimization on the module.
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e. // Scalarize by iterating elements of e.
...@@ -211,19 +211,19 @@ class CodeGenLLVM : ...@@ -211,19 +211,19 @@ class CodeGenLLVM :
void InitFuncState(); void InitFuncState();
// Get alignment given index. // Get alignment given index.
void GetAlignment( void GetAlignment(
DataType t, const Variable* buf_var, const Expr& index, DataType t, const VarNode* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits); int* p_alignment, int* p_native_bits);
// Get constant string // Get constant string
llvm::Value* GetConstString(const std::string& str); llvm::Value* GetConstString(const std::string& str);
// do a scalarize call with f // do a scalarize call with f
llvm::Value* CreateScalarizedCall( llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// handle module import // handle module import
void HandleImport(const std::string& code); void HandleImport(const std::string& code);
// cast operatpr // cast operatpr
llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value); llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
// comparison op // comparison op
llvm::Value* GetVarValue(const Variable* v) const; llvm::Value* GetVarValue(const VarNode* v) const;
llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
...@@ -245,7 +245,7 @@ class CodeGenLLVM : ...@@ -245,7 +245,7 @@ class CodeGenLLVM :
llvm::Value* stride, llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body); const VarExpr& loop_var, const Stmt& body);
// add alias information. // add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type); void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
// The IRBuilder. // The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>; using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function // The current function
...@@ -280,9 +280,9 @@ class CodeGenLLVM : ...@@ -280,9 +280,9 @@ class CodeGenLLVM :
/*! \brief native vector bits of current targetx*/ /*! \brief native vector bits of current targetx*/
int native_vector_bits_{0}; int native_vector_bits_{0};
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_; std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
// The definition of local variable. // The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_; std::unordered_map<const VarNode*, llvm::Value*> var_map_;
// global strings // global strings
std::unordered_map<std::string, llvm::Constant*> str_map_; std::unordered_map<std::string, llvm::Constant*> str_map_;
// Whether current function is restricted // Whether current function is restricted
...@@ -290,9 +290,9 @@ class CodeGenLLVM : ...@@ -290,9 +290,9 @@ class CodeGenLLVM :
// The analyzer information // The analyzer information
std::unique_ptr<arith::Analyzer> analyzer_; std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias) // set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_; std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer. // set of volatile buffer.
std::unordered_set<const Variable*> volatile_buf_; std::unordered_set<const VarNode*> volatile_buf_;
/*! \brief Helper struct for debug infos. */ /*! \brief Helper struct for debug infos. */
struct DebugInfo { struct DebugInfo {
std::unique_ptr<llvm::DIBuilder> di_builder_; std::unique_ptr<llvm::DIBuilder> di_builder_;
......
...@@ -46,7 +46,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -46,7 +46,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
llvm::ValueAsMetadata::get(ConstInt32(1)) })); llvm::ValueAsMetadata::get(ConstInt32(1)) }));
} }
void VisitStmt_(const Allocate* op) final { void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr; llvm::Value* buf = nullptr;
if (op->new_expr.defined()) { if (op->new_expr.defined()) {
...@@ -129,8 +129,8 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -129,8 +129,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
return builder_->CreateCall(f, {}); return builder_->CreateCall(f, {});
} }
llvm::Value* CreateStorageSync(const Call* op) final { llvm::Value* CreateStorageSync(const CallNode* op) final {
const std::string& sync = op->args[0].as<StringImm>()->value; const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9 // TODO(tqchen) warp sync in CUDA9
return nullptr; return nullptr;
......
...@@ -65,14 +65,14 @@ bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) ...@@ -65,14 +65,14 @@ bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature)
class CodeGenX86_64 final : public CodeGenCPU { class CodeGenX86_64 final : public CodeGenCPU {
public: public:
llvm::Value* VisitExpr_(const Cast* op) override; llvm::Value* VisitExpr_(const CastNode* op) override;
private: private:
llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty, llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
const std::vector<llvm::Value*>& args); const std::vector<llvm::Value*>& args);
}; };
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
// LLVM does not automatically generate the correct instruction sequences for // LLVM does not automatically generate the correct instruction sequences for
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
// vcvtph2ps), so we explicitly generate them ourselves. // vcvtph2ps), so we explicitly generate them ourselves.
...@@ -90,22 +90,23 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { ...@@ -90,22 +90,23 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())), LLVMType(DataType::Float(32, from.lanes())),
{ {
MakeValue(ir::Call::make( MakeValue(ir::CallNode::make(
DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
ir::Call::PureIntrinsic)), ir::CallNode::PureIntrinsic)),
MakeValue( MakeValue(
ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())), ir::BroadcastNode::make(
/*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)), ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
/*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)), /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
}); });
} }
if (from.lanes() >= 8 && has_f16c) { if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin( return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
{MakeValue(ir::Call::make( {MakeValue(ir::CallNode::make(
DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
ir::Call::PureIntrinsic))}); ir::CallNode::PureIntrinsic))});
} }
} }
......
...@@ -64,21 +64,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") ...@@ -64,21 +64,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>(); const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
const Expr& x = call->args[0]; const Expr& x = call->args[0];
Expr one = make_const(x.dtype(), 1); Expr one = make_const(x.dtype(), 1);
Expr two = make_const(x.dtype(), 2); Expr two = make_const(x.dtype(), 2);
Expr neg_two = make_const(x.dtype(), -2); Expr neg_two = make_const(x.dtype(), -2);
Expr exp_neg2x = ir::Call::make( Expr exp_neg2x = ir::CallNode::make(
x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic); x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
Expr exp_pos2x = ir::Call::make( Expr exp_pos2x = ir::CallNode::make(
x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic); x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::Select::make( *rv = ir::SelectNode::make(
x >= make_zero(x.dtype()), tanh_pos, tanh_neg); x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
}); });
......
...@@ -39,34 +39,34 @@ namespace codegen { ...@@ -39,34 +39,34 @@ namespace codegen {
template<unsigned id, int num_signature> template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>(); const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<Expr> cargs; Array<Expr> cargs;
// intrin id. // intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) { for (Expr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::Call::make( *rv = ir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic); call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic);
} }
template<unsigned id, int num_signature> template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>(); const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<Expr> cargs; Array<Expr> cargs;
// intrin id. // intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) { for (Expr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::Call::make( *rv = ir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic); call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic);
} }
} // namespace codegen } // namespace codegen
......
...@@ -35,14 +35,14 @@ namespace codegen { ...@@ -35,14 +35,14 @@ namespace codegen {
inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0]; Expr e = args[0];
using namespace ir; using namespace ir;
const Call* call = e.as<Call>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
std::ostringstream intrinsic_name; std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name; intrinsic_name << "__nv_" << call->name;
if (call->dtype.bits() == 32) intrinsic_name << "f"; if (call->dtype.bits() == 32) intrinsic_name << "f";
*rv = Call::make(call->dtype, intrinsic_name.str(), call->args, *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern); CallNode::PureExtern);
} }
namespace llvm { namespace llvm {
......
...@@ -35,12 +35,12 @@ namespace codegen { ...@@ -35,12 +35,12 @@ namespace codegen {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0]; Expr e = args[0];
using namespace ir; using namespace ir;
const Call* call = e.as<Call>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
std::ostringstream intrinsic_name; std::ostringstream intrinsic_name;
intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
*rv = Call::make(call->dtype, intrinsic_name.str(), call->args, *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern); CallNode::PureExtern);
} }
namespace llvm { namespace llvm {
......
...@@ -62,45 +62,45 @@ class CodeGenSPIRV: ...@@ -62,45 +62,45 @@ class CodeGenSPIRV:
return VisitExpr(e); return VisitExpr(e);
} }
// override codegen // override codegen
spirv::Value VisitExpr_(const Variable* op) override; spirv::Value VisitExpr_(const VarNode* op) override;
spirv::Value VisitExpr_(const Cast* op) override; spirv::Value VisitExpr_(const CastNode* op) override;
spirv::Value VisitExpr_(const IntImm* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override;
spirv::Value VisitExpr_(const UIntImm* op) override; spirv::Value VisitExpr_(const UIntImmNode* op) override;
spirv::Value VisitExpr_(const FloatImm* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override;
spirv::Value VisitExpr_(const StringImm* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override;
spirv::Value VisitExpr_(const Add* op) override; spirv::Value VisitExpr_(const AddNode* op) override;
spirv::Value VisitExpr_(const Sub* op) override; spirv::Value VisitExpr_(const SubNode* op) override;
spirv::Value VisitExpr_(const Mul* op) override; spirv::Value VisitExpr_(const MulNode* op) override;
spirv::Value VisitExpr_(const Div* op) override; spirv::Value VisitExpr_(const DivNode* op) override;
spirv::Value VisitExpr_(const Mod* op) override; spirv::Value VisitExpr_(const ModNode* op) override;
spirv::Value VisitExpr_(const Min* op) override; spirv::Value VisitExpr_(const MinNode* op) override;
spirv::Value VisitExpr_(const Max* op) override; spirv::Value VisitExpr_(const MaxNode* op) override;
spirv::Value VisitExpr_(const LT* op) override; spirv::Value VisitExpr_(const LTNode* op) override;
spirv::Value VisitExpr_(const LE* op) override; spirv::Value VisitExpr_(const LENode* op) override;
spirv::Value VisitExpr_(const GT* op) override; spirv::Value VisitExpr_(const GTNode* op) override;
spirv::Value VisitExpr_(const GE* op) override; spirv::Value VisitExpr_(const GENode* op) override;
spirv::Value VisitExpr_(const EQ* op) override; spirv::Value VisitExpr_(const EQNode* op) override;
spirv::Value VisitExpr_(const NE* op) override; spirv::Value VisitExpr_(const NENode* op) override;
spirv::Value VisitExpr_(const And* op) override; spirv::Value VisitExpr_(const AndNode* op) override;
spirv::Value VisitExpr_(const Or* op) override; spirv::Value VisitExpr_(const OrNode* op) override;
spirv::Value VisitExpr_(const Not* op) override; spirv::Value VisitExpr_(const NotNode* op) override;
spirv::Value VisitExpr_(const Select* op) override; spirv::Value VisitExpr_(const SelectNode* op) override;
spirv::Value VisitExpr_(const Let* op) override; spirv::Value VisitExpr_(const LetNode* op) override;
spirv::Value VisitExpr_(const Call* op) override; spirv::Value VisitExpr_(const CallNode* op) override;
spirv::Value VisitExpr_(const Ramp* op) override; spirv::Value VisitExpr_(const RampNode* op) override;
spirv::Value VisitExpr_(const Broadcast* op) override; spirv::Value VisitExpr_(const BroadcastNode* op) override;
spirv::Value VisitExpr_(const Load* op) override; spirv::Value VisitExpr_(const LoadNode* op) override;
// stmt // stmt
void VisitStmt_(const Store* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumerNode* op) override;
protected: protected:
/*! \brief The storage information */ /*! \brief The storage information */
...@@ -129,7 +129,7 @@ class CodeGenSPIRV: ...@@ -129,7 +129,7 @@ class CodeGenSPIRV:
void InitFuncState(); void InitFuncState();
// Get the thread index // Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent); spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
spirv::Value CreateStorageSync(const Call* op); spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const Expr& e, void Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f); std::function<void(int i, spirv::Value v)> f);
// The builder // The builder
...@@ -139,9 +139,9 @@ class CodeGenSPIRV: ...@@ -139,9 +139,9 @@ class CodeGenSPIRV:
// Likely branch // Likely branch
uint32_t weight_likely_branch_{128}; uint32_t weight_likely_branch_{128};
// the storage scope of allocation // the storage scope of allocation
std::unordered_map<const Variable*, StorageInfo> storage_info_; std::unordered_map<const VarNode*, StorageInfo> storage_info_;
// The definition of local variable. // The definition of local variable.
std::unordered_map<const Variable*, spirv::Value> var_map_; std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer. // The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_; std::unique_ptr<arith::Analyzer> analyzer_;
}; };
......
...@@ -35,17 +35,17 @@ using namespace runtime; ...@@ -35,17 +35,17 @@ using namespace runtime;
template<unsigned id> template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>(); const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<Expr> cargs; Array<Expr> cargs;
// intrin id. // intrin id.
cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id)); cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
for (Expr arg : call->args) { for (Expr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::Call::make( *rv = ir::CallNode::make(
call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic); call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
......
...@@ -96,13 +96,13 @@ class CodeGenStackVM ...@@ -96,13 +96,13 @@ class CodeGenStackVM
* \param v The variable. * \param v The variable.
* \return the heap index of the var. * \return the heap index of the var.
*/ */
int AllocVarID(const Variable* v); int AllocVarID(const VarNode* v);
/*! /*!
* \brief Get a variable name. * \brief Get a variable name.
* \param v The variable. * \param v The variable.
* \return the heap index of the var. * \return the heap index of the var.
*/ */
int GetVarID(const Variable* v) const; int GetVarID(const VarNode* v) const;
// Push binary operator // Push binary operator
void PushBinary(StackVM::OpCode op_int64, void PushBinary(StackVM::OpCode op_int64,
const Expr& a, const Expr& a,
...@@ -111,52 +111,52 @@ class CodeGenStackVM ...@@ -111,52 +111,52 @@ class CodeGenStackVM
void PushCast(DataType dst, DataType src); void PushCast(DataType dst, DataType src);
// overloadable functions // overloadable functions
// expression // expression
void VisitExpr_(const Variable* op) final; void VisitExpr_(const VarNode* op) final;
void VisitExpr_(const Load* op) final; void VisitExpr_(const LoadNode* op) final;
void VisitExpr_(const Let* op) final; void VisitExpr_(const LetNode* op) final;
void VisitExpr_(const Call* op) final; void VisitExpr_(const CallNode* op) final;
void VisitExpr_(const Add* op) final; void VisitExpr_(const AddNode* op) final;
void VisitExpr_(const Sub* op) final; void VisitExpr_(const SubNode* op) final;
void VisitExpr_(const Mul* op) final; void VisitExpr_(const MulNode* op) final;
void VisitExpr_(const Div* op) final; void VisitExpr_(const DivNode* op) final;
void VisitExpr_(const Mod* op) final; void VisitExpr_(const ModNode* op) final;
void VisitExpr_(const Min* op) final; void VisitExpr_(const MinNode* op) final;
void VisitExpr_(const Max* op) final; void VisitExpr_(const MaxNode* op) final;
void VisitExpr_(const EQ* op) final; void VisitExpr_(const EQNode* op) final;
void VisitExpr_(const NE* op) final; void VisitExpr_(const NENode* op) final;
void VisitExpr_(const LT* op) final; void VisitExpr_(const LTNode* op) final;
void VisitExpr_(const LE* op) final; void VisitExpr_(const LENode* op) final;
void VisitExpr_(const GT* op) final; void VisitExpr_(const GTNode* op) final;
void VisitExpr_(const GE* op) final; void VisitExpr_(const GENode* op) final;
void VisitExpr_(const And* op) final; void VisitExpr_(const AndNode* op) final;
void VisitExpr_(const Or* op) final; void VisitExpr_(const OrNode* op) final;
void VisitExpr_(const Cast* op) final; void VisitExpr_(const CastNode* op) final;
void VisitExpr_(const Not* op) final; void VisitExpr_(const NotNode* op) final;
void VisitExpr_(const Select* op) final; void VisitExpr_(const SelectNode* op) final;
void VisitExpr_(const Ramp* op) final; void VisitExpr_(const RampNode* op) final;
void VisitExpr_(const Broadcast* op) final; void VisitExpr_(const BroadcastNode* op) final;
void VisitExpr_(const IntImm* op) final; void VisitExpr_(const IntImmNode* op) final;
void VisitExpr_(const UIntImm* op) final; void VisitExpr_(const UIntImmNode* op) final;
void VisitExpr_(const FloatImm* op) final; void VisitExpr_(const FloatImmNode* op) final;
void VisitExpr_(const StringImm* op) final; void VisitExpr_(const StringImmNode* op) final;
// statment // statment
void VisitStmt_(const LetStmt* op) final; void VisitStmt_(const LetStmtNode* op) final;
void VisitStmt_(const Store* op) final; void VisitStmt_(const StoreNode* op) final;
void VisitStmt_(const For* op) final; void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const IfThenElse* op) final; void VisitStmt_(const IfThenElseNode* op) final;
void VisitStmt_(const Allocate* op) final; void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmt* op) final; void VisitStmt_(const AttrStmtNode* op) final;
void VisitStmt_(const AssertStmt* op) final; void VisitStmt_(const AssertStmtNode* op) final;
void VisitStmt_(const Evaluate* op) final; void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const SeqStmtNode* op) final; void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumer* op) final; void VisitStmt_(const ProducerConsumerNode* op) final;
private: private:
bool debug_{false}; bool debug_{false};
/*! \brief The vm to be generated */ /*! \brief The vm to be generated */
StackVM vm_; StackVM vm_;
/*! \brief id of each variable */ /*! \brief id of each variable */
std::unordered_map<const Variable*, int> var_idmap_; std::unordered_map<const VarNode*, int> var_idmap_;
/*! \brief id of each string */ /*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_; std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */ /*! \brief id of each global function */
......
...@@ -90,49 +90,49 @@ class CodeGenHybrid : ...@@ -90,49 +90,49 @@ class CodeGenHybrid :
return os.str(); return os.str();
} }
// expression // expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment // statment
void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const Store* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const Provide* op) override; void VisitStmt_(const ProvideNode* op) override;
void VisitStmt_(const For* op) override; void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElse* op) override; void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const Allocate* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const Realize* op) override; void VisitStmt_(const RealizeNode* op) override;
void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override; void VisitStmt_(const ProducerConsumerNode* op) override;
/*! /*!
* \brief Print Type represetnation of type t. * \brief Print Type represetnation of type t.
* \param t The type representation. * \param t The type representation.
...@@ -154,7 +154,7 @@ class CodeGenHybrid : ...@@ -154,7 +154,7 @@ class CodeGenHybrid :
* Values are the corresponding IDs.*/ * Values are the corresponding IDs.*/
std::map<std::pair<const Object *, int>, std::string> id_map_; std::map<std::pair<const Object *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */ /*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_; std::map<const VarNode *, std::string> binds_;
/*! /*!
* \brief Find an unallocated name for the given prefix. * \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix. * \param prefix The given prefix.
...@@ -166,7 +166,7 @@ class CodeGenHybrid : ...@@ -166,7 +166,7 @@ class CodeGenHybrid :
* \brief Get or allocate the ID for the given variable. * \brief Get or allocate the ID for the given variable.
* \param v The given variable. * \param v The given variable.
*/ */
std::string GetVarID(const Variable *v); std::string GetVarID(const VarNode *v);
/*! /*!
* \brief Get or allocate the ID for the given tensor. * \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name. * \param func The tensor to allocate a name.
......
...@@ -90,29 +90,29 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot ...@@ -90,29 +90,29 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot
return lhs == other.get(); return lhs == other.get();
} }
bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<IntImm>()) { if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<UIntImm>()) { if (const auto* rhs = other.as<UIntImmNode>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImm>()) { if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<StringImm>()) { if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
...@@ -151,34 +151,34 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other ...@@ -151,34 +151,34 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other
} \ } \
} \ } \
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv); TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod); TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT); TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE); TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT); TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ); TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And); TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Not>()) { if (const auto* rhs = other.as<NotNode>()) {
return Equal(lhs->a, rhs->a); return Equal(lhs->a, rhs->a);
} else { } else {
return false; return false;
} }
} }
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Cast>()) { if (const auto* rhs = other.as<CastNode>()) {
if (lhs->dtype != rhs->dtype) return false; if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value); return Equal(lhs->value, rhs->value);
} else { } else {
...@@ -186,8 +186,8 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { ...@@ -186,8 +186,8 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
} }
} }
bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Call>()) { if (const auto* rhs = other.as<CallNode>()) {
return return
lhs->name == rhs->name && lhs->name == rhs->name &&
lhs->dtype == rhs->dtype && lhs->dtype == rhs->dtype &&
...@@ -198,8 +198,8 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { ...@@ -198,8 +198,8 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
} }
} }
bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) { bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Select>()) { if (const auto* rhs = other.as<SelectNode>()) {
return return
Equal(lhs->condition, rhs->condition) && Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) && Equal(lhs->true_value, rhs->true_value) &&
...@@ -220,19 +220,19 @@ size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) { ...@@ -220,19 +220,19 @@ size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
} }
} }
size_t AttrsHashHandler::VisitAttr_(const IntImm* op) { size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
return std::hash<int64_t>()(op->value); return std::hash<int64_t>()(op->value);
} }
size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) { size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
return std::hash<uint64_t>()(op->value); return std::hash<uint64_t>()(op->value);
} }
size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) { size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value); return std::hash<double>()(op->value);
} }
size_t AttrsHashHandler::VisitAttr_(const StringImm* op) { size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
return std::hash<std::string>()(op->value); return std::hash<std::string>()(op->value);
} }
...@@ -265,31 +265,31 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { ...@@ -265,31 +265,31 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \ return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \ } \
TVM_DEFINE_ATTRS_BINOP_HASH(Add); TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul); TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod); TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Min); TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GE); TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GT); TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LE); TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LT); TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ); TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NE); TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(And); TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(Or); TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
size_t AttrsHashHandler::VisitAttr_(const Not* op) { size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
static size_t key = std::hash<std::string>()(Not::_type_key); static size_t key = std::hash<std::string>()(NotNode::_type_key);
return Combine(key, Hash(op->a)); return Combine(key, Hash(op->a));
} }
size_t AttrsHashHandler::VisitAttr_(const Cast* op) { size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key); static size_t key = std::hash<std::string>()(CastNode::_type_key);
AttrsHash hasher; AttrsHash hasher;
size_t res = key; size_t res = key;
res = Combine(res, hasher(op->dtype)); res = Combine(res, hasher(op->dtype));
...@@ -297,8 +297,8 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) { ...@@ -297,8 +297,8 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
return res; return res;
} }
size_t AttrsHashHandler::VisitAttr_(const Call* op) { size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
static size_t key = std::hash<std::string>()(Call::_type_key); static size_t key = std::hash<std::string>()(CallNode::_type_key);
AttrsHash hasher; AttrsHash hasher;
size_t res = key; size_t res = key;
res = Combine(res, hasher(op->name)); res = Combine(res, hasher(op->name));
...@@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) { ...@@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) {
return res; return res;
} }
size_t AttrsHashHandler::VisitAttr_(const Select* op) { size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
static size_t key = std::hash<std::string>()(Select::_type_key); static size_t key = std::hash<std::string>()(SelectNode::_type_key);
size_t res = key; size_t res = key;
res = Combine(res, Hash(op->condition)); res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value)); res = Combine(res, Hash(op->true_value));
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
namespace tvm { namespace tvm {
// TODO(tqchen): change to floormod/div // TODO(tqchen): change to floormod/div
using IndexMod = ir::FloorMod; using IndexMod = ir::FloorModNode;
using IndexDiv = ir::FloorDiv; using IndexDiv = ir::FloorDivNode;
Array<Expr> SimplifyArray(Array<Expr> array) { Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) { for (size_t i = 0; i < array.size(); ++i) {
...@@ -65,7 +65,7 @@ inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) { ...@@ -65,7 +65,7 @@ inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
while (!split_buffer.empty()) { while (!split_buffer.empty()) {
const Expr* top_ele = split_buffer.top(); const Expr* top_ele = split_buffer.top();
split_buffer.pop(); split_buffer.pop();
auto expr_add_match = top_ele->as<Add>(); auto expr_add_match = top_ele->as<AddNode>();
if (expr_add_match) { if (expr_add_match) {
split_buffer.push(&expr_add_match->b); split_buffer.push(&expr_add_match->b);
split_buffer.push(&expr_add_match->a); split_buffer.push(&expr_add_match->a);
...@@ -88,13 +88,13 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr, ...@@ -88,13 +88,13 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
const Expr &mod_l_expr, const Expr &mod_l_expr,
const Expr &mod_r_expr) { const Expr &mod_r_expr) {
using namespace ir; using namespace ir;
const Mul* mult_ptr = mult_expr.as<Mul>(); const MulNode* mult_ptr = mult_expr.as<MulNode>();
if (!mult_ptr) return std::make_pair(false, Expr()); if (!mult_ptr) return std::make_pair(false, Expr());
Expr mult_outer = mult_ptr->b; Expr mult_outer = mult_ptr->b;
const Expr* inner = &(mult_ptr->a); const Expr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier // 1. Calculate the outer multiplier
while (true) { while (true) {
mult_ptr = inner->as<Mul>(); mult_ptr = inner->as<MulNode>();
if (mult_ptr) { if (mult_ptr) {
inner = &(mult_ptr->a); inner = &(mult_ptr->a);
mult_outer = mult_ptr->b * mult_outer; mult_outer = mult_ptr->b * mult_outer;
...@@ -113,8 +113,8 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr, ...@@ -113,8 +113,8 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
Expr no_opt_sum; // Sum of the exprs that cannot be optimized Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) { while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>(); auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<Mul>(); auto inner_mult_ptr = search_ptr->as<MulNode>();
auto inner_add_ptr = search_ptr->as<Add>(); auto inner_add_ptr = search_ptr->as<AddNode>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
return std::make_pair(false, Expr()); return std::make_pair(false, Expr());
} else if (inner_div_ptr) { } else if (inner_div_ptr) {
...@@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles, ...@@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
*has_mod = false; *has_mod = false;
for (const Expr* ele : eles) { for (const Expr* ele : eles) {
auto mod_ptr = ele->as<IndexMod>(); auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<Mul>(); auto mult_ptr = ele->as<MulNode>();
if (mod_ptr) { if (mod_ptr) {
*has_mod = true; *has_mod = true;
mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
...@@ -252,7 +252,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { ...@@ -252,7 +252,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
// Scalar case // Scalar case
if (n->shape.size() == 0 && index.size() == 1) { if (n->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImm>(); auto is_int = index[0].as<IntImmNode>();
CHECK(is_int && is_int->value == 0); CHECK(is_int && is_int->value == 0);
base = base + index[0]; base = base + index[0];
} else { } else {
...@@ -285,7 +285,7 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype) ...@@ -285,7 +285,7 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype)
offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = offset * make_const(offset.dtype(), dtype.lanes());
} }
if (dtype.lanes() != 1) { if (dtype.lanes() != 1) {
return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
} else { } else {
return offset; return offset;
} }
...@@ -299,13 +299,13 @@ Expr Buffer::vload(Array<Expr> begin, DataType dtype) const { ...@@ -299,13 +299,13 @@ Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
<< "Cannot load " << dtype << "Cannot load " << dtype
<< " from buffer of " << n->dtype; << " from buffer of " << n->dtype;
if (dtype == DataType::Bool()) { if (dtype == DataType::Bool()) {
return ir::Cast::make( return ir::CastNode::make(
DataType::Bool(), DataType::Bool(),
ir::Load::make( ir::LoadNode::make(
DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
const_true())); const_true()));
} else { } else {
return ir::Load::make( return ir::LoadNode::make(
dtype, n->data, BufferOffset(n, begin, dtype), dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes())); const_true(dtype.lanes()));
} }
...@@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const { ...@@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
<< "Cannot load " << dtype << "Cannot load " << dtype
<< " from buffer of " << n->dtype; << " from buffer of " << n->dtype;
if (value.dtype() == DataType::Bool()) { if (value.dtype() == DataType::Bool()) {
return ir::Store::make(n->data, return ir::StoreNode::make(n->data,
ir::Cast::make(DataType::Int(8), value), ir::CastNode::make(DataType::Int(8), value),
BufferOffset(n, begin, DataType::Int(8)), BufferOffset(n, begin, DataType::Int(8)),
const_true()); const_true());
} else { } else {
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes())); const_true(dtype.lanes()));
} }
} }
...@@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E ...@@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
int highest_dim = 0; int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else { } else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset; extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
} }
Expr elem_offset = self->elem_offset + offset; Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) { if (content_lanes > 1) {
...@@ -405,8 +405,8 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E ...@@ -405,8 +405,8 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
Array<Expr> acc_args{ Array<Expr> acc_args{
e_dtype, self->data, elem_offset, e_dtype, self->data, elem_offset,
extent, make_const(DataType::Int(32), access_mask)}; extent, make_const(DataType::Int(32), access_mask)};
return ir::Call::make( return ir::CallNode::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic);
} }
Buffer BufferNode::make(Var data, Buffer BufferNode::make(Var data,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment