Commit 34f2adb9 by tqchen

Switch to HalideIR, with C API compile

parent 151707e0
[submodule "dmlc-core"] [submodule "dmlc-core"]
path = dmlc-core path = dmlc-core
url = https://github.com/dmlc/dmlc-core url = https://github.com/dmlc/dmlc-core
[submodule "HalideIR"]
path = HalideIR
url = ssh://git@github.com/tqchen/HalideIR
Subproject commit 79a09d0fd60ae7fb6917a647832664212f7cc844
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test doc .PHONY: clean all test doc
all: lib/libtvm.a lib/libtvm.so all: lib/libtvm.a lib/libtvm.so
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
include tests/cpp/unittest.mk include tests/cpp/unittest.mk
...@@ -28,6 +31,11 @@ lib/libtvm.so: $(ALL_DEP) ...@@ -28,6 +31,11 @@ lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
$(LIB_HALIDE_IR): LIBHALIDEIR
LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
lint: lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src python2 dmlc-core/scripts/lint.py tvm cpp include src
......
/*!
* Copyright (c) 2016 by Contributors
* \file array.h
* \brief Array container in the DSL graph.
*/
#ifndef TVM_ARRAY_H_
#define TVM_ARRAY_H_
#include <type_traits>
#include <vector>
#include <initializer_list>
#include "./base.h"
namespace tvm {
/*! \brief node content in array */
class ArrayNode : public Node {
public:
/*! \brief the data content */
std::vector<std::shared_ptr<Node> > data;
const char* type_key() const override {
return "ArrayNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
LOG(FATAL) << "need to specially handle list attrs";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
// Do nothing, specially handled
}
};
/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
* \tparam T The content NodeRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
class Array : public NodeRef {
public:
/*!
* \brief default constructor
*/
Array() {}
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T> && other) { // NOLINT(*)
node_ = std::move(other.node_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T> &other) { // NOLINT(*)
node_ = other.node_;
}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T> && other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T> & other) {
node_ = std::move(other.node_);
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<ArrayNode>();
n->data.reserve(end - begin);
for (IterType i = begin; i < end; ++i) {
n->data.push_back(i->node_);
}
node_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline T operator[](size_t i) const {
T inst;
inst.node_ = static_cast<const ArrayNode*>(node_.get())->data[i];
return inst;
}
/*! \return The size of the array */
inline size_t size() const {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
/*! \brief copy on write semantics */
inline void CopyOnWrite() {
if (node_.get() == nullptr || node_.unique()) return;
node_ = std::make_shared<ArrayNode>(
*static_cast<const ArrayNode*>(node_.get()));
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data.push_back(item.node_);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
this->CopyOnWrite();
static_cast<ArrayNode*>(node_.get())->data[i] = value.node_;
}
/*! \brief wrapper class to represent an array reference */
struct ArrayItemRef {
/*! \brief reference to parent */
Array<T>* parent;
/*! \brief The index */
size_t index;
/*!
* \brief assign operator
* \param other The value to be assigned
* \return reference to self.
*/
inline ArrayItemRef& operator=(const T& other) {
parent->Set(index, other);
return *this;
}
/*! \brief The conversion operator */
inline operator T() const {
return (*static_cast<const Array<T>*>(parent))[index];
}
// overload print
friend std::ostream& operator<<(
std::ostream &os, const typename Array<T>::ArrayItemRef& r) { // NOLINT(*0
return os << r.operator T();
}
};
/*!
* \brief Get reference of i-th element from array.
* \param i The index
* \return the ref to i-th element.
*/
inline ArrayItemRef operator[](size_t i) {
return ArrayItemRef{this, i};
}
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) {
os << '[';
} else {
os << ", ";
}
os << r[i];
}
os << ']';
return os;
}
};
} // namespace tvm
#endif // TVM_ARRAY_H_
...@@ -13,183 +13,16 @@ ...@@ -13,183 +13,16 @@
#include <functional> #include <functional>
#include <typeinfo> #include <typeinfo>
#include <type_traits> #include <type_traits>
#include <tvm/node.h>
namespace tvm { namespace tvm {
// forward declaration using ::tvm::Node;
class Node; using ::tvm::NodeRef;
class NodeRef; using ::tvm::AttrVisitor;
class UnaryOp;
class BinaryOp;
/*! \brief pointer type mask */
const int kPtrTypeMask = 16;
/*! \brief list of all supported data types */
enum DataType : int {
kUnknown = 0,
kInt32 = 1,
kFloat32 = 2,
kInt32Buffer = kInt32 | kPtrTypeMask,
kFloat32Buffer = kFloat32 | kPtrTypeMask
};
/*!
* \brief convert pointer type to data type
* \param ptr_type The pointer type.
* \return The corresponding data type.
*/
inline DataType Ptr2DataType(DataType ptr_type) {
CHECK_GE(ptr_type, kPtrTypeMask);
return static_cast<DataType>(ptr_type & (kPtrTypeMask -1));
}
/*!
* \brief List of subset node types used for quick runtime switch.
*
* \note The value of NodeType is not used for serialization type_key is used instead.
* \note is_type and type_key can be used to do type checking for all types
* \note kOtherNodes could mean more than one node type.
*/
enum NodeType {
// expr nodes
kVarNode,
kIntNode,
kFloatNode,
kUnaryOpNode,
kBinaryOpNode,
kReduceNode,
kTensorReadNode,
kBufferReadNode,
// stmt nodes
kStoreNode,
kForRangeNode,
kIfThenElseNode,
kOtherNodes
};
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, const UnaryOp** value) = 0;
virtual void Visit(const char* key, const BinaryOp** value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief A function to be applied when visit each NodeRef Field.
* \param ref The child to be visited.
*/
using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>;
/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class Node {
public:
/*! \brief virtual destructor */
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*! \brief verify the correctness of node struct after it get mutated by visitor */
virtual void Verify() const {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains NodeRefFields.
* \param visitor The visitor
*/
virtual void VisitNodeRefFields(FNodeRefVisit visitor) {}
/*!
* \tparam NodeType the type to be checked.
* \return whether the stored type is node type
*/
template<typename TNode>
inline bool is_type() const;
protected:
// node ref can see this
friend class NodeRef;
/*!
* \brief optional: safe destruction function
* Can be called in destructor of composite types.
* This can be used to avoid stack overflow when
* recursive destruction long graph(1M nodes),
*
* It is totally OK to not call this in destructor.
*/
void Destroy();
/*! \brief the node type enum */
NodeType node_type_{kOtherNodes};
};
/*! \brief base class of all node reference object */
class NodeRef {
public:
/*!
* \return typed pointer of the node
* \tparam TNode the type of the node.
*/
template<typename TNode>
inline const TNode* Get() const;
/*! \return the node type */
inline NodeType node_type() const;
/*! \return wheyjer the expression is null */
inline bool is_null() const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return the raw internal pointer of the node */
inline Node* node_ptr() const;
protected:
template<typename T, typename>
friend class Array;
friend class APIVariantValue;
friend class Node;
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
/*! \brief the internal node */
std::shared_ptr<Node> node_;
};
/*! \brief typedef the factory function of data iterator */ /*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>; using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*! /*!
* \brief Registry entry for NodeFactory * \brief Registry entry for NodeFactory
*/ */
...@@ -202,54 +35,5 @@ struct NodeFactoryReg ...@@ -202,54 +35,5 @@ struct NodeFactoryReg
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \ DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); }) .set_body([]() { return std::make_shared<TypeName>(); })
// implementations of inline functions after this
inline NodeType NodeRef::node_type() const {
return node_->node_type_;
}
template<typename TNode>
inline bool Node::is_type() const {
const std::type_info& tinfo = typeid(*this);
if (&typeid(TNode) == &tinfo) return true;
return typeid(TNode) == tinfo;
}
template<typename TNode>
inline const TNode* NodeRef::Get() const {
CHECK(node_->is_type<TNode>())
<< " type inconsistent, expected " << typeid(TNode).name()
<< " given " << typeid(*this).name();
return static_cast<const TNode*>(node_.get());
}
inline bool NodeRef::is_null() const {
return node_.get() == nullptr;
}
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_ptr());
}
inline Node* NodeRef::node_ptr() const {
return node_.get();
}
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::NodeRef> {
std::size_t operator()(const ::tvm::NodeRef& k) const {
return k.hash();
}
};
} // namespace std
#endif // TVM_BASE_H_ #endif // TVM_BASE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <memory>
#include "./base.h"
#include "./array.h"
#include "./expr.h"
namespace tvm {
// Internal node container of Range
class RangeNode;
// Internal node container of RDomain
class RDomainNode;
/*! \brief Node range */
class Range : public NodeRef {
public:
/*! \brief constructor */
Range() {}
/*!
* \brief constructor
* \param begin start of the range.
* \param end end of the range.
*/
Range(Expr begin, Expr end);
/*! \return The extent of the range */
Expr extent() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RangeNode* operator->() const;
/*! \return the begining of the range */
inline const Expr& begin() const;
/*! \return the end of the range */
inline const Expr& end() const;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*)
os << '[' << r.begin() << ", " << r.end() <<')';
return os;
}
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
/*!
* \return The domain of the reduction.
*/
inline const Domain& domain() const;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r.domain() << ")";
return os;
}
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
// implements of inline functions
inline const RangeNode* Range::operator->() const {
return static_cast<const RangeNode*>(node_.get());
}
inline const Expr& Range::begin() const {
return (*this)->begin;
}
inline const Expr& Range::end() const {
return (*this)->end;
}
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
inline const Domain& RDomain::domain() const {
return (*this)->domain;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
...@@ -7,121 +7,13 @@ ...@@ -7,121 +7,13 @@
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <type_traits> #include <type_traits>
#include <ir/Expr.h>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
// Forward declare Expr
class Expr;
class Var;
/*! using Halide::Type;
* \brief create a constant expression using Halide::Expr;
* \tparam T the value type
* \param value The value to the constant.
* \return The created expression
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value);
/*!
* \brief create a integer expression
* \param value The value to the expression
* \return the expression.
*/
Expr IntConstant(int64_t value);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr FloatConstant(double value);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr BufferRead(Var buffer, Expr offset);
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Expr : public NodeRef {
public:
/*! \brief default constructor */
Expr() {}
/*!
* \brief constructor from constant value
* \param value the constant value
* \tparam T The constant type
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
Expr(T value) { // NOLINT(*)
*this = std::move(constant<T>(value));
}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Expr(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
/*! \return the expression type of the expression */
inline DataType dtype() const;
// print the expression.
friend std::ostream& operator<<(std::ostream &os, const Expr& e) { // NOLINT(*)
e.Print(os);
return os;
}
private:
// print the expression
void Print(std::ostream& os) const; // NOLINT(*)
};
/*!
* \brief Variable class to represent the symbolic placeholder
* in the DSL, internally it is a VarNode.
*
* The Variable is uniquely identified by the address of VarNode.
*/
class Var : public Expr {
public:
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};
/*! \brief base of expression node */
class ExprNode : public Node {
public:
/*! \brief type of data stored in expression */
DataType dtype_{kUnknown};
};
// implementations
inline DataType Expr::dtype() const {
return static_cast<const ExprNode*>(node_.get())->dtype_;
}
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value) {
if (std::is_integral<T>::value) {
return IntConstant(static_cast<int64_t>(value));
} else {
return FloatConstant(static_cast<double>(value));
}
}
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::Expr> {
std::size_t operator()(const ::tvm::NodeRef& k) const {
return k.hash();
}
};
} // namespace std } // namespace std
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.h
* \brief Defines the expression nodes in AST.
*/
#ifndef TVM_EXPR_NODE_H_
#define TVM_EXPR_NODE_H_
#include <string>
#include "./domain.h"
#include "./tensor.h"
#include "./expr.h"
namespace tvm {
/*! \brief variable node for symbolic variables */
struct VarNode : public ExprNode {
/*! \brief hint name of the variable */
std::string name;
/*! \brief constructor */
VarNode() {
node_type_ = kVarNode;
}
const char* type_key() const override {
return "VarNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype_);
}
};
/*! \brief integer constant node */
struct IntNode : public ExprNode {
public:
/*! \brief the value field */
int64_t value;
/*! \brief constructor */
IntNode() {
node_type_ = kIntNode;
dtype_ = kInt32;
}
const char* type_key() const override {
return "IntNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
visitor->Visit("dtype", &dtype_);
}
};
/*! \brief float constant node */
struct FloatNode : public ExprNode {
/*! \brief the value field */
double value;
/*! \brief constructor */
FloatNode() {
node_type_ = kFloatNode;
dtype_ = kFloat32;
}
const char* type_key() const override {
return "FloatNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
visitor->Visit("dtype", &dtype_);
}
};
/*! \brief Unary mapping operator */
struct UnaryOpNode : public ExprNode {
/*! \brief The operator */
const UnaryOp* op;
/*! \brief The source expression */
Expr src;
/*! \brief constructor */
UnaryOpNode() {
node_type_ = kUnaryOpNode;
}
UnaryOpNode(const UnaryOp* op, Expr && src)
: op(op), src(std::move(src)) {
node_type_ = kUnaryOpNode;
dtype_ = this->src.dtype();
}
~UnaryOpNode() {
this->Destroy();
}
const char* type_key() const override {
return "UnaryOpNode";
}
void Verify() const override {
CHECK_EQ(dtype_, src.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("src", &src);
}
};
/*! \brief Binary mapping operator */
struct BinaryOpNode : public ExprNode {
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The left operand */
Expr lhs;
/*! \brief The right operand */
Expr rhs;
/*! \brief constructor, do not use constructor */
BinaryOpNode() {
node_type_ = kBinaryOpNode;
}
BinaryOpNode(const BinaryOp* op, Expr && lhs, Expr && rhs)
: op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {
node_type_ = kBinaryOpNode;
dtype_ = this->lhs.dtype();
}
~BinaryOpNode() {
this->Destroy();
}
const char* type_key() const override {
return "BinaryOpNode";
}
void Verify() const override {
CHECK_EQ(dtype_, lhs.dtype());
CHECK_EQ(dtype_, rhs.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("lhs", &lhs);
fvisit("rhs", &rhs);
}
};
/*! \brief Reduction operator operator */
struct ReduceNode : public ExprNode {
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The source operand */
Expr src;
/*! \brief The reduction domain */
RDomain rdom;
/*! \brief constructor, do not use constructor */
ReduceNode() {
node_type_ = kReduceNode;
}
ReduceNode(const BinaryOp* op, Expr && src, RDomain && rdom)
: op(op), src(std::move(src)), rdom(std::move(rdom)) {
node_type_ = kReduceNode;
dtype_ = this->src.dtype();
}
~ReduceNode() {
this->Destroy();
}
const char* type_key() const override {
return "ReduceNode";
}
void Verify() const override {
CHECK_EQ(dtype_, src.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("src", &src);
fvisit("rdom", &rdom);
}
};
/*! \brief Tensor read operator */
struct TensorReadNode : public ExprNode {
/*! \brief The tensor to be read from */
Tensor tensor;
/*! \brief The indices of read */
Array<Expr> indices;
/*! \brief constructor, do not use constructor */
TensorReadNode() {
node_type_ = kTensorReadNode;
}
TensorReadNode(Tensor && tensor, Array<Expr> && indices)
: tensor(std::move(tensor)), indices(std::move(indices)) {
node_type_ = kReduceNode;
dtype_ = tensor->dtype;
}
~TensorReadNode() {
this->Destroy();
}
const char* type_key() const override {
return "TensorReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, tensor->dtype);
for (size_t i = 0; i < indices.size(); ++i) {
CHECK_EQ(indices[i].dtype(), kInt32);
}
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("tensor", &tensor);
fvisit("indices", &indices);
}
};
/*! \brief Buffer read node */
struct BufferReadNode : public ExprNode {
/*! \brief The buffer variable to be read from */
Var buffer;
/*! \brief The offset to be read from */
Expr offset;
/*! \brief constructor, do not use constructor */
BufferReadNode() {
node_type_ = kBufferReadNode;
}
const char* type_key() const override {
return "BufferReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, Ptr2DataType(buffer.dtype()));
CHECK_EQ(offset.dtype(), kInt32);
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("buffer", &buffer);
fvisit("offset", &offset);
}
};
} // namespace tvm
#endif // TVM_EXPR_NODE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.h
* \brief Expression util
*/
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include <vector>
#include "./expr.h"
#include "./expr_node.h"
namespace tvm {
/*!
* \brief simplify the expression src
* \param src The source expression
* \return the simplified expression.
*/
Expr Simplify(Expr src);
/*!
* \brief replace the variables in expression src by specification from dict
* \param src The source expression
* \param dict The specification for variable replacement
* \return the new expression with variable replaced
*/
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
*/
template<typename FVisit>
inline void Visit(const Expr& expr, FVisit fvisit) {
// TODO(tqchen) change to stack based impl.
switch (expr.node_type()) {
case kBinaryOpNode: {
const auto* n = expr.Get<BinaryOpNode>();
Visit(n->lhs, fvisit);
Visit(n->rhs, fvisit);
break;
}
case kUnaryOpNode: {
const auto* n = expr.Get<UnaryOpNode>();
Visit(n->src, fvisit);
break;
}
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
Visit(n->src, fvisit);
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
Visit(n->indices[i], fvisit);
}
break;
}
default: break;
}
fvisit(expr);
}
/*!
* \brief transform the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
* \return the new expression after transformation
*/
template<typename FVisit>
inline Expr Transform(const Expr& expr, FVisit fvisit) {
// TODO(tqchen) change to stack based impl.
std::vector<Expr> children;
switch (expr.node_type()) {
case kBinaryOpNode: {
const auto* n = expr.Get<BinaryOpNode>();
Expr e = Transform(n->lhs, fvisit);
children.push_back(e);
children.push_back(Transform(n->rhs, fvisit));
break;
}
case kUnaryOpNode: {
const auto* n = expr.Get<UnaryOpNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
children.push_back(Transform(n->src, fvisit));
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
children.push_back(Transform(n->indices[i], fvisit));
}
break;
}
default: break;
}
Expr ret = fvisit(expr, children);
return ret;
}
} // namespace tvm
#endif // TVM_EXPR_UTIL_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op.h
* \brief Defines the operators
*/
#ifndef TVM_OP_H_
#define TVM_OP_H_
#include <dmlc/registry.h>
#include <string>
#include "./expr.h"
#include "./domain.h"
namespace tvm {
/*! \brief binary operator */
class BinaryOp {
public:
// virtual destructor
virtual ~BinaryOp() {}
/*! \return the function name to be called in binary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the binary op
* \param lhs left operand
* \param rhs right operand
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
/*!
* \brief make a reduction of src over rdom,
* \param src Source expression.
* \param rdom reduction domain.
* \return the result expr
*/
Expr Reduce(Expr src, RDomain rdom) const;
/*!
* \brief get binary op by name
* \param name name of operator
*/
static const BinaryOp* Get(const char* name);
};
/*! \brief unary operator */
class UnaryOp {
public:
/*! \return the function name to be called in unary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the unary op
* \param src left operand
* \return the result expr
*/
Expr operator()(Expr src) const;
/*!
* \brief get unary op by name
* \param name name of operator
*/
static const UnaryOp* Get(const char* name);
};
class AddOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "+";
}
};
class SubOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "-";
}
};
class MulOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "*";
}
};
class DivOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "/";
}
};
class MaxOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "max";
}
};
class MinOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "min";
}
};
#define DEFINE_BINARY_OP_OVERLOAD(OpChar) \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
static const BinaryOp* op = BinaryOp::Get(#OpChar); \
return (*op)(lhs, rhs); \
}
#define DEFINE_BINARY_OP_FUNCTION(FuncName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
static const BinaryOp* op = BinaryOp::Get(#FuncName); \
return (*op)(lhs, rhs); \
}
#define DEFINE_REDUCE_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr src, RDomain rdom) { \
static const BinaryOp* op = BinaryOp::Get(#OpName); \
return op->Reduce(src, rdom); \
}
DEFINE_BINARY_OP_OVERLOAD(+);
DEFINE_BINARY_OP_OVERLOAD(-);
DEFINE_BINARY_OP_OVERLOAD(*);
DEFINE_BINARY_OP_OVERLOAD(/);
DEFINE_BINARY_OP_FUNCTION(max);
DEFINE_BINARY_OP_FUNCTION(min);
DEFINE_REDUCE_FUNCTION(max, max);
DEFINE_REDUCE_FUNCTION(min, min);
DEFINE_REDUCE_FUNCTION(sum, +);
// overload negation
inline Expr operator-(Expr src) {
return src * (-1);
}
// template of op registry
template<typename Op>
struct OpReg {
std::string name;
std::unique_ptr<Op> op;
inline OpReg& set(Op* op) {
this->op.reset(op);
return *this;
}
};
using UnaryOpReg = OpReg<UnaryOp>;
using BinaryOpReg = OpReg<BinaryOp>;
#define TVM_REGISTER_BINARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::BinaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
#define TVM_REGISTER_UNARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::UnaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
} // namespace tvm
#endif // TVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.h
* \brief The statement creation functions.
* The underlying container are defined in stmt_node.h
*/
#ifndef TVM_STMT_H_
#define TVM_STMT_H_
#include <type_traits>
#include "./base.h"
#include "./domain.h"
namespace tvm {
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Stmt : public NodeRef {
public:
/*! \brief default constructor */
Stmt() {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Stmt(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
};
/*!
* \brief construct Store Stmt.
* \param buffer The variable representing the buffer.
* \param offset The offset in the buffer
* \param src The source expression.
*/
Stmt Store(Var buffer, Expr offset, Expr src);
/*!
* \brief construct ForRange Stmt
* \param loop_var The loop variable
* \param range The loop range
* \param body The loop body
*/
Stmt ForRange(Var loop_var, Range range, Stmt body);
/*!
* \brief construct a IfThenElse
* \param cond The condition.
* \param then_body The body to go to in then condition.
* \param else_body The body to go to in else condition.
*/
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body);
} // namespace tvm
#endif // TVM_STMT_H_
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.h
* \brief Common data structure for codegen
*/
#ifndef TVM_STMT_NODE_H_
#define TVM_STMT_NODE_H_
#include "./base.h"
#include "./domain.h"
namespace tvm {
/*!
* \brief The internal base class of StmtNode
* So far no extra stuffs in here.
*/
struct StmtNode : public Node {
};
/*! \brief Store data into buffer */
struct StoreNode : public StmtNode {
/*! \brief the variable representing the buffer */
Var buffer;
/*! \brief the buffer offset */
Expr offset;
/*! \brief The source expression*/
Expr src;
/*! \brief constructor */
StoreNode() {
node_type_ = kStoreNode;
}
const char* type_key() const override {
return "StoreNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("buffer", &buffer);
fvisit("offset", &offset);
fvisit("src", &src);
}
void Verify() const override {
CHECK_EQ(Ptr2DataType(buffer.dtype()), src.dtype());
CHECK_EQ(offset.dtype(), kInt32);
}
};
/*! \brief for loop in range */
struct ForRangeNode : public StmtNode {
/*! \brief loop variable */
Var loop_var;
/*! \brief The loop range */
Range range;
/*! \brief body of the loop */
Stmt body;
/*! \brief constructor */
ForRangeNode() {
node_type_ = kForRangeNode;
}
const char* type_key() const override {
return "ForRangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("loop_var", &loop_var);
fvisit("range", &range);
fvisit("body", &body);
}
void Verify() const override {
CHECK_EQ(loop_var.dtype(), kInt32);
CHECK_EQ(this->range->begin.dtype(), loop_var.dtype());
CHECK_EQ(this->range->end.dtype(), loop_var.dtype());
}
};
/*! \brief conditional expression */
struct IfThenElseNode : public StmtNode {
/*! \brief The condition */
Expr cond;
/*! \brief The statement in then */
Stmt then_body;
/*! \brief The statement in else */
Stmt else_body;
/*! \brief constructor */
IfThenElseNode() {
node_type_ = kIfThenElseNode;
}
const char* type_key() const override {
return "IfThenElseNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("cond", &cond);
fvisit("then_body", &then_body);
fvisit("else_body", &else_body);
}
void Verify() const override {
CHECK_EQ(cond.dtype(), kInt32);
}
};
} // namespace tvm
#endif // TVM_STMT_NODE_H_
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
* \file c_api.cc * \file c_api.cc
*/ */
#include <tvm/c_api.h> #include <tvm/c_api.h>
#include <tvm/op.h>
#include "./c_api_common.h" #include "./c_api_common.h"
#include "./c_api_registry.h" #include "./c_api_registry.h"
...@@ -28,6 +27,36 @@ struct TVMAPIThreadLocalEntry { ...@@ -28,6 +27,36 @@ struct TVMAPIThreadLocalEntry {
inline void SetReturn(ArgVariant* ret_val, int* ret_typeid); inline void SetReturn(ArgVariant* ret_val, int* ret_typeid);
}; };
namespace tvm {
inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits, lanes = 0;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}
}
using namespace tvm; using namespace tvm;
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
...@@ -39,45 +68,59 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -39,45 +68,59 @@ struct APIAttrGetter : public AttrVisitor {
std::string skey; std::string skey;
APIVariantValue* ret; APIVariantValue* ret;
void Visit(const char* key, double* value) override { void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, int64_t* value) override { void Visit(const char* key, int64_t* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, int* value) override { void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], std::numeric_limits<int64_t>::max())
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]); if (skey == key) *ret = static_cast<int64_t>(value[0]);
} }
void Visit(const char* key, std::string* value) override { void Visit(const char* key, int* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, Type* value) final {
if (skey == key) *ret = Type2String(value[0]);
} }
void Visit(const char* key, const UnaryOp** value) override { void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0]->FunctionName(); if (skey == key) *ret = value[0];
} }
void Visit(const char* key, const BinaryOp** value) override { void Visit(const char* key, NodeRef* value) final {
if (skey == key) *ret = value[0]->FunctionName(); if (skey == key) *ret = value[0];
} }
}; };
struct APIAttrDir : public AttrVisitor { struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names; std::vector<std::string>* names;
void Visit(const char* key, double* value) override { void Visit(const char* key, double* value) final {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, int64_t* value) override { void Visit(const char* key, uint64_t* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, int* value) override { void Visit(const char* key, bool* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, std::string* value) override { void Visit(const char* key, int* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, const UnaryOp** value) override { void Visit(const char* key, Type* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, const BinaryOp** value) override { void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, NodeRef* value) final {
names->push_back(key); names->push_back(key);
} }
}; };
...@@ -199,17 +242,9 @@ int TVMNodeGetAttr(NodeHandle handle, ...@@ -199,17 +242,9 @@ int TVMNodeGetAttr(NodeHandle handle,
if (ret->ret_value.type_id != kNull) { if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid); ret->SetReturn(ret_val, ret_typeid);
} else { } else {
const std::string& skey = getter.skey;
(*tnode)->VisitNodeRefFields([&skey, ret](const char* key, NodeRef* ref) {
if (key == skey) ret->ret_value = *ref;
});
if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid);
} else {
*ret_typeid = kNull; *ret_typeid = kNull;
} }
} }
}
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
} }
...@@ -223,9 +258,6 @@ int TVMNodeListAttrNames(NodeHandle handle, ...@@ -223,9 +258,6 @@ int TVMNodeListAttrNames(NodeHandle handle,
APIAttrDir dir; APIAttrDir dir;
dir.names = &(ret->ret_vec_str); dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir); (*tnode)->VisitAttrs(&dir);
(*tnode)->VisitNodeRefFields([ret](const char* key, NodeRef* ref) {
ret->ret_vec_str.push_back(key);
});
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
* \file c_api_impl.cc * \file c_api_impl.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/tensor.h>
#include <tvm/expr_util.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace dmlc { namespace dmlc {
...@@ -18,115 +15,4 @@ namespace tvm { ...@@ -18,115 +15,4 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
// expression logic x
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0),
static_cast<DataType>(static_cast<int>(args.at(1))));
})
.add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var");
TVM_REGISTER_API(constant)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
*ret = IntConstant(args.at(0));
} else if (args.at(0).type_id == kDouble) {
*ret = FloatConstant(args.at(0));
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number");
TVM_REGISTER_API(binary_op)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kStr);
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
})
.add_argument("op", "str", "operator")
.add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand");
TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
})
.add_argument("src", "NodeBase", "the node base");
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
})
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "end of the range");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_TensorInput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Tensor(
static_cast<Array<Expr> >(args.at(0)),
static_cast<std::string>(args.at(1)),
static_cast<DataType>(static_cast<int>(args.at(1))));
});
TVM_REGISTER_API(simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Simplify(args.at(0));
});
// transformations
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
os << args.at(0).operator RDomain();
} else if (sptr->is_type<RangeNode>()) {
os << args.at(0).operator Range();
} else {
os << args.at(0).operator Expr();
}
*ret = os.str();
})
.add_argument("expr", "Expr", "expression to be printed");
} // namespace tvm } // namespace tvm
...@@ -62,18 +62,14 @@ struct APIVariantValue { ...@@ -62,18 +62,14 @@ struct APIVariantValue {
inline operator T() const { inline operator T() const {
if (type_id == kNull) return T(); if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr; return T(sptr);
T inst;
inst.node_ = std::move(x);
return inst;
} }
inline operator Expr() const { inline operator Expr() const {
if (type_id == kNull) return Expr(); if (type_id == kNull) return Expr();
if (type_id == kLong) return IntConstant(operator int64_t()); if (type_id == kLong) return Expr(operator int64_t());
if (type_id == kDouble) return FloatConstant(operator double()); if (type_id == kDouble) return Expr(operator double());
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr; return Expr(sptr);
return Expr(std::move(x));
} }
inline operator double() const { inline operator double() const {
CHECK_EQ(type_id, kDouble); CHECK_EQ(type_id, kDouble);
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/domain.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
namespace tvm {
Range::Range(Expr begin, Expr end) {
node_ = std::make_shared<RangeNode>(
std::move(begin), std::move(end));
}
Expr Range::extent() const {
return Simplify(end() - begin());
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
namespace tvm {
Var::Var(std::string name, DataType dtype) {
auto node = std::make_shared<VarNode>();
node->name = std::move(name);
node->dtype_ = dtype;
node_ = std::move(node);
}
Expr IntConstant(int64_t value) {
auto nptr = std::make_shared<IntNode>();
nptr->value = value;
return Expr(std::move(nptr));
}
Expr FloatConstant(double value) {
auto nptr = std::make_shared<FloatNode>();
nptr->value = value;
return Expr(std::move(nptr));
}
Expr BufferRead(Var buffer, Expr offset) {
auto nptr = std::make_shared<BufferReadNode>();
nptr->buffer = std::move(buffer);
nptr->offset = std::move(offset);
nptr->Verify();
return Expr(std::move(nptr));
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.cc
*/
#include <tvm/expr_node.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
void Node::Destroy() {
bool safe = true;
this->VisitNodeRefFields([&safe](const char* k, NodeRef* r) {
if (r->node_.get() != nullptr) safe = false;
});
if (!safe) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
n->VisitNodeRefFields([&safe, &stack, &to_delete](const char* k, NodeRef* r) {
if (r->node_.unique()) {
stack.push_back(r->node_.get());
to_delete.emplace_back(std::move(r->node_));
} else {
r->node_.reset();
}
});
}
}
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_REGISTER_NODE_TYPE(TensorReadNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.cc
*/
#include <tvm/expr_util.h>
#include <tvm/op.h>
namespace tvm {
inline bool is_ingeter(DataType t) {
return t == kInt32;
}
/*! \brief Canonical form of expression */
struct CanonicalExpr {
/*! \brief the e->value */
std::unordered_map<Expr, int64_t> dict;
/*! \brief constant value in the expresssion */
int64_t constant{0};
// change CanonicalExpr as expr
inline Expr AsExpr() const {
Expr e;
using KV = std::pair<Expr, int64_t>;
std::vector<KV> tlist(dict.begin(), dict.end());
std::sort(tlist.begin(), tlist.end(), [](const KV& lhs, const KV& rhs) {
return lhs.first.hash() < rhs.first.hash();
});
for (auto &kv : tlist) {
if (kv.second == 0) continue;
Expr tmp;
if (kv.second == 1) {
tmp = kv.first;
} else {
tmp = kv.first * kv.second;
}
if (e.is_null()) {
e = tmp;
} else {
e = e + tmp;
}
}
if (e.is_null()) {
return IntConstant(constant);
} else {
if (constant != 0) e = e + constant;
return e;
}
}
inline void Add(const Expr& e, int beta) {
auto it = dict.find(e);
if (it != dict.end()) {
it->second += beta;
if (it->second == 0) dict.erase(it);
} else {
dict[e] = beta;
}
}
};
// out += beta * Canonicalize(e)
void AddCanonical(const Expr& e,
CanonicalExpr* out,
int beta) {
static const BinaryOp* add_op = BinaryOp::Get("+");
static const BinaryOp* sub_op = BinaryOp::Get("-");
static const BinaryOp* mul_op = BinaryOp::Get("*");
static const BinaryOp* max_op = BinaryOp::Get("max");
static const BinaryOp* min_op = BinaryOp::Get("min");
CHECK(!e.is_null()) << "cannot simplify null";
switch (e.node_type()) {
case kIntNode: {
out->constant += (e.Get<IntNode>()->value) * beta; return;
}
case kBinaryOpNode: {
const auto* n = e.Get<BinaryOpNode>();
if (n->op == add_op) {
AddCanonical(n->lhs, out, beta);
AddCanonical(n->rhs, out, beta);
return;
}
if (n->op == sub_op) {
AddCanonical(n->lhs, out, beta);
AddCanonical(n->rhs, out, -beta);
return;
}
if (n->op == mul_op) {
if (n->lhs.node_type() == kIntNode) {
AddCanonical(n->rhs, out, beta * (n->lhs.Get<IntNode>()->value)); return;
} else if (n->rhs.node_type() == kIntNode) {
AddCanonical(n->lhs, out, beta * (n->rhs.Get<IntNode>()->value)); return;
}
CanonicalExpr clhs, crhs;
AddCanonical(n->lhs, &clhs, 1);
if (clhs.dict.size() == 0) {
AddCanonical(n->rhs, out, beta * clhs.constant); return;
}
AddCanonical(n->rhs, &crhs, 1);
if (crhs.dict.size() == 0) {
AddCanonical(n->lhs, out, beta * crhs.constant); return;
}
out->Add(e, beta); return;
}
if (n->op == max_op) {
CanonicalExpr res;
AddCanonical(n->lhs, &res, 1);
AddCanonical(n->rhs, &res, -1);
if (res.dict.size() == 0) {
if (res.constant > 0) {
AddCanonical(n->lhs, out, beta); return;
} else {
AddCanonical(n->rhs, out, beta); return;
}
} else {
out->Add(e, beta); return;
}
}
if (n->op == min_op) {
CanonicalExpr res;
AddCanonical(n->lhs, &res, 1);
AddCanonical(n->rhs, &res, -1);
if (res.dict.size() == 0) {
if (res.constant <= 0) {
AddCanonical(n->lhs, out, beta); return;
} else {
AddCanonical(n->rhs, out, beta); return;
}
} else {
out->Add(e, beta); return;
}
}
out->Add(e, beta);
return;
}
default: {
out->Add(e, beta); return;
}
}
}
Expr Simplify(Expr src) {
CanonicalExpr cexpr;
AddCanonical(src, &cexpr, 1);
return cexpr.AsExpr();
}
Expr ExprWithNewChildren(Expr src, std::vector<Expr> children) {
if (children.size()) {
switch (src.node_type()) {
case kBinaryOpNode: {
const auto* n = src.Get<BinaryOpNode>();
if (n->lhs == children[0] && n->rhs == children[0])
return src;
return (*n->op)(children[0], children[1]);
}
case kUnaryOpNode: {
const auto* n = src.Get<UnaryOpNode>();
if (n->src == children[0])
return src;
return (*n->op)(children[0]);
}
case kReduceNode: {
const auto* n = src.Get<ReduceNode>();
if (n->src == children[0])
return src;
return (n->op)->Reduce(children[0], n->rdom);
}
case kTensorReadNode: {
const auto* n = src.Get<TensorReadNode>();
bool same = true;
for (size_t i = 0; i < n->indices.size(); ++i) {
if (n->indices[i] != children[i]) {
same = false;
break;
}
}
if (same)
return src;
Array<Expr> indices(children);
return n->tensor(indices);
}
default: {
return src;
}
}
}
return src;
}
Expr Bind(Expr src, std::unordered_map<Expr, Expr> dict) {
auto replace = [&](Expr e, std::vector<Expr> children) {
switch (e.node_type()) {
case kVarNode: {
auto it = dict.find(e);
if (it != dict.end()) {
return it->second;
}
return e;
}
default: {
return ExprWithNewChildren(e, children);
}
}
};
return Transform(src, replace);
}
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
}
switch (this->node_type()) {
case kVarNode: {
os << Get<VarNode>()->name; return;
}
case kIntNode: {
os << Get<IntNode>()->value; return;
}
case kFloatNode: {
os << Get<FloatNode>()->value; return;
}
case kBinaryOpNode: {
const auto* n = Get<BinaryOpNode>();
const char* fname = n->op->FunctionName();
if (fname[1] == '\0' && !isalpha(fname[0])) {
os << '(';
n->lhs.Print(os);
os << ' ' << fname[0] << ' ';
n->rhs.Print(os);
os << ')';
} else {
os << fname << '(';
n->lhs.Print(os);
os << ", ";
n->rhs.Print(os);
os << ')';
}
return;
}
case kUnaryOpNode: {
const auto* n = Get<UnaryOpNode>();
os << n->op->FunctionName() << '(';
n->src.Print(os);
os << ')';
return;
}
case kReduceNode: {
const auto* n = Get<ReduceNode>();
os << "reduce("<< n->op->FunctionName() << ", ";
n->src.Print(os);
os << ", " << n->rdom << ')';
return;
}
case kTensorReadNode: {
const auto* n = Get<TensorReadNode>();
os << n->tensor->name << n->indices;
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
}
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file op.cc
*/
#include <tvm/op.h>
#include <tvm/expr_node.h>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::BinaryOpReg);
DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
} // namespace dmlc
namespace tvm {
Expr UnaryOp::operator()(Expr src) const {
auto nptr = std::make_shared<UnaryOpNode>(this, std::move(src));
nptr->Verify();
return Expr(std::move(nptr));
}
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
auto nptr = std::make_shared<BinaryOpNode>(
this, std::move(lhs), std::move(rhs));
nptr->Verify();
return Expr(std::move(nptr));
}
Expr BinaryOp::Reduce(Expr src, RDomain rdom) const {
auto nptr = std::make_shared<ReduceNode>(
this, std::move(src), std::move(rdom));
nptr->Verify();
return Expr(std::move(nptr));
}
const BinaryOp* BinaryOp::Get(const char* name) {
const auto* op = dmlc::Registry<BinaryOpReg>::Find(name);
CHECK(op != nullptr) << "cannot find " << name;
return op->op.get();
}
TVM_REGISTER_BINARY_OP(+, AddOp);
TVM_REGISTER_BINARY_OP(-, SubOp);
TVM_REGISTER_BINARY_OP(*, MulOp);
TVM_REGISTER_BINARY_OP(/, DivOp);
TVM_REGISTER_BINARY_OP(max, MaxOp);
TVM_REGISTER_BINARY_OP(min, MinOp);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.cc
*/
#include <tvm/expr.h>
#include <tvm/stmt.h>
#include <tvm/stmt_node.h>
namespace tvm {
Stmt Store(Var buffer, Expr offset, Expr src) {
auto nptr = std::make_shared<StoreNode>();
nptr->buffer = std::move(buffer);
nptr->offset = std::move(offset);
nptr->src = std::move(src);
nptr->Verify();
return Stmt(std::move(nptr));
}
Stmt ForRange(Var loop_var, Range range, Stmt body) {
auto nptr = std::make_shared<ForRangeNode>();
nptr->loop_var = std::move(loop_var);
nptr->range = std::move(range);
nptr->body = std::move(body);
nptr->Verify();
return Stmt(std::move(nptr));
}
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body) {
auto nptr = std::make_shared<IfThenElseNode>();
nptr->cond = std::move(cond);
nptr->then_body = std::move(then_body);
nptr->else_body = std::move(else_body);
nptr->Verify();
return Stmt(std::move(nptr));
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.cc
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
#include <memory>
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, DataType dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Tensor::Tensor(Array<Expr> shape, FCompute fcompute, std::string name) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->shape = std::move(shape);
size_t ndim = node->shape.size();
std::vector<Var> dim_index;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_index" << i;
dim_index.push_back(Var(os.str()));
}
node->dim_index = Array<Var>(dim_index);
node->source = fcompute(node->dim_index);
node->dtype = node->source.dtype();
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
auto node = std::make_shared<TensorReadNode>();
node->tensor = *this;
node->indices = std::move(indices);
return Expr(std::move(node));
}
std::vector<Tensor> Tensor::InputTensors() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
std::vector<Tensor> inputs;
if (n->source.is_null()) return inputs;
Visit(n->source, [&inputs](const Expr& e) {
if (e.node_type() == kTensorReadNode) {
inputs.push_back(e.Get<TensorReadNode>()->tensor);
}
});
return inputs;
}
bool Tensor::IsRTensor() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
if (n->source.is_null()) return false;
return n->source.node_type() == kReduceNode;
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <ir/IR.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
using namespace Halide::Internal;
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm);
TVM_REGISTER_NODE_TYPE(StringImm);
TVM_REGISTER_NODE_TYPE(Cast);
TVM_REGISTER_NODE_TYPE(Variable);
TVM_REGISTER_NODE_TYPE(Add);
TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ);
TVM_REGISTER_NODE_TYPE(NE);
TVM_REGISTER_NODE_TYPE(LT);
TVM_REGISTER_NODE_TYPE(LE);
TVM_REGISTER_NODE_TYPE(GT);
TVM_REGISTER_NODE_TYPE(GE);
TVM_REGISTER_NODE_TYPE(And);
TVM_REGISTER_NODE_TYPE(Or);
TVM_REGISTER_NODE_TYPE(Not);
TVM_REGISTER_NODE_TYPE(Select);
TVM_REGISTER_NODE_TYPE(Load);
TVM_REGISTER_NODE_TYPE(Ramp);
TVM_REGISTER_NODE_TYPE(Broadcast);
TVM_REGISTER_NODE_TYPE(Call);
TVM_REGISTER_NODE_TYPE(Let);
TVM_REGISTER_NODE_TYPE(LetStmt);
TVM_REGISTER_NODE_TYPE(AssertStmt);
TVM_REGISTER_NODE_TYPE(ProducerConsumer);
TVM_REGISTER_NODE_TYPE(For);
TVM_REGISTER_NODE_TYPE(Store);
TVM_REGISTER_NODE_TYPE(Provide);
TVM_REGISTER_NODE_TYPE(Allocate);
TVM_REGISTER_NODE_TYPE(Free);
TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
*/
#include <tvm/schedule.h>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment