Unverified Commit 0218557c by Tianqi Chen Committed by GitHub

[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. (#3533)

* [INFA][IR] Build and Evolve Low-level IR. Remove dep from HalideIR.


* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>

* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
parent 2d53f84d
......@@ -76,7 +76,6 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
add_definitions(-DHalide_SHARED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
......@@ -112,8 +111,8 @@ else(MSVC)
endif(MSVC)
# add source group
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})
......@@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/node/*.cc
src/schedule/*.cc
)
......@@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
......@@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm)
# Related headers
target_include_directories(
tvm
PUBLIC "3rdparty/HalideIR/src"
PUBLIC "topi/include")
target_include_directories(
tvm_topi
......@@ -295,11 +289,6 @@ if (INSTALL_DEV)
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
FILES_MATCHING
PATTERN "*.h"
......@@ -319,8 +308,6 @@ endif(INSTALL_DEV)
# More target definitions
if(MSVC)
target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
......
......@@ -591,7 +591,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
......
......@@ -89,8 +89,8 @@ inline TNodeRef NullValue() {
}
template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
......
......@@ -221,7 +221,7 @@ class Layout : public NodeRef {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
......@@ -243,7 +243,7 @@ class Layout : public NodeRef {
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
if (var->var->name_hint == axis.name()) {
return true;
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file tvm/dtype.h
* \brief Data type used in IR.
*/
#ifndef TVM_DTYPE_H_
#define TVM_DTYPE_H_
#include "runtime/packed_func.h"
namespace tvm {
class Expr;
/*!
* \brief Primitive data types in tvm.
*/
class DataType {
public:
/*! \brief default constructor */
DataType() {}
/*!
* \brief Constructor
* \param dtype The DLDataType
*/
explicit DataType(DLDataType dtype)
: data_(dtype) {}
/*!
* \brief Constructor
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
*/
DataType(int code, int bits, int lanes) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
}
/*! \return The type code. */
int code() const {
return static_cast<int>(data_.code);
}
/*! \return number of bits in the data. */
int bits() const {
return static_cast<int>(data_.bits);
}
/*! \return number of bytes to store each scalar. */
int bytes() const {
return (bits() + 7) / 8;
}
/*! \return number of lanes in the data. */
int lanes() const {
return static_cast<int>(data_.lanes);
}
/*! \return whether type is a scalar type. */
bool is_scalar() const {
return lanes() == 1;
}
/*! \return whether type is a scalar type. */
bool is_bool() const {
return code() == kDLUInt && bits() == 1;
}
/*! \return whether type is a float type. */
bool is_float() const {
return code() == kDLFloat;
}
/*! \return whether type is an int type. */
bool is_int() const {
return code() == kDLInt;
}
/*! \return whether type is an uint type. */
bool is_uint() const {
return code() == kDLUInt;
}
/*! \return whether type is a handle type. */
bool is_handle() const {
return code() == kHandle;
}
/*! \return whether type is a vector type. */
bool is_vector() const {
return lanes() > 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const {
return DataType(data_.code, data_.bits, lanes);
}
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
* \return the result type.
*/
DataType with_bits(int bits) const {
return DataType(data_.code, bits, data_.lanes);
}
/*!
* \brief Get the scalar version of the type.
* \return the result type.
*/
DataType element_of() const {
return with_lanes(1);
}
// operator overloadings
bool operator==(const DataType& other) const {
return
data_.code == other.data_.code &&
data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
}
bool operator!=(const DataType& other) const {
return !operator==(other);
}
operator DLDataType () const {
return data_;
}
/*! \return the maximum possible value in this format. */
TVM_DLL Expr max() const;
/*! \return the minimum possible value in this format. */
TVM_DLL Expr min() const;
private:
DLDataType data_;
};
/*!
* \brief Construct an int type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \return The constructed data type.
*/
inline DataType Int(int bits, int lanes = 1) {
return DataType(kDLInt, bits, lanes);
}
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType UInt(int bits, int lanes = 1) {
return DataType(kDLUInt, bits, lanes);
}
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Bool(int lanes = 1) {
return UInt(1, lanes);
}
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Float(int bits, int lanes = 1) {
return DataType(kDLFloat, bits, lanes);
}
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Handle(int bits = 64, int lanes = 1) {
return DataType(kHandle, bits, lanes);
}
/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
*/
inline DataType TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}
/*!
* \brief Convert DLDataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DataType TVMType2Type(DLDataType t) {
return DataType(t.code, t.bits, t.lanes);
}
/*!
* \brief Convert DataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DLDataType Type2TVMType(DataType t) {
return t.operator DLDataType();
}
/*!
* \brief Get the number of bytes needed in a vector.
* \param dtype The data type.
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
// Overload print function.
inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
using namespace tvm::runtime;
return os << dtype.operator DLDataType();
}
// Backward compatibility
using Type = DataType;
} // namespace tvm
#endif // TVM_DTYPE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
......@@ -25,72 +24,107 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "base.h"
#include "dtype.h"
#include "node/container.h"
#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
using HalideIR::Float;
using HalideIR::Bool;
using HalideIR::Int;
using HalideIR::UInt;
using HalideIR::Handle;
using HalideIR::ExprHash;
using HalideIR::ExprEqual;
using HalideIR::Expr;
using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
/*! \brief Base node of all expressions. */
class ExprNode : public Node {
public:
/*! \brief The data type of the expression. */
DataType type;
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
};
/*! \brief Container of all expressions. */
class Expr : public NodeRef {
public:
Expr() {}
explicit Expr(NodePtr<Node> ptr) : NodeRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL Expr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL Expr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL Expr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType type() const {
return static_cast<const ExprNode*>(get())->type;
}
}
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
}
/*! \brief type indicate the container type */
using ContainerType = ExprNode;
};
inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}
/*! \brief Base node of all statements. */
class StmtNode : public Node {
public:
static constexpr const char* _type_key = "Stmt";
TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
};
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
/*! \brief Container of all statements */
class Stmt : public NodeRef {
public:
TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
};
class Var;
/*!
* \brief A variable node in the IR.
*
* A vraible is uniquely identified by its address.
*
* Each variable is only binded once in the following nodes:
* - Allocate
* - For
* - Let
* - LetStmt
*/
class Variable : public ExprNode {
public:
/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("name", &name_hint);
}
static constexpr const char* _type_key = "Variable";
TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
};
/*! \brief a named variable in TVM */
class Var : public HalideIR::VarExpr {
class Var : public Expr {
public:
EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(NodePtr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
explicit Var(NodePtr<Node> n) : Expr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
Type t = Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
......@@ -99,10 +133,47 @@ class Var : public HalideIR::VarExpr {
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* get() const {
return static_cast<Variable*>(node_.get());
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
// Backward compatibility, will be removed later.
using VarExpr = Var;
using BaseExprNode = ExprNode;
using ExprHash = NodeHash;
using ExprEqual = NodeEqual;
class Integer;
/*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
};
/*!
* \brief Container of constant integer (IntImm).
......@@ -148,34 +219,52 @@ class Integer : public Expr {
using ContainerType = IntImm;
};
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr min;
/*! \brief the extend of range */
Expr extent;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
/*! \brief container class of iteration variable. */
class IterVarNode;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
/*!
* \brief same as HalideIR::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public HalideIR::IR::Range {
static constexpr const char* _type_key = "Range";
TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
};
/*! \brief Range constainer */
class Range : public NodeRef {
public:
/*! \brief constructor */
Range() {}
explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
TVM_DLL Range(Expr begin, Expr end);
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
/*!
* \brief construct a new range with min and extent
* The corresponding constructor is removed,
* because that is counter convention of tradition meaning
* of range(begin, end)
*
* \param min The minimum range.
* \param extent The extent of the range.
*/
static Range make_by_min_extent(Expr min, Expr extent);
// declare range.
TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
};
/*! \brief container class of iteration variable. */
class IterVarNode;
using Region = Array<Range>;
/*!
......@@ -289,9 +378,6 @@ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
using Domain = Array<Range>;
// print functions for expr
TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
......@@ -364,7 +450,7 @@ inline const char* IterVarType2String(IterVarType t) {
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
TVM_DLL Var var(std::string name_hint, Type t = Int(32));
/*
* \brief Template function to convert Map to unordered_map
......@@ -382,6 +468,32 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
return ret;
}
// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit IRPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
TVM_DLL void Print(const NodeRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = IRFunctor<void(const NodeRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace tvm
namespace std {
......
......@@ -16,18 +16,19 @@
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir.h
* \brief Additional high level nodes in the IR
*/
// Acknowledgement: Most low-level IR nodes originate from Halide.
#ifndef TVM_IR_H_
#define TVM_IR_H_
#include <ir/Expr.h>
#include <ir/IR.h>
#include <type_traits>
#include <string>
#include <vector>
#include <utility>
#include "base.h"
#include "expr.h"
#include "runtime/util.h"
......@@ -35,17 +36,561 @@
namespace tvm {
namespace ir {
using HalideIR::Internal::BaseExprNode;
using HalideIR::Internal::ExprNode;
using HalideIR::Internal::StmtNode;
using HalideIR::Internal::IRNodeType;
using HalideIR::Internal::ForType;
using HalideIR::DeviceAPI;
using IntImm = tvm::IntImm;
using Variable = tvm::Variable;
/*! \brief constant unsigned integer. */
class UIntImm : public ExprNode {
public:
/*! \brief The constant value content. */
uint64_t value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL static Expr make(Type t, uint64_t value);
static constexpr const char* _type_key = "UIntImm";
TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode);
};
/*! \brief Floating point constants. */
class FloatImm : public ExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL static Expr make(Type t, double value);
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode);
};
/*! \brief String constants, only used in asserts. */
class StringImm : public ExprNode {
public:
/*! \brief The constant value content. */
std::string value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL Expr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode);
};
/*!
* \brief Cast value from one data type to another.
* \note The lanes of value should keep fixed.
*/
class Cast : public ExprNode {
public:
/*! \brief Original data type. */
Expr value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL static Expr make(Type t, Expr v);
static constexpr const char* _type_key = "Cast";
TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode);
};
/*!
* \brief Base template to implement binary ops.
* \tparam T The type of the child class.
*/
template<typename T>
class BinaryOpNode : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
}
static Expr make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
NodePtr<T> node = make_node<T>();
node->type = a.type();
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
};
/*! \brief a + b */
class Add : public BinaryOpNode<Add> {
public:
static constexpr const char* _type_key = "Add";
};
/*! \brief a - b */
class Sub : public BinaryOpNode<Sub> {
public:
static constexpr const char* _type_key = "Sub";
};
/*! \brief a * b */
class Mul : public BinaryOpNode<Mul> {
public:
static constexpr const char* _type_key = "Mul";
};
/*!
* \brief a / b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
class Div : public BinaryOpNode<Div> {
public:
static constexpr const char* _type_key = "Div";
};
/*!
* \brief a % b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
class Mod : public BinaryOpNode<Mod> {
public:
static constexpr const char* _type_key = "Mod";
};
/*! \brief min(a, b) */
class Min : public BinaryOpNode<Min> {
public:
static constexpr const char* _type_key = "Min";
};
/*! \brief max(a, b) */
class Max : public BinaryOpNode<Max> {
public:
static constexpr const char* _type_key = "Max";
};
/*!
* \brief Base template to implement comparison ops.
* \tparam T The type of the child class.
*/
template<typename T>
class CmpOpNode : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
}
static Expr make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
NodePtr<T> node = make_node<T>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
};
/*! \brief a == b */
class EQ : public CmpOpNode<EQ> {
public:
static constexpr const char* _type_key = "EQ";
};
/*! \brief a != b */
class NE : public CmpOpNode<NE> {
public:
static constexpr const char* _type_key = "NE";
};
/*! \brief a < b */
class LT : public CmpOpNode<LT> {
public:
static constexpr const char* _type_key = "LT";
};
/*! \brief a <= b */
struct LE : public CmpOpNode<LE> {
public:
static constexpr const char* _type_key = "LE";
};
/*! \brief a > b */
class GT : public CmpOpNode<GT> {
public:
static constexpr const char* _type_key = "GT";
};
/*! \brief a >= b */
class GE : public CmpOpNode<GE> {
public:
static constexpr const char* _type_key = "GE";
};
/*! \brief a && b */
class And : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &(this->type));
v->Visit("a", &a);
v->Visit("b", &b);
}
TVM_DLL static Expr make(Expr a, Expr b);
static constexpr const char* _type_key = "And";
TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode);
};
/*! \brief a || b */
class Or : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
/*! \brief The right operand. */
Expr b;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("a", &a);
v->Visit("b", &b);
}
TVM_DLL static Expr make(Expr a, Expr b);
static constexpr const char* _type_key = "Or";
TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode);
};
/*! \brief !a */
class Not : public ExprNode {
public:
/*! \brief The input operand. */
Expr a;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("a", &a);
}
TVM_DLL static Expr make(Expr a);
static constexpr const char* _type_key = "Not";
TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode);
};
/*!
* \brief return true_value if condition is true, otherwise return false_value.
* \note Both true_value and false_value could be evaluated
* regardless of the condition value.
* Do not use it to guard against out of bound access,
* please use if_then_else instead.
*/
class Select : public ExprNode {
public:
/*! \brief The condition */
Expr condition;
/*! \brief value to be returned when condition is true. */
Expr true_value;
/*! \brief value to be returned when condition is false. */
Expr false_value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("condition", &condition);
v->Visit("true_value", &true_value);
v->Visit("false_value", &false_value);
}
TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
static constexpr const char* _type_key = "Select";
TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode);
};
/*!
* \brief Load the value from buffer_var.
*
* Equivalent to ((DType*)buffer_var)[index]
* where DType is the type specified by type().element_of().
*
* For example, if type = float32x3, then the load will corresponds to
*
* \code
*
* auto buffer = static_cast<float*>(buffer_var);
* auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
*
* \endcode
*/
class Load : public ExprNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The index locations to be loaded. */
Expr index;
/*! \brief The predicate to mask which lanes would be loaded. */
Expr predicate;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
}
TVM_DLL static Expr make(Type type, Var buffer_var, Expr index, Expr predicate);
static constexpr const char* _type_key = "Load";
TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode);
};
/*!
* \brief Construct a vector with lanes elements
* where its i-th element equals base + i * stride.
* This is useful to construct a index for a continuous vector load.
*
* Examples:
* - ramp(0, 1, 3) = [0, 1, 2]
* - ramp(1, 2, 4) = [1, 3, 5, 7]
*/
class Ramp : public ExprNode {
public:
/*! \brief The base value. */
Expr base;
/*! \brief The stride of each step. */
Expr stride;
/*! \brief Total number of lanes. */
int lanes;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("base", &base);
v->Visit("stride", &stride);
v->Visit("lanes", &lanes);
}
TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode);
};
/*! \brief Create a vector where all the elements are value. */
class Broadcast : public ExprNode {
public:
/*! \brief The base value. */
Expr value;
/*! \brief The numerb of lanes. */
int lanes;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
v->Visit("lanes", &lanes);
}
TVM_DLL static Expr make(Expr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode);
};
/*!
* \brief Let binding. Bind var to value then evaluate body.
*/
class Let : public ExprNode {
public:
/*! \brief The variable. */
Var var;
/*! \brief The value to be binded. */
Expr value;
/*! \brief The result expression. */
Expr body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
}
TVM_DLL static Expr make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "Let";
TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode);
};
// Call node, represent a function call or a multi-dimensional array load.
//
// TODO(tvm-team):
// Refactor call with more explicit property registrations.
// rather than calling a string symbol.
// We should move most information into function itself and remove name.
/*! \brief Base node of internal functions. */
class FunctionBaseNode : public Node {
public:
/*! \return the name of the function */
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;
};
/*! \brief reference to a function */
class FunctionRef : public NodeRef {
public:
TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode);
};
/*!
* \brief Call node.
*/
class Call : public ExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
/*! \brief Extern "C" function. */
Extern = 0,
/*! \brief Extern CXX function. */
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
/*! \brief Halide-style call, evaluates func(args). */
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
/*! \brief Intrinsic functions that are pure. */
PureIntrinsic = 5
};
/*! \brief The name of the function/intrinsic. */
std::string name;
/*! \brief The arguments. */
Array<Expr> args;
/*! \brief Type of calls. */
CallType call_type;
/*! \brief The function to be called. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index{0};
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
v->Visit("func", &func);
v->Visit("value_index", &value_index);
}
TVM_DLL static Expr make(Type type,
std::string name,
Array<Expr> args,
CallType call_type,
FunctionRef func = FunctionRef(),
int value_index = 0);
/*! \return Whether call node is pure. */
bool is_pure() const {
return (call_type == PureExtern ||
call_type == PureIntrinsic ||
call_type == Halide);
}
/*!
* \return Whether call node corresponds to a defined intrinsic.
* \param intrin_name The name of the intrinsic.
*/
bool is_intrinsic(const char* intrin_name) const {
return
((call_type == Intrinsic ||
call_type == PureIntrinsic) &&
name == intrin_name);
}
static constexpr const char* _type_key = "Call";
TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
// Build-in intrinsics
static constexpr const char* reinterpret = "reinterpret";
static constexpr const char* bitwise_and = "bitwise_and";
static constexpr const char* bitwise_not = "bitwise_not";
static constexpr const char* bitwise_xor = "bitwise_xor";
static constexpr const char* bitwise_or = "bitwise_or";
static constexpr const char* shift_left = "shift_left";
static constexpr const char* shift_right = "shift_right";
static constexpr const char* popcount = "popcount";
static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
};
/*!
* \brief Shuffle instruction.
* vec = concat(vectors)
* result = (vec[indices[0]], vec[indices[1]] ...)
*/
class Shuffle : public ExprNode {
public:
/*! \brief the input vectors. */
Array<Expr> vectors;
/*! \brief The indices of each element. */
Array<Expr> indices;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("vectors", &vectors);
v->Visit("indices", &indices);
}
TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
TVM_DLL static Expr make_concat(Array<Expr> vectors);
TVM_DLL static Expr make_extract_element(Expr vector, int index);
// Node container for CommReducer
struct CommReducerNode;
static constexpr const char* _type_key = "Shuffle";
TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode);
};
struct CommReducer : public NodeRef {
// Reduce operator
class CommReducerNode;
class CommReducer : public NodeRef {
public:
CommReducer() {}
explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
/*!
......@@ -66,7 +611,8 @@ struct CommReducer : public NodeRef {
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
struct CommReducerNode : public Node {
class CommReducerNode : public Node {
public:
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
......@@ -82,8 +628,10 @@ struct CommReducerNode : public Node {
/*! \brief Function call operator to combine a and b */
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
TVM_DLL static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);
TVM_DLL static CommReducer make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("lhs", &lhs);
......@@ -104,7 +652,8 @@ inline const CommReducerNode* CommReducer::operator->() const {
}
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
class Reduce : public ExprNode {
public:
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
......@@ -134,17 +683,483 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode);
};
/*! \brief Any shape. */
struct Any : public ExprNode<Any> {
class Any : public ExprNode {
public:
void VisitAttrs(AttrVisitor* v) final {}
TVM_DLL static Expr make();
void VisitAttrs(AttrVisitor* v) final {}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode);
};
// Statements
/*!
* \brief Let binding, bind var to value, then run body.
*/
class LetStmt : public StmtNode {
public:
/*! \brief The variable. */
Var var;
/*! \brief The value to be binded. */
Expr value;
/*! \brief The body block. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode);
};
/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This provide auxiliary information for IR passes that transforms body.
*
* In terms of effect, this is equivalent to Block(Evaluate(value), body).
*
* Examples of possible usage:
* - Bound of function, variables.
* - Hint which block corresponds to a parallel region.
*/
class AttrStmt : public StmtNode {
public:
/*! \brief this is attribute about certain node */
NodeRef node;
/*! \brief the type key of the attribute */
std::string attr_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expr value;
/*! \brief The body statement to be executed */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("node", &node);
v->Visit("attr_key", &attr_key);
v->Visit("value", &value);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(NodeRef node,
std::string type_key,
Expr value,
Stmt body);
static constexpr const char* _type_key = "AttrStmt";
TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode);
};
/*!
* \brief Assert condition, if an error occurs, return the error message.
*/
class AssertStmt : public StmtNode {
public:
/*! \brief Condition to be checked. */
Expr condition;
/*! \brief Error message when assertion failed. */
Expr message;
/*!
* \brief Body which this assertion holds true.
* Will be executed after the assertion.
*/
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("condition", &condition);
v->Visit("message", &message);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode);
};
// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
class ProducerConsumer : public StmtNode {
public:
/*! \brief The corresponding tensor. */
FunctionRef func;
/*! \brief Whether the relation is producer. */
bool is_producer;
/*! \brief Body to be executed. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("is_producer", &is_producer);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode);
};
/*!
* \brief Store value to the buffer.
*
* Equivalent to ((DType*)buffer_var)[index] = value.
* where DType is the type specified by type().element_of().
*
* For example, if type = float32x3, then the load will corresponds to
*
* \code
*
* auto buffer = static_cast<float*>(buffer_var);
* buffer[index.v0] = value.v0;
* buffer[index.v1] = value.v1;
* buffer[index.v2] = value.v2;
*
* \endcode
* \sa Load
*/
class Store : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The value to be stored. */
Expr value;
/*! \brief The index locations to be stored. */
Expr index;
/*! \brief The predicate to mask which lanes would be stored. */
Expr predicate;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("buffer_var", &buffer_var);
v->Visit("value", &value);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
}
TVM_DLL static Stmt make(Var buffer_var,
Expr value,
Expr index,
Expr predicate);
static constexpr const char* _type_key = "Store";
TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode);
};
/*!
* \brief Store value into mult-dimensional array defined by func.
*/
class Provide : public StmtNode {
public:
/*! \brief The function to be updated. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index{0};
/*! \brief The value to be stored. */
Expr value;
/*! \brief The index arguments of the function. */
Array<Expr> args;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("value", &value);
v->Visit("args", &args);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
Expr value,
Array<Expr> args);
static constexpr const char* _type_key = "Provide";
TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode);
};
/*!
* \brief Allocate a buffer that can be used in body.
*/
class Allocate : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The type of the buffer. */
DataType type;
/*! \brief The extents of the buffer. */
Array<Expr> extents;
/*! \brief Only allocate buffer when condition is satisfied. */
Expr condition;
/*! \brief The body to be executed. */
Stmt body;
// The following two fields are deprecated
// kept for backward compatibility and will be refactored later.
Expr new_expr;
std::string free_function;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &type);
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(Var buffer_var,
DataType type,
Array<Expr> extents,
Expr condition,
Stmt body,
Expr new_expr = Expr(),
std::string free_function = std::string());
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \return The result.
*/
int32_t constant_allocation_size() const {
return constant_allocation_size(extents);
}
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(
const Array<Expr>& extents);
static constexpr const char* _type_key = "Allocate";
TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode);
};
/*! \brief Free the resources in the buffer before the scope ends. */
class Free : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("buffer_var", &buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode);
};
/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*/
class Realize : public StmtNode {
public:
/*! \brief The function to be realized. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType type;
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
Expr condition;
/*! \brief The body of realization. */
Stmt body;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("dtype", &type);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType type,
Region bounds,
Expr condition,
Stmt body);
static constexpr const char* _type_key = "Realize";
TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode);
};
/*!
* \brief A sequence of statements.
*/
class Block : public StmtNode {
public:
/*! \brief The first statement. */
Stmt first;
/*! \brief The restof statments. */
Stmt rest;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("first", &first);
v->Visit("rest", &rest);
}
TVM_DLL static Stmt make(Stmt first, Stmt rest);
TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
static constexpr const char* _type_key = "Block";
TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode);
};
/*!
* \brief IfThenElse statment.
*/
class IfThenElse : public StmtNode {
public:
/*! \brief The condition. */
Expr condition;
/*! \brief The branch to be executed when condition is true. */
Stmt then_case;
/*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("condition", &condition);
v->Visit("then_case", &then_case);
v->Visit("else_case", &else_case);
}
TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode);
};
/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
* If value do not have side-effect, this node can be safely removed.
*/
class Evaluate : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
Expr value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("value", &value);
}
TVM_DLL static Stmt make(Expr v);
static constexpr const char* _type_key = "Evaluate";
TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode);
};
/*! \brief Additional annotation of for loop. */
enum class ForType : int {
/*! \brief serial execution. */
Serial = 0,
/*! \brief parallel execution on CPU. */
Parallel = 1,
/*! \brief Vector SIMD loop annotaion. */
Vectorized = 2,
/*! \brief Unroll annotation. */
Unrolled = 3
};
// Kevice api of for loop
// kept for backward compatibility
// consider refactor and remove later.
enum class DeviceAPI: int {
None = 0
};
/*!
* \brief A for loop, with poissible type annotations.
*
* \code
*
* for (loop_var = min; loop_var < min + extent; ++loop_var) {
* // body
* }
* \endcode
*/
class For : public StmtNode {
public:
/*! \brief The loop variable. */
Var loop_var;
/*! \brief The minimum value of iteration. */
Expr min;
/*! \brief The extent of the iteration. */
Expr extent;
/*! \brief The type of the for loop. */
ForType for_type;
/*!
* \brief Deprecated, reserved for backward compatibility.
* Consider refactor and remove later.
*/
DeviceAPI device_api;
/*! \brief The body of the for loop. */
Stmt body;
TVM_DLL static Stmt make(Var loop_var,
Expr min,
Expr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
v->Visit("for_type", &for_type);
v->Visit("device_api", &device_api);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "For";
TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode);
};
/*!
* \brief A prefetch hint of func.
*/
class Prefetch : public StmtNode {
public:
/*! \brief The function to be prefetched. */
FunctionRef func;
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
DataType type;
/*! \brief Bounds to be prefetched. */
Region bounds;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
v->Visit("type", &type);
v->Visit("bounds", &bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType type,
Region bounds);
static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode);
};
/*!
......@@ -517,50 +1532,6 @@ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
} // namespace intrinsic
// Reuse IR node defintiion from HalideIR
using HalideIR::Internal::IntImm;
using HalideIR::Internal::UIntImm;
using HalideIR::Internal::FloatImm;
using HalideIR::Internal::StringImm;
using HalideIR::Internal::Cast;
using HalideIR::Internal::Add;
using HalideIR::Internal::Sub;
using HalideIR::Internal::Mul;
using HalideIR::Internal::Div;
using HalideIR::Internal::Mod;
using HalideIR::Internal::Min;
using HalideIR::Internal::Max;
using HalideIR::Internal::EQ;
using HalideIR::Internal::NE;
using HalideIR::Internal::LT;
using HalideIR::Internal::LE;
using HalideIR::Internal::GT;
using HalideIR::Internal::GE;
using HalideIR::Internal::And;
using HalideIR::Internal::Or;
using HalideIR::Internal::Not;
using HalideIR::Internal::Select;
using HalideIR::Internal::Load;
using HalideIR::Internal::Ramp;
using HalideIR::Internal::Broadcast;
using HalideIR::Internal::Call;
using HalideIR::Internal::Let;
using HalideIR::Internal::LetStmt;
using HalideIR::Internal::AttrStmt;
using HalideIR::Internal::AssertStmt;
using HalideIR::Internal::ProducerConsumer;
using HalideIR::Internal::For;
using HalideIR::Internal::Store;
using HalideIR::Internal::Provide;
using HalideIR::Internal::Allocate;
using HalideIR::Internal::Free;
using HalideIR::Internal::Realize;
using HalideIR::Internal::Prefetch;
using HalideIR::Internal::Block;
using HalideIR::Internal::IfThenElse;
using HalideIR::Internal::Evaluate;
using HalideIR::Internal::Shuffle;
/*!
* \brief Create a type annotation expression
* \param dtype The data type
......@@ -571,6 +1542,10 @@ inline Expr TypeAnnotation(Type dtype) {
"type_annotation", {},
ir::Call::PureIntrinsic);
}
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
} // namespace ir
} // namespace tvm
......
......@@ -25,6 +25,7 @@
#define TVM_IR_MUTATOR_H_
#include <unordered_map>
#include <utility>
#include "expr.h"
#include "ir.h"
#include "tvm/node/ir_functor.h"
......
......@@ -25,7 +25,6 @@
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
#include <ir/FunctionBase.h>
#include <string>
#include "base.h"
......@@ -42,7 +41,7 @@ class LoweredFuncNode;
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
class LoweredFunc : public ir::FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
......@@ -66,7 +65,7 @@ enum LoweredFuncType : int {
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
class LoweredFuncNode : public ir::FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/container.h
* \brief Array/Map container in the DSL graph.
*/
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"
namespace tvm {
/*! \brief array node content in array */
class ArrayNode : public Node {
public:
/*! \brief the data content */
std::vector<NodePtr<Node> > data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to array have no effect.
}
static constexpr const char* _type_key = "Array";
TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node);
};
/*! \brief map node content */
class MapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
// hash function
struct Hash {
size_t operator()(const NodePtr<Node>& n) const {
return std::hash<Node*>()(n.get());
}
};
// comparator
struct Equal {
bool operator()(
const NodePtr<Node>& a,
const NodePtr<Node>& b) const {
return a.get() == b.get();
}
};
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
NodePtr<Node>,
NodePtr<Node>,
Hash, Equal>;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "Map";
TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node);
};
/*! \brief specialized map node with string as key */
class StrMapNode : public Node {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::string,
NodePtr<Node> >;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node);
};
/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
class IterAdapter {
public:
explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter& operator++(int) { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter operator+(int offset) const { // NOLINT(*)
return IterAdapter(iter_ + offset);
}
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_);
}
private:
TIter iter_;
};
/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
class Array : public NodeRef {
public:
/*!
* \brief default constructor
*/
Array() {
node_ = make_node<ArrayNode>();
}
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T> &other) : NodeRef(other.node_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
*/
explicit Array(size_t n, const T& val) {
auto tmp_node = make_node<ArrayNode>();
for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val.node_);
}
node_ = std::move(tmp_node);
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T> & other) {
node_ = other.node_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back((*it).node_);
}
node_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return T(static_cast<const ArrayNode*>(node_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<ArrayNode> n = make_node<ArrayNode>();
n->data = static_cast<ArrayNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<ArrayNode*>(node_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ArrayNode* n = this->CopyOnWrite();
n->data[i] = value.node_;
}
/*! \return whether array is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container node */
using ContainerType = ArrayNode;
struct Ptr2NodeRef {
using ResultType = T;
static inline T convert(const NodePtr<Node>& n) {
return T(n);
}
};
using iterator = IterAdapter<Ptr2NodeRef,
std::vector<NodePtr<Node> >::const_iterator>;
using reverse_iterator = IterAdapter<
Ptr2NodeRef,
std::vector<NodePtr<Node> >::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ArrayNode*>(node_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend());
}
};
/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam K The key NodeRef type.
* \tparam V The value NodeRef type.
*/
template<typename K,
typename V,
typename = typename std::enable_if<
std::is_base_of<NodeRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
class Map : public NodeRef {
public:
/*!
* \brief default constructor
*/
Map() {
node_ = make_node<MapNode>();
}
/*!
* \brief move constructor
* \param other source
*/
Map(Map<K, V> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) : NodeRef(other.node_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
template<typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(Map<K, V> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(const Map<K, V> & other) {
node_ = other.node_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
NodePtr<MapNode> n = make_node<MapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.node_,
i->second.node_));
}
node_ = std::move(n);
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V operator[](const K& key) const {
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V at(const K& key) const {
return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const MapNode*>(node_.get())->data.size();
}
/*! \return The number of elements of the key */
inline size_t count(const K& key) const {
if (node_.get() == nullptr) return 0;
return static_cast<const MapNode*>(node_.get())->data.count(key.node_);
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline MapNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<MapNode> n = make_node<MapNode>();
n->data = static_cast<const MapNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<MapNode*>(node_.get());
}
/*!
* \brief set the Map.
* \param key The index key.
* \param value The value to be setted.
*/
inline void Set(const K& key, const V& value) {
MapNode* n = this->CopyOnWrite();
n->data[key.node_] = value.node_;
}
/*! \return whether array is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container node */
using ContainerType = MapNode;
struct Ptr2NodeRef {
using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair<
NodePtr<Node>,
NodePtr<Node> >& n) {
return std::make_pair(K(n.first), V(n.second));
}
};
using iterator = IterAdapter<
Ptr2NodeRef, MapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const MapNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const MapNode*>(node_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_));
}
};
// specialize of string map
template<typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public NodeRef {
public:
// for code reuse
Map() {
node_ = make_node<StrMapNode>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
Map(const Map<std::string, V> &other) : NodeRef(other.node_) { // NOLINT(*)
}
explicit Map(NodePtr<Node> n) : NodeRef(n) {}
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V> && other) {
node_ = std::move(other.node_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
node_ = other.node_;
return *this;
}
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_node<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first,
i->second.node_));
}
node_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
}
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(node_.get())->data.size();
}
inline size_t count(const std::string& key) const {
if (node_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(node_.get())->data.count(key);
}
inline StrMapNode* CopyOnWrite() {
if (node_.get() == nullptr || !node_.unique()) {
NodePtr<StrMapNode> n = make_node<StrMapNode>();
n->data = static_cast<const StrMapNode*>(node_.get())->data;
NodePtr<Node>(std::move(n)).swap(node_);
}
return static_cast<StrMapNode*>(node_.get());
}
inline void Set(const std::string& key, const V& value) {
StrMapNode* n = this->CopyOnWrite();
n->data[key] = value.node_;
}
inline bool empty() const {
return size() == 0;
}
using ContainerType = StrMapNode;
struct Ptr2NodeRef {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<
std::string,
NodePtr<Node> >& n) {
return std::make_pair(n.first, V(n.second));
}
};
using iterator = IterAdapter<
Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key));
}
};
} // namespace tvm
#endif // TVM_NODE_CONTAINER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
#include <utility>
#include <functional>
#include "node.h"
namespace tvm {
/*!
* \brief A dynamically dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signature
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include <utility>
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/node.h
* \brief Node system data structure.
*/
#ifndef TVM_NODE_NODE_H_
#define TVM_NODE_NODE_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/node_base.h>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
namespace tvm {
// forward declaration
class DataType;
class Node;
class NodeRef;
namespace runtime {
// forward declaration
class NDArray;
// forward declaration
class Object;
} // namespace runtime
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class TVM_DLL AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual ~AttrVisitor() = default;
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::Object* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief base class of node container in DSL AST.
*/
class TVM_DLL Node : public NodeBase {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the node */
virtual const uint32_t type_index() const = 0;
/*!
* \brief Whether this node derives from node with type_index=tid.
* Implemented by TVM_DECLARE_NODE_TYPE_INFO
*
* \param tid The type index.
* \return the check result.
*/
virtual const bool _DerivedFrom(uint32_t tid) const;
/*!
* \brief get a runtime unique type index given a type key
* \param type_key Type key of a type.
* \return the corresponding type index.
*/
static uint32_t TypeKey2Index(const char* type_key);
/*!
* \brief get type key from type index.
* \param index The type index
* \return the corresponding type key.
*/
static const char* TypeIndex2Key(uint32_t index);
/*!
* \return whether the type is derived from
*/
template<typename T>
inline bool derived_from() const;
/*!
* \return whether the node is of type T
* \tparam The type to be checked.
*/
template<typename T>
inline bool is_type() const;
/*!
* \brief Get a NodePtr that holds reference to this Node.
* \return the NodePtr
*/
inline NodePtr<Node> GetNodePtr() const;
// node ref can see this
friend class NodeRef;
static constexpr const char* _type_key = "Node";
};
/*! \brief Base class of all node reference object */
class NodeRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Node;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool same_as(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator<(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of IRNode */
inline uint32_t type_index() const;
/*! \return the internal node pointer */
inline const Node* get() const;
/*! \return the internal node pointer */
inline const Node* operator->() const;
/*!
* \brief Downcast this ir node to its actual type (e.g. Add, or
* Select). This returns nullptr if the node is not of the requested
* type. Example usage:
*
* if (const Add *add = node->as<Add>()) {
* // This is an add node
* }
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as() const;
/*!
* \brief A more powerful version of as that also works with
* intermediate base types.
* \tparam T the target type, must be subtype of IRNode
*/
template<typename T>
inline const T *as_derived() const;
/*! \brief default constructor */
NodeRef() = default;
explicit NodeRef(NodePtr<Node> node) : node_(node) {}
/*! \brief the internal node object, do not touch */
NodePtr<Node> node_;
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
* \brief helper macro to declare type information in a terminal node
*/
#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
const uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
const bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
// implementations of inline functions after this
template<typename T>
inline bool Node::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
template<typename T>
inline bool Node::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
inline NodePtr<Node> Node::GetNodePtr() const {
return NodePtr<Node>(const_cast<Node*>(this));
}
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(ptr->GetNodePtr());
}
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
ref->template derived_from<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.node_));
}
inline const Node* NodeRef::get() const {
return node_.get();
}
inline const Node* NodeRef::operator->() const {
return node_.get();
}
inline bool NodeRef::defined() const {
return node_.get() != nullptr;
}
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::same_as(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator<(const NodeRef& other) const {
return node_.get() < other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
}
inline uint32_t NodeRef::type_index() const {
CHECK(node_.get() != nullptr)
<< "null type";
return get()->type_index();
}
template<typename T>
inline const T* NodeRef::as() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && ptr->is_type<T>()) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template<typename T>
inline const T* NodeRef::as_derived() const {
const Node* ptr = static_cast<const Node*>(get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
/*! \brief The hash function for nodes */
struct NodeHash {
size_t operator()(const NodeRef& a) const {
return a.hash();
}
};
/*! \brief The equal comparator for nodes */
struct NodeEqual {
bool operator()(const NodeRef& a, const NodeRef& b) const {
return a.get() == b.get();
}
};
} // namespace tvm
#endif // TVM_NODE_NODE_H_
......@@ -53,7 +53,7 @@ struct TensorDom {
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public FunctionBaseNode {
class OperationNode : public ir::FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
......@@ -463,7 +463,7 @@ class ExternOpNode : public OperationNode {
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
......@@ -530,7 +530,7 @@ class HybridOpNode : public OperationNode {
v->Visit("axis", &axis);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
......
......@@ -70,7 +70,9 @@ struct NodeTypeChecker<Array<T> > {
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
if (!NodeTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
......@@ -144,7 +146,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
return TNodeRef(sptr);
}
inline TVMArgValue::operator HalideIR::Expr() const {
inline TVMArgValue::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -240,21 +242,21 @@ inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { /
}
// type related stuffs
inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
return this->operator=(Type2TVMType(t));
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
return this->operator=(t.operator DLDataType());
}
inline TVMRetValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
inline TVMRetValue::operator tvm::DataType() const {
return DataType(operator DLDataType());
}
inline TVMArgValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
inline TVMArgValue::operator tvm::DataType() const {
return DataType(operator DLDataType());
}
inline void TVMArgsSetter::operator()(
size_t i, const HalideIR::Type& t) const {
this->operator()(i, Type2TVMType(t));
size_t i, const DataType& t) const {
this->operator()(i, t.operator DLDataType());
}
} // namespace runtime
} // namespace tvm
......
......@@ -42,14 +42,6 @@
#include "object.h"
#include "node_base.h"
namespace HalideIR {
// Forward declare type for extensions
// The header works fine without depending on this.
struct Type;
struct Expr;
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
......@@ -58,6 +50,8 @@ struct Expr;
namespace tvm {
// forward declarations
class Integer;
class DataType;
class Expr;
namespace runtime {
......@@ -626,8 +620,8 @@ class TVMArgValue : public TVMPODValue_ {
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
inline operator tvm::DataType() const;
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
......@@ -835,8 +829,8 @@ class TVMRetValue : public TVMPODValue_ {
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const NodePtr<Node>& other);
// type related
inline operator HalideIR::Type() const;
inline TVMRetValue& operator=(const HalideIR::Type& other);
inline operator tvm::DataType() const;
inline TVMRetValue& operator=(const tvm::DataType& other);
private:
template<typename T>
......@@ -1184,7 +1178,7 @@ class TVMArgsSetter {
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const HalideIR::Type& t) const;
inline void operator()(size_t i, const tvm::DataType& t) const;
private:
/*! \brief The values fields */
......
......@@ -75,24 +75,24 @@ class Stage : public NodeRef {
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
EXPORT Stage& set_scope(std::string scope); // NOLINT(*)
TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
EXPORT Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline.
* \return reference to self.
*/
EXPORT Stage& compute_inline(); // NOLINT(*)
TVM_DLL Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at group root.
* \return reference to self.
*/
EXPORT Stage& compute_root(); // NOLINT(*)
TVM_DLL Stage& compute_root(); // NOLINT(*)
/*!
* \brief Bind the IterVar to thread index.
*
......@@ -100,7 +100,7 @@ class Stage : public NodeRef {
* \param thread_ivar The thread axis to be bound.
* \return reference to self.
*/
EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar);
TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
/*!
* \brief Set the predicate to determine whether a store to the array should be performed.
* Use this when there are multiple threads performing the same store and we only
......@@ -111,7 +111,7 @@ class Stage : public NodeRef {
* \param predicate The condition to be checked.
* \return reference to self.
*/
EXPORT Stage& set_store_predicate(Expr predicate);
TVM_DLL Stage& set_store_predicate(Expr predicate);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
......@@ -120,7 +120,7 @@ class Stage : public NodeRef {
* This is a beta feature.
* \return reference to self.
*/
EXPORT Stage& env_threads(Array<IterVar> threads);
TVM_DLL Stage& env_threads(Array<IterVar> threads);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
......@@ -129,7 +129,7 @@ class Stage : public NodeRef {
* \param p_inner The result inner domain.
* \return reference to self.
*/
EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
......@@ -139,7 +139,7 @@ class Stage : public NodeRef {
* \param p_inner The result inner domain.
* \return reference to self.
*/
EXPORT Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
......@@ -147,7 +147,7 @@ class Stage : public NodeRef {
* \param p_target The result target domain.
* \return reference to self.
*/
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
......@@ -161,13 +161,13 @@ class Stage : public NodeRef {
*
* \return reference to self.
*/
EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
EXPORT Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
......@@ -183,7 +183,7 @@ class Stage : public NodeRef {
* \param p_y_inner Inner axis of y dimension
* \return reference to self.
*/
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
......@@ -192,7 +192,7 @@ class Stage : public NodeRef {
* \param var The axis to be vectorized.
* \return reference to self.
*/
EXPORT Stage& vectorize(IterVar var); // NOLINT(*)
TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
......@@ -200,19 +200,19 @@ class Stage : public NodeRef {
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
EXPORT Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
*/
EXPORT Stage& unroll(IterVar var); // NOLINT(*)
TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
EXPORT Stage& parallel(IterVar var); // NOLINT(*)
TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief Annotate the iteration with pragma
*
......@@ -222,7 +222,7 @@ class Stage : public NodeRef {
*
* \return reference to self.
*/
EXPORT Stage& pragma(IterVar var,
TVM_DLL Stage& pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value = Expr()); // NOLINT(*)
/*!
......@@ -232,7 +232,7 @@ class Stage : public NodeRef {
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
EXPORT Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
......@@ -243,12 +243,12 @@ class Stage : public NodeRef {
* \param offset The required offset factor.
* \return reference to self
*/
EXPORT Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
EXPORT Stage& double_buffer(); // NOLINT(*)
TVM_DLL Stage& double_buffer(); // NOLINT(*)
/*!
* \brief Schedule for OpenGL fragment shader.
* \return reference to self.
......@@ -289,13 +289,13 @@ class Schedule : public NodeRef {
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
EXPORT Stage operator[](const Operation& op);
TVM_DLL Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
EXPORT Stage operator[](const Tensor& tensor) {
TVM_DLL Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
......@@ -307,7 +307,7 @@ class Schedule : public NodeRef {
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
EXPORT Stage create_group(const Array<Tensor>& outputs,
TVM_DLL Stage create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs = false);
/*!
......@@ -319,7 +319,7 @@ class Schedule : public NodeRef {
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
EXPORT Tensor cache_read(const Tensor& tensor,
TVM_DLL Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
......@@ -338,7 +338,7 @@ class Schedule : public NodeRef {
* \param scope The scope of the storage.
* \return The created tensor.
*/
EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
......@@ -355,7 +355,7 @@ class Schedule : public NodeRef {
* \param scope The scope of the storage.
* \return The created tensor.
*/
EXPORT Tensor cache_write(const Tensor& tensor, const std::string& scope);
TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
......@@ -369,7 +369,7 @@ class Schedule : public NodeRef {
* \param factor_axis The position where the new axis is placed.
* \return The created factored tensors.
*/
EXPORT Array<Tensor> rfactor(const Tensor& tensor,
TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis,
int factor_axis = 0);
/*!
......@@ -556,14 +556,14 @@ class ScheduleNode : public Node {
* \param op The candidate Operation.
* \return true if the schedule has the Operation. Otherwise, false.
*/
EXPORT bool Contain(const Operation& op) const;
TVM_DLL bool Contain(const Operation& op) const;
/*!
* \brief Check if the schedule contains a Tensor.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
EXPORT bool Contain(const Tensor& tensor) const {
TVM_DLL bool Contain(const Tensor& tensor) const {
return Contain(tensor->op);
}
......@@ -572,7 +572,7 @@ class ScheduleNode : public Node {
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
EXPORT static Schedule make(Array<Operation> ops);
TVM_DLL static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
......
......@@ -70,7 +70,7 @@ void AutoInlineElemWise(Schedule sch);
*
* \param sch The schedule to be inlined.
*/
EXPORT void AutoInlineInjective(Schedule sch);
TVM_DLL void AutoInlineInjective(Schedule sch);
} // namespace schedule
} // namespace tvm
......
......@@ -24,7 +24,6 @@
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include <ir/FunctionBase.h>
#include <tvm/node/container.h>
#include <string>
#include <vector>
......@@ -43,8 +42,6 @@ class TensorNode;
// internal node container for Operation
class OperationNode;
using HalideIR::IR::FunctionRef;
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
......@@ -140,7 +137,7 @@ class Tensor : public NodeRef {
};
/*! \brief Operation that produces tensors */
class Operation : public FunctionRef {
class Operation : public ir::FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
......
......@@ -59,13 +59,12 @@ TVM_REGISTER_API("make._range_by_min_extent")
TVM_REGISTER_API("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body
) {
int for_type, int device_api, Stmt body) {
return For::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<HalideIR::DeviceAPI>(device_api),
static_cast<DeviceAPI>(device_api),
body);
});
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
......@@ -36,10 +35,10 @@
namespace tvm {
TVM_REGISTER_API("_min_value")
.set_body_method(&Type::min);
.set_body_method(&DataType::min);
TVM_REGISTER_API("_max_value")
.set_body_method(&Type::max);
.set_body_method(&DataType::max);
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......
......@@ -52,12 +52,6 @@ class CanonicalExprNode : public BaseExprNode {
// overrides
void VisitAttrs(tvm::AttrVisitor* v) final {
}
void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final {
LOG(FATAL) << "not supported";
}
IRNodeType type_info() const final {
return IRNodeType::ExtensionExpr;
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
......
......@@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl :
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
return Everything(
static_cast<const ir::BaseExprNode*>(op)->type);
static_cast<const ExprNode*>(op)->type);
}
Entry VisitExpr(const Expr& expr) final {
......
......@@ -224,15 +224,15 @@ std::string CodeGenOpenGL::GetBufferRef(
void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
switch (t.code()) {
case halideir_type_int:
case kDLInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
os << "int";
break;
case halideir_type_uint:
case kDLUInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
os << "uint";
break;
case halideir_type_float:
case kDLFloat:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
os << "float";
break;
......
......@@ -18,19 +18,94 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/expr_operator.h>
#include <ir/IRPrinter.h>
#include <memory>
#include <limits>
namespace tvm {
using HalideIR::IR::RangeNode;
// maximum and min values
Expr DataType::max() const {
using namespace ir;
CHECK_EQ(lanes(), 1);
if (is_int()) {
if (bits() == 64) {
return IntImm::make(*this, std::numeric_limits<int64_t>::max());
} else if (bits() < 64) {
int64_t val = 1;
val = (val << (bits() - 1)) - 1;
return IntImm::make(*this, val);
}
} else if (is_uint()) {
if (bits() == 64) {
return UIntImm::make(*this, std::numeric_limits<uint64_t>::max());
} else if (bits() < 64) {
uint64_t val = 1;
val = (val << static_cast<uint64_t>(bits())) - 1;
return UIntImm::make(*this, val);
}
} else if (is_float()) {
if (bits() == 64) {
return FloatImm::make(*this, std::numeric_limits<double>::max());
} else if (bits() == 32) {
return FloatImm::make(*this, std::numeric_limits<float>::max());
} else if (bits() == 16) {
return FloatImm::make(*this, 65504.0);
}
}
LOG(FATAL) << "Cannot decide max_value for type" << *this;
return Expr();
}
Expr DataType::min() const {
using namespace ir;
CHECK_EQ(lanes(), 1);
if (is_int()) {
if (bits() == 64) {
return IntImm::make(*this, std::numeric_limits<int64_t>::lowest());
} else if (bits() < 64) {
int64_t val = 1;
val = -(val << (bits() - 1));
return IntImm::make(*this, val);
}
} else if (is_uint()) {
return UIntImm::make(*this, 0);
} else if (is_float()) {
if (bits() == 64) {
return FloatImm::make(*this, std::numeric_limits<double>::lowest());
} else if (bits() == 32) {
return FloatImm::make(*this, std::numeric_limits<float>::lowest());
} else if (bits() == 16) {
return FloatImm::make(*this, -65504.0);
}
}
LOG(FATAL) << "Cannot decide min_value for type" << *this;
return Expr();
}
Expr::Expr(int32_t value)
: Expr(IntImm::make(Int(32), value)) {}
Expr::Expr(float value)
: Expr(ir::FloatImm::make(Float(32), value)) {}
Expr::Expr(std::string str)
: Expr(ir::StringImm::make(str)) {}
Var::Var(std::string name_hint, DataType t)
: Var(Variable::make(t, name_hint)) {}
Var Variable::make(DataType t, std::string name_hint) {
NodePtr<Variable> node = make_node<Variable>();
node->type = t;
node->name_hint = std::move(name_hint);
return Var(node);
}
Range::Range(Expr begin, Expr end)
: Range(make_node<RangeNode>(
......@@ -38,12 +113,23 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) {
}
Integer IntImm::make(Type t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
NodePtr<IntImm> node = make_node<IntImm>();
node->type = t;
node->value = value;
return Integer(node);
}
Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
return Range(make_node<RangeNode>(min, extent));
}
IterVar IterVarNode::make(Range dom, Var var,
IterVarType t, std::string thread_tag) {
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
std::string thread_tag) {
NodePtr<IterVarNode> n = make_node<IterVarNode>();
n->dom = dom;
n->var = var;
......@@ -62,19 +148,48 @@ IterVar reduce_axis(Range dom, std::string name) {
dom, Var(name), kCommReduce);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).print(n);
return os;
}
void Dump(const NodeRef& n) {
std::cerr << n << "\n";
}
Var var(const std::string& name_hint, Type t) {
Var var(std::string name_hint, Type t) {
return Var(name_hint, t);
}
void IRPrinter::Print(const NodeRef& ir) {
static const FType& f = vtable();
if (!ir.defined()) {
stream << "(nullptr)";
} else {
if (f.can_dispatch(ir)) {
f(ir, this);
} else {
// default value, output type key and addr.
stream << ir->type_key() << "(" << ir.get() << ")";
}
}
}
void IRPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
IRPrinter::FType& IRPrinter::vtable() {
static FType inst;
return inst;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) {
if (op->type == Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->type << ")" << op->value;
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
......@@ -91,11 +206,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) {
.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
......
......@@ -18,65 +18,231 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file ir.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"
namespace HalideIR {
namespace Internal {
namespace tvm {
namespace ir {
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
using tvm::ir::Any;
using tvm::ir::AttrStmt;
// constructors
Expr UIntImm::make(DataType t, uint64_t value) {
CHECK(t.is_uint() && t.lanes() == 1)
<< "ValueError: UIntImm can only take scalar";
NodePtr<UIntImm> node = make_node<UIntImm>();
node->type = t;
node->value = value;
return Expr(node);
}
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
Expr FloatImm::make(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar";
NodePtr<FloatImm> node = make_node<FloatImm>();
node->type = t;
node->value = value;
return Expr(node);
}
template<>
void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
Expr StringImm::make(std::string value) {
NodePtr<StringImm> node = make_node<StringImm>();
node->type = Handle();
node->value = std::move(value);
return Expr(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});
Expr Cast::make(DataType t, Expr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.type().lanes());
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
NodePtr<Cast> node = make_node<Cast>();
node->type = t;
node->value = std::move(value);
return Expr(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
} // namespace HalideIR
namespace tvm {
namespace ir {
Expr And::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.type().is_bool());
CHECK(b.type().is_bool());
CHECK(a.type() == b.type()) << "TypeError: mismatched types";
NodePtr<And> node = make_node<And>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
Expr Or::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.type().is_bool());
CHECK(b.type().is_bool());
CHECK(a.type() == b.type()) << "TypeError: mismatched types";
NodePtr<Or> node = make_node<Or>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
Expr Not::make(Expr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.type().is_bool());
NodePtr<Not> node = make_node<Not>();
node->type = Bool(a.type().lanes());
node->a = std::move(a);
return Expr(node);
}
Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
CHECK(condition.type().is_bool());
CHECK_EQ(condition.type().lanes(), true_value.type().lanes());
CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types";
NodePtr<Select> node = make_node<Select>();
node->type = true_value.type();
node->condition = std::move(condition);
node->true_value = std::move(true_value);
node->false_value = std::move(false_value);
return Expr(node);
}
Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
CHECK(buffer_var.defined());
CHECK(predicate.defined());
CHECK(index.defined());
CHECK_EQ(type.lanes(), index.type().lanes());
CHECK_EQ(type.lanes(), predicate.type().lanes());
NodePtr<Load> node = make_node<Load>();
node->type = type;
node->buffer_var = std::move(buffer_var);
node->index = std::move(index);
node->predicate = std::move(predicate);
return Expr(node);
}
Expr Ramp::make(Expr base, Expr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
CHECK(base.type().is_scalar());
CHECK(stride.type().is_scalar());
CHECK_GT(lanes, 1);
CHECK_EQ(stride.type(), base.type());
NodePtr<Ramp> node = make_node<Ramp>();
node->type = base.type().with_lanes(lanes);
node->base = base;
node->stride = stride;
node->lanes = lanes;
return Expr(node);
}
Expr Broadcast::make(Expr value, int lanes) {
CHECK(value.defined());
CHECK(value.type().is_scalar());
CHECK_GT(lanes, 1);
NodePtr<Broadcast> node = make_node<Broadcast>();
node->type = value.type().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
return Expr(node);
}
Expr Let::make(Var var, Expr value, Expr body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.type(), var.type());
NodePtr<Let> node = make_node<Let>();
node->type = body.type();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
return Expr(node);
}
Expr Call::make(DataType type,
std::string name,
Array<Expr> args,
CallType call_type,
FunctionRef func,
int value_index) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
if (call_type == Halide) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].type().is_int());
}
}
NodePtr<Call> node = make_node<Call>();
node->type = type;
node->name = std::move(name);
node->args = std::move(args);
node->call_type = call_type;
node->func = std::move(func);
node->value_index = value_index;
return Expr(node);
}
Expr Shuffle::make(Array<Expr> vectors,
Array<Expr> indices) {
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
Type base_type = vectors[0].type().element_of();
int total_lanes = 0;
for (Expr val : vectors) {
CHECK(val.type().element_of() == base_type);
total_lanes += val.type().lanes();
}
CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
NodePtr<Shuffle> node = make_node<Shuffle>();
node->type = base_type.with_lanes(static_cast<int>(indices.size()));
node->vectors = std::move(vectors);
node->indices = std::move(indices);
return Expr(node);
}
Expr Shuffle::make_concat(Array<Expr> vectors) {
CHECK_NE(vectors.size(), 0);
if (vectors.size() == 1) {
return vectors[0];
}
Array<Expr> indices;
int index = 0;
for (const Expr& e : vectors) {
for (int i = 0; i < e.type().lanes(); ++i) {
indices.push_back(IntImm::make(Int(32), index++));
}
}
return make(vectors, indices);
}
Expr Shuffle::make_extract_element(Expr vector, int index) {
return make({vector}, {Integer(index)});
}
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
......@@ -132,6 +298,802 @@ Expr Any::make() {
return Expr(n);
}
Stmt LetStmt::make(Var var, Expr value, Stmt body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.type(), var.type());
NodePtr<LetStmt> node = make_node<LetStmt>();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
return Stmt(node);
}
Stmt AttrStmt::make(NodeRef node,
std::string attr_key,
Expr value,
Stmt body) {
auto n = make_node<AttrStmt>();
n->node = node;
n->attr_key = std::move(attr_key);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.type() == Int(32) ||
message.as<StringImm>())
<< "TypeError: AssertStmt message must be an int or string:"
<< message << "\n";
NodePtr<AssertStmt> node = make_node<AssertStmt>();
node->condition = std::move(condition);
node->message = std::move(message);
node->body = std::move(body);
return Stmt(node);
}
Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
NodePtr<ProducerConsumer> node = make_node<ProducerConsumer>();
node->func = std::move(func);
node->is_producer = is_producer;
node->body = std::move(body);
return Stmt(node);
}
Stmt For::make(Var loop_var,
Expr min,
Expr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.type().is_scalar());
CHECK(extent.type().is_scalar());
CHECK(loop_var.type().is_scalar());
CHECK(body.defined());
NodePtr<For> node = make_node<For>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
node->extent = std::move(extent);
node->for_type = for_type;
node->device_api = device_api;
node->body = std::move(body);
return Stmt(node);
}
Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
CHECK(value.defined());
CHECK(index.defined());
CHECK(predicate.defined());
CHECK_EQ(value.type().lanes(), index.type().lanes());
CHECK_EQ(value.type().lanes(), predicate.type().lanes());
NodePtr<Store> node = make_node<Store>();
node->buffer_var = std::move(buffer_var);
node->value = std::move(value);
node->index = std::move(index);
node->predicate = std::move(predicate);
return Stmt(node);
}
Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
CHECK(value_index >=0 && value_index < func->num_outputs())
<< "value index output function return value bound";
CHECK(value.defined()) << "Provide of undefined value\n";
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined()) << "Provide to undefined location\n";
}
NodePtr<Provide> node = make_node<Provide>();
node->func = std::move(func);
node->value_index = value_index;
node->value = std::move(value);
node->args = std::move(args);
return Stmt(node);
}
Stmt Allocate::make(Var buffer_var,
DataType type,
Array<Expr> extents,
Expr condition,
Stmt body,
Expr new_expr,
std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].type().is_scalar());
}
CHECK(body.defined());
CHECK(condition.defined());
CHECK(condition.type().is_bool());
NodePtr<Allocate> node = make_node<Allocate>();
node->buffer_var = std::move(buffer_var);
node->type = type;
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->new_expr = std::move(new_expr);
node->free_function = std::move(free_function);
return Stmt(node);
}
int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
if (const IntImm *int_size = extents[i].as<IntImm>()) {
result *= int_size->value;
if (result > std::numeric_limits<int32_t>::max()) {
return 0;
}
} else {
return 0;
}
}
return static_cast<int32_t>(result);
}
Stmt Free::make(Var buffer_var) {
NodePtr<Free> node = make_node<Free>();
node->buffer_var = buffer_var;
return Stmt(node);
}
Stmt Realize::make(FunctionRef func,
int value_index,
DataType type,
Region bounds,
Expr condition,
Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
CHECK(bounds[i]->min.type().is_scalar());
CHECK(bounds[i]->extent.type().is_scalar());
}
CHECK(body.defined());
CHECK(condition.defined());
CHECK(condition.type().is_bool());
NodePtr<Realize> node = make_node<Realize>();
node->func = std::move(func);
node->value_index = value_index;
node->type = type;
node->bounds = std::move(bounds);
node->condition = std::move(condition);
node->body = std::move(body);
return Stmt(node);
}
Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
CHECK(bounds[i]->min.type().is_scalar());
CHECK(bounds[i]->extent.type().is_scalar());
}
NodePtr<Prefetch> node = make_node<Prefetch>();
node->func = std::move(func);
node->value_index = value_index;
node->type = type;
node->bounds = std::move(bounds);
return Stmt(node);
}
Stmt Block::make(Stmt first, Stmt rest) {
CHECK(first.defined());
CHECK(rest.defined());
NodePtr<Block> node = make_node<Block>();
// canonicalize.
if (const Block* b = first.as<Block>()) {
node->first = b->first;
node->rest = Block::make(b->rest, rest);
} else {
node->first = std::move(first);
node->rest = std::move(rest);
}
return Stmt(node);
}
Stmt Block::make(const std::vector<Stmt>& stmts) {
if (stmts.empty()) {
return Stmt();
}
Stmt result = stmts.back();
for (size_t i = stmts.size() - 1; i != 0; --i) {
result = Block::make(stmts[i - 1], result);
}
return result;
}
Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
// else_case may be null.
NodePtr<IfThenElse> node = make_node<IfThenElse>();
node->condition = std::move(condition);
node->then_case = std::move(then_case);
node->else_case = std::move(else_case);
return Stmt(node);
}
Stmt Evaluate::make(Expr value) {
CHECK(value.defined());
NodePtr<Evaluate> node = make_node<Evaluate>();
node->value = std::move(value);
return Stmt(node);
}
// Printers
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<UIntImm>([](const UIntImm* op, IRPrinter* p) {
p->stream << "(" << op->type << ")" << op->value;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloatImm>([](const FloatImm* op, IRPrinter* p) {
auto& stream = p->stream;
switch (op->type.bits()) {
case 64:
stream << op->value;
break;
case 32:
stream << op->value << 'f';
break;
case 16:
stream << op->value << 'h';
break;
default:
LOG(FATAL) << "Unknown float type bits=" << op->type.bits();
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StringImm>([](const StringImm* op, IRPrinter* p) {
auto& stream = p->stream;
stream << '"';
for (size_t i = 0; i < op->value.size(); ++i) {
unsigned char c = op->value[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
const char* hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
stream << '"';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Cast>([](const Cast* op, IRPrinter* p) {
p->stream << op->type << '(';
p->Print(op->value);
p->stream << ')';
})
.set_dispatch<Variable>([](const Variable* op, IRPrinter* p) {
// omit the type
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
})
.set_dispatch<Add>([](const Add* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " + ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<Sub>([](const Sub* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " - ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<Mul>([](const Mul* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << "*";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<Div>([](const Div* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << "/";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<Mod>([](const Mod* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " % ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<Min>([](const Min* op, IRPrinter* p) {
p->stream << "min(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
})
.set_dispatch<Max>([](const Max* op, IRPrinter* p) {
p->stream << "max(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
})
.set_dispatch<EQ>([](const EQ* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " == ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<NE>([](const NE* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " != ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<LT>([](const LT* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " < ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<LE>([](const LE* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " <= ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<GT>([](const GT* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " > ";
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<GE>([](const GE* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " >= ";
p->Print(op->b);
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<And>([](const And* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " && ";
p->Print(op->b);
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Or>([](const Or* op, IRPrinter* p) {
p->stream << '(';
p->Print(op->a);
p->stream << " || ";
p->Print(op->b);
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Not>([](const Not* op, IRPrinter* p) {
p->stream << '!';
p->Print(op->a);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Select>([](const Select* op, IRPrinter* p) {
p->stream << "select(";
p->Print(op->condition);
p->stream << ", ";
p->Print(op->true_value);
p->stream << ", ";
p->Print(op->false_value);
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Load>([](const Load* op, IRPrinter* p) {
p->stream << op->buffer_var << "[";
p->Print(op->index);
p->stream << "]";
if (!is_one(op->predicate)) {
p->stream << " if ";
p->Print(op->predicate);
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Ramp>([](const Ramp* op, IRPrinter* p) {
p->stream << "ramp(";
p->Print(op->base);
p->stream << ", ";
p->Print(op->stride);
p->stream << ", " << op->lanes << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Broadcast>([](const Broadcast* op, IRPrinter* p) {
p->stream << "x" << op->lanes << "(";
p->Print(op->value);
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Call>([](const Call* op, IRPrinter* p) {
p->stream << op->name << "(";
for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]);
if (i < op->args.size() - 1) {
p->stream << ", ";
}
}
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Let>([](const Let* op, IRPrinter* p) {
p->stream << "(let " << op->var << " = ";
p->Print(op->value);
p->stream << " in ";
p->Print(op->body);
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LetStmt>([](const LetStmt* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "let " << op->var << " = ";
p->Print(op->value);
p->stream << '\n';
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "// attr [";
p->Print(op->node);
p->stream << "] "
<< op->attr_key << " = ";
p->Print(op->value);
p->stream << '\n';
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AssertStmt>([](const AssertStmt* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "assert(";
p->Print(op->condition);
p->stream << ", ";
p->Print(op->message);
p->stream << ")\n";
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer* op, IRPrinter* p) {
if (op->is_producer) {
p->PrintIndent();
p->stream << "produce " << op->func->func_name() << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
} else {
p->Print(op->body);
}
});
std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
switch (type) {
case ForType::Serial:
out << "for";
break;
case ForType::Parallel:
out << "parallel";
break;
case ForType::Unrolled:
out << "unrolled";
break;
case ForType::Vectorized:
out << "vectorized";
break;
}
return out;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<For>([](const For* op, IRPrinter* p) {
p->PrintIndent();
p->stream << op->for_type << " (" << op->loop_var << ", ";
p->Print(op->min);
p->stream << ", ";
p->Print(op->extent);
p->stream << ") {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Store>([](const Store* op, IRPrinter* p) {
p->PrintIndent();
p->stream << op->buffer_var << "[";
p->Print(op->index);
p->stream << "] = ";
p->Print(op->value);
if (!is_one(op->predicate)) {
p->stream << " if ";
p->Print(op->predicate);
}
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Provide>([](const Provide* op, IRPrinter* p) {
p->PrintIndent();
p->stream << op->func->func_name() << "(";
for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]);
if (i < op->args.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (op->func->num_outputs() != 1) {
p->stream << ".value[" << op->value_index << "]";
}
p->stream << " =";
p->Print(op->value);
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Allocate>([](const Allocate* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "allocate " << op->buffer_var << "[" << op->type;
for (size_t i = 0; i < op->extents.size(); ++i) {
p->stream << " * ";
p->Print(op->extents[i]);
}
p->stream << "]";
if (!is_one(op->condition)) {
p->stream << " if ";
p->Print(op->condition);
}
p->stream << "\n";
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Free>([](const Free* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "free " << op->buffer_var;
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Realize>([](const Realize* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "realize " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
p->stream << ", ";
p->Print(op->bounds[i]->extent);
p->stream << "]";
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (op->func->num_outputs() != 1) {
p->stream << ".value[" << op->value_index << "]";
}
if (!is_one(op->condition)) {
p->stream << " if ";
p->Print(op->condition);
}
p->stream << " {\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Prefetch>([](const Prefetch* op, IRPrinter* p) {
p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
p->stream << ", ";
p->Print(op->bounds[i]->extent);
p->stream << "]";
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
if (op->func->num_outputs() != 1) {
p->stream << ".value[" << op->value_index << "]";
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Block>([](const Block* op, IRPrinter* p) {
p->Print(op->first);
if (op->rest.defined()) p->Print(op->rest);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IfThenElse>([](const IfThenElse* op, IRPrinter* p) {
p->PrintIndent();
while (true) {
p->stream << "if (" << op->condition << ") {\n";
p->indent += 2;
p->Print(op->then_case);
p->indent -= 2;
if (!op->else_case.defined()) {
break;
}
if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
p->PrintIndent();
p->stream << "} else ";
op = nested_if;
} else {
p->PrintIndent();
p->stream << "} else {\n";
p->indent += 2;
p->Print(op->else_case);
p->indent -= 2;
break;
}
}
p->PrintIndent();
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Evaluate>([](const Evaluate* op, IRPrinter* p) {
p->PrintIndent();
p->Print(op->value);
p->stream << "\n";
});
template<typename T>
void PrintList(const Array<T> &exprs, IRPrinter* p) {
for (size_t i = 0; i < exprs.size(); ++i) {
p->Print(exprs[i]);
if (i < exprs.size() - 1) {
p->stream << ", ";
}
}
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Shuffle>([](const Shuffle* op, IRPrinter* p) {
p->stream << "shuffle(";
PrintList(op->vectors, p);
p->stream << ", ";
PrintList(op->indices, p);
p->stream << ")";
});
// Container printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ArrayNode>([](const ArrayNode* op, IRPrinter* p) {
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
if (i != 0) {
p->stream << ", ";
}
p->Print(NodeRef(op->data[i]));
}
p->stream << ']';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MapNode>([](const MapNode* op, IRPrinter* p) {
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) {
p->stream << ", ";
}
p->Print(NodeRef(it->first));
p->stream << ": ";
p->Print(NodeRef(it->second));
}
p->stream << '}';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StrMapNode>([](const StrMapNode* op, IRPrinter* p) {
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
if (it != op->data.begin()) {
p->stream << ", ";
}
p->stream << '\"' << it->first << "\": ";
p->Print(NodeRef(it->second));
}
p->stream << '}';
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRPrinter* p) {
p->stream << "reduce(combiner="
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode* op, IRPrinter* p) {
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
p->stream << "?";
});
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
......
......@@ -24,26 +24,23 @@
#include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/tensor_intrin.h>
#include <ir/IR.h>
#include <memory>
namespace tvm {
// Tensor
Expr Tensor::operator()(Array<Var> indices) const {
Array<Expr> arr(indices.begin(), indices.end());
return operator()(arr);
}
Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call;
using ir::Call;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
// TODO(tqchen):
// Think of re-organize and consolidate with object.
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
TVM_DLL const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
TVM_DLL uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
TVM_DLL const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Compute Op.
* \file compute_op.cc
*/
......@@ -250,7 +249,7 @@ Stmt BaseComputeOpNode::BuildRealize(
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
HalideIR::Internal::Region bounds;
Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief External computation rule.
* \file extern_op.cc
*/
......@@ -140,7 +139,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
......@@ -28,7 +27,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <ir/Expr.h>
#include <unordered_set>
#include <string>
#include <utility>
......@@ -143,7 +141,7 @@ Stmt HybridOpNode::BuildRealize(
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
......@@ -442,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
}
const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For::make(target->var, range->min, range->extent,
for_type, HalideIR::DeviceAPI::None, body);
for_type, DeviceAPI::None, body);
}
};
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Utility to make loop nest.
* \file op_util.cc
*/
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Scan Operator.
* \file scan_op.cc
*/
......@@ -264,7 +263,7 @@ Stmt ScanOpNode::BuildRealize(
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = stage->op.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i);
HalideIR::Internal::Region bounds;
Region bounds;
bounds.push_back(tdom);
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
......@@ -34,7 +33,6 @@ namespace ir {
using arith::IntSet;
using arith::DomainTouched;
using HalideIR::Internal::Region;
class PrefetchInjector : public IRMutator {
public:
......
......@@ -347,8 +347,7 @@ class IRDeepCompare :
return order_;
}
int CompareRegion(const HalideIR::Internal::Region& lhs,
const HalideIR::Internal::Region& rhs) {
int CompareRegion(const Region& lhs, const Region& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
......
......@@ -225,7 +225,7 @@ Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
......@@ -255,7 +255,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
......
......@@ -26,6 +26,7 @@
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
......@@ -56,7 +57,7 @@ class StorageAccessVisitor : public IRVisitor {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
VarExpr buffer;
Var buffer = NullValue<Var>();
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
......@@ -66,7 +67,7 @@ class StorageAccessVisitor : public IRVisitor {
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write{false};
bool double_buffer_write = false;
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
......
......@@ -41,7 +41,6 @@
namespace tvm {
namespace ir {
using HalideIR::Internal::Region;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
......@@ -186,7 +185,7 @@ class StorageFlattener : public IRMutator {
}
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int32_t const_size = Allocate::constant_allocation_size(shape);
int align = GetTempAllocaAlignment(op->type, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
......@@ -348,14 +347,14 @@ class StorageFlattener : public IRMutator {
for (int i = starts; i >= 0; --i) {
if (i < starts) {
stmt = For::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
return stmt;
......
......@@ -77,7 +77,7 @@ struct GraphCodegen {
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<ir::StringImm>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
......@@ -289,7 +289,7 @@ class RelayBuildModule : public runtime::ModuleNode {
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
if (attrs->dtype == Int(32)) {
*rv = true;
}
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/graph_codegen.cc
* \brief Graph runtime codegen
*/
......@@ -238,7 +237,7 @@ class GraphRuntimeCodegen
* \param shape
* \return std::vector<int64_t>
*/
std::vector<int64_t> _ShapeToJSON(tvm::Array<HalideIR::Expr> shape) {
std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
std::vector<int64_t> ret;
for (IndexExpr dim : shape) {
const int64_t* pval = as_const_int(dim);
......@@ -623,9 +622,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<HalideIR::Expr> ret;
Array<tvm::Expr> ret;
for (const auto &kv : this->output_.params) {
HalideIR::Expr name = ir::StringImm::make(kv.first);
tvm::Expr name = ir::StringImm::make(kv.first);
ret.push_back(name);
}
*rv = ret;
......
......@@ -102,7 +102,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->print(node->type_annotation);
p->Print(node->type_annotation);
}
p->stream << ")";
});
......
......@@ -150,7 +150,7 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;
Array<HalideIR::Expr> oshape;
Array<IndexExpr> oshape;
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
......@@ -34,7 +33,7 @@ namespace tvm {
namespace schedule {
// key to specific tensor dimension.
struct TensorDimKey {
FunctionRef f;
ir::FunctionRef f;
int value_index;
int dim;
TensorDimKey() {}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file schedule_lang.cc
*/
#include <tvm/schedule.h>
......@@ -813,34 +812,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
})
.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
p->stream << "split(parent=";
p->print(op->parent);
p->Print(op->parent);
p->stream << ", outer=";
p->print(op->outer);
p->Print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
p->stream << "split(";
p->stream << "outer=";
p->print(op->outer);
p->Print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->Print(op->inner);
p->stream << ", fused=";
p->print(op->fused);
p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
p->stream << "rebase(";
p->stream << "parent=";
p->print(op->parent);
p->Print(op->parent);
p->stream << ", rebased=";
p->print(op->rebased);
p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
......
......@@ -23,14 +23,13 @@
#include <tvm/expr_operator.h>
namespace {
using namespace tvm;
using namespace tvm::ir;
using namespace HalideIR::Internal;
using namespace HalideIR;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
VarExpr var;
Var var;
int int_val;
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
......@@ -49,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return IntImm::make(Int(32), vm->int_val);
return Expr(IntImm::make(Int(32), vm->int_val));
} else {
return e;
}
......@@ -58,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
} // namespace
TEST(IRMutator, Basic) {
using namespace HalideIR::Internal;
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
......
......@@ -23,8 +23,8 @@
TEST(IRSSA, Convert) {
using namespace HalideIR::Internal;
using namespace tvm;
using namespace tvm::ir;
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
......@@ -35,7 +35,7 @@ TEST(IRSSA, Convert) {
}
TEST(IRSSA, Basic) {
using namespace HalideIR::Internal;
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = Evaluate::make(x + y);
......
......@@ -23,7 +23,6 @@
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
using namespace HalideIR::Internal;
using namespace tvm;
int n_var = 0;
Var x("x"), y;
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file topi/image/resize.h
* \brief image resize constructors
*/
......@@ -55,17 +54,17 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
const Expr max_y, const Expr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3];
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;
......@@ -268,17 +267,17 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
out_shape, [&](const Array<Var>& indices) {
auto in_y = indices[1] * y_ratio;
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[2] * x_ratio;
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf;
......
......@@ -689,7 +689,7 @@ inline Tensor sequence_mask(const Tensor& data,
auto bid = out_index[1 - axis];
len_index.push_back(bid);
Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
tvm::make_const(data->dtype, mask_value), data(out_index));
return ret;
}, name, tag);
return out;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment