Commit 8278e02f by tqchen

checkin basic cpp test

parent 05e871d4
...@@ -3,13 +3,17 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ ...@@ -3,13 +3,17 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all .PHONY: clean all test
all: lib/libtvm.a all: lib/libtvm.a
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) ALL_DEP = $(ALL_OBJ)
include tests/cpp/unittest.mk
test: $(TEST)
build/%.o: src/%.cc build/%.o: src/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
...@@ -24,7 +28,7 @@ lint: ...@@ -24,7 +28,7 @@ lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src python2 dmlc-core/scripts/lint.py tvm cpp include src
clean: clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
-include build/*.d -include build/*.d
-include build/*/*.d -include build/*/*.d
...@@ -76,7 +76,7 @@ using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>; ...@@ -76,7 +76,7 @@ using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>;
class Node { class Node {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~Node(); virtual ~Node() {}
/*! \return The unique type key of the node */ /*! \return The unique type key of the node */
virtual const char* type_key() const = 0; virtual const char* type_key() const = 0;
/*! \brief verify the correctness of node struct after it get mutated by visitor */ /*! \brief verify the correctness of node struct after it get mutated by visitor */
...@@ -101,8 +101,6 @@ class Node { ...@@ -101,8 +101,6 @@ class Node {
*/ */
template<typename TNode> template<typename TNode>
inline bool is_type() const; inline bool is_type() const;
/*! \return the node type */
inline NodeType node_type() const;
protected: protected:
// node ref can see this // node ref can see this
...@@ -120,12 +118,15 @@ class NodeRef { ...@@ -120,12 +118,15 @@ class NodeRef {
*/ */
template<typename TNode> template<typename TNode>
inline const TNode* Get() const; inline const TNode* Get() const;
/*! \return the node type */
inline NodeType node_type() const;
/*! \return wheyjer the expression is null */ /*! \return wheyjer the expression is null */
inline bool is_null() const; inline bool is_null() const;
protected:
NodeRef() = default; NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {} explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
protected:
/*! \brief the internal node */ /*! \brief the internal node */
std::shared_ptr<Node> node_; std::shared_ptr<Node> node_;
}; };
...@@ -146,8 +147,8 @@ struct NodeFactoryReg ...@@ -146,8 +147,8 @@ struct NodeFactoryReg
.set_body([]() { return std::make_shared<TypeName>(); }) .set_body([]() { return std::make_shared<TypeName>(); })
// implementations of inline functions after this // implementations of inline functions after this
inline NodeType Node::node_type() const { inline NodeType NodeRef::node_type() const {
return node_type_; return node_->node_type_;
} }
template<typename TNode> template<typename TNode>
......
...@@ -34,12 +34,12 @@ class Expr : public NodeRef { ...@@ -34,12 +34,12 @@ class Expr : public NodeRef {
* \brief copy constructor * \brief copy constructor
* \param other the input * \param other the input
*/ */
Expr(const Expr& other) = default; // NOLINT(*) Expr(const Expr& other) = default;
/*! /*!
* \brief move constructor * \brief move constructor
* \param other the input * \param other the input
*/ */
Expr(Expr&& other) = default; // NOLINT(*) Expr(Expr&& other) = default;
/*! /*!
* \brief assign operator. * \brief assign operator.
* \param other the input. * \param other the input.
...@@ -66,9 +66,20 @@ class Expr : public NodeRef { ...@@ -66,9 +66,20 @@ class Expr : public NodeRef {
* \brief constructor from node pointer * \brief constructor from node pointer
* \param nptr Another node shared pointer * \param nptr Another node shared pointer
*/ */
explicit Expr(std::shared_ptr<Node> nptr) : NodeRef(nptr) {} explicit Expr(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
/*! \return the expression type of the expression */ /*! \return the expression type of the expression */
inline DataType dtype() const; inline DataType dtype() const;
// print the expression.
friend std::ostream& operator<<(std::ostream &os, const Expr& e) { // NOLINT(*)
e.Print(os);
return os;
}
private:
// print the expression
void Print(std::ostream& os) const; // NOLINT(*)
}; };
/*! \brief Variable class */ /*! \brief Variable class */
...@@ -77,10 +88,8 @@ class Var : public Expr { ...@@ -77,10 +88,8 @@ class Var : public Expr {
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*) Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
}; };
/*! \brief */
Expr IntConstant(int64_t value); Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value); Expr FloatConstant(int64_t value);
Expr operator+(Expr lhs, Expr rhs);
/*! \brief base of expression node */ /*! \brief base of expression node */
class ExprNode : public Node { class ExprNode : public Node {
......
...@@ -77,6 +77,11 @@ class UnaryOpNode : public ExprNode { ...@@ -77,6 +77,11 @@ class UnaryOpNode : public ExprNode {
UnaryOpNode() { UnaryOpNode() {
node_type_ = kUnaryOpNode; node_type_ = kUnaryOpNode;
} }
UnaryOpNode(const UnaryOp* op, Expr && src)
: op(op), src(std::move(src)) {
node_type_ = kUnaryOpNode;
dtype_ = this->src.dtype();
}
const char* type_key() const override { const char* type_key() const override {
return "UnaryOpNode"; return "UnaryOpNode";
} }
...@@ -104,10 +109,10 @@ struct BinaryOpNode : public ExprNode { ...@@ -104,10 +109,10 @@ struct BinaryOpNode : public ExprNode {
BinaryOpNode() { BinaryOpNode() {
node_type_ = kBinaryOpNode; node_type_ = kBinaryOpNode;
} }
BinaryOpNode(const BinaryOp* op, Expr lhs, Expr rhs) BinaryOpNode(const BinaryOp* op, Expr && lhs, Expr && rhs)
: op(op), lhs(lhs), rhs(rhs) { : op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {
node_type_ = kBinaryOpNode; node_type_ = kBinaryOpNode;
dtype_ = lhs.dtype(); dtype_ = this->lhs.dtype();
} }
const char* type_key() const override { const char* type_key() const override {
return "BinaryOpNode"; return "BinaryOpNode";
......
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.h
* \brief Expression util
*/
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include "./expr.h"
namespace tvm {
} // namespace tvm
#endif // TVM_EXPR_UTIL_H_
...@@ -11,35 +11,106 @@ ...@@ -11,35 +11,106 @@
namespace tvm { namespace tvm {
/*! \brief binary operator */
class BinaryOp { class BinaryOp {
public: public:
virtual std::string Format(const std::string& lhs, const std::string& rhs); /*! \return the function name to be called in binary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the binary op
* \param lhs left operand
* \param rhs right operand
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
}; };
/*! \brief unary operator */
class UnaryOp { class UnaryOp {
public: public:
/*! \return the function name to be called in unary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the unary op
* \param src left operand
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
}; };
class AddOp : public BinaryOp { class AddOp : public BinaryOp {
public: public:
const char* FunctionName() const override {
return "+";
}
static AddOp* Get(); static AddOp* Get();
}; };
class SubOp : public BinaryOp { class SubOp : public BinaryOp {
public: public:
const char* FunctionName() const override {
return "-";
}
static SubOp* Get(); static SubOp* Get();
}; };
class MulOp : public BinaryOp { class MulOp : public BinaryOp {
public: public:
static SubOp* Get(); const char* FunctionName() const override {
return "*";
}
static MulOp* Get();
}; };
class DivOp : public BinaryOp { class DivOp : public BinaryOp {
public: public:
const char* FunctionName() const override {
return "/";
}
static DivOp* Get(); static DivOp* Get();
}; };
class MaxOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "max";
}
static MaxOp* Get();
};
class MinOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "min";
}
static MinOp* Get();
};
#define DEFINE_OP_OVERLOAD(OpChar, OpName) \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
}
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
}
DEFINE_OP_OVERLOAD(+, AddOp);
DEFINE_OP_OVERLOAD(-, SubOp);
DEFINE_OP_OVERLOAD(*, MulOp);
DEFINE_OP_OVERLOAD(/, DivOp);
DEFINE_BINARY_OP_FUNCTION(max, MaxOp);
DEFINE_BINARY_OP_FUNCTION(min, MinOp);
} // namespace tvm } // namespace tvm
#endif // TVM_OP_H_ #endif // TVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm.h
* \brief Header to include all C++ API.
*/
#ifndef TVM_TVM_H_
#define TVM_TVM_H_
#include "./base.h"
#include "./expr.h"
#include "./op.h"
#include "./tensor.h"
#endif // TVM_TVM_H_
...@@ -5,13 +5,60 @@ ...@@ -5,13 +5,60 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/op.h> #include <tvm/op.h>
#include <tvm/expr_node.h> #include <tvm/expr_node.h>
#include <cctype>
namespace tvm { namespace tvm {
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
}
switch (this->node_type()) {
case kVarNode: {
os << Get<VarNode>()->name; return;
}
case kIntNode: {
os << Get<IntNode>()->value; return;
}
case kFloatNode: {
os << Get<FloatNode>()->value; return;
}
case kBinaryOpNode: {
const auto* n = Get<BinaryOpNode>();
const char* fname = n->op->FunctionName();
if (fname[1] == '\0' && !isalpha(fname[0])) {
os << '(';
n->lhs.Print(os);
os << ' ' << fname[0] << ' ';
n->rhs.Print(os);
os << ')';
} else {
os << fname << '(';
n->lhs.Print(os);
os << ", ";
n->rhs.Print(os);
os << ')';
}
return;
}
case kUnaryOpNode: {
const auto* n = Get<UnaryOpNode>();
os << n->op->FunctionName() << '(';
n->src.Print(os);
os << ')';
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
}
}
Var::Var(std::string name, DataType dtype) { Var::Var(std::string name, DataType dtype) {
auto node_ = std::make_shared<VarNode>(); auto node = std::make_shared<VarNode>();
node_->name = std::move(name); node->name = std::move(name);
node_->dtype_ = dtype; node->dtype_ = dtype;
node_ = std::move(node);
} }
Expr IntConstant(int64_t value) { Expr IntConstant(int64_t value) {
...@@ -26,10 +73,4 @@ Expr FloatConstant(double value) { ...@@ -26,10 +73,4 @@ Expr FloatConstant(double value) {
return Expr(std::move(nptr)); return Expr(std::move(nptr));
} }
Expr operator+(Expr lhs, Expr rhs) {
auto nptr = std::make_shared<BinaryOpNode>(AddOp::Get(), lhs, rhs);
nptr->Verify();
return Expr(std::move(nptr));
}
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.cc
*/
#include <tvm/expr_util.h>
namespace tvm {
} // namespace tvm
...@@ -3,9 +3,28 @@ ...@@ -3,9 +3,28 @@
* \file op.cc * \file op.cc
*/ */
#include <tvm/op.h> #include <tvm/op.h>
#include <tvm/expr_node.h>
namespace tvm { namespace tvm {
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
auto nptr = std::make_shared<BinaryOpNode>(
this, std::move(lhs), std::move(rhs));
nptr->Verify();
return Expr(std::move(nptr));
} }
#define DEFINE_SINGLETON_GET(TypeName) \
TypeName* TypeName::Get() { \
static TypeName inst; \
return &inst; \
}
DEFINE_SINGLETON_GET(AddOp);
DEFINE_SINGLETON_GET(SubOp);
DEFINE_SINGLETON_GET(MulOp);
DEFINE_SINGLETON_GET(DivOp);
DEFINE_SINGLETON_GET(MaxOp);
DEFINE_SINGLETON_GET(MinOp);
} // namespace tvm
unittest
*.d
*_test
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
TEST(Expr, Basic) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
std::ostringstream os;
os << z;
CHECK(os.str() == "max(((x + 1) + 2), 100)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
GTEST_LIB=$(GTEST_PATH)/lib/
GTEST_INC=$(GTEST_PATH)/include/
TEST_SRC = $(wildcard tests/cpp/*_test.cc)
TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC))
tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.a
$(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d
$(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \
-L$(GTEST_LIB) $(LDFLAGS) -lgtest
-include tests/cpp/*.d
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment