Commit 8278e02f by tqchen

checkin basic cpp test

parent 05e871d4
......@@ -3,13 +3,17 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC
# specify tensor path
.PHONY: clean all
.PHONY: clean all test
all: lib/libtvm.a
SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ)
include tests/cpp/unittest.mk
test: $(TEST)
build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
......@@ -24,7 +28,7 @@ lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
-include build/*.d
-include build/*/*.d
......@@ -76,7 +76,7 @@ using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>;
class Node {
public:
/*! \brief virtual destructor */
virtual ~Node();
virtual ~Node() {}
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*! \brief verify the correctness of node struct after it get mutated by visitor */
......@@ -101,8 +101,6 @@ class Node {
*/
template<typename TNode>
inline bool is_type() const;
/*! \return the node type */
inline NodeType node_type() const;
protected:
// node ref can see this
......@@ -120,12 +118,15 @@ class NodeRef {
*/
template<typename TNode>
inline const TNode* Get() const;
/*! \return the node type */
inline NodeType node_type() const;
/*! \return wheyjer the expression is null */
inline bool is_null() const;
protected:
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {}
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
protected:
/*! \brief the internal node */
std::shared_ptr<Node> node_;
};
......@@ -146,8 +147,8 @@ struct NodeFactoryReg
.set_body([]() { return std::make_shared<TypeName>(); })
// implementations of inline functions after this
inline NodeType Node::node_type() const {
return node_type_;
inline NodeType NodeRef::node_type() const {
return node_->node_type_;
}
template<typename TNode>
......
......@@ -34,12 +34,12 @@ class Expr : public NodeRef {
* \brief copy constructor
* \param other the input
*/
Expr(const Expr& other) = default; // NOLINT(*)
Expr(const Expr& other) = default;
/*!
* \brief move constructor
* \param other the input
*/
Expr(Expr&& other) = default; // NOLINT(*)
Expr(Expr&& other) = default;
/*!
* \brief assign operator.
* \param other the input.
......@@ -66,9 +66,20 @@ class Expr : public NodeRef {
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Expr(std::shared_ptr<Node> nptr) : NodeRef(nptr) {}
explicit Expr(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
/*! \return the expression type of the expression */
inline DataType dtype() const;
// print the expression.
friend std::ostream& operator<<(std::ostream &os, const Expr& e) { // NOLINT(*)
e.Print(os);
return os;
}
private:
// print the expression
void Print(std::ostream& os) const; // NOLINT(*)
};
/*! \brief Variable class */
......@@ -77,10 +88,8 @@ class Var : public Expr {
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};
/*! \brief */
Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value);
Expr operator+(Expr lhs, Expr rhs);
/*! \brief base of expression node */
class ExprNode : public Node {
......
......@@ -77,6 +77,11 @@ class UnaryOpNode : public ExprNode {
UnaryOpNode() {
node_type_ = kUnaryOpNode;
}
UnaryOpNode(const UnaryOp* op, Expr && src)
: op(op), src(std::move(src)) {
node_type_ = kUnaryOpNode;
dtype_ = this->src.dtype();
}
const char* type_key() const override {
return "UnaryOpNode";
}
......@@ -104,10 +109,10 @@ struct BinaryOpNode : public ExprNode {
BinaryOpNode() {
node_type_ = kBinaryOpNode;
}
BinaryOpNode(const BinaryOp* op, Expr lhs, Expr rhs)
: op(op), lhs(lhs), rhs(rhs) {
BinaryOpNode(const BinaryOp* op, Expr && lhs, Expr && rhs)
: op(op), lhs(std::move(lhs)), rhs(std::move(rhs)) {
node_type_ = kBinaryOpNode;
dtype_ = lhs.dtype();
dtype_ = this->lhs.dtype();
}
const char* type_key() const override {
return "BinaryOpNode";
......
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.h
* \brief Expression util
*/
#ifndef TVM_EXPR_UTIL_H_
#define TVM_EXPR_UTIL_H_
#include "./expr.h"
namespace tvm {
} // namespace tvm
#endif // TVM_EXPR_UTIL_H_
......@@ -11,35 +11,106 @@
namespace tvm {
/*! \brief binary operator */
class BinaryOp {
public:
virtual std::string Format(const std::string& lhs, const std::string& rhs);
/*! \return the function name to be called in binary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the binary op
* \param lhs left operand
* \param rhs right operand
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
};
/*! \brief unary operator */
class UnaryOp {
public:
/*! \return the function name to be called in unary op */
virtual const char* FunctionName() const = 0;
/*!
* \brief apply the unary op
* \param src left operand
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
};
class AddOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "+";
}
static AddOp* Get();
};
class SubOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "-";
}
static SubOp* Get();
};
class MulOp : public BinaryOp {
public:
static SubOp* Get();
const char* FunctionName() const override {
return "*";
}
static MulOp* Get();
};
class DivOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "/";
}
static DivOp* Get();
};
class MaxOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "max";
}
static MaxOp* Get();
};
class MinOp : public BinaryOp {
public:
const char* FunctionName() const override {
return "min";
}
static MinOp* Get();
};
#define DEFINE_OP_OVERLOAD(OpChar, OpName) \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
}
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
}
DEFINE_OP_OVERLOAD(+, AddOp);
DEFINE_OP_OVERLOAD(-, SubOp);
DEFINE_OP_OVERLOAD(*, MulOp);
DEFINE_OP_OVERLOAD(/, DivOp);
DEFINE_BINARY_OP_FUNCTION(max, MaxOp);
DEFINE_BINARY_OP_FUNCTION(min, MinOp);
} // namespace tvm
#endif // TVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm.h
* \brief Header to include all C++ API.
*/
#ifndef TVM_TVM_H_
#define TVM_TVM_H_
#include "./base.h"
#include "./expr.h"
#include "./op.h"
#include "./tensor.h"
#endif // TVM_TVM_H_
......@@ -5,13 +5,60 @@
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <cctype>
namespace tvm {
void Expr::Print(std::ostream& os) const {
if (is_null()) {
os << "null"; return;
}
switch (this->node_type()) {
case kVarNode: {
os << Get<VarNode>()->name; return;
}
case kIntNode: {
os << Get<IntNode>()->value; return;
}
case kFloatNode: {
os << Get<FloatNode>()->value; return;
}
case kBinaryOpNode: {
const auto* n = Get<BinaryOpNode>();
const char* fname = n->op->FunctionName();
if (fname[1] == '\0' && !isalpha(fname[0])) {
os << '(';
n->lhs.Print(os);
os << ' ' << fname[0] << ' ';
n->rhs.Print(os);
os << ')';
} else {
os << fname << '(';
n->lhs.Print(os);
os << ", ";
n->rhs.Print(os);
os << ')';
}
return;
}
case kUnaryOpNode: {
const auto* n = Get<UnaryOpNode>();
os << n->op->FunctionName() << '(';
n->src.Print(os);
os << ')';
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
}
}
Var::Var(std::string name, DataType dtype) {
auto node_ = std::make_shared<VarNode>();
node_->name = std::move(name);
node_->dtype_ = dtype;
auto node = std::make_shared<VarNode>();
node->name = std::move(name);
node->dtype_ = dtype;
node_ = std::move(node);
}
Expr IntConstant(int64_t value) {
......@@ -26,10 +73,4 @@ Expr FloatConstant(double value) {
return Expr(std::move(nptr));
}
Expr operator+(Expr lhs, Expr rhs) {
auto nptr = std::make_shared<BinaryOpNode>(AddOp::Get(), lhs, rhs);
nptr->Verify();
return Expr(std::move(nptr));
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr_util.cc
*/
#include <tvm/expr_util.h>
namespace tvm {
} // namespace tvm
......@@ -3,9 +3,28 @@
* \file op.cc
*/
#include <tvm/op.h>
#include <tvm/expr_node.h>
namespace tvm {
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
auto nptr = std::make_shared<BinaryOpNode>(
this, std::move(lhs), std::move(rhs));
nptr->Verify();
return Expr(std::move(nptr));
}
#define DEFINE_SINGLETON_GET(TypeName) \
TypeName* TypeName::Get() { \
static TypeName inst; \
return &inst; \
}
DEFINE_SINGLETON_GET(AddOp);
DEFINE_SINGLETON_GET(SubOp);
DEFINE_SINGLETON_GET(MulOp);
DEFINE_SINGLETON_GET(DivOp);
DEFINE_SINGLETON_GET(MaxOp);
DEFINE_SINGLETON_GET(MinOp);
} // namespace tvm
unittest
*.d
*_test
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
TEST(Expr, Basic) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
std::ostringstream os;
os << z;
CHECK(os.str() == "max(((x + 1) + 2), 100)");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
GTEST_LIB=$(GTEST_PATH)/lib/
GTEST_INC=$(GTEST_PATH)/include/
TEST_SRC = $(wildcard tests/cpp/*_test.cc)
TEST = $(patsubst tests/cpp/%_test.cc, tests/cpp/%_test, $(TEST_SRC))
tests/cpp/%_test: tests/cpp/%_test.cc lib/libtvm.a
$(CXX) -std=c++11 $(CFLAGS) -MM -MT tests/cpp/$* $< >tests/cpp/$*.d
$(CXX) -std=c++11 $(CFLAGS) -I$(GTEST_INC) -o $@ $(filter %.cc %.a, $^) \
-L$(GTEST_LIB) $(LDFLAGS) -lgtest
-include tests/cpp/*.d
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment