Unverified Commit 0218557c by Tianqi Chen Committed by GitHub

[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. (#3533)

* [INFA][IR] Build and Evolve Low-level IR. Remove dep from HalideIR.


* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>

* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
parent 2d53f84d
......@@ -76,7 +76,6 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
add_definitions(-DHalide_SHARED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
......@@ -112,8 +111,8 @@ else(MSVC)
endif(MSVC)
# add source group
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})
......@@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/node/*.cc
src/schedule/*.cc
)
......@@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
......@@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm)
# Related headers
target_include_directories(
tvm
PUBLIC "3rdparty/HalideIR/src"
PUBLIC "topi/include")
target_include_directories(
tvm_topi
......@@ -295,11 +289,6 @@ if (INSTALL_DEV)
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
FILES_MATCHING
PATTERN "*.h"
......@@ -319,8 +308,6 @@ endif(INSTALL_DEV)
# More target definitions
if(MSVC)
target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -591,7 +591,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
......
......@@ -89,8 +89,8 @@ inline TNodeRef NullValue() {
}
template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
......
......@@ -221,7 +221,7 @@ class Layout : public NodeRef {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
......@@ -243,7 +243,7 @@ class Layout : public NodeRef {
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
if (var->var->name_hint == axis.name()) {
return true;
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file tvm/dtype.h
* \brief Data type used in IR.
*/
#ifndef TVM_DTYPE_H_
#define TVM_DTYPE_H_
#include "runtime/packed_func.h"
namespace tvm {
class Expr;
/*!
* \brief Primitive data types in tvm.
*/
class DataType {
public:
/*! \brief default constructor */
DataType() {}
/*!
* \brief Constructor
* \param dtype The DLDataType
*/
explicit DataType(DLDataType dtype)
: data_(dtype) {}
/*!
* \brief Constructor
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
*/
DataType(int code, int bits, int lanes) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
}
/*! \return The type code. */
int code() const {
return static_cast<int>(data_.code);
}
/*! \return number of bits in the data. */
int bits() const {
return static_cast<int>(data_.bits);
}
/*! \return number of bytes to store each scalar. */
int bytes() const {
return (bits() + 7) / 8;
}
/*! \return number of lanes in the data. */
int lanes() const {
return static_cast<int>(data_.lanes);
}
/*! \return whether type is a scalar type. */
bool is_scalar() const {
return lanes() == 1;
}
/*! \return whether type is a scalar type. */
bool is_bool() const {
return code() == kDLUInt && bits() == 1;
}
/*! \return whether type is a float type. */
bool is_float() const {
return code() == kDLFloat;
}
/*! \return whether type is an int type. */
bool is_int() const {
return code() == kDLInt;
}
/*! \return whether type is an uint type. */
bool is_uint() const {
return code() == kDLUInt;
}
/*! \return whether type is a handle type. */
bool is_handle() const {
return code() == kHandle;
}
/*! \return whether type is a vector type. */
bool is_vector() const {
return lanes() > 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const {
return DataType(data_.code, data_.bits, lanes);
}
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
* \return the result type.
*/
DataType with_bits(int bits) const {
return DataType(data_.code, bits, data_.lanes);
}
/*!
* \brief Get the scalar version of the type.
* \return the result type.
*/
DataType element_of() const {
return with_lanes(1);
}
// operator overloadings
bool operator==(const DataType& other) const {
return
data_.code == other.data_.code &&
data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
}
bool operator!=(const DataType& other) const {
return !operator==(other);
}
operator DLDataType () const {
return data_;
}
/*! \return the maximum possible value in this format. */
TVM_DLL Expr max() const;
/*! \return the minimum possible value in this format. */
TVM_DLL Expr min() const;
private:
DLDataType data_;
};
/*!
* \brief Construct an int type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \return The constructed data type.
*/
inline DataType Int(int bits, int lanes = 1) {
return DataType(kDLInt, bits, lanes);
}
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType UInt(int bits, int lanes = 1) {
return DataType(kDLUInt, bits, lanes);
}
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Bool(int lanes = 1) {
return UInt(1, lanes);
}
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Float(int bits, int lanes = 1) {
return DataType(kDLFloat, bits, lanes);
}
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Handle(int bits = 64, int lanes = 1) {
return DataType(kHandle, bits, lanes);
}
/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
*/
inline DataType TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}
/*!
* \brief Convert DLDataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DataType TVMType2Type(DLDataType t) {
return DataType(t.code, t.bits, t.lanes);
}
/*!
* \brief Convert DataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DLDataType Type2TVMType(DataType t) {
return t.operator DLDataType();
}
/*!
* \brief Get the number of bytes needed in a vector.
* \param dtype The data type.
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
// Overload print function.
inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
using namespace tvm::runtime;
return os << dtype.operator DLDataType();
}
// Backward compatibility
using Type = DataType;
} // namespace tvm
#endif // TVM_DTYPE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
......@@ -25,72 +24,107 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "base.h"
#include "dtype.h"
#include "node/container.h"
#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
using HalideIR::Float;
using HalideIR::Bool;
using HalideIR::Int;
using HalideIR::UInt;
using HalideIR::Handle;
using HalideIR::ExprHash;
using HalideIR::ExprEqual;
using HalideIR::Expr;
using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
/*! \brief Base node of all expressions. */
class ExprNode : public Node {
public:
/*! \brief The data type of the expression. */
DataType type;
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
};
/*! \brief Container of all expressions. */
class Expr : public NodeRef {
public:
Expr() {}
explicit Expr(NodePtr<Node> ptr) : NodeRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
TVM_DLL Expr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
TVM_DLL Expr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
TVM_DLL Expr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType type() const {
return static_cast<const ExprNode*>(get())->type;
}
}
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
}
/*! \brief type indicate the container type */
using ContainerType = ExprNode;
};
inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}
/*! \brief Base node of all statements. */
class StmtNode : public Node {
public:
static constexpr const char* _type_key = "Stmt";
TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
};
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
/*! \brief Container of all statements */
class Stmt : public NodeRef {
public:
TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
};
class Var;
/*!
* \brief A variable node in the IR.
*
* A vraible is uniquely identified by its address.
*
* Each variable is only binded once in the following nodes:
* - Allocate
* - For
* - Let
* - LetStmt
*/
class Variable : public ExprNode {
public:
/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("name", &name_hint);
}
static constexpr const char* _type_key = "Variable";
TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
};
/*! \brief a named variable in TVM */
class Var : public HalideIR::VarExpr {
class Var : public Expr {
public:
EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(NodePtr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
explicit Var(NodePtr<Node> n) : Expr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
Type t = Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
......@@ -99,10 +133,47 @@ class Var : public HalideIR::VarExpr {
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const Variable* get() const {
return static_cast<Variable*>(node_.get());
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
// Backward compatibility, will be removed later.
using VarExpr = Var;
using BaseExprNode = ExprNode;
using ExprHash = NodeHash;
using ExprEqual = NodeEqual;
class Integer;
/*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
};
/*!
* \brief Container of constant integer (IntImm).
......@@ -148,34 +219,52 @@ class Integer : public Expr {
using ContainerType = IntImm;
};
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr min;
/*! \brief the extend of range */
Expr extent;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
/*! \brief container class of iteration variable. */
class IterVarNode;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
/*!
* \brief same as HalideIR::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public HalideIR::IR::Range {
static constexpr const char* _type_key = "Range";
TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
};
/*! \brief Range constainer */
class Range : public NodeRef {
public:
/*! \brief constructor */
Range() {}
explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
TVM_DLL Range(Expr begin, Expr end);
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
/*!
* \brief construct a new range with min and extent
* The corresponding constructor is removed,
* because that is counter convention of tradition meaning
* of range(begin, end)
*
* \param min The minimum range.
* \param extent The extent of the range.
*/
static Range make_by_min_extent(Expr min, Expr extent);
// declare range.
TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
};
/*! \brief container class of iteration variable. */
class IterVarNode;
using Region = Array<Range>;
/*!
......@@ -289,9 +378,6 @@ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
using Domain = Array<Range>;
// print functions for expr
TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
......@@ -364,7 +450,7 @@ inline const char* IterVarType2String(IterVarType t) {
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
TVM_DLL Var var(std::string name_hint, Type t = Int(32));
/*
* \brief Template function to convert Map to unordered_map
......@@ -382,6 +468,32 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
return ret;
}
// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit IRPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
TVM_DLL void Print(const NodeRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = IRFunctor<void(const NodeRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace tvm
namespace std {
......
......@@ -25,6 +25,7 @@
#define TVM_IR_MUTATOR_H_
#include <unordered_map>
#include <utility>
#include "expr.h"
#include "ir.h"
#include "tvm/node/ir_functor.h"
......
......@@ -25,7 +25,6 @@
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
#include <ir/FunctionBase.h>
#include <string>
#include "base.h"
......@@ -42,7 +41,7 @@ class LoweredFuncNode;
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
class LoweredFunc : public ir::FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
......@@ -66,7 +65,7 @@ enum LoweredFuncType : int {
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
class LoweredFuncNode : public ir::FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/ir_functor.h
* \brief Defines the IRFunctor data structures.
*/
#ifndef TVM_NODE_IR_FUNCTOR_H_
#define TVM_NODE_IR_FUNCTOR_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
#include <utility>
#include <functional>
#include "node.h"
namespace tvm {
/*!
* \brief A dynamically dispatched functor on NodeRef in the first argument.
*
* \code
* IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const IntImm* op) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signature
*/
template<typename FType>
class IRFunctor;
template<typename R, typename ...Args>
class IRFunctor<R(const NodeRef& n, Args...)> {
private:
using Function = std::function<R (const NodeRef&n, Args...)>;
using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
/*! \brief internal function table */
std::vector<Function> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
inline bool can_dispatch(const NodeRef& n) const {
uint32_t type_index = n.type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor , dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
inline R operator()(const NodeRef& n, Args... args) const {
uint32_t type_index = n.type_index();
CHECK(type_index < func_.size() &&
func_[type_index] != nullptr)
<< "IRFunctor calls un-registered function on type "
<< Node::TypeIndex2Key(type_index);
return func_[type_index](n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(Function f) { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << Node::TypeIndex2Key(tindex)
<< " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief set the dispacher for type TNode
* This allows f to used detailed const Node pointer to replace NodeRef
*
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
Function fun = [f](const NodeRef& n, Args... args) {
return f(static_cast<const TNode*>(n.node_.get()),
std::forward<Args>(args)...);
};
return this->set_dispatch<TNode>(fun);
}
/*!
* \brief unset the dispacher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template<typename TNode>
inline TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
/*!
* \brief Useful macro to set IRFunctor dispatch in a global static field.
*
* \code
* // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
* // vtable allows easy patch in of new Node types, without changing
* // interface of IRPrinter.
*
* class IRPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of IRFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
ClsName::FField()
/*!
* \brief A container for a list of callbacks. All callbacks are invoked when
* the object is destructed.
*/
class IRFunctorCleanList {
public:
~IRFunctorCleanList() {
for (auto &f : clean_items) {
f();
}
}
void append(std::function<void()> func) {
clean_items.push_back(func);
}
private:
std::vector< std::function<void()> > clean_items;
};
/*!
* \brief A wrapper around IRFunctor that will record calls to set_dispatch
* and make a corresponding call to clear_dispatch when the last copy of
* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
* this can be used by NNVM and other libraries to unregister callbacks when
* the library is unloaded. This prevents crashes when the underlying IRFunctor
* is destructed as it will no longer contain std::function instances allocated
* by a library that has been unloaded.
*/
template<typename FType>
class IRFunctorStaticRegistry;
template<typename R, typename ...Args>
class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
private:
IRFunctor<R(const NodeRef& n, Args...)> *irf_;
std::shared_ptr<IRFunctorCleanList> free_list;
using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
public:
IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
irf_ = irf;
free_list = std::make_shared<IRFunctorCleanList>();
}
template<typename TNode>
inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
irf_->template set_dispatch<TNode>(f);
auto irf_copy = irf_;
free_list.get()->append([irf_copy] {
irf_copy->template clear_dispatch<TNode>();
});
return *this;
}
};
/*!
* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
* the compiler to deduce the template types.
*/
template<typename R, typename ...Args>
IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
IRFunctor<R(const NodeRef& n, Args...)> *irf) {
return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
}
#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
/*!
* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
* TVM_STATIC_IR_FUNCTOR.
*/
#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
MakeIRFunctorStaticRegistry(&ClsName::FField())
} // namespace tvm
#endif // TVM_NODE_IR_FUNCTOR_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/memory.h
* \brief Node memory management.
*/
#ifndef TVM_NODE_MEMORY_H_
#define TVM_NODE_MEMORY_H_
#include <utility>
#include "node.h"
namespace tvm {
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
//
template<typename T>
class SimpleNodeAllocator {
public:
template<typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static NodeBase::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(NodeBase* ptr) {
delete static_cast<T*>(ptr);
}
};
template<typename T, typename... Args>
inline NodePtr<T> make_node(Args&&... args) {
using Allocator = SimpleNodeAllocator<T>;
static_assert(std::is_base_of<NodeBase, T>::value,
"make_node can only be used to create NodeBase");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return NodePtr<T>(node);
}
} // namespace tvm
#endif // TVM_NODE_MEMORY_H_
......@@ -53,7 +53,7 @@ struct TensorDom {
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public FunctionBaseNode {
class OperationNode : public ir::FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
......@@ -463,7 +463,7 @@ class ExternOpNode : public OperationNode {
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
......@@ -530,12 +530,12 @@ class HybridOpNode : public OperationNode {
v->Visit("axis", &axis);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
......
......@@ -70,7 +70,9 @@ struct NodeTypeChecker<Array<T> > {
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
if (!NodeTypeChecker<T>::Check(p.get())) {
return false;
}
}
return true;
}
......@@ -144,7 +146,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
return TNodeRef(sptr);
}
inline TVMArgValue::operator HalideIR::Expr() const {
inline TVMArgValue::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
......@@ -240,21 +242,21 @@ inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { /
}
// type related stuffs
inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
return this->operator=(Type2TVMType(t));
inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
return this->operator=(t.operator DLDataType());
}
inline TVMRetValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
inline TVMRetValue::operator tvm::DataType() const {
return DataType(operator DLDataType());
}
inline TVMArgValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
inline TVMArgValue::operator tvm::DataType() const {
return DataType(operator DLDataType());
}
inline void TVMArgsSetter::operator()(
size_t i, const HalideIR::Type& t) const {
this->operator()(i, Type2TVMType(t));
size_t i, const DataType& t) const {
this->operator()(i, t.operator DLDataType());
}
} // namespace runtime
} // namespace tvm
......
......@@ -42,14 +42,6 @@
#include "object.h"
#include "node_base.h"
namespace HalideIR {
// Forward declare type for extensions
// The header works fine without depending on this.
struct Type;
struct Expr;
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
......@@ -58,6 +50,8 @@ struct Expr;
namespace tvm {
// forward declarations
class Integer;
class DataType;
class Expr;
namespace runtime {
......@@ -626,8 +620,8 @@ class TVMArgValue : public TVMPODValue_ {
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const;
inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const;
inline operator tvm::DataType() const;
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
......@@ -835,8 +829,8 @@ class TVMRetValue : public TVMPODValue_ {
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const NodePtr<Node>& other);
// type related
inline operator HalideIR::Type() const;
inline TVMRetValue& operator=(const HalideIR::Type& other);
inline operator tvm::DataType() const;
inline TVMRetValue& operator=(const tvm::DataType& other);
private:
template<typename T>
......@@ -1184,7 +1178,7 @@ class TVMArgsSetter {
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
inline void operator()(size_t i, const HalideIR::Type& t) const;
inline void operator()(size_t i, const tvm::DataType& t) const;
private:
/*! \brief The values fields */
......
......@@ -70,7 +70,7 @@ void AutoInlineElemWise(Schedule sch);
*
* \param sch The schedule to be inlined.
*/
EXPORT void AutoInlineInjective(Schedule sch);
TVM_DLL void AutoInlineInjective(Schedule sch);
} // namespace schedule
} // namespace tvm
......
......@@ -24,7 +24,6 @@
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include <ir/FunctionBase.h>
#include <tvm/node/container.h>
#include <string>
#include <vector>
......@@ -43,8 +42,6 @@ class TensorNode;
// internal node container for Operation
class OperationNode;
using HalideIR::IR::FunctionRef;
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
......@@ -140,7 +137,7 @@ class Tensor : public NodeRef {
};
/*! \brief Operation that produces tensors */
class Operation : public FunctionRef {
class Operation : public ir::FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -59,14 +59,13 @@ TVM_REGISTER_API("make._range_by_min_extent")
TVM_REGISTER_API("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body
) {
int for_type, int device_api, Stmt body) {
return For::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<HalideIR::DeviceAPI>(device_api),
body);
min,
extent,
static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api),
body);
});
TVM_REGISTER_API("make.Load")
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
......@@ -36,10 +35,10 @@
namespace tvm {
TVM_REGISTER_API("_min_value")
.set_body_method(&Type::min);
.set_body_method(&DataType::min);
TVM_REGISTER_API("_max_value")
.set_body_method(&Type::max);
.set_body_method(&DataType::max);
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......
......@@ -52,12 +52,6 @@ class CanonicalExprNode : public BaseExprNode {
// overrides
void VisitAttrs(tvm::AttrVisitor* v) final {
}
void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final {
LOG(FATAL) << "not supported";
}
IRNodeType type_info() const final {
return IRNodeType::ExtensionExpr;
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
......
......@@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl :
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
return Everything(
static_cast<const ir::BaseExprNode*>(op)->type);
static_cast<const ExprNode*>(op)->type);
}
Entry VisitExpr(const Expr& expr) final {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -224,15 +224,15 @@ std::string CodeGenOpenGL::GetBufferRef(
void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
switch (t.code()) {
case halideir_type_int:
case kDLInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
os << "int";
break;
case halideir_type_uint:
case kDLUInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
os << "uint";
break;
case halideir_type_float:
case kDLFloat:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
os << "float";
break;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,19 +18,94 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/expr_operator.h>
#include <ir/IRPrinter.h>
#include <memory>
#include <limits>
namespace tvm {
using HalideIR::IR::RangeNode;
// maximum and min values
Expr DataType::max() const {
using namespace ir;
CHECK_EQ(lanes(), 1);
if (is_int()) {
if (bits() == 64) {
return IntImm::make(*this, std::numeric_limits<int64_t>::max());
} else if (bits() < 64) {
int64_t val = 1;
val = (val << (bits() - 1)) - 1;
return IntImm::make(*this, val);
}
} else if (is_uint()) {
if (bits() == 64) {
return UIntImm::make(*this, std::numeric_limits<uint64_t>::max());
} else if (bits() < 64) {
uint64_t val = 1;
val = (val << static_cast<uint64_t>(bits())) - 1;
return UIntImm::make(*this, val);
}
} else if (is_float()) {
if (bits() == 64) {
return FloatImm::make(*this, std::numeric_limits<double>::max());
} else if (bits() == 32) {
return FloatImm::make(*this, std::numeric_limits<float>::max());
} else if (bits() == 16) {
return FloatImm::make(*this, 65504.0);
}
}
LOG(FATAL) << "Cannot decide max_value for type" << *this;
return Expr();
}
Expr DataType::min() const {
using namespace ir;
CHECK_EQ(lanes(), 1);
if (is_int()) {
if (bits() == 64) {
return IntImm::make(*this, std::numeric_limits<int64_t>::lowest());
} else if (bits() < 64) {
int64_t val = 1;
val = -(val << (bits() - 1));
return IntImm::make(*this, val);
}
} else if (is_uint()) {
return UIntImm::make(*this, 0);
} else if (is_float()) {
if (bits() == 64) {
return FloatImm::make(*this, std::numeric_limits<double>::lowest());
} else if (bits() == 32) {
return FloatImm::make(*this, std::numeric_limits<float>::lowest());
} else if (bits() == 16) {
return FloatImm::make(*this, -65504.0);
}
}
LOG(FATAL) << "Cannot decide min_value for type" << *this;
return Expr();
}
Expr::Expr(int32_t value)
: Expr(IntImm::make(Int(32), value)) {}
Expr::Expr(float value)
: Expr(ir::FloatImm::make(Float(32), value)) {}
Expr::Expr(std::string str)
: Expr(ir::StringImm::make(str)) {}
Var::Var(std::string name_hint, DataType t)
: Var(Variable::make(t, name_hint)) {}
Var Variable::make(DataType t, std::string name_hint) {
NodePtr<Variable> node = make_node<Variable>();
node->type = t;
node->name_hint = std::move(name_hint);
return Var(node);
}
Range::Range(Expr begin, Expr end)
: Range(make_node<RangeNode>(
......@@ -38,12 +113,23 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) {
}
Integer IntImm::make(Type t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
NodePtr<IntImm> node = make_node<IntImm>();
node->type = t;
node->value = value;
return Integer(node);
}
Range Range::make_by_min_extent(Expr min, Expr extent) {
return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
return Range(make_node<RangeNode>(min, extent));
}
IterVar IterVarNode::make(Range dom, Var var,
IterVarType t, std::string thread_tag) {
IterVar IterVarNode::make(Range dom,
Var var,
IterVarType t,
std::string thread_tag) {
NodePtr<IterVarNode> n = make_node<IterVarNode>();
n->dom = dom;
n->var = var;
......@@ -62,19 +148,48 @@ IterVar reduce_axis(Range dom, std::string name) {
dom, Var(name), kCommReduce);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).print(n);
return os;
}
void Dump(const NodeRef& n) {
std::cerr << n << "\n";
}
Var var(const std::string& name_hint, Type t) {
Var var(std::string name_hint, Type t) {
return Var(name_hint, t);
}
void IRPrinter::Print(const NodeRef& ir) {
static const FType& f = vtable();
if (!ir.defined()) {
stream << "(nullptr)";
} else {
if (f.can_dispatch(ir)) {
f(ir, this);
} else {
// default value, output type key and addr.
stream << ir->type_key() << "(" << ir.get() << ")";
}
}
}
void IRPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
IRPrinter::FType& IRPrinter::vtable() {
static FType inst;
return inst;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) {
if (op->type == Int(32)) {
p->stream << op->value;
} else {
p->stream << "(" << op->type << ")" << op->value;
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
......@@ -91,11 +206,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) {
.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
......
......@@ -24,26 +24,23 @@
#include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/tensor_intrin.h>
#include <ir/IR.h>
#include <memory>
namespace tvm {
// Tensor
Expr Tensor::operator()(Array<Var> indices) const {
Array<Expr> arr(indices.begin(), indices.end());
return operator()(arr);
}
Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call;
using ir::Call;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Implementation of Node API
* \file node.cc
*/
#include <tvm/node/node.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
// TODO(tqchen):
// Think of re-organize and consolidate with object.
namespace tvm {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
TVM_DLL const bool Node::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Node::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
TVM_DLL uint32_t Node::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
TVM_DLL const char* Node::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Compute Op.
* \file compute_op.cc
*/
......@@ -250,7 +249,7 @@ Stmt BaseComputeOpNode::BuildRealize(
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
HalideIR::Internal::Region bounds;
Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief External computation rule.
* \file extern_op.cc
*/
......@@ -140,7 +139,7 @@ Stmt ExternOpNode::BuildRealize(
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
......@@ -28,7 +27,6 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <ir/Expr.h>
#include <unordered_set>
#include <string>
#include <utility>
......@@ -143,7 +141,7 @@ Stmt HybridOpNode::BuildRealize(
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
......@@ -442,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
}
const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For::make(target->var, range->min, range->extent,
for_type, HalideIR::DeviceAPI::None, body);
for_type, DeviceAPI::None, body);
}
};
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Utility to make loop nest.
* \file op_util.cc
*/
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \brief Scan Operator.
* \file scan_op.cc
*/
......@@ -264,7 +263,7 @@ Stmt ScanOpNode::BuildRealize(
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = stage->op.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i);
HalideIR::Internal::Region bounds;
Region bounds;
bounds.push_back(tdom);
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
......@@ -34,7 +33,6 @@ namespace ir {
using arith::IntSet;
using arith::DomainTouched;
using HalideIR::Internal::Region;
class PrefetchInjector : public IRMutator {
public:
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -347,8 +347,7 @@ class IRDeepCompare :
return order_;
}
int CompareRegion(const HalideIR::Internal::Region& lhs,
const HalideIR::Internal::Region& rhs) {
int CompareRegion(const Region& lhs, const Region& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -225,7 +225,7 @@ Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
......@@ -255,7 +255,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -26,6 +26,7 @@
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
......@@ -56,7 +57,7 @@ class StorageAccessVisitor : public IRVisitor {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
VarExpr buffer;
Var buffer = NullValue<Var>();
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
......@@ -66,7 +67,7 @@ class StorageAccessVisitor : public IRVisitor {
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write{false};
bool double_buffer_write = false;
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
......
......@@ -41,7 +41,6 @@
namespace tvm {
namespace ir {
using HalideIR::Internal::Region;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
......@@ -186,7 +185,7 @@ class StorageFlattener : public IRMutator {
}
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int32_t const_size = Allocate::constant_allocation_size(shape);
int align = GetTempAllocaAlignment(op->type, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
......@@ -348,14 +347,14 @@ class StorageFlattener : public IRMutator {
for (int i = starts; i >= 0; --i) {
if (i < starts) {
stmt = For::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
return stmt;
......
......@@ -77,7 +77,7 @@ struct GraphCodegen {
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<ir::StringImm>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
......@@ -289,7 +289,7 @@ class RelayBuildModule : public runtime::ModuleNode {
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == HalideIR::Int(32)) {
if (attrs->dtype == Int(32)) {
*rv = true;
}
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/graph_codegen.cc
* \brief Graph runtime codegen
*/
......@@ -238,7 +237,7 @@ class GraphRuntimeCodegen
* \param shape
* \return std::vector<int64_t>
*/
std::vector<int64_t> _ShapeToJSON(tvm::Array<HalideIR::Expr> shape) {
std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
std::vector<int64_t> ret;
for (IndexExpr dim : shape) {
const int64_t* pval = as_const_int(dim);
......@@ -623,9 +622,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<HalideIR::Expr> ret;
Array<tvm::Expr> ret;
for (const auto &kv : this->output_.params) {
HalideIR::Expr name = ir::StringImm::make(kv.first);
tvm::Expr name = ir::StringImm::make(kv.first);
ret.push_back(name);
}
*rv = ret;
......
......@@ -102,7 +102,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->print(node->type_annotation);
p->Print(node->type_annotation);
}
p->stream << ")";
});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -150,7 +150,7 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;
Array<HalideIR::Expr> oshape;
Array<IndexExpr> oshape;
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
......@@ -34,7 +33,7 @@ namespace tvm {
namespace schedule {
// key to specific tensor dimension.
struct TensorDimKey {
FunctionRef f;
ir::FunctionRef f;
int value_index;
int dim;
TensorDimKey() {}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file schedule_lang.cc
*/
#include <tvm/schedule.h>
......@@ -813,34 +812,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
})
.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
p->stream << "split(parent=";
p->print(op->parent);
p->Print(op->parent);
p->stream << ", outer=";
p->print(op->outer);
p->Print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
p->stream << "split(";
p->stream << "outer=";
p->print(op->outer);
p->Print(op->outer);
p->stream << ", inner=";
p->print(op->inner);
p->Print(op->inner);
p->stream << ", fused=";
p->print(op->fused);
p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
p->stream << "rebase(";
p->stream << "parent=";
p->print(op->parent);
p->Print(op->parent);
p->stream << ", rebased=";
p->print(op->rebased);
p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
p->print(op->iter);
p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -23,14 +23,13 @@
#include <tvm/expr_operator.h>
namespace {
using namespace tvm;
using namespace tvm::ir;
using namespace HalideIR::Internal;
using namespace HalideIR;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
VarExpr var;
Var var;
int int_val;
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
......@@ -49,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
return IntImm::make(Int(32), vm->int_val);
return Expr(IntImm::make(Int(32), vm->int_val));
} else {
return e;
}
......@@ -58,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
} // namespace
TEST(IRMutator, Basic) {
using namespace HalideIR::Internal;
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
......
......@@ -23,8 +23,8 @@
TEST(IRSSA, Convert) {
using namespace HalideIR::Internal;
using namespace tvm;
using namespace tvm::ir;
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
......@@ -35,7 +35,7 @@ TEST(IRSSA, Convert) {
}
TEST(IRSSA, Basic) {
using namespace HalideIR::Internal;
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = Evaluate::make(x + y);
......
......@@ -23,7 +23,6 @@
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
using namespace HalideIR::Internal;
using namespace tvm;
int n_var = 0;
Var x("x"), y;
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file topi/image/resize.h
* \brief image resize constructors
*/
......@@ -55,17 +54,17 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
const Expr max_y, const Expr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3];
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;
......@@ -268,17 +267,17 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
out_shape, [&](const Array<Var>& indices) {
auto in_y = indices[1] * y_ratio;
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[2] * x_ratio;
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf;
......
......@@ -689,7 +689,7 @@ inline Tensor sequence_mask(const Tensor& data,
auto bid = out_index[1 - axis];
len_index.push_back(bid);
Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
tvm::make_const(data->dtype, mask_value), data(out_index));
return ret;
}, name, tag);
return out;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment