Unverified Commit cf59b206 by Tianqi Chen Committed by GitHub

[REFACTOR] Establish tir (#4740)

TIR is the new namespace for low-level IR
for tensor-level optimizations and loop transformations.

This PR establishes the namespace and files.

- lowered_func.h,buffer.h,data_layout.h -> tir/buffer.h,tir/data_layout.h,tir/lowered_func.h
- ir.h -> tir/expr.h, tir/stmt.h
- ir_functor_ext.h -> tir/expr_functor.h, tir/stmt_functor.h
parent 7e392019
...@@ -132,12 +132,11 @@ file(GLOB_RECURSE COMPILER_SRCS ...@@ -132,12 +132,11 @@ file(GLOB_RECURSE COMPILER_SRCS
src/top/*.cc src/top/*.cc
src/api/*.cc src/api/*.cc
src/autotvm/*.cc src/autotvm/*.cc
src/lang/*.cc src/tir/*.cc
src/pass/*.cc
) )
file(GLOB CODEGEN_SRCS file(GLOB CODEGEN_SRCS
src/codegen/*.cc src/codegen/*.cc
) )
list(APPEND COMPILER_SRCS ${CODEGEN_SRCS}) list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
......
...@@ -27,9 +27,10 @@ ...@@ -27,9 +27,10 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
using namespace tvm; using namespace tvm;
using namespace tvm::tir;
using namespace tvm::runtime; using namespace tvm::runtime;
namespace tvm_ext { namespace tvm_ext {
......
...@@ -46,7 +46,7 @@ def __lldb_init_module(debugger, _): ...@@ -46,7 +46,7 @@ def __lldb_init_module(debugger, _):
"tvm::IterVarAttr", "tvm::IterVarAttr",
"tvm::IterVarRelation", "tvm::IterVarRelation",
"tvm::Layout", "tvm::Layout",
"tvm::LoweredFunc", "tir::LoweredFunc",
"tvm::Map", "tvm::Map",
"tvm::Map", "tvm::Map",
"tvm::MemoryInfo", "tvm::MemoryInfo",
...@@ -60,7 +60,7 @@ def __lldb_init_module(debugger, _): ...@@ -60,7 +60,7 @@ def __lldb_init_module(debugger, _):
"tvm::TensorIntrin", "tvm::TensorIntrin",
"tvm::TensorIntrinCall", "tvm::TensorIntrinCall",
"tvm::TypedEnvFunc", "tvm::TypedEnvFunc",
"tvm::Var", "tvm::tir::Var",
"tvm::ir::CommReducer", "tvm::ir::CommReducer",
"tvm::ir::FunctionRef", "tvm::ir::FunctionRef",
"tvm::relay::BaseTensorType", "tvm::relay::BaseTensorType",
......
...@@ -50,6 +50,8 @@ namespace arith { ...@@ -50,6 +50,8 @@ namespace arith {
// Forward declare Analyzer // Forward declare Analyzer
class Analyzer; class Analyzer;
using tir::Var;
/*! /*!
* \brief Constant integer up and lower bound(inclusive). * \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis. * Useful for value bound analysis.
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/arith/int_set.h> #include <tvm/arith/int_set.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <unordered_map> #include <unordered_map>
...@@ -37,6 +38,11 @@ class Tensor; ...@@ -37,6 +38,11 @@ class Tensor;
} }
namespace arith { namespace arith {
using tir::Var;
using tir::VarNode;
using tir::Domain;
using tir::Stmt;
/*! /*!
* \brief Deduce the bound of the target variable in a expression, * \brief Deduce the bound of the target variable in a expression,
* give the domain of each variables. Return undefined IntSet to * give the domain of each variables. Return undefined IntSet to
......
...@@ -25,12 +25,16 @@ ...@@ -25,12 +25,16 @@
#define TVM_ARITH_INT_SET_H_ #define TVM_ARITH_INT_SET_H_
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <unordered_map> #include <unordered_map>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using tir::Var;
using tir::VarNode;
using tir::IterVar;
//----------------------------------------------- //-----------------------------------------------
// Integer set data structure. // Integer set data structure.
// //
...@@ -165,7 +169,7 @@ IntSet EvalSet(PrimExpr e, ...@@ -165,7 +169,7 @@ IntSet EvalSet(PrimExpr e,
* \return An integer set that can cover all the possible values of e. * \return An integer set that can cover all the possible values of e.
*/ */
IntSet EvalSet(PrimExpr e, IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map); const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
/*! /*!
* \brief Find an symbolic integer set that contains is union over * \brief Find an symbolic integer set that contains is union over
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -39,7 +39,7 @@ namespace arith { ...@@ -39,7 +39,7 @@ namespace arith {
* \return [coeff[i]] if it is possible, empty array if it is not. * \return [coeff[i]] if it is possible, empty array if it is not.
*/ */
Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
const Array<Var>& vars); const Array<tir::Var>& vars);
/*! /*!
* \brief Detect if expression corresponds to clip bound of the vars * \brief Detect if expression corresponds to clip bound of the vars
...@@ -50,7 +50,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, ...@@ -50,7 +50,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
* return empty if the e does not match the pattern. * return empty if the e does not match the pattern.
*/ */
Array<PrimExpr> DetectClipBound(const PrimExpr& e, Array<PrimExpr> DetectClipBound(const PrimExpr& e,
const Array<Var>& vars); const Array<tir::Var>& vars);
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
...@@ -24,9 +24,11 @@ ...@@ -24,9 +24,11 @@
#ifndef TVM_BUILD_MODULE_H_ #ifndef TVM_BUILD_MODULE_H_
#define TVM_BUILD_MODULE_H_ #define TVM_BUILD_MODULE_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/support/with.h> #include <tvm/support/with.h>
#include <tvm/top/schedule_pass.h> #include <tvm/top/schedule_pass.h>
#include <tvm/tir/lowered_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,10 +36,6 @@ ...@@ -34,10 +36,6 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "runtime/packed_func.h"
#include "lowered_func.h"
namespace tvm { namespace tvm {
/*! /*!
...@@ -174,11 +172,12 @@ class BuildConfig : public ::tvm::ObjectRef { ...@@ -174,11 +172,12 @@ class BuildConfig : public ::tvm::ObjectRef {
* \param config The build configuration. * \param config The build configuration.
* \return The lowered function. * \return The lowered function.
*/ */
TVM_DLL Array<LoweredFunc> lower(top::Schedule sch, TVM_DLL Array<tir::LoweredFunc> lower(
const Array<top::Tensor>& args, top::Schedule sch,
const std::string& name, const Array<top::Tensor>& args,
const std::unordered_map<top::Tensor, Buffer>& binds, const std::string& name,
const BuildConfig& config); const std::unordered_map<top::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*! /*!
* \brief Split host/device function and running necessary pass before build * \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built. * \param funcs The functions to be built.
...@@ -188,10 +187,11 @@ TVM_DLL Array<LoweredFunc> lower(top::Schedule sch, ...@@ -188,10 +187,11 @@ TVM_DLL Array<LoweredFunc> lower(top::Schedule sch,
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array, * \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array second is device function array
*/ */
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Target& target, const Array<tir::LoweredFunc>& funcs,
const Target& target_host, const Target& target,
const BuildConfig& config); const Target& target_host,
const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from an array of lowered functions. * \brief Build a device and host module for a specific target from an array of lowered functions.
...@@ -201,7 +201,7 @@ TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc> ...@@ -201,7 +201,7 @@ TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>
* \param config The build configuration. * \param config The build configuration.
* \return The built module. * \return The built module.
*/ */
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs, TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
...@@ -216,7 +216,7 @@ TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -216,7 +216,7 @@ TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
* \param config The build configuration. * \param config The build configuration.
* \return The built module that contains code for different processors. * \return The built module that contains code for different processors.
*/ */
TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input, TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
...@@ -231,7 +231,7 @@ TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input, ...@@ -231,7 +231,7 @@ TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
* \param config The build configuration. * \param config The build configuration.
* \return The built module that contains code for different processors. * \return The built module that contains code for different processors.
*/ */
TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input, TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
#ifndef TVM_CODEGEN_H_ #ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_ #define TVM_CODEGEN_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <string> #include <string>
#include "expr.h"
#include "lowered_func.h"
#include "runtime/packed_func.h"
namespace tvm { namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */ /*! \brief namespace for lowlevel IR pass and codegen */
...@@ -45,7 +46,7 @@ using runtime::TVMRetValue; ...@@ -45,7 +46,7 @@ using runtime::TVMRetValue;
* *
* \note Calls global API function "_codegen_build_" + target * \note Calls global API function "_codegen_build_" + target
*/ */
runtime::Module Build(const Array<LoweredFunc>& funcs, runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target); const std::string& target);
/*! /*!
* \brief Pack imported device library to a C file. * \brief Pack imported device library to a C file.
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/ir/span.h> #include <tvm/ir/span.h>
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <string> #include <string>
#include <algorithm>
#include <limits> #include <limits>
namespace tvm { namespace tvm {
...@@ -123,76 +124,6 @@ class PrimExpr : public BaseExpr { ...@@ -123,76 +124,6 @@ class PrimExpr : public BaseExpr {
}; };
/*! /*!
* \brief Constant integer literals in the program.
* \sa IntImm
*/
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to IntImmNode.
*
* \sa IntImmNode
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};
/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};
/*!
* \brief Base node of all non-primitive expressions. * \brief Base node of all non-primitive expressions.
* *
* RelayExpr supports tensor types, functions and ADT as * RelayExpr supports tensor types, functions and ADT as
...@@ -304,6 +235,163 @@ class BaseFunc : public RelayExpr { ...@@ -304,6 +235,163 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
}; };
// PrimExprs that are useful as runtime containers.
//
/*!
* \brief Constant integer literals in the program.
* \sa IntImm
*/
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to IntImmNode.
*
* \sa IntImmNode
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL IntImm(DataType dtype, int64_t value);
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
};
/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
};
/*!
* \brief Container of constant int that adds more constructors.
*
* This is used to store and automate type check
* attributes that must be constant integer.
*
* \sa IntImm
*/
class Integer : public IntImm {
public:
Integer() {}
/*!
* \brief constructor from node.
*/
explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*)
/*!
* \brief Construct integer from int imm.
* \param other The other value.
*/
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const IntImm& other) {
data_ = ObjectRef::GetDataPtr<Object>(other);
return *this;
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(data_ != nullptr)
<< " Trying to reference a null Integer";
return (*this)->value;
}
};
/*! \brief range over one dimension */
class RangeNode : public Object {
public:
/*! \brief beginning of the node */
PrimExpr min;
/*! \brief the extend of range */
PrimExpr extent;
/*! \brief constructor */
RangeNode() {}
RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
static constexpr const char* _type_key = "Range";
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
/*! \brief Range constainer */
class Range : public ObjectRef {
public:
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
TVM_DLL Range(PrimExpr begin, PrimExpr end);
/*!
* \brief construct a new range with min and extent
* The corresponding constructor is removed,
* because that is counter convention of tradition meaning
* of range(begin, end)
*
* \param min The minimum range.
* \param extent The extent of the range.
*/
static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
// declare range.
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
// implementataions // implementataions
inline const Type& RelayExprNode::checked_type() const { inline const Type& RelayExprNode::checked_type() const {
CHECK(checked_type_.defined()) CHECK(checked_type_.defined())
......
...@@ -46,6 +46,13 @@ class NodePrinter { ...@@ -46,6 +46,13 @@ class NodePrinter {
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>; using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
TVM_DLL static FType& vtable(); TVM_DLL static FType& vtable();
}; };
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const ObjectRef& node);
} // namespace tvm } // namespace tvm
namespace tvm { namespace tvm {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <tvm/ir/span.h> #include <tvm/ir/span.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/node/node.h> #include <tvm/node/node.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -29,11 +29,16 @@ ...@@ -29,11 +29,16 @@
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/tir/data_layout.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using tir::Layout;
using tir::LayoutAxis;
using tir::BijectiveLayoutNode;
/*! \brief operator pattern used in graph fusion */ /*! \brief operator pattern used in graph fusion */
enum OpPatternKind { enum OpPatternKind {
// Elementwise operation // Elementwise operation
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <string> #include <string>
#include "base.h" #include "base.h"
...@@ -40,7 +40,7 @@ namespace relay { ...@@ -40,7 +40,7 @@ namespace relay {
// namespace update for backward compact // namespace update for backward compact
// will be removed later. // will be removed later.
using Any = tvm::ir::AnyNode; using Any = tvm::tir::AnyNode;
using Kind = TypeKind; using Kind = TypeKind;
using Type = tvm::Type; using Type = tvm::Type;
using TypeNode = tvm::TypeNode; using TypeNode = tvm::TypeNode;
......
...@@ -18,20 +18,21 @@ ...@@ -18,20 +18,21 @@
*/ */
/*! /*!
* \file tvm/buffer.h * \file tvm/tir/buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer. * \brief Symbolic n-dimensional array, to represent a memory buffer.
*/ */
#ifndef TVM_BUFFER_H_ #ifndef TVM_TIR_BUFFER_H_
#define TVM_BUFFER_H_ #define TVM_TIR_BUFFER_H_
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string> #include <string>
#include "expr.h"
#include "expr_operator.h"
#include "tvm/node/container.h"
namespace tvm { namespace tvm {
namespace tir {
// Internal node container Buffer // Internal node container Buffer
class BufferNode; class BufferNode;
...@@ -186,5 +187,6 @@ inline const BufferNode* Buffer::operator->() const { ...@@ -186,5 +187,6 @@ inline const BufferNode* Buffer::operator->() const {
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype = DataType::Float(32), DataType dtype = DataType::Float(32),
std::string name = "buffer"); std::string name = "buffer");
} // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_BUFFER_H_ #endif // TVM_TIR_BUFFER_H_
...@@ -18,15 +18,16 @@ ...@@ -18,15 +18,16 @@
*/ */
/*! /*!
* \file tvm/data_layout.h * \file tvm/tir/data_layout.h
* \brief Layout expression to describe the data organization of a tensor. * \brief Layout expression to describe the data organization of a tensor.
* And BijectiveLayout to mapping two data layouts between each other. * And BijectiveLayout to mapping two data layouts between each other.
*/ */
#ifndef TVM_DATA_LAYOUT_H_ #ifndef TVM_TIR_DATA_LAYOUT_H_
#define TVM_DATA_LAYOUT_H_ #define TVM_TIR_DATA_LAYOUT_H_
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string> #include <string>
#include <sstream> #include <sstream>
...@@ -34,16 +35,16 @@ ...@@ -34,16 +35,16 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include "expr_operator.h"
namespace tvm { namespace tvm {
namespace tir {
class LayoutAxis { class LayoutAxis {
public: public:
static const LayoutAxis& Get(const char name); static const LayoutAxis& Get(const char name);
// Get the singleton LayoutAxis using itvar->var->name_hint // Get the singleton LayoutAxis using itvar->var->name_hint
static const LayoutAxis& Get(const IterVar& itvar); static const LayoutAxis& Get(const tir::IterVar& itvar);
// Get the singleton LayoutAxis using name[0] (size of name must be 1). // Get the singleton LayoutAxis using name[0] (size of name must be 1).
static const LayoutAxis& make(const std::string& name); static const LayoutAxis& make(const std::string& name);
...@@ -102,7 +103,7 @@ class LayoutNode : public Object { ...@@ -102,7 +103,7 @@ class LayoutNode : public Object {
* it is a variable for a primal axis, but a constant for a subordinate axis. * it is a variable for a primal axis, but a constant for a subordinate axis.
* Empty for scalar's layout. * Empty for scalar's layout.
*/ */
Array<IterVar> axes; Array<tir::IterVar> axes;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
...@@ -132,7 +133,7 @@ class Layout : public ObjectRef { ...@@ -132,7 +133,7 @@ class Layout : public ObjectRef {
/*! \brief default constructor */ /*! \brief default constructor */
Layout() = default; Layout() = default;
explicit Layout(const Array<IterVar>& axes); explicit Layout(const Array<tir::IterVar>& axes);
/*! \brief construct from a string */ /*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
...@@ -264,7 +265,7 @@ class Layout : public ObjectRef { ...@@ -264,7 +265,7 @@ class Layout : public ObjectRef {
*/ */
bool Contains(const LayoutAxis& axis) const { bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false; if (!defined()) return false;
for (const IterVar var : operator->()->axes) { for (const tir::IterVar var : operator->()->axes) {
if (var->var->name_hint == axis.name()) { if (var->var->name_hint == axis.name()) {
return true; return true;
} }
...@@ -276,7 +277,7 @@ class Layout : public ObjectRef { ...@@ -276,7 +277,7 @@ class Layout : public ObjectRef {
CHECK(defined()) << "Try to access axis from an undefined layout."; CHECK(defined()) << "Try to access axis from an undefined layout.";
int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i; int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i; CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
const IterVar axis = operator->()->axes[index]; const tir::IterVar axis = operator->()->axes[index];
return LayoutAxis::Get(axis); return LayoutAxis::Get(axis);
} }
...@@ -371,7 +372,7 @@ class BijectiveLayout : public ObjectRef { ...@@ -371,7 +372,7 @@ class BijectiveLayout : public ObjectRef {
inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
return static_cast<const BijectiveLayoutNode*>(get()); return static_cast<const BijectiveLayoutNode*>(get());
} }
} // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_DATA_LAYOUT_H_ #endif // TVM_TIR_DATA_LAYOUT_H_
...@@ -18,27 +18,28 @@ ...@@ -18,27 +18,28 @@
*/ */
/*! /*!
* \file tvm/ir_pass.h * \file tvm/tir/ir_pass.h
* \brief Collection of IR pass functions * \brief Collection of IR pass functions
* *
* When the pass functions in this file are for Stmt, * When the pass functions in this file are for Stmt,
* we can use PassFunction(Evaluate(expr)) to apply it to Expr * we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/ */
#ifndef TVM_IR_PASS_H_ #ifndef TVM_TIR_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_TIR_IR_PASS_H_
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <string> #include <string>
#include "expr.h"
#include "buffer.h"
#include "lowered_func.h"
namespace tvm { namespace tvm {
namespace ir { namespace tir {
/*! /*!
* \brief Simplify the expression. * \brief Simplify the expression.
...@@ -593,8 +594,6 @@ bool VerifyMemory(LoweredFunc func, int device_type); ...@@ -593,8 +594,6 @@ bool VerifyMemory(LoweredFunc func, int device_type);
bool VerifyGPUCode(Stmt stmt, bool VerifyGPUCode(Stmt stmt,
Map<std::string, PrimExpr> constraints); Map<std::string, PrimExpr> constraints);
} // namespace tir
} // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_TIR_IR_PASS_H_
#endif // TVM_IR_PASS_H_
...@@ -18,21 +18,20 @@ ...@@ -18,21 +18,20 @@
*/ */
/*! /*!
* \file tvm/lowered_func.h * \file tvm/tir/lowered_func.h
* \brief Information about a lowered TVM function. * \brief Information about a lowered TVM function.
* This data structure is final step toward codegen. * This data structure is final step toward codegen.
*/ */
#ifndef TVM_LOWERED_FUNC_H_ #ifndef TVM_TIR_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_ #define TVM_TIR_LOWERED_FUNC_H_
#include <tvm/top/tensor.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <string> #include <string>
#include "expr.h"
#include "tvm/node/container.h"
namespace tvm { namespace tvm {
namespace tir {
// Internal node container of lowered function. // Internal node container of lowered function.
class LoweredFuncNode; class LoweredFuncNode;
...@@ -41,7 +40,7 @@ class LoweredFuncNode; ...@@ -41,7 +40,7 @@ class LoweredFuncNode;
* \brief LoweredFunc represents function after lowering. * \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen. * This is the final IR representation before codegen.
*/ */
class LoweredFunc : public ir::FunctionRef { class LoweredFunc : public FunctionRef {
public: public:
LoweredFunc() {} LoweredFunc() {}
explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {} explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
...@@ -65,7 +64,7 @@ enum LoweredFuncType : int { ...@@ -65,7 +64,7 @@ enum LoweredFuncType : int {
}; };
/*! \brief Node container of LoweredFunc */ /*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public ir::FunctionBaseNode { class LoweredFuncNode : public tir::FunctionBaseNode {
public: public:
/*! \brief The name of the function */ /*! \brief The name of the function */
std::string name; std::string name;
...@@ -138,13 +137,13 @@ class LoweredFuncNode : public ir::FunctionBaseNode { ...@@ -138,13 +137,13 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
inline const LoweredFuncNode* LoweredFunc::operator->() const { inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(get()); return static_cast<const LoweredFuncNode*>(get());
} }
} // namespace tir
} // namespace tvm } // namespace tvm
namespace std { namespace std {
template <> template <>
struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash { struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
}; };
} }
#endif // TVM_LOWERED_FUNC_H_ #endif // TVM_TIR_LOWERED_FUNC_H_
...@@ -28,9 +28,9 @@ ...@@ -28,9 +28,9 @@
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <tvm/buffer.h> #include <tvm/tir/buffer.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -41,8 +41,6 @@ ...@@ -41,8 +41,6 @@
namespace tvm { namespace tvm {
namespace top { namespace top {
using arith::IntSet;
/*! /*!
* \brief Temporary data structure to store union * \brief Temporary data structure to store union
* of bounds of each axis of Tensor. * of bounds of each axis of Tensor.
...@@ -58,7 +56,7 @@ struct TensorDom { ...@@ -58,7 +56,7 @@ struct TensorDom {
/*! /*!
* \brief Base class of all operation nodes * \brief Base class of all operation nodes
*/ */
class OperationNode : public ir::FunctionBaseNode { class OperationNode : public tir::FunctionBaseNode {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
...@@ -554,6 +552,29 @@ class HybridOpNode : public OperationNode { ...@@ -554,6 +552,29 @@ class HybridOpNode : public OperationNode {
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
}; };
/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<PrimExpr (const Array<Var>& i)>; using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_TOP_SCHEDULE_H_ #ifndef TVM_TOP_SCHEDULE_H_
#define TVM_TOP_SCHEDULE_H_ #define TVM_TOP_SCHEDULE_H_
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/top/tensor_intrin.h> #include <tvm/top/tensor_intrin.h>
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/arith/bound.h> #include <tvm/arith/bound.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -39,6 +39,9 @@ ...@@ -39,6 +39,9 @@
namespace tvm { namespace tvm {
namespace top { namespace top {
using arith::IntSet;
using namespace tvm::tir;
// Internal node container of Tensor // Internal node container of Tensor
class TensorNode; class TensorNode;
// internal node container for Operation // internal node container for Operation
...@@ -139,7 +142,7 @@ class Tensor : public ObjectRef { ...@@ -139,7 +142,7 @@ class Tensor : public ObjectRef {
}; };
/*! \brief Operation that produces tensors */ /*! \brief Operation that produces tensors */
class Operation : public ir::FunctionRef { class Operation : public tir::FunctionRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Operation() {} Operation() {}
...@@ -215,18 +218,18 @@ inline bool Tensor::operator!=(const Tensor& other) const { ...@@ -215,18 +218,18 @@ inline bool Tensor::operator!=(const Tensor& other) const {
// macro to turn every operation of slice to expression // macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline PrimExpr operator Op (const Tensor::Slice& a) { \ inline PrimExpr operator Op (const Tensor::Slice& a) { \
return Op a.operator PrimExpr() ; \ return Op a.operator PrimExpr() ; \
} \ } \
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \ template<typename T> \
inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
return a.operator PrimExpr() Op b; \ return a.operator PrimExpr() Op b; \
} \ } \
template<typename T> \ template<typename T> \
inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \
return a Op b.operator PrimExpr(); \ return a Op b.operator PrimExpr(); \
} \ } \
inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
return a.operator PrimExpr() Op b.operator PrimExpr(); \ return a.operator PrimExpr() Op b.operator PrimExpr(); \
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_TOP_TENSOR_INTRIN_H_ #define TVM_TOP_TENSOR_INTRIN_H_
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/buffer.h> #include <tvm/tir/buffer.h>
#include <string> #include <string>
......
...@@ -20,7 +20,7 @@ This namespace is used for developers. While you do not see any declarations. ...@@ -20,7 +20,7 @@ This namespace is used for developers. While you do not see any declarations.
The functions are automatically exported from C++ side via PackedFunc. The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner. Each api is a PackedFunc that can be called in a positional argument manner.
You can read "include/tvm/ir_pass.h" for the function signature and You can read "include/tvm/tir/ir_pass.h" for the function signature and
"src/api/api_pass.cc" for the PackedFunc's body of these functions. "src/api/api_pass.cc" for the PackedFunc's body of these functions.
""" """
from ._ffi.function import _init_api from ._ffi.function import _init_api
......
...@@ -24,12 +24,12 @@ There can be internal header files within each module that sit in src. ...@@ -24,12 +24,12 @@ There can be internal header files within each module that sit in src.
- support: Internal support utilities. - support: Internal support utilities.
- runtime: Minimum runtime related codes. - runtime: Minimum runtime related codes.
- node: base infra for IR/AST nodes that is dialect independent. - node: base infra for IR/AST nodes that is dialect independent.
- ir: Common IR infrastructure.
- tir: Tensor-level IR.
- arith: Arithmetic expression and set simplification. - arith: Arithmetic expression and set simplification.
- top: tensor operation DSL for compute and schedule. - top: tensor operation DSL for compute and schedule.
- relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks. - relay: Relay IR, high-level optimization.
- pass: The optimization pass on the IR structure.
- codegen: The code generator. - codegen: The code generator.
- autotvm: The auto-tuning module. - autotvm: The auto-tuning module.
- contrib: Contrib extension libraries. - contrib: Contrib extension libraries.
- api: API function registration. - api: API function registration.
- lang: The definition of DSL related data structure.
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include <tvm/arith/pattern.h> #include <tvm/arith/pattern.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \file api_base.cc * \file api_base.cc
*/ */
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/node/serialization.h> #include <tvm/node/serialization.h>
......
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
* Implementation of API functions related to Codegen * Implementation of API functions related to Codegen
* \file c_api_codegen.cc * \file c_api_codegen.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
namespace tvm { namespace tvm {
...@@ -32,7 +32,7 @@ namespace codegen { ...@@ -32,7 +32,7 @@ namespace codegen {
TVM_REGISTER_GLOBAL("codegen._Build") TVM_REGISTER_GLOBAL("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<LoweredFunc>()) { if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]); *ret = Build({args[0]}, args[1]);
} else { } else {
*ret = Build(args[0], args[1]); *ret = Build(args[0], args[1]);
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
* Implementation of API functions related to IR build * Implementation of API functions related to IR build
* \file api_ir.cc * \file api_ir.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
namespace tvm { namespace tvm {
namespace ir { namespace tir {
TVM_REGISTER_GLOBAL("_Var") TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
...@@ -233,5 +233,5 @@ TVM_REGISTER_GLOBAL("make._OpIfThenElse") ...@@ -233,5 +233,5 @@ TVM_REGISTER_GLOBAL("make._OpIfThenElse")
return if_then_else(cond, true_value, false_value); return if_then_else(cond, true_value, false_value);
}); });
} // namespace ir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
* Implementation of API functions related to Higher DSL build. * Implementation of API functions related to Higher DSL build.
* \file api_lang.cc * \file api_lang.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/top/operation.h> #include <tvm/top/operation.h>
#include <tvm/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/data_layout.h> #include <tvm/tir/data_layout.h>
namespace tvm { namespace tvm {
...@@ -44,9 +44,9 @@ TVM_REGISTER_GLOBAL("_max_value") ...@@ -44,9 +44,9 @@ TVM_REGISTER_GLOBAL("_max_value")
TVM_REGISTER_GLOBAL("_const") TVM_REGISTER_GLOBAL("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kDLInt) { if (args[0].type_code() == kDLInt) {
*ret = make_const(args[1], args[0].operator int64_t()); *ret = tir::make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kDLFloat) { } else if (args[0].type_code() == kDLFloat) {
*ret = make_const(args[1], args[0].operator double()); *ret = tir::make_const(args[1], args[0].operator double());
} else { } else {
LOG(FATAL) << "only accept int or float"; LOG(FATAL) << "only accept int or float";
} }
...@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("_LargeUIntImm") ...@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("_LargeUIntImm")
.set_body_typed(LargeUIntImm); .set_body_typed(LargeUIntImm);
TVM_REGISTER_GLOBAL("_str") TVM_REGISTER_GLOBAL("_str")
.set_body_typed(ir::StringImmNode::make); .set_body_typed(tir::StringImmNode::make);
TVM_REGISTER_GLOBAL("_Array") TVM_REGISTER_GLOBAL("_Array")
...@@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("_MapItems") ...@@ -200,7 +200,7 @@ TVM_REGISTER_GLOBAL("_MapItems")
auto* n = static_cast<const StrMapNode*>(ptr); auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>(); auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImmNode::make(kv.first)); rkvs->data.push_back(tir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = Array<ObjectRef>(rkvs); *ret = Array<ObjectRef>(rkvs);
...@@ -216,6 +216,8 @@ TVM_REGISTER_GLOBAL("Range") ...@@ -216,6 +216,8 @@ TVM_REGISTER_GLOBAL("Range")
} }
}); });
namespace tir {
TVM_REGISTER_GLOBAL("_Buffer") TVM_REGISTER_GLOBAL("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 10); CHECK_EQ(args.size(), 10);
...@@ -272,6 +274,7 @@ TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape") ...@@ -272,6 +274,7 @@ TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape")
TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape") TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape); .set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
namespace top { namespace top {
TVM_REGISTER_GLOBAL("_Tensor") TVM_REGISTER_GLOBAL("_Tensor")
...@@ -444,6 +447,6 @@ TVM_REGISTER_GLOBAL("_ScheduleRFactor") ...@@ -444,6 +447,6 @@ TVM_REGISTER_GLOBAL("_ScheduleRFactor")
} // namespace top } // namespace top
TVM_REGISTER_GLOBAL("_CommReducerCombine") TVM_REGISTER_GLOBAL("_CommReducerCombine")
.set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator()); .set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
} // namespace tvm } // namespace tvm
...@@ -21,15 +21,16 @@ ...@@ -21,15 +21,16 @@
* Exposure of pass functions. * Exposure of pass functions.
* \file api_pass.cc * \file api_pass.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir.h> #include <tvm/tir/stmt.h>
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
namespace tvm { namespace tvm {
namespace ir { namespace tir {
TVM_REGISTER_GLOBAL("ir_pass.Simplify") TVM_REGISTER_GLOBAL("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
...@@ -120,7 +121,7 @@ TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") ...@@ -120,7 +121,7 @@ TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1]; PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const ObjectRef& n) { tir::PostOrderVisit(args[0], [f](const ObjectRef& n) {
f(n); f(n);
}); });
}); });
...@@ -176,5 +177,5 @@ REGISTER_PASS(InstrumentBoundCheckers); ...@@ -176,5 +177,5 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment) REGISTER_PASS(InferFragment)
} // namespace ir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* Implementation of API functions related to schedule pass. * Implementation of API functions related to schedule pass.
* \file api_schedule.cc * \file api_schedule.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <tvm/top/schedule_pass.h> #include <tvm/top/schedule_pass.h>
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* Code mainly used for test purposes. * Code mainly used for test purposes.
* \file api_test.cc * \file api_test.cc
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
/*! /*!
* \file tvm/arith/analyzer.cc * \file tvm/arith/analyzer.cc
*/ */
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -48,7 +48,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr) { ...@@ -48,7 +48,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
void Analyzer::Bind(const Var& var, const Range& range) { void Analyzer::Bind(const Var& var, const Range& range) {
CHECK(range.defined()); CHECK(range.defined());
if (is_one(range->extent)) { if (tir::is_one(range->extent)) {
this->Bind(var, range->min); this->Bind(var, range->min);
} else { } else {
this->const_int_bound.Bind(var, range); this->const_int_bound.Bind(var, range);
...@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() { ...@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() {
} }
bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImmNode>()) { if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value >= lower_bound; return ptr->value >= lower_bound;
} }
auto bd = this->const_int_bound(this->rewrite_simplify(expr)); auto bd = this->const_int_bound(this->rewrite_simplify(expr));
...@@ -102,9 +102,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) { ...@@ -102,9 +102,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
} }
PrimExpr Analyzer::Simplify(const PrimExpr& expr) { PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
if (is_const(expr)) return expr; if (tir::is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr); auto res = this->rewrite_simplify(expr);
if (is_const(res)) return res; if (tir::is_const(res)) return res;
res = this->canonical_simplify(res); res = this->canonical_simplify(res);
return res; return res;
} }
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
* \file bound_deducer.cc * \file bound_deducer.cc
* \brief Utility to deduce bound of expression * \brief Utility to deduce bound of expression
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
// a visitor to find the path to the target variable // a visitor to find the path to the target variable
// from a expression. // from a expression.
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \brief Canonical form based simplification. * \brief Canonical form based simplification.
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
#include "rewrite_simplify.h" #include "rewrite_simplify.h"
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
class SumExpr; class SumExpr;
class SplitExpr; class SplitExpr;
...@@ -157,7 +157,7 @@ class SplitExpr : public PrimExpr { ...@@ -157,7 +157,7 @@ class SplitExpr : public PrimExpr {
inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
if (index.same_as(other->index)) return true; if (index.same_as(other->index)) return true;
return ir::Equal(index, other->index); return tir::Equal(index, other->index);
} }
inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_ARITH_COMPUTE_EXPR_H_ #ifndef TVM_ARITH_COMPUTE_EXPR_H_
#define TVM_ARITH_COMPUTE_EXPR_H_ #define TVM_ARITH_COMPUTE_EXPR_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <limits> #include <limits>
#include <algorithm> #include <algorithm>
...@@ -57,7 +57,7 @@ inline PrimExpr ComputeReduce( ...@@ -57,7 +57,7 @@ inline PrimExpr ComputeReduce(
inline bool GetConst(PrimExpr e, int64_t* out) { inline bool GetConst(PrimExpr e, int64_t* out) {
if (e.dtype().is_vector()) return false; if (e.dtype().is_vector()) return false;
const int64_t* v = as_const_int(e); const int64_t* v = tir::as_const_int(e);
if (v) { if (v) {
*out = *v; return true; *out = *v; return true;
} else { } else {
...@@ -77,37 +77,37 @@ inline bool GetConstInt(PrimExpr e, int* out) { ...@@ -77,37 +77,37 @@ inline bool GetConstInt(PrimExpr e, int* out) {
} }
template<> template<>
inline PrimExpr Compute<ir::AddNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b; return a + b;
} }
template<> template<>
inline PrimExpr Compute<ir::SubNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::SubNode>(PrimExpr a, PrimExpr b) {
return a - b; return a - b;
} }
template<> template<>
inline PrimExpr Compute<ir::MulNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::MulNode>(PrimExpr a, PrimExpr b) {
return a * b; return a * b;
} }
template<> template<>
inline PrimExpr Compute<ir::DivNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::DivNode>(PrimExpr a, PrimExpr b) {
return truncdiv(a, b); return truncdiv(a, b);
} }
template<> template<>
inline PrimExpr Compute<ir::ModNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::ModNode>(PrimExpr a, PrimExpr b) {
return truncmod(a, b); return truncmod(a, b);
} }
template<> template<>
inline PrimExpr Compute<ir::MaxNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::MaxNode>(PrimExpr a, PrimExpr b) {
return max(a, b); return max(a, b);
} }
template<> template<>
inline PrimExpr Compute<ir::MinNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::MinNode>(PrimExpr a, PrimExpr b) {
return min(a, b); return min(a, b);
} }
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#ifndef TVM_ARITH_CONST_FOLD_H_ #ifndef TVM_ARITH_CONST_FOLD_H_
#define TVM_ARITH_CONST_FOLD_H_ #define TVM_ARITH_CONST_FOLD_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include "int_operator.h" #include "int_operator.h"
...@@ -76,7 +76,7 @@ inline bool IsIndexType(const DataType& type) { ...@@ -76,7 +76,7 @@ inline bool IsIndexType(const DataType& type) {
#define TVM_ARITH_CONST_PROPAGATION(BODY) \ #define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::FloatImmNode; \ using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \ const IntImmNode* pa = a.as<IntImmNode>(); \
const IntImmNode* pb = b.as<IntImmNode>(); \ const IntImmNode* pb = b.as<IntImmNode>(); \
const FloatImmNode* fa = a.as<FloatImmNode>(); \ const FloatImmNode* fa = a.as<FloatImmNode>(); \
...@@ -96,7 +96,7 @@ inline bool IsIndexType(const DataType& type) { ...@@ -96,7 +96,7 @@ inline bool IsIndexType(const DataType& type) {
// specialization of constant folders. // specialization of constant folders.
template<> template<>
inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::AddNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pb) return IntImm(rtype, pa->value + pb->value);
...@@ -110,7 +110,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) { ...@@ -110,7 +110,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::SubNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pa && pb) return IntImm(rtype, pa->value - pb->value);
...@@ -122,7 +122,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) { ...@@ -122,7 +122,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::MulNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa && pb) return IntImm(rtype, pa->value * pb->value);
...@@ -148,7 +148,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) { ...@@ -148,7 +148,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::DivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
...@@ -177,7 +177,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) { ...@@ -177,7 +177,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
...@@ -187,7 +187,7 @@ inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) { ...@@ -187,7 +187,7 @@ inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) {
if (pa->value == 0) return a; if (pa->value == 0) return a;
} }
if (pb) { if (pb) {
if (pb->value == 1) return make_zero(rtype); if (pb->value == 1) return tir::make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero"; CHECK_NE(pb->value, 0) << "Divide by zero";
} }
}); });
...@@ -195,7 +195,7 @@ inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) { ...@@ -195,7 +195,7 @@ inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::FloorDivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
...@@ -222,17 +222,17 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) { ...@@ -222,17 +222,17 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({ TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) { if (pa && pb) {
return IntImm(rtype, arith::floormod(pa->value, pb->value)); return IntImm(rtype, floormod(pa->value, pb->value));
} }
if (pa) { if (pa) {
if (pa->value == 0) return a; if (pa->value == 0) return a;
} }
if (pb) { if (pb) {
if (pb->value == 1) return make_zero(rtype); if (pb->value == 1) return tir::make_zero(rtype);
CHECK_NE(pb->value, 0) << "Divide by zero"; CHECK_NE(pb->value, 0) << "Divide by zero";
} }
}); });
...@@ -240,7 +240,7 @@ inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) { ...@@ -240,7 +240,7 @@ inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
...@@ -251,7 +251,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) { ...@@ -251,7 +251,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype(); const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
...@@ -262,7 +262,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) { ...@@ -262,7 +262,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::GTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
...@@ -271,7 +271,7 @@ inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) { ...@@ -271,7 +271,7 @@ inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::GENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
...@@ -280,7 +280,7 @@ inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) { ...@@ -280,7 +280,7 @@ inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::LTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
...@@ -289,7 +289,7 @@ inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) { ...@@ -289,7 +289,7 @@ inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::LENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
...@@ -298,7 +298,7 @@ inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) { ...@@ -298,7 +298,7 @@ inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::EQNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
...@@ -307,7 +307,7 @@ inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) { ...@@ -307,7 +307,7 @@ inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::NENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
...@@ -316,7 +316,7 @@ inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) { ...@@ -316,7 +316,7 @@ inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::AndNode>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>(); const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>(); const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return b; if (pa && pa->value) return b;
...@@ -327,7 +327,7 @@ inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) { ...@@ -327,7 +327,7 @@ inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) { inline PrimExpr TryConstFold<tir::OrNode>(PrimExpr a, PrimExpr b) {
const IntImmNode* pa = a.as<IntImmNode>(); const IntImmNode* pa = a.as<IntImmNode>();
const IntImmNode* pb = b.as<IntImmNode>(); const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return a; if (pa && pa->value) return a;
...@@ -338,7 +338,7 @@ inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) { ...@@ -338,7 +338,7 @@ inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
} }
template<> template<>
inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) { inline PrimExpr TryConstFold<tir::NotNode>(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>(); const IntImmNode* pa = a.as<IntImmNode>();
if (pa) { if (pa) {
return IntImm(DataType::UInt(1), !(pa->value)); return IntImm(DataType::UInt(1), !(pa->value));
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file tvm/arith/const_int_bound.cc * \file tvm/arith/const_int_bound.cc
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <algorithm> #include <algorithm>
#include "int_operator.h" #include "int_operator.h"
#include "pattern_match.h" #include "pattern_match.h"
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
...@@ -133,7 +133,7 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -133,7 +133,7 @@ class ConstIntBoundAnalyzer::Impl :
// a linear search over additional info // a linear search over additional info
// assume we won't have a lot of conditions // assume we won't have a lot of conditions
for (const BoundInfo& info : additional_info_) { for (const BoundInfo& info : additional_info_) {
if (ir::Equal(expr, info.expr)) { if (tir::Equal(expr, info.expr)) {
res = Intersect(res, info.bound); res = Intersect(res, info.bound);
} }
} }
......
...@@ -21,15 +21,16 @@ ...@@ -21,15 +21,16 @@
* \file detect_linear_equation.cc * \file detect_linear_equation.cc
* \brief Utility to detect patterns in the expression. * \brief Utility to detect patterns in the expression.
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
// Linear equation, the components can be undefined. // Linear equation, the components can be undefined.
struct LinearEqEntry { struct LinearEqEntry {
...@@ -211,7 +212,7 @@ bool DetectClipBound( ...@@ -211,7 +212,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, 1)) { if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift // var + shift >=0 -> var >= -shift
if (p.min_value.defined()) { if (p.min_value.defined()) {
p.min_value = ir::MaxNode::make(p.min_value, -ret.base); p.min_value = tir::MaxNode::make(p.min_value, -ret.base);
} else { } else {
p.min_value = -ret.base; p.min_value = -ret.base;
} }
...@@ -220,7 +221,7 @@ bool DetectClipBound( ...@@ -220,7 +221,7 @@ bool DetectClipBound(
if (is_const_int(ret.coeff, -1)) { if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift // -var + shift >=0 -> var <= shift
if (p.max_value.defined()) { if (p.max_value.defined()) {
p.max_value = ir::MinNode::make(p.max_value, ret.base); p.max_value = tir::MinNode::make(p.max_value, ret.base);
} else { } else {
p.max_value = ret.base; p.max_value = ret.base;
} }
...@@ -244,7 +245,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) { ...@@ -244,7 +245,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
// e must be connected by and. // e must be connected by and.
Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) { Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
std::vector<PrimExpr> splits; std::vector<PrimExpr> splits;
SplitCommExpr<ir::AndNode>(e, &splits); SplitCommExpr<tir::AndNode>(e, &splits);
std::unordered_map<const VarNode*, IntervalEntry> rmap; std::unordered_map<const VarNode*, IntervalEntry> rmap;
for (Var v : vars) { for (Var v : vars) {
rmap[v.get()] = IntervalEntry(); rmap[v.get()] = IntervalEntry();
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
* \file bound_deducer.cc * \file bound_deducer.cc
* \brief Utility to deduce bound of expression * \brief Utility to deduce bound of expression
*/ */
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
// Find Read region of the tensor in the stmt. // Find Read region of the tensor in the stmt.
class FuncTouchedDomain final : public StmtExprVisitor { class FuncTouchedDomain final : public StmtExprVisitor {
......
...@@ -47,7 +47,7 @@ inline bool WillOverflow(int64_t x, ...@@ -47,7 +47,7 @@ inline bool WillOverflow(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::AddNode>(int64_t x, inline bool WillOverflow<tir::AddNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
...@@ -57,7 +57,7 @@ inline bool WillOverflow<ir::AddNode>(int64_t x, ...@@ -57,7 +57,7 @@ inline bool WillOverflow<ir::AddNode>(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::SubNode>(int64_t x, inline bool WillOverflow<tir::SubNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
...@@ -67,7 +67,7 @@ inline bool WillOverflow<ir::SubNode>(int64_t x, ...@@ -67,7 +67,7 @@ inline bool WillOverflow<ir::SubNode>(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::MulNode>(int64_t x, inline bool WillOverflow<tir::MulNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
...@@ -84,7 +84,7 @@ inline bool WillOverflow<ir::MulNode>(int64_t x, ...@@ -84,7 +84,7 @@ inline bool WillOverflow<ir::MulNode>(int64_t x,
} }
template<> template<>
inline bool WillOverflow<ir::ModNode>(int64_t x, inline bool WillOverflow<tir::ModNode>(int64_t x,
int64_t y, int64_t y,
int64_t min_value, int64_t min_value,
int64_t max_value) { int64_t max_value) {
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* \brief The integer set functions * \brief The integer set functions
*/ */
#include <tvm/arith/int_set.h> #include <tvm/arith/int_set.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <utility> #include <utility>
...@@ -35,6 +35,11 @@ ...@@ -35,6 +35,11 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using tir::make_const;
using tir::make_zero;
using tir::is_zero;
using tir::is_one;
PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
...@@ -79,7 +84,7 @@ struct is_logical_op { ...@@ -79,7 +84,7 @@ struct is_logical_op {
#define TVM_DECLARE_LOGICAL_OP(OP) \ #define TVM_DECLARE_LOGICAL_OP(OP) \
template<> \ template<> \
struct is_logical_op<ir::OP> { \ struct is_logical_op<tir::OP> { \
static const bool value = true; \ static const bool value = true; \
}; };
...@@ -118,7 +123,7 @@ inline IntervalSet Combine(Analyzer* analyzer, ...@@ -118,7 +123,7 @@ inline IntervalSet Combine(Analyzer* analyzer,
} }
template<> template<>
inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer, inline IntervalSet Combine<tir::AddNode>(Analyzer* analyer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -136,7 +141,7 @@ inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer, ...@@ -136,7 +141,7 @@ inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer,
} }
template<> template<>
inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer, inline IntervalSet Combine<tir::SubNode>(Analyzer* analyer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -155,7 +160,7 @@ inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer, ...@@ -155,7 +160,7 @@ inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer,
template<> template<>
inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer, inline IntervalSet Combine<tir::MulNode>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -178,7 +183,7 @@ inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer, ...@@ -178,7 +183,7 @@ inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer,
PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value); return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) { } else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode; using tir::SelectNode;
PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
PrimExpr e1 = a->min_value * b->min_value; PrimExpr e1 = a->min_value * b->min_value;
PrimExpr e2 = a->max_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value;
...@@ -190,7 +195,7 @@ inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer, ...@@ -190,7 +195,7 @@ inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer,
} }
template<> template<>
inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer, inline IntervalSet Combine<tir::DivNode>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -213,7 +218,7 @@ inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer, ...@@ -213,7 +218,7 @@ inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer,
PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value); return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) { } else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode; using tir::SelectNode;
PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
PrimExpr e1 = a->min_value / b->min_value; PrimExpr e1 = a->min_value / b->min_value;
PrimExpr e2 = a->max_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value;
...@@ -225,7 +230,7 @@ inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer, ...@@ -225,7 +230,7 @@ inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer,
} }
template<> template<>
inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer, inline IntervalSet Combine<tir::ModNode>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -256,7 +261,7 @@ inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer, ...@@ -256,7 +261,7 @@ inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer,
template<> template<>
inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer, inline IntervalSet Combine<tir::FloorDivNode>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -279,7 +284,7 @@ inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer, ...@@ -279,7 +284,7 @@ inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer,
PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value); return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) { } else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode; using tir::SelectNode;
PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e1 = floordiv(a->min_value, b->min_value);
PrimExpr e2 = floordiv(a->max_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value);
...@@ -291,7 +296,7 @@ inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer, ...@@ -291,7 +296,7 @@ inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer,
} }
template<> template<>
inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer, inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -317,7 +322,7 @@ inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer, ...@@ -317,7 +322,7 @@ inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer,
} }
template<> template<>
inline IntervalSet Combine<ir::MaxNode>(Analyzer* analzyer, inline IntervalSet Combine<tir::MaxNode>(Analyzer* analzyer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -330,7 +335,7 @@ inline IntervalSet Combine<ir::MaxNode>(Analyzer* analzyer, ...@@ -330,7 +335,7 @@ inline IntervalSet Combine<ir::MaxNode>(Analyzer* analzyer,
} }
template<> template<>
inline IntervalSet Combine<ir::MinNode>(Analyzer* analzyer, inline IntervalSet Combine<tir::MinNode>(Analyzer* analzyer,
IntervalSet a, IntervalSet a,
IntervalSet b) { IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) { if (a->IsSinglePoint() && b->IsSinglePoint()) {
...@@ -351,7 +356,7 @@ IntervalSet ToIntervalSet(IntSet set) { ...@@ -351,7 +356,7 @@ IntervalSet ToIntervalSet(IntSet set) {
return IntervalSet::Everything(); return IntervalSet::Everything();
} }
using namespace ir; using namespace tir;
// Simplified version of int set evaluator that operates on IntervalSet // Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset. // We might use better set analysis in the future to replace the intervalset.
...@@ -603,17 +608,17 @@ bool IntSet::is_single_point() const { ...@@ -603,17 +608,17 @@ bool IntSet::is_single_point() const {
bool IntSet::can_prove_positive() const { bool IntSet::can_prove_positive() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_positive_const(ir::Simplify(s_int->min_value))); return (s_int && is_positive_const(tir::Simplify(s_int->min_value)));
} }
bool IntSet::can_prove_negative() const { bool IntSet::can_prove_negative() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
return (s_int && is_negative_const(ir::Simplify(s_int->max_value))); return (s_int && is_negative_const(tir::Simplify(s_int->max_value)));
} }
bool IntSet::can_prove_non_positive() const { bool IntSet::can_prove_non_positive() const {
if (const auto* s_int = (*this).as<IntervalSetNode>()) { if (const auto* s_int = (*this).as<IntervalSetNode>()) {
auto max = ir::Simplify(s_int->max_value); auto max = tir::Simplify(s_int->max_value);
return is_zero(max) || is_negative_const(max); return is_zero(max) || is_negative_const(max);
} }
return false; return false;
...@@ -621,7 +626,7 @@ bool IntSet::can_prove_non_positive() const { ...@@ -621,7 +626,7 @@ bool IntSet::can_prove_non_positive() const {
bool IntSet::can_prove_non_negative() const { bool IntSet::can_prove_non_negative() const {
if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
auto min = ir::Simplify(s_int->min_value); auto min = tir::Simplify(s_int->min_value);
return is_zero(min) || is_positive_const(min); return is_zero(min) || is_positive_const(min);
} }
return false; return false;
...@@ -665,7 +670,7 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) { ...@@ -665,7 +670,7 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
// Range related code // Range related code
inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) {
return is_zero(ir::Simplify(lhs - rhs)); return is_zero(tir::Simplify(lhs - rhs));
} }
IntSet IntSet::range(Range r) { IntSet IntSet::range(Range r) {
...@@ -692,8 +697,8 @@ IntSet Union(const Array<IntSet>& sets) { ...@@ -692,8 +697,8 @@ IntSet Union(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) { for (size_t i = 1; i < sets.size(); ++i) {
x = Union(&ana, x, ToIntervalSet(sets[i])); x = Union(&ana, x, ToIntervalSet(sets[i]));
} }
return IntervalSet(ir::Simplify(x->min_value), return IntervalSet(tir::Simplify(x->min_value),
ir::Simplify(x->max_value)); tir::Simplify(x->max_value));
} }
IntSet Intersect(const Array<IntSet>& sets) { IntSet Intersect(const Array<IntSet>& sets) {
...@@ -704,8 +709,8 @@ IntSet Intersect(const Array<IntSet>& sets) { ...@@ -704,8 +709,8 @@ IntSet Intersect(const Array<IntSet>& sets) {
for (size_t i = 1; i < sets.size(); ++i) { for (size_t i = 1; i < sets.size(); ++i) {
x = Intersect(&ana, x, ToIntervalSet(sets[i])); x = Intersect(&ana, x, ToIntervalSet(sets[i]));
} }
return IntervalSet(ir::Simplify(x->min_value), return IntervalSet(tir::Simplify(x->min_value),
ir::Simplify(x->max_value)); tir::Simplify(x->max_value));
} }
Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) { Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_ARITH_INTERVAL_SET_H_ #define TVM_ARITH_INTERVAL_SET_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <limits> #include <limits>
#include "const_fold.h" #include "const_fold.h"
......
...@@ -20,14 +20,14 @@ ...@@ -20,14 +20,14 @@
/*! /*!
* \file tvm/arith/ir_mutator_with_analyzer.cc * \file tvm/arith/ir_mutator_with_analyzer.cc
*/ */
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include "ir_mutator_with_analyzer.h" #include "ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const ForNode* op) { VisitStmt_(const ForNode* op) {
...@@ -39,7 +39,7 @@ VisitStmt_(const ForNode* op) { ...@@ -39,7 +39,7 @@ VisitStmt_(const ForNode* op) {
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const LetStmtNode* op) { VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!tir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
} }
// We keep the let-binding here // We keep the let-binding here
...@@ -128,7 +128,7 @@ VisitStmt_(const AssertStmtNode* op) { ...@@ -128,7 +128,7 @@ VisitStmt_(const AssertStmtNode* op) {
PrimExpr IRMutatorWithAnalyzer:: PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const CallNode* op) { VisitExpr_(const CallNode* op) {
// add condition context to if_then_else // add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) {
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
PrimExpr true_value, false_value; PrimExpr true_value, false_value;
{ {
...@@ -162,7 +162,7 @@ VisitExpr_(const CallNode* op) { ...@@ -162,7 +162,7 @@ VisitExpr_(const CallNode* op) {
PrimExpr IRMutatorWithAnalyzer:: PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const LetNode* op) { VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!tir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
} }
// We keep the let-binding here // We keep the let-binding here
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
#define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <utility> #include <utility>
...@@ -40,7 +40,7 @@ namespace arith { ...@@ -40,7 +40,7 @@ namespace arith {
* *
* \sa src/arithmetic/ir_mutator_with_analyzer.cc * \sa src/arithmetic/ir_mutator_with_analyzer.cc
*/ */
class IRMutatorWithAnalyzer : public ir::StmtExprMutator { class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
public: public:
explicit IRMutatorWithAnalyzer(Analyzer* analyzer) explicit IRMutatorWithAnalyzer(Analyzer* analyzer)
: analyzer_(analyzer) {} : analyzer_(analyzer) {}
...@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { ...@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
using StmtExprMutator::VisitExpr_; using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information. // override functions that need to populate the context information.
Stmt VisitStmt_(const ir::ForNode* op) override; Stmt VisitStmt_(const tir::ForNode* op) override;
Stmt VisitStmt_(const ir::LetStmtNode* op) override; Stmt VisitStmt_(const tir::LetStmtNode* op) override;
Stmt VisitStmt_(const ir::IfThenElseNode* op) override; Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
Stmt VisitStmt_(const ir::AttrStmtNode* op) override; Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
Stmt VisitStmt_(const ir::AssertStmtNode* op) override; Stmt VisitStmt_(const tir::AssertStmtNode* op) override;
PrimExpr VisitExpr_(const ir::LetNode* op) override; PrimExpr VisitExpr_(const tir::LetNode* op) override;
PrimExpr VisitExpr_(const ir::SelectNode* op) override; PrimExpr VisitExpr_(const tir::SelectNode* op) override;
PrimExpr VisitExpr_(const ir::CallNode* op) override; PrimExpr VisitExpr_(const tir::CallNode* op) override;
PrimExpr VisitExpr_(const ir::ReduceNode* op) override; PrimExpr VisitExpr_(const tir::ReduceNode* op) override;
protected: protected:
/*! \brief internal analyzer field. */ /*! \brief internal analyzer field. */
......
...@@ -26,11 +26,11 @@ ...@@ -26,11 +26,11 @@
#define TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
namespace tvm { namespace tvm {
namespace ir { namespace tir {
class IRVisitorWithAnalyzer final : public StmtExprVisitor { class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public: public:
...@@ -71,6 +71,6 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { ...@@ -71,6 +71,6 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
}; };
} // namespace ir } // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ #endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* \brief Modular set analysis * \brief Modular set analysis
*/ */
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
TVM_REGISTER_NODE_TYPE(ModularSetNode); TVM_REGISTER_NODE_TYPE(ModularSetNode);
......
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
* return (max(x, y) + z).Eval(); * return (max(x, y) + z).Eval();
* } * }
* *
* tvm::Var tx, ty; * tvm::tir::Var tx, ty;
* arith::PVar<IntImm> c; * arith::PVar<IntImm> c;
* arith::PVar<Var> v; * arith::PVar<Var> v;
* // We can match integer and Var, both of which are * // We can match integer and Var, both of which are
...@@ -65,7 +65,7 @@ ...@@ -65,7 +65,7 @@
#ifndef TVM_ARITH_PATTERN_MATCH_H_ #ifndef TVM_ARITH_PATTERN_MATCH_H_
#define TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tuple> #include <tuple>
#include "const_fold.h" #include "const_fold.h"
...@@ -135,7 +135,7 @@ class PEqualChecker<PrimExpr> { ...@@ -135,7 +135,7 @@ class PEqualChecker<PrimExpr> {
public: public:
bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
return ir::Equal(lhs, rhs); return tir::Equal(lhs, rhs);
} }
}; };
...@@ -283,7 +283,7 @@ class PConstWithTypeLike : ...@@ -283,7 +283,7 @@ class PConstWithTypeLike :
void InitMatch_() const {} void InitMatch_() const {}
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::IntImmNode* ptr = node.as<ir::IntImmNode>()) { if (const tir::IntImmNode* ptr = node.as<tir::IntImmNode>()) {
return ptr->value == value_; return ptr->value == value_;
} else { } else {
return false; return false;
...@@ -291,7 +291,7 @@ class PConstWithTypeLike : ...@@ -291,7 +291,7 @@ class PConstWithTypeLike :
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return make_const(ref_.Eval().dtype(), value_); return tir::make_const(ref_.Eval().dtype(), value_);
} }
private: private:
...@@ -325,30 +325,30 @@ class PConstWithTypeLike : ...@@ -325,30 +325,30 @@ class PConstWithTypeLike :
// raise ambiguity error for operator overload of / and % // raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a)); TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a)); TVM_PATTERN_BINARY_OP_EX(operator%, tir::ModNode, DivAmbiguityError(a));
// arithmetic expressions // arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::AddNode); TVM_PATTERN_BINARY_OP(operator+, tir::AddNode);
TVM_PATTERN_BINARY_OP(operator-, ir::SubNode); TVM_PATTERN_BINARY_OP(operator-, tir::SubNode);
TVM_PATTERN_BINARY_OP(operator*, ir::MulNode); TVM_PATTERN_BINARY_OP(operator*, tir::MulNode);
TVM_PATTERN_BINARY_OP(min, ir::MinNode); TVM_PATTERN_BINARY_OP(min, tir::MinNode);
TVM_PATTERN_BINARY_OP(max, ir::MaxNode); TVM_PATTERN_BINARY_OP(max, tir::MaxNode);
TVM_PATTERN_BINARY_OP(div, ir::DivNode); TVM_PATTERN_BINARY_OP(div, tir::DivNode);
TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode); TVM_PATTERN_BINARY_OP(truncdiv, tir::DivNode);
TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode); TVM_PATTERN_BINARY_OP(truncmod, tir::ModNode);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode); TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDivNode);
TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode); TVM_PATTERN_BINARY_OP(floormod, tir::FloorModNode);
// logical expressions // logical expressions
TVM_PATTERN_BINARY_OP(operator>, ir::GTNode); TVM_PATTERN_BINARY_OP(operator>, tir::GTNode);
TVM_PATTERN_BINARY_OP(operator>=, ir::GENode); TVM_PATTERN_BINARY_OP(operator>=, tir::GENode);
TVM_PATTERN_BINARY_OP(operator<, ir::LTNode); TVM_PATTERN_BINARY_OP(operator<, tir::LTNode);
TVM_PATTERN_BINARY_OP(operator<=, ir::LENode); TVM_PATTERN_BINARY_OP(operator<=, tir::LENode);
TVM_PATTERN_BINARY_OP(operator==, ir::EQNode); TVM_PATTERN_BINARY_OP(operator==, tir::EQNode);
TVM_PATTERN_BINARY_OP(operator!=, ir::NENode); TVM_PATTERN_BINARY_OP(operator!=, tir::NENode);
TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode); TVM_PATTERN_BINARY_OP(operator&&, tir::AndNode);
TVM_PATTERN_BINARY_OP(operator||, ir::OrNode); TVM_PATTERN_BINARY_OP(operator||, tir::OrNode);
/*! /*!
* \brief Pattern not expression. * \brief Pattern not expression.
...@@ -365,7 +365,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > { ...@@ -365,7 +365,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::NotNode* ptr = node.as<ir::NotNode>()) { if (const tir::NotNode* ptr = node.as<tir::NotNode>()) {
if (!value_.Match_(ptr->a)) return false; if (!value_.Match_(ptr->a)) return false;
return true; return true;
} else { } else {
...@@ -374,7 +374,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > { ...@@ -374,7 +374,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return ir::NotNode::make(value_.Eval()); return tir::NotNode::make(value_.Eval());
} }
private: private:
...@@ -411,7 +411,7 @@ class PSelectExpr : ...@@ -411,7 +411,7 @@ class PSelectExpr :
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::SelectNode* ptr = node.as<ir::SelectNode>()) { if (const tir::SelectNode* ptr = node.as<tir::SelectNode>()) {
if (!condition_.Match_(ptr->condition)) return false; if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false; if (!true_value_.Match_(ptr->true_value)) return false;
if (!false_value_.Match_(ptr->false_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false;
...@@ -422,7 +422,7 @@ class PSelectExpr : ...@@ -422,7 +422,7 @@ class PSelectExpr :
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return ir::SelectNode::make( return tir::SelectNode::make(
condition_.Eval(), true_value_.Eval(), false_value_.Eval()); condition_.Eval(), true_value_.Eval(), false_value_.Eval());
} }
...@@ -473,7 +473,7 @@ class PCastExpr : ...@@ -473,7 +473,7 @@ class PCastExpr :
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::CastNode* ptr = node.as<ir::CastNode>()) { if (const tir::CastNode* ptr = node.as<tir::CastNode>()) {
if (!dtype_.Match_(ptr->dtype)) return false; if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false; if (!value_.Match_(ptr->value)) return false;
return true; return true;
...@@ -483,7 +483,7 @@ class PCastExpr : ...@@ -483,7 +483,7 @@ class PCastExpr :
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return ir::CastNode::make(dtype_.Eval(), value_.Eval()); return tir::CastNode::make(dtype_.Eval(), value_.Eval());
} }
private: private:
...@@ -531,7 +531,7 @@ class PRampExpr : ...@@ -531,7 +531,7 @@ class PRampExpr :
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::RampNode* ptr = node.as<ir::RampNode>()) { if (const tir::RampNode* ptr = node.as<tir::RampNode>()) {
if (!base_.Match_(ptr->base)) return false; if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false; if (!stride_.Match_(ptr->stride)) return false;
if (!lanes_.Match_(ptr->lanes)) return false; if (!lanes_.Match_(ptr->lanes)) return false;
...@@ -542,7 +542,7 @@ class PRampExpr : ...@@ -542,7 +542,7 @@ class PRampExpr :
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
} }
private: private:
...@@ -593,7 +593,7 @@ class PBroadcastExpr : ...@@ -593,7 +593,7 @@ class PBroadcastExpr :
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::BroadcastNode* ptr = node.as<ir::BroadcastNode>()) { if (const tir::BroadcastNode* ptr = node.as<tir::BroadcastNode>()) {
if (!value_.Match_(ptr->value)) return false; if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false; if (!lanes_.Match_(ptr->lanes)) return false;
return true; return true;
...@@ -603,7 +603,7 @@ class PBroadcastExpr : ...@@ -603,7 +603,7 @@ class PBroadcastExpr :
} }
PrimExpr Eval() const { PrimExpr Eval() const {
return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
} }
private: private:
...@@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor { ...@@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor {
}; };
struct PCallExprMatchFunctor { struct PCallExprMatchFunctor {
const ir::CallNode* call_; const tir::CallNode* call_;
bool matched_{true}; bool matched_{true};
explicit PCallExprMatchFunctor(const ir::CallNode* call) explicit PCallExprMatchFunctor(const tir::CallNode* call)
: call_(call) {} : call_(call) {}
template<typename T> template<typename T>
...@@ -705,7 +705,7 @@ class PCallExpr : ...@@ -705,7 +705,7 @@ class PCallExpr :
} }
bool Match_(const ObjectRef& node) const { bool Match_(const ObjectRef& node) const {
if (const ir::CallNode* ptr = node.as<ir::CallNode>()) { if (const tir::CallNode* ptr = node.as<tir::CallNode>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false; if (ptr->name != Op::kName) return false;
detail::PCallExprMatchFunctor fmatch(ptr); detail::PCallExprMatchFunctor fmatch(ptr);
...@@ -730,8 +730,8 @@ class PCallExpr : ...@@ -730,8 +730,8 @@ class PCallExpr :
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \ struct OpName { \
static PrimExpr Eval(Array<PrimExpr> args) { \ static PrimExpr Eval(Array<PrimExpr> args) { \
return ir::CallNode::make(args[0].dtype(), kName, args, \ return tir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \ tir::CallNode::PureIntrinsic); \
} \ } \
static constexpr const char* kName = IntrinStr; \ static constexpr const char* kName = IntrinStr; \
}; \ }; \
...@@ -751,8 +751,8 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); ...@@ -751,8 +751,8 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \ struct OpName { \
static PrimExpr Eval(Array<PrimExpr> args) { \ static PrimExpr Eval(Array<PrimExpr> args) { \
return ir::CallNode::make(args[0].dtype(), kName, args, \ return tir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \ tir::CallNode::PureIntrinsic); \
} \ } \
static constexpr const char* kName = IntrinStr; \ static constexpr const char* kName = IntrinStr; \
}; \ }; \
...@@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); ...@@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
// if_then_else // if_then_else
struct PIfThenElseOp { struct PIfThenElseOp {
static PrimExpr Eval(Array<PrimExpr> args) { static PrimExpr Eval(Array<PrimExpr> args) {
return ir::CallNode::make( return tir::CallNode::make(
args[1].dtype(), kName, args, args[1].dtype(), kName, args,
ir::CallNode::PureIntrinsic); tir::CallNode::PureIntrinsic);
} }
static constexpr const char* kName = "tvm_if_then_else"; static constexpr const char* kName = "tvm_if_then_else";
}; };
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
// Acknowledgement: Most rewrite-rules are from Halide. // Acknowledgement: Most rewrite-rules are from Halide.
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <algorithm> #include <algorithm>
#include "const_fold.h" #include "const_fold.h"
#include "pattern_match.h" #include "pattern_match.h"
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
// macro for doing simple rewrite // macro for doing simple rewrite
#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ #define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
...@@ -1747,7 +1747,7 @@ VisitExpr_(const CastNode* op) { ...@@ -1747,7 +1747,7 @@ VisitExpr_(const CastNode* op) {
PrimExpr RewriteSimplifier::Impl:: PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const LetNode* op) { VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!tir::HasSideEffect(value)) {
// it is fine to discard the let binding // it is fine to discard the let binding
// because the value will always be inlined in the simplifier. // because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_ARITH_REWRITE_SIMPLIFY_H_ #define TVM_ARITH_REWRITE_SIMPLIFY_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "const_fold.h" #include "const_fold.h"
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
/*! /*!
* \brief Rewrite-based simplifier. * \brief Rewrite-based simplifier.
......
...@@ -21,17 +21,17 @@ ...@@ -21,17 +21,17 @@
* \file stmt_simplify.cc * \file stmt_simplify.cc
* \brief Statement simplifier based on analyzer * \brief Statement simplifier based on analyzer
*/ */
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include "ir_mutator_with_analyzer.h" #include "ir_mutator_with_analyzer.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
using namespace ir; using namespace tir;
class StmtSimplifier : public IRMutatorWithAnalyzer { class StmtSimplifier : public IRMutatorWithAnalyzer {
public: public:
...@@ -59,7 +59,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -59,7 +59,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Stmt VisitStmt_(const LetStmtNode* op) { Stmt VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) { if (!tir::HasSideEffect(value)) {
// it is fine to discard the let binding // it is fine to discard the let binding
// because the call to simplify will always inline the var. // because the call to simplify will always inline the var.
analyzer_->Bind(op->var, value); analyzer_->Bind(op->var, value);
...@@ -93,7 +93,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -93,7 +93,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith } // namespace arith
namespace ir { namespace tir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
...@@ -123,5 +123,5 @@ PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) { ...@@ -123,5 +123,5 @@ PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) { Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(std::move(stmt), vrange); return CanonicalSimplify(std::move(stmt), vrange);
} }
} // namespace ir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -60,7 +60,7 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { ...@@ -60,7 +60,7 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == attr::virtual_thread) {
Var var = op->node.as<tvm::IterVarNode>()->var; Var var = op->node.as<tir::IterVarNode>()->var;
const auto *extent = op->value.as<IntImmNode>(); const auto *extent = op->value.as<IntImmNode>();
CHECK(extent); CHECK(extent);
......
...@@ -26,14 +26,15 @@ ...@@ -26,14 +26,15 @@
#ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_ #ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_
#define TVM_AUTOTVM_FEATURE_VISITOR_H_ #define TVM_AUTOTVM_FEATURE_VISITOR_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
namespace autotvm { namespace autotvm {
using namespace tvm::ir; using namespace tvm::tir;
/*! /*!
* \brief Type of for loop, used as one-hot encoding in features * \brief Type of for loop, used as one-hot encoding in features
...@@ -69,7 +70,7 @@ class FeatureVisitor : public StmtExprVisitor { ...@@ -69,7 +70,7 @@ class FeatureVisitor : public StmtExprVisitor {
* \param ann_type The type for the for loop * \param ann_type The type for the for loop
* \return skip Whether skip this node * \return skip Whether skip this node
*/ */
virtual bool EnterItervar_(tvm::Var var, int64_t length, AnnotationType ann_type) = 0; virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0;
/*! \brief Exit a for loop subtree */ /*! \brief Exit a for loop subtree */
virtual void ExitItervar_() = 0; virtual void ExitItervar_() = 0;
/*! /*!
...@@ -77,7 +78,7 @@ class FeatureVisitor : public StmtExprVisitor { ...@@ -77,7 +78,7 @@ class FeatureVisitor : public StmtExprVisitor {
* \param buffer_var The buffer to access. * \param buffer_var The buffer to access.
* \param index Index expression * \param index Index expression
*/ */
virtual void EnterMem_(tvm::Var buffer_var, tvm::PrimExpr index) = 0; virtual void EnterMem_(tir::Var buffer_var, tvm::PrimExpr index) = 0;
/*! \brief Exit a memory access node */ /*! \brief Exit a memory access node */
virtual void ExitMem_() = 0; virtual void ExitMem_() = 0;
}; };
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/expr_functor.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <stack> #include <stack>
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <unordered_map> #include <unordered_map>
#include <string> #include <string>
#include "../runtime/meta_data.h" #include "../runtime/meta_data.h"
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/top/operation.h> #include <tvm/top/operation.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -37,6 +37,7 @@ namespace tvm { ...@@ -37,6 +37,7 @@ namespace tvm {
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
using tir::LoweredFunc;
TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode);
...@@ -58,39 +59,39 @@ Target DefaultTargetHost(Target target) { ...@@ -58,39 +59,39 @@ Target DefaultTargetHost(Target target) {
} }
} }
Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
DataType dtype, DataType dtype,
std::string name, std::string name,
int data_alignment, int data_alignment,
int offset_factor, int offset_factor,
bool compact) { bool compact) {
auto data = Var(name, DataType::Handle()); auto data = tir::Var(name, DataType::Handle());
bool has_any = false; bool has_any = false;
if (!compact) { if (!compact) {
for (const auto& it : shape) { for (const auto& it : shape) {
if (it.as<VarNode>()) { if (it.as<tir::VarNode>()) {
has_any = true; has_any = true;
break; break;
} }
} }
} }
BufferType buffer_type = has_any ? kAutoBroadcast : kDefault; tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault;
PrimExpr elem_offset; PrimExpr elem_offset;
if (offset_factor != 0) { if (offset_factor != 0) {
elem_offset = Var(name + "_elem_offset", shape[0].dtype()); elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype());
} else { } else {
elem_offset = PrimExpr(); elem_offset = PrimExpr();
} }
return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", return tir::BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
data_alignment, offset_factor, buffer_type); data_alignment, offset_factor, buffer_type);
} }
void GetBinds(const Array<top::Tensor>& args, void GetBinds(const Array<top::Tensor>& args,
bool compact, bool compact,
const std::unordered_map<top::Tensor, Buffer>& binds, const std::unordered_map<top::Tensor, tir::Buffer>& binds,
Map<top::Tensor, Buffer>* out_binds, Map<top::Tensor, tir::Buffer>* out_binds,
Array<ObjectRef>* out_arg_list, Array<ObjectRef>* out_arg_list,
const BuildConfig& config) { const BuildConfig& config) {
*out_binds = binds; *out_binds = binds;
...@@ -117,50 +118,50 @@ void GetBinds(const Array<top::Tensor>& args, ...@@ -117,50 +118,50 @@ void GetBinds(const Array<top::Tensor>& args,
* \param config The build configuration. * \param config The build configuration.
* \return The built Stmt. * \return The built Stmt.
*/ */
Stmt BuildStmt(top::Schedule sch, tir::Stmt BuildStmt(top::Schedule sch,
const Array<top::Tensor>& args, const Array<top::Tensor>& args,
const std::unordered_map<top::Tensor, Buffer>& binds, const std::unordered_map<top::Tensor, tir::Buffer>& binds,
bool loop_partition, bool loop_partition,
Array<ObjectRef> *out_arg_list, Array<ObjectRef> *out_arg_list,
const BuildConfig& config) { const BuildConfig& config) {
sch = sch.normalize(); sch = sch.normalize();
// Phase 0 // Phase 0
auto bounds = top::InferBound(sch); auto bounds = top::InferBound(sch);
auto stmt = top::ScheduleOps(sch, bounds, false); auto stmt = top::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt); stmt = tir::InjectPrefetch(stmt);
bool compact = ir::VerifyCompactBuffer(stmt); bool compact = tir::VerifyCompactBuffer(stmt);
Map<top::Tensor, Buffer> out_binds; Map<top::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, out_arg_list, config); GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
// Phase 1 // Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64, stmt = tir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers); config->instrument_bound_checkers);
stmt = ir::CanonicalSimplify(stmt); stmt = tir::CanonicalSimplify(stmt);
if (loop_partition) { if (loop_partition) {
stmt = ir::LoopPartition(stmt, config->partition_const_loop); stmt = tir::LoopPartition(stmt, config->partition_const_loop);
} }
if (config->disable_vectorize) { if (config->disable_vectorize) {
stmt = ir::SkipVectorize(stmt); stmt = tir::SkipVectorize(stmt);
} else { } else {
stmt = ir::VectorizeLoop(stmt); stmt = tir::VectorizeLoop(stmt);
} }
stmt = ir::InjectVirtualThread(stmt); stmt = tir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt); stmt = tir::StorageRewrite(stmt);
stmt = ir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config->auto_unroll_max_extent, config->unroll_explicit); config->auto_unroll_max_extent, config->unroll_explicit);
// Phase 2 // Phase 2
stmt = ir::Simplify(stmt); stmt = tir::Simplify(stmt);
stmt = ir::RemoveNoOp(stmt); stmt = tir::RemoveNoOp(stmt);
if (!(config->disable_select_rewriting)) if (!(config->disable_select_rewriting))
stmt = ir::RewriteUnsafeSelect(stmt); stmt = tir::RewriteUnsafeSelect(stmt);
if (config->instrument_bound_checkers) if (config->instrument_bound_checkers)
stmt = ir::InstrumentBoundCheckers(stmt); stmt = tir::InstrumentBoundCheckers(stmt);
return stmt; return stmt;
} }
...@@ -168,11 +169,11 @@ Stmt BuildStmt(top::Schedule sch, ...@@ -168,11 +169,11 @@ Stmt BuildStmt(top::Schedule sch,
Array<LoweredFunc> lower(top::Schedule sch, Array<LoweredFunc> lower(top::Schedule sch,
const Array<top::Tensor>& args, const Array<top::Tensor>& args,
const std::string& name, const std::string& name,
const std::unordered_map<top::Tensor, Buffer>& binds, const std::unordered_map<top::Tensor, tir::Buffer>& binds,
const BuildConfig& config) { const BuildConfig& config) {
Array<ObjectRef> out_arg_list; Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
} }
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
...@@ -190,27 +191,27 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -190,27 +191,27 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
Array<LoweredFunc> fdevice; Array<LoweredFunc> fdevice;
for (const auto& x : funcs) { for (const auto& x : funcs) {
CHECK(ir::VerifyMemory(x, target->device_type)) CHECK(tir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in " << "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?"; << x->func_name() << ". Did you forget to bind?";
if (x->func_type == kMixedFunc) { if (x->func_type == tir::kMixedFunc) {
auto func = x; auto func = x;
if (config->detect_global_barrier) { if (config->detect_global_barrier) {
func = ir::ThreadSync(func, "global"); func = tir::ThreadSync(func, "global");
} }
func = ir::ThreadSync(func, "shared"); func = tir::ThreadSync(func, "shared");
func = ir::ThreadSync(func, "warp"); func = tir::ThreadSync(func, "warp");
func = ir::LowerThreadAllreduce(func, target->thread_warp_size); func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = ir::SplitHostDevice(func); auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]); fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) { for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
fdevice.push_back(*f); fdevice.push_back(*f);
} }
} else if (x->func_type == kHostFunc) { } else if (x->func_type == tir::kHostFunc) {
fhost.push_back(x); fhost.push_back(x);
} else if (x->func_type == kDeviceFunc) { } else if (x->func_type == tir::kDeviceFunc) {
fdevice.push_back(x); fdevice.push_back(x);
} else { } else {
LOG(FATAL) << "unknown function type " << x->func_type; LOG(FATAL) << "unknown function type " << x->func_type;
...@@ -220,7 +221,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -220,7 +221,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fdevice.size(); i++) { for (size_t i = 0; i < fdevice.size(); i++) {
auto warp_size = target->thread_warp_size; auto warp_size = target->thread_warp_size;
auto func = fdevice[i]; auto func = fdevice[i];
func = ir::LowerWarpMemory(fdevice[i], warp_size); func = tir::LowerWarpMemory(fdevice[i], warp_size);
fdevice.Set(i, func); fdevice.Set(i, func);
} }
...@@ -234,7 +235,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -234,7 +235,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fdevice.size(); ++i) { for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i]; auto func = fdevice[i];
func = ir::LowerIntrin(func, target->target_name); func = tir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func); fdevice.Set(i, func);
} }
...@@ -247,17 +248,17 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -247,17 +248,17 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::BindDeviceType(func, target->device_type); func = tir::BindDeviceType(func, target->device_type);
func = ir::LowerDeviceStorageAccessInfo(func); func = tir::LowerDeviceStorageAccessInfo(func);
func = ir::LowerTVMBuiltin(func); func = tir::LowerTVMBuiltin(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = ir::LowerIntrin(func, target_host->target_name); func = tir::LowerIntrin(func, target_host->target_name);
func = ir::LowerDeviceStorageAccessInfo(func); func = tir::LowerDeviceStorageAccessInfo(func);
func = ir::CombineContextCall(func); func = tir::CombineContextCall(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
return {fhost, fdevice}; return {fhost, fdevice};
...@@ -580,7 +581,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc") ...@@ -580,7 +581,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
std::vector<std::string> tags_vector; std::vector<std::string> tags_vector;
for (auto& tag : tags) { for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::ir::StringImmNode>()->value); tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
} }
generic_func generic_func
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \brief Common utilities to generated C style code. * \brief Common utilities to generated C style code.
*/ */
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
...@@ -37,17 +37,17 @@ ...@@ -37,17 +37,17 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
runtime::Module Build(const Array<LoweredFunc>& funcs, runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target) { const std::string& target) {
std::string mode = target; std::string mode = target;
size_t pos = mode.find(' '); size_t pos = mode.find(' ');
if (pos != std::string::npos) { if (pos != std::string::npos) {
mode = mode.substr(0, pos); mode = mode.substr(0, pos);
} }
Array<LoweredFunc> transformed_funcs; Array<tir::LoweredFunc> transformed_funcs;
if (BuildConfig::Current()->disable_assert) { if (BuildConfig::Current()->disable_assert) {
for (const auto& x : funcs) { for (const auto& x : funcs) {
auto func = ir::SkipAssert(x); auto func = tir::SkipAssert(x);
transformed_funcs.push_back(func); transformed_funcs.push_back(func);
} }
} }
......
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
#include <iomanip> #include <iomanip>
#include <cctype> #include <cctype>
#include "codegen_c.h" #include "codegen_c.h"
#include "../pass/ir_util.h"
#include "../arith/compute_expr.h" #include "../arith/compute_expr.h"
#include "../tir/pass/ir_util.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
void CodeGenC::Init(bool output_ssa) { void CodeGenC::Init(bool output_ssa) {
print_ssa_form_ = output_ssa; print_ssa_form_ = output_ssa;
...@@ -809,18 +809,18 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { ...@@ -809,18 +809,18 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
} }
void CodeGenC::VisitStmt_(const AttrStmtNode* op) { void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) { if (!var_idmap_.count(iv->var.get())) {
BindThreadIndex(iv); BindThreadIndex(iv);
} }
} }
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value; alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
} else if (op->attr_key == ir::attr::volatile_scope) { } else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
volatile_buf_.insert(v); volatile_buf_.insert(v);
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
#ifndef TVM_CODEGEN_CODEGEN_C_H_ #ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_ #define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
...@@ -37,7 +38,7 @@ ...@@ -37,7 +38,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
/*! /*!
* \brief A base class to generate C code. * \brief A base class to generate C code.
* *
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_C_HOST_H_ #define TVM_CODEGEN_CODEGEN_C_HOST_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -93,9 +94,9 @@ std::string CodeGenCUDA::Finish() { ...@@ -93,9 +94,9 @@ std::string CodeGenCUDA::Finish() {
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) { void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
CHECK(is_const_int(op->min, 0)); CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) { if (op->for_type == tir::ForType::Unrolled) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_CUDA_H_ #define TVM_CODEGEN_CODEGEN_CUDA_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "codegen_c.h" #include "codegen_c.h"
...@@ -43,7 +43,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -43,7 +43,7 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
} }
// override behavior // override behavior
void VisitStmt_(const ir::ForNode* op) final; void VisitStmt_(const tir::ForNode* op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp( void PrintVecBinaryOp(
......
...@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { ...@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
return e.vid; return e.vid;
} }
std::string CodeGenSourceBase::AllocVarID(const VarNode* v) { std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
CHECK(!var_idmap_.count(v)) CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint; << "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint; std::string key = v->name_hint;
...@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const VarNode* v) { ...@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const VarNode* v) {
return vid; return vid;
} }
std::string CodeGenSourceBase::GetVarID(const VarNode* v) const { std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const {
auto it = var_idmap_.find(v); auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end()) CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint; << "Find undefined Variable " << v->name_hint;
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ #ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ #define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -66,13 +67,13 @@ class CodeGenSourceBase { ...@@ -66,13 +67,13 @@ class CodeGenSourceBase {
* \param v The variable. * \param v The variable.
* \return the variable name. * \return the variable name.
*/ */
std::string AllocVarID(const VarNode* v); std::string AllocVarID(const tir::VarNode* v);
/*! /*!
* \brief Get a variable name. * \brief Get a variable name.
* \param v The variable. * \param v The variable.
* \return the variable name. * \return the variable name.
*/ */
std::string GetVarID(const VarNode* v) const; std::string GetVarID(const tir::VarNode* v) const;
/*! /*!
* \brief Get the SSA ID corresponds to src * \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment * If necessary, generate new assignment
...@@ -110,7 +111,7 @@ class CodeGenSourceBase { ...@@ -110,7 +111,7 @@ class CodeGenSourceBase {
/*! \brief the stream to be printed */ /*! \brief the stream to be printed */
std::ostringstream stream; std::ostringstream stream;
/*! \brief name of each variable */ /*! \brief name of each variable */
std::unordered_map<const VarNode*, std::string> var_idmap_; std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
private: private:
/*! \brief assignment map of ssa */ /*! \brief assignment map of ssa */
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_CODEGEN_VHLS_H_ #define TVM_CODEGEN_CODEGEN_VHLS_H_
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <string> #include <string>
#include "codegen_c.h" #include "codegen_c.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file intrin_rule_default.cc * \file intrin_rule_default.cc
* \brief Default intrinsic rules. * \brief Default intrinsic rules.
*/ */
#include <tvm/expr_operator.h> #include <tvm/tir/op.h>
#include "intrin_rule.h" #include "intrin_rule.h"
namespace tvm { namespace tvm {
......
...@@ -24,15 +24,15 @@ ...@@ -24,15 +24,15 @@
#ifndef TVM_CODEGEN_INTRIN_RULE_H_ #ifndef TVM_CODEGEN_INTRIN_RULE_H_
#define TVM_CODEGEN_INTRIN_RULE_H_ #define TVM_CODEGEN_INTRIN_RULE_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <string> #include <string>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
using namespace ir; using namespace tir;
// Add float suffix to the intrinsics // Add float suffix to the intrinsics
struct FloatSuffix { struct FloatSuffix {
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "../build_common.h" #include "../build_common.h"
#include "../codegen_source_base.h" #include "../codegen_source_base.h"
#include "../../pass/ir_util.h"
#include "../../runtime/rocm/rocm_module.h" #include "../../runtime/rocm/rocm_module.h"
namespace tvm { namespace tvm {
......
...@@ -58,7 +58,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { ...@@ -58,7 +58,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
} }
PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
using namespace ir; using namespace tir;
const PrimExpr& e = call->args[2]; const PrimExpr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
...@@ -71,7 +71,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { ...@@ -71,7 +71,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e); vcnt_args.push_back(e);
return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
} }
// Popcount lowering rule: // Popcount lowering rule:
...@@ -96,7 +96,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { ...@@ -96,7 +96,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8); vcnt8_args.push_back(input8);
PrimExpr vcnt8 = ir::CallNode::make( PrimExpr vcnt8 = tir::CallNode::make(
uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit // Accumulation 8->16bit
...@@ -104,7 +104,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { ...@@ -104,7 +104,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8); vcnt16_args.push_back(vcnt8);
PrimExpr vcnt16 = ir::CallNode::make( PrimExpr vcnt16 = tir::CallNode::make(
uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) { if (call->dtype.bits() == 16) {
return vcnt16; return vcnt16;
...@@ -115,7 +115,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { ...@@ -115,7 +115,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16); vcnt32_args.push_back(vcnt16);
PrimExpr vcnt32 = ir::CallNode::make( PrimExpr vcnt32 = tir::CallNode::make(
uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) { if (call->dtype.bits() == 32) {
return vcnt32; return vcnt32;
...@@ -126,7 +126,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { ...@@ -126,7 +126,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32); vcnt64_args.push_back(vcnt32);
return ir::CallNode::make( return tir::CallNode::make(
call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
} }
......
...@@ -23,11 +23,10 @@ ...@@ -23,11 +23,10 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "codegen_cpu.h" #include "codegen_cpu.h"
#include "../../pass/ir_util.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -423,7 +422,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { ...@@ -423,7 +422,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs. // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
// This is easier than set the alias scope manually. // This is easier than set the alias scope manually.
using llvm::BasicBlock; using llvm::BasicBlock;
Array<Var> vargs = ir::UndefinedVars(op->body, {}); Array<Var> vargs = tir::UndefinedVars(op->body, {});
std::vector<llvm::Value*> arg_values; std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types; std::vector<llvm::Type*> arg_types;
for (Var v : vargs) { for (Var v : vargs) {
...@@ -513,7 +512,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { ...@@ -513,7 +512,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
llvm::Function::PrivateLinkage, llvm::Function::PrivateLinkage,
"__tvm_parallel_lambda", module_.get()); "__tvm_parallel_lambda", module_.get());
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
Array<Var> vfields = ir::UndefinedVars(body, {}); Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes; uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes); llvm::Value* cdata = PackClosureData(vfields, &nbytes);
BasicBlock* par_launch_end = CheckCallSuccess( BasicBlock* par_launch_end = CheckCallSuccess(
...@@ -582,7 +581,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod ...@@ -582,7 +581,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
} }
// allocate and setup the closure, call the closure. // allocate and setup the closure, call the closure.
uint64_t nbytes; uint64_t nbytes;
Array<Var> vfields = ir::UndefinedVars(body, {}); Array<Var> vfields = tir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields, &nbytes); llvm::Value* cdata = PackClosureData(vfields, &nbytes);
BasicBlock* init_end = CheckCallSuccess( BasicBlock* init_end = CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
...@@ -692,7 +691,7 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue, ...@@ -692,7 +691,7 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall(
RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs), RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, *ret_tcode})); ret_value, *ret_tcode}));
DataType r_api_type = ir::APIType(r_type); DataType r_api_type = tir::APIType(r_type);
*rvalue = builder_->CreateAlignedLoad( *rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value, builder_->CreatePointerCast(ret_value,
LLVMType(r_api_type)->getPointerTo()), LLVMType(r_api_type)->getPointerTo()),
...@@ -870,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { ...@@ -870,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
} }
void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::coproc_uop_scope) { if (op->attr_key == tir::attr::coproc_uop_scope) {
this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body); this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) { } else if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op); this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) { } else if (attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") { if (op->attr_key == "pragma_parallel_stride_pattern") {
...@@ -892,7 +891,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { ...@@ -892,7 +891,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
builder_->CreateCall( builder_->CreateCall(
RuntimeTVMParallelBarrier(), RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv}); {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == ir::attr::pragma_import_llvm) { } else if (op->attr_key == tir::attr::pragma_import_llvm) {
const StringImmNode* value = op->value.as<StringImmNode>(); const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr); CHECK(value != nullptr);
this->HandleImport(value->value); this->HandleImport(value->value);
......
...@@ -30,9 +30,6 @@ ...@@ -30,9 +30,6 @@
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "codegen_cpu.h" #include "codegen_cpu.h"
#include "../build_common.h" #include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -1179,17 +1176,17 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { ...@@ -1179,17 +1176,17 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
} }
} }
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
alloc_storage_info_[v].scope = alloc_storage_info_[v].scope =
runtime::StorageScope::make(op->value.as<StringImmNode>()->value); runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == ir::attr::storage_alignment) { } else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
alloc_storage_info_[v].alignment = alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImmNode>()->value); static_cast<int>(op->value.as<IntImmNode>()->value);
} else if (op->attr_key == ir::attr::volatile_scope) { } else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
volatile_buf_.insert(v); volatile_buf_.insert(v);
......
...@@ -26,8 +26,10 @@ ...@@ -26,8 +26,10 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
...@@ -37,11 +39,13 @@ ...@@ -37,11 +39,13 @@
#include <unordered_set> #include <unordered_set>
#include "llvm_common.h" #include "llvm_common.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
/*! /*!
* \brief A base class to generate a LLVM. * \brief A base class to generate a LLVM.
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "../build_common.h" #include "../build_common.h"
#include "../../pass/ir_util.h"
#include "../../runtime/cuda/cuda_module.h" #include "../../runtime/cuda/cuda_module.h"
namespace tvm { namespace tvm {
......
...@@ -90,11 +90,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ...@@ -90,11 +90,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())), LLVMType(DataType::Float(32, from.lanes())),
{ {
MakeValue(ir::CallNode::make( MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value}, DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
ir::CallNode::PureIntrinsic)), tir::CallNode::PureIntrinsic)),
MakeValue( MakeValue(
ir::BroadcastNode::make( tir::BroadcastNode::make(
FloatImm(DataType::Float(32), 0), from.lanes())), FloatImm(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
...@@ -104,9 +104,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ...@@ -104,9 +104,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
if (from.lanes() >= 8 && has_f16c) { if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin( return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
{MakeValue(ir::CallNode::make( {MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value}, DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
ir::CallNode::PureIntrinsic))}); tir::CallNode::PureIntrinsic))});
} }
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
*/ */
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/tir/op.h>
#include "intrin_rule_llvm.h" #include "intrin_rule_llvm.h"
namespace tvm { namespace tvm {
...@@ -63,22 +64,24 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") ...@@ -63,22 +64,24 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0]; PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>(); const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
const PrimExpr& x = call->args[0]; const PrimExpr& x = call->args[0];
PrimExpr one = make_const(x.dtype(), 1); PrimExpr one = make_const(x.dtype(), 1);
PrimExpr two = make_const(x.dtype(), 2); PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_two = make_const(x.dtype(), -2); PrimExpr neg_two = make_const(x.dtype(), -2);
PrimExpr exp_neg2x = ir::CallNode::make( PrimExpr exp_neg2x = tir::CallNode::make(
x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic); x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_pos2x = ir::CallNode::make( PrimExpr exp_pos2x = tir::CallNode::make(
x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic); x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic);
PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::SelectNode::make( *rv = tir::SelectNode::make(
x >= make_zero(x.dtype()), tanh_pos, tanh_neg); x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
}); });
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_ #define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
...@@ -38,7 +38,7 @@ namespace codegen { ...@@ -38,7 +38,7 @@ namespace codegen {
template<unsigned id, int num_signature> template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0]; PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>(); const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<PrimExpr> cargs; Array<PrimExpr> cargs;
// intrin id. // intrin id.
...@@ -48,14 +48,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -48,14 +48,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
for (PrimExpr arg : call->args) { for (PrimExpr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::CallNode::make( *rv = tir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic); call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic);
} }
template<unsigned id, int num_signature> template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0]; PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>(); const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<PrimExpr> cargs; Array<PrimExpr> cargs;
// intrin id. // intrin id.
...@@ -64,8 +64,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -64,8 +64,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
for (PrimExpr arg : call->args) { for (PrimExpr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::CallNode::make( *rv = tir::CallNode::make(
call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic); call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic);
} }
} // namespace codegen } // namespace codegen
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
*/ */
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <sstream> #include <sstream>
...@@ -32,7 +32,7 @@ namespace codegen { ...@@ -32,7 +32,7 @@ namespace codegen {
inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0]; PrimExpr e = args[0];
using namespace ir; using namespace tir;
const CallNode* call = e.as<CallNode>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
*/ */
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <sstream> #include <sstream>
...@@ -33,7 +33,7 @@ namespace codegen { ...@@ -33,7 +33,7 @@ namespace codegen {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0]; PrimExpr e = args[0];
using namespace ir; using namespace tir;
const CallNode* call = e.as<CallNode>(); const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
std::ostringstream intrinsic_name; std::ostringstream intrinsic_name;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
// Use libspirv for parsing and validating code. // Use libspirv for parsing and validating code.
#include <libspirv.h> #include <libspirv.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include "codegen_spirv.h" #include "codegen_spirv.h"
#include "../build_common.h" #include "../build_common.h"
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* \file codegen_spirv.cc * \file codegen_spirv.cc
* \brief Generate SPIRV block * \brief Generate SPIRV block
*/ */
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <string> #include <string>
#include "codegen_spirv.h" #include "codegen_spirv.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
...@@ -406,7 +406,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { ...@@ -406,7 +406,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
CHECK((me->coeff % ramp->lanes) == 0 && CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0) (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV"; << "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index = ir::Simplify( PrimExpr vec_index = tir::Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess( spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index)); ptr_type, buffer, MakeValue(vec_index));
...@@ -484,7 +484,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { ...@@ -484,7 +484,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
CHECK((me->coeff % ramp->lanes) == 0 && CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0) (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV"; << "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index = ir::Simplify( PrimExpr vec_index = tir::Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess( spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index)); ptr_type, buffer, MakeValue(vec_index));
...@@ -615,12 +615,12 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { ...@@ -615,12 +615,12 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
} }
} }
} else if (op->attr_key == ir::attr::storage_scope) { } else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
storage_info_[v].scope = storage_info_[v].scope =
runtime::StorageScope::make(op->value.as<StringImmNode>()->value); runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == ir::attr::volatile_scope) { } else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>(); const VarNode* v = op->node.as<VarNode>();
CHECK(v); CHECK(v);
storage_info_[v].is_volatile = true; storage_info_[v].is_volatile = true;
......
...@@ -25,9 +25,9 @@ ...@@ -25,9 +25,9 @@
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_ #define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <vector> #include <vector>
#include <memory> #include <memory>
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
/*! /*!
* \brief Code generator into SPIRV * \brief Code generator into SPIRV
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file intrin_rule_spirv.cc * \file intrin_rule_spirv.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <GLSL.std.450.h> #include <GLSL.std.450.h>
namespace tvm { namespace tvm {
...@@ -34,7 +34,7 @@ using namespace runtime; ...@@ -34,7 +34,7 @@ using namespace runtime;
template<unsigned id> template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0]; PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>(); const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr); CHECK(call != nullptr);
Array<PrimExpr> cargs; Array<PrimExpr> cargs;
// intrin id. // intrin id.
...@@ -43,8 +43,8 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { ...@@ -43,8 +43,8 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
for (PrimExpr arg : call->args) { for (PrimExpr arg : call->args) {
cargs.push_back(arg); cargs.push_back(arg);
} }
*rv = ir::CallNode::make( *rv = tir::CallNode::make(
call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic); call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic);
} }
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define TVM_CODEGEN_SPIRV_IR_BUILDER_H_ #define TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file codegen_stackvm.cc * \file codegen_stackvm.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "codegen_stackvm.h" #include "codegen_stackvm.h"
...@@ -29,7 +30,7 @@ ...@@ -29,7 +30,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
// map struct field kind to runtime variants // map struct field kind to runtime variants
// We keep two separate enums to ensure runtime/compiler isolation. // We keep two separate enums to ensure runtime/compiler isolation.
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_ #ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_ #define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace ir; using namespace tir;
using runtime::StackVM; using runtime::StackVM;
/*! /*!
......
...@@ -31,7 +31,7 @@ namespace contrib { ...@@ -31,7 +31,7 @@ namespace contrib {
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using namespace ir; using namespace tir;
std::string dot_to_underscore(std::string s) { std::string dot_to_underscore(std::string s) {
for (auto &ch : s) for (auto &ch : s)
...@@ -288,7 +288,7 @@ void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) { ...@@ -288,7 +288,7 @@ void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) {
} }
void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
auto iter_var = op->node.as<IterVarNode>(); auto iter_var = op->node.as<IterVarNode>();
CHECK(iter_var); CHECK(iter_var);
binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
...@@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { ...@@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
indent_ += tab_; indent_ += tab_;
PrintStmt(op->body); PrintStmt(op->body);
indent_ -= tab_; indent_ -= tab_;
} else if (op->attr_key == ir::attr::realize_scope) { } else if (op->attr_key == tir::attr::realize_scope) {
auto v = Downcast<FunctionRef>(op->node); auto v = Downcast<FunctionRef>(op->node);
alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value; alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
PrintStmt(op->body); PrintStmt(op->body);
......
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <tvm/ir_functor_ext.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/top/schedule.h> #include <tvm/top/schedule.h>
#include <map> #include <map>
#include <string> #include <string>
...@@ -39,7 +39,7 @@ namespace tvm { ...@@ -39,7 +39,7 @@ namespace tvm {
namespace contrib { namespace contrib {
using namespace top; using namespace top;
using namespace ir; using namespace tir;
/*! /*!
* \brief A base class to generate Hybrid Script. * \brief A base class to generate Hybrid Script.
* *
......
...@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode); ...@@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
using namespace ir; using namespace tir;
// Equal handler. // Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
*/ */
#include <tvm/ir/env_func.h> #include <tvm/ir/env_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
namespace tvm { namespace tvm {
......
...@@ -29,14 +29,23 @@ ...@@ -29,14 +29,23 @@
// //
// Rationale: convert from IterVar and top::Tensor // Rationale: convert from IterVar and top::Tensor
#include <tvm/top/tensor.h> #include <tvm/top/tensor.h>
#include <tvm/expr.h> #include <tvm/tir/expr.h>
namespace tvm { namespace tvm {
PrimExpr::PrimExpr(int32_t value)
: PrimExpr(IntImm(DataType::Int(32), value)) {}
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(std::string str)
: PrimExpr(tir::StringImmNode::make(str)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) { PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker; using runtime::ObjectTypeChecker;
if (ptr->IsInstance<IterVarNode>()) { if (ptr->IsInstance<tir::IterVarNode>()) {
return IterVar(ptr)->var; return tir::IterVar(ptr)->var;
} }
if (ptr->IsInstance<top::TensorNode>()) { if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ptr)(); return top::Tensor(ptr)();
...@@ -47,6 +56,7 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) { ...@@ -47,6 +56,7 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
return PrimExpr(ptr); return PrimExpr(ptr);
} }
IntImm::IntImm(DataType dtype, int64_t value) { IntImm::IntImm(DataType dtype, int64_t value) {
CHECK(dtype.is_scalar()) CHECK(dtype.is_scalar())
<< "ValueError: IntImm can only take scalar."; << "ValueError: IntImm can only take scalar.";
...@@ -66,6 +76,17 @@ TVM_REGISTER_GLOBAL("make.IntImm") ...@@ -66,6 +76,17 @@ TVM_REGISTER_GLOBAL("make.IntImm")
return IntImm(dtype, value); return IntImm(dtype, value);
}); });
TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->dtype << ")" << op->value;
}
});
FloatImm::FloatImm(DataType dtype, double value) { FloatImm::FloatImm(DataType dtype, double value) {
CHECK_EQ(dtype.lanes(), 1) CHECK_EQ(dtype.lanes(), 1)
...@@ -81,6 +102,49 @@ TVM_REGISTER_GLOBAL("make.FloatImm") ...@@ -81,6 +102,49 @@ TVM_REGISTER_GLOBAL("make.FloatImm")
return FloatImm(dtype, value); return FloatImm(dtype, value);
}); });
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
case 64:
stream << op->value;
break;
case 32:
stream << op->value << 'f';
break;
case 16:
stream << op->value << 'h';
break;
default:
LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
}
});
Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>(
begin,
tir::is_zero(begin) ? end : (end - begin))) {
}
Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
GlobalVar::GlobalVar(std::string name_hint) { GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>(); ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
...@@ -101,4 +165,46 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -101,4 +165,46 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "GlobalVar(" << node->name_hint << ")"; p->stream << "GlobalVar(" << node->name_hint << ")";
}); });
// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
if (i != 0) {
p->stream << ", ";
}
p->Print(op->data[i]);
}
p->stream << ']';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) {
p->stream << ", ";
}
p->Print(it->first);
p->stream << ": ";
p->Print(it->second);
}
p->stream << '}';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) {
p->stream << ", ";
}
p->stream << '\"' << it->first << "\": ";
p->Print(it->second);
}
p->stream << '}';
});
} // namespace tvm } // namespace tvm
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
// TODO(tqchen): Update to use String container after it is merged. // TODO(tqchen): Update to use String container after it is merged.
#include <tvm/ir.h> #include <tvm/tir/expr.h>
#include <stack> #include <stack>
#include <unordered_set> #include <unordered_set>
...@@ -268,7 +268,7 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { ...@@ -268,7 +268,7 @@ void SequentialNode::ResolveDependency(const IRModule& mod) {
inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array, inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
const std::string& pass_name) { const std::string& pass_name) {
for (auto x : pass_array) { for (auto x : pass_array) {
auto* str_name = x.as<ir::StringImmNode>(); auto* str_name = x.as<tir::StringImmNode>();
CHECK(str_name) << "pass name must be str"; CHECK(str_name) << "pass name must be str";
if (str_name->value == pass_name) return true; if (str_name->value == pass_name) return true;
} }
...@@ -310,7 +310,7 @@ IRModule SequentialNode::operator()(const IRModule& module, ...@@ -310,7 +310,7 @@ IRModule SequentialNode::operator()(const IRModule& module,
if (!PassEnabled(pass_info)) continue; if (!PassEnabled(pass_info)) continue;
// resolve dependencies // resolve dependencies
for (const auto& it : pass_info->required) { for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImmNode>(); const auto* name = it.as<tvm::tir::StringImmNode>();
CHECK(name); CHECK(name);
mod = GetPass(name->value)(mod, pass_ctx); mod = GetPass(name->value)(mod, pass_ctx);
} }
...@@ -349,7 +349,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -349,7 +349,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "opt_level: " << node->opt_level; p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: [" << "\n"; p->stream << "required passes: [" << "\n";
for (const auto& it : node->required) { for (const auto& it : node->required) {
const auto* str = it.as<tvm::ir::StringImmNode>(); const auto* str = it.as<tvm::tir::StringImmNode>();
p->stream << str->value << ", "; p->stream << str->value << ", ";
} }
p->stream << "]\n"; p->stream << "]\n";
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file expr.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/expr_operator.h>
#include <memory>
#include <limits>
namespace tvm {
PrimExpr::PrimExpr(int32_t value)
: PrimExpr(IntImm(DataType::Int(32), value)) {}
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {}
Var::Var(std::string name_hint, DataType t)
: Var(make_object<VarNode>(t, name_hint)) {}
VarNode::VarNode(DataType t, std::string name_hint) {
this->dtype = t;
this->name_hint = std::move(name_hint);
}
SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}
SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}
Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>(
begin,
is_zero(begin) ? end : (end - begin))) {
}
Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
std::string thread_tag) {
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
n->dom = dom;
n->var = var;
n->iter_type = t;
n->thread_tag = thread_tag;
return IterVar(n);
}
IterVar thread_axis(Range dom, std::string tag) {
return IterVarNode::make(
dom, Var(tag), kThreadIndex, tag);
}
IterVar reduce_axis(Range dom, std::string name) {
return IterVarNode::make(
dom, Var(name), kCommReduce);
}
void Dump(const ObjectRef& n) {
std::cerr << n << "\n";
}
Var var(std::string name_hint, DataType t) {
return Var(name_hint, t);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->dtype << ")" << op->value;
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", ";
}
if (op->dom.defined()) {
p->stream << op->dom;
}
if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag;
}
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm
...@@ -49,4 +49,8 @@ NodePrinter::FType& NodePrinter::vtable() { ...@@ -49,4 +49,8 @@ NodePrinter::FType& NodePrinter::vtable() {
static FType inst; static FType inst;
return inst; return inst;
} }
void Dump(const ObjectRef& n) {
std::cerr << n << "\n";
}
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment