Commit c41d9d23 by tqchen

Temp checkin c++ code.

parent 1a18f08e
...@@ -88,4 +88,6 @@ ENV/ ...@@ -88,4 +88,6 @@ ENV/
# Rope project settings # Rope project settings
.ropeproject .ropeproject
*~ *~
*.pyc *.pyc
\ No newline at end of file *~
build
[submodule "dmlc-core"]
path = dmlc-core
url = https://github.com/dmlc/dmlc-core
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC
# specify tensor path
.PHONY: clean all
all: lib/libtvm.a
SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ)
build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o
-include build/*.d
-include build/*/*.d
Subproject commit 39007ac49b6087339dc3104324cb4e0de47f1c5f
/*!
* Copyright (c) 2016 by Contributors
* \file base.h
* \brief Defines the base data structure
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <string>
#include <memory>
#include <functional>
#include <typeinfo>
namespace tvm {
// forward declaration
class Node;
class NodeRef;
class UnaryOp;
class BinaryOp;
/*! \brief list of all supported data types */
enum DataType {
kUnknown,
kInt32,
kFloat32
};
/*!
* \brief List of subset node types used for quick runtime switch.
*
* \note The value of NodeType is not used for serialization type_key is used instead.
* \note is_type and type_key can be used to do type checking for all types
* \note kOtherNodes could mean more than one node type.
*/
enum NodeType {
kVarNode,
kIntNode,
kFloatNode,
kUnaryOpNode,
kBinaryOpNode,
kReduceNode,
kTensorReadNode,
kOtherNodes
};
/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, const UnaryOp** value) = 0;
virtual void Visit(const char* key, const BinaryOp** value) = 0;
//! \endcond
};
/*!
* \brief A function to be applied when visit each NodeRef Field.
* \param ref The child to be visited.
*/
using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>;
/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class Node {
public:
/*! \brief virtual destructor */
virtual ~Node();
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*! \brief verify the correctness of node struct after it get mutated by visitor */
virtual void Verify() const {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains NodeRefFields.
* \param visitor The visitor
*/
virtual void VisitNodeRefFields(FNodeRefVisit visitor) {}
/*!
* \tparam NodeType the type to be checked.
* \return whether the stored type is node type
*/
template<typename TNode>
inline bool is_type() const;
/*! \return the node type */
inline NodeType node_type() const;
protected:
// node ref can see this
friend class NodeRef;
/*! \brief the node type enum */
NodeType node_type_{kOtherNodes};
};
/*! \brief base class of all node reference object */
class NodeRef {
public:
/*!
* \return typed pointer of the node
* \tparam TNode the type of the node.
*/
template<typename TNode>
inline const TNode* Get() const;
/*! \return wheyjer the expression is null */
inline bool is_null() const;
protected:
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {}
/*! \brief the internal node */
std::shared_ptr<Node> node_;
};
/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct NodeFactoryReg
: public dmlc::FunctionRegEntryBase<NodeFactoryReg,
NodeFactory> {
};
#define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); })
// implementations of inline functions after this
inline NodeType Node::node_type() const {
return node_type_;
}
template<typename TNode>
inline bool Node::is_type() const {
const std::type_info& tinfo = typeid(*this);
if (&typeid(TNode) == &tinfo) return true;
return typeid(TNode) == tinfo;
}
template<typename TNode>
inline const TNode* NodeRef::Get() const {
CHECK(node_->is_type<TNode>())
<< " type inconsistent, expected " << typeid(TNode).name()
<< " given " << typeid(*this).name();
return static_cast<const TNode*>(node_.get());
}
inline bool NodeRef::is_null() const {
return node_.get() == nullptr;
}
} // namespace tvm
#endif // TVM_BASE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of TVM DSL
*/
#ifndef TVM_C_API_H_
#define TVM_C_API_H_
#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport)
#else
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport)
#endif
#else
#define TVM_DLL TVM_EXTERN_C
#endif
/*! \brief handle to node creator */
typedef void* NodeCreatorHandle;
/*! \brief handle to node */
typedef void* NodeHandle;
TVM_DLL int TVMNodeCreatorGet(const char* node_type,
NodeCreatorHandle *handle);
TVM_DLL int TVMNodeCreate(NodeCreatorHandle handle,
int num_child_ref,
const char* child_ref_keys,
NodeHandle* child_node_refs,
int num_attrs,
const char* attr_keys,
const char* attr_vals,
NodeHandle* handle);
TVM_DLL int TVMNodeGetAttr(const char* key,
const char** value);
TVM_DLL int TVMNodeGetChildNodeRef(const char* key,
NodeHandle* out);
#endif // TVM_C_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <memory>
namespace tvm {
class RDom {
};
} // namespace tvm
#endif // TVM_DOMAIN_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr.h
* \brief Defines the expressions in AST.
*/
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <type_traits>
#include "./base.h"
namespace tvm {
// forward declare Expr
class Expr;
/*!
* \brief create a constant expression
* \tparam T the value type
* \param value The value to the constant.
* \return The created expression
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value);
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Expr : public NodeRef {
public:
/*! \brief default constructor */
Expr() = default;
/*!
* \brief copy constructor
* \param other the input
*/
Expr(const Expr& other) = default; // NOLINT(*)
/*!
* \brief move constructor
* \param other the input
*/
Expr(Expr&& other) = default; // NOLINT(*)
/*!
* \brief assign operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(const Expr& other) = default;
/*!
* \brief assign move operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(Expr&& other) = default;
/*!
* \brief constructor from constant value
* \param value the constant value
* \tparam T The constant type
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
Expr(T value) { // NOLINT(*)
*this = std::move(constant<T>(value));
}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Expr(std::shared_ptr<Node> nptr) : NodeRef(nptr) {}
/*! \return the expression type of the expression */
inline DataType dtype() const;
};
/*! \brief Variable class */
class Var : public Expr {
public:
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};
/*! \brief */
Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value);
Expr operator+(Expr lhs, Expr rhs);
/*! \brief base of expression node */
class ExprNode : public Node {
public:
/*! \brief type of data stored in expression */
DataType dtype_{kUnknown};
};
// inline implementations
inline DataType Expr::dtype() const {
return static_cast<const ExprNode*>(node_.get())->dtype_;
}
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value) {
if (std::is_integral<T>::value) {
return IntConstant(static_cast<int64_t>(value));
} else {
return FloatConstant(static_cast<double>(value));
}
}
} // namespace tvm
#endif // TVM_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.h
* \brief Defines the expression nodes in AST.
*/
#ifndef TVM_EXPR_NODE_H_
#define TVM_EXPR_NODE_H_
#include <string>
#include "./domain.h"
#include "./tensor.h"
#include "./expr.h"
namespace tvm {
/*! \brief variable node for symbolic variables */
class VarNode : public ExprNode {
public:
/*! \brief hint name of the variable */
std::string name;
/*! \brief constructor */
VarNode() {
node_type_ = kVarNode;
}
const char* type_key() const override {
return "VarNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
}
};
/*! \brief integer constant node */
class IntNode : public ExprNode {
public:
/*! \brief the value field */
int64_t value;
/*! \brief constructor */
IntNode() {
node_type_ = kIntNode;
dtype_ = kInt32;
}
const char* type_key() const override {
return "IntNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
}
};
/*! \brief float constant node */
class FloatNode : public ExprNode {
public:
/*! \brief the value field */
double value;
/*! \brief constructor */
FloatNode() {
node_type_ = kFloatNode;
dtype_ = kFloat32;
}
const char* type_key() const override {
return "IntNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
}
};
/*! \brief Unary mapping operator */
class UnaryOpNode : public ExprNode {
public:
/*! \brief The operator */
const UnaryOp* op;
/*! \brief The source expression */
Expr src;
/*! \brief constructor */
UnaryOpNode() {
node_type_ = kUnaryOpNode;
}
const char* type_key() const override {
return "UnaryOpNode";
}
void Verify() const override {
CHECK_EQ(dtype_, src.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("src", &src);
}
};
/*! \brief Binary mapping operator */
struct BinaryOpNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The left operand */
Expr lhs;
/*! \brief The right operand */
Expr rhs;
/*! \brief constructor, do not use constructor */
BinaryOpNode() {
node_type_ = kBinaryOpNode;
}
BinaryOpNode(const BinaryOp* op, Expr lhs, Expr rhs)
: op(op), lhs(lhs), rhs(rhs) {
node_type_ = kBinaryOpNode;
dtype_ = lhs.dtype();
}
const char* type_key() const override {
return "BinaryOpNode";
}
void Verify() const override {
CHECK_EQ(dtype_, lhs.dtype());
CHECK_EQ(dtype_, rhs.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("lhs", &lhs);
fvisit("rhs", &rhs);
}
};
} // namespace tvm
#endif // TVM_EXPR_NODE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op.h
* \brief Defines the operators
*/
#ifndef TVM_OP_H_
#define TVM_OP_H_
#include <string>
#include "./expr.h"
namespace tvm {
class BinaryOp {
public:
virtual std::string Format(const std::string& lhs, const std::string& rhs);
};
class UnaryOp {
public:
};
class AddOp : public BinaryOp {
public:
static AddOp* Get();
};
class SubOp : public BinaryOp {
public:
static SubOp* Get();
};
class MulOp : public BinaryOp {
public:
static SubOp* Get();
};
class DivOp : public BinaryOp {
public:
static DivOp* Get();
};
} // namespace tvm
#endif // TVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tensor.h
* \brief Dataflow tensor object
*/
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include "./expr.h"
namespace tvm {
class Tensor {
private:
/*! \brief The shape of the tensor */
/*! \brief source expression */
Expr src_expr;
};
} // namespace tvm
#endif // TVM_TENSOR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
namespace tvm {
Var::Var(std::string name, DataType dtype) {
auto node_ = std::make_shared<VarNode>();
node_->name = std::move(name);
node_->dtype_ = dtype;
}
Expr IntConstant(int64_t value) {
auto nptr = std::make_shared<IntNode>();
nptr->value = value;
return Expr(std::move(nptr));
}
Expr FloatConstant(double value) {
auto nptr = std::make_shared<FloatNode>();
nptr->value = value;
return Expr(std::move(nptr));
}
Expr operator+(Expr lhs, Expr rhs) {
auto nptr = std::make_shared<BinaryOpNode>(AddOp::Get(), lhs, rhs);
nptr->Verify();
return Expr(std::move(nptr));
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_node.cc
*/
#include <tvm/expr_node.h>
#include <memory>
namespace tvm {
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file op.cc
*/
#include <tvm/op.h>
namespace tvm {
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment