Commit 622cee7a by tqchen

Add most of IR constructors

parent a41d644a
Subproject commit 2a1001108b9112c4e594c456ffd364b57db10b6b Subproject commit 872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <tvm/node.h>
#include <string> #include <string>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <typeinfo> #include <typeinfo>
#include <type_traits> #include <type_traits>
#include <tvm/node.h>
namespace tvm { namespace tvm {
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#ifndef TVM_EXPR_H_ #ifndef TVM_EXPR_H_
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <type_traits>
#include <ir/Expr.h> #include <ir/Expr.h>
#include <type_traits>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
...@@ -15,5 +15,5 @@ namespace tvm { ...@@ -15,5 +15,5 @@ namespace tvm {
using Halide::Type; using Halide::Type;
using Halide::Expr; using Halide::Expr;
} // namespace std } // namespace tvm
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
...@@ -4,3 +4,5 @@ from __future__ import absolute_import as _abs ...@@ -4,3 +4,5 @@ from __future__ import absolute_import as _abs
from .function import * from .function import *
from ._ctypes._api import register_node from ._ctypes._api import register_node
from . import expr from . import expr
from . import stmt
from . import make
...@@ -162,19 +162,23 @@ def _make_function(handle, name): ...@@ -162,19 +162,23 @@ def _make_function(handle, name):
return func return func
def register_node(type_key): def register_node(type_key=None):
"""register node type """register node type
Parameters Parameters
---------- ----------
type_key : str type_key : str or cls
The type key of the node The type key of the node
""" """
def register(cls): if isinstance(type_key, str):
NODE_TYPE[type_key] = cls def register(cls):
NODE_TYPE[type_key] = cls
return cls
return register
else:
cls = type_key
NODE_TYPE[cls.__name__] = cls
return cls return cls
return register
def _init_function_module(root_namespace): def _init_function_module(root_namespace):
"""List and add all the functions to current module.""" """List and add all the functions to current module."""
...@@ -189,11 +193,21 @@ def _init_function_module(root_namespace): ...@@ -189,11 +193,21 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace] module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace]
module_make = sys.modules["%s.make" % root_namespace]
for name in op_names: for name in op_names:
hdl = FunctionHandle() hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl))) check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
function = _make_function(hdl, name) if name.startswith("_make_"):
if function.__name__.startswith('_'): fname = name[6:]
else:
fname = name
function = _make_function(hdl, fname)
if name.startswith("_make_"):
setattr(module_make, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function) setattr(module_internal, function.__name__, function)
else: else:
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import function as _func from . import function as _func
from . import make as _make
class Expr(NodeBase): class Expr(NodeBase):
def __repr__(self): def __repr__(self):
return _func.format_str(self) return _func.format_str(self)
def __add__(self, other): def __add__(self, other):
return binary_op('+', self, other) return _make.Add(self, other)
def __radd__(self, other): def __radd__(self, other):
return self.__add__(other) return self.__add__(other)
def __sub__(self, other): def __sub__(self, other):
return binary_op('-', self, other) return _make.Sub(self, other)
def __rsub__(self, other): def __rsub__(self, other):
return binary_op('-', other, self) return _make.Sub(other, self)
def __mul__(self, other): def __mul__(self, other):
return binary_op('*', self, other) return _make.Mul(self, other)
def __rmul__(self, other): def __rmul__(self, other):
return binary_op('*', other, self) return _make.Mul(other, self)
def __div__(self, other): def __div__(self, other):
return binary_op('/', self, other) return _make.Div(self, other)
def __rdiv__(self, other): def __rdiv__(self, other):
return binary_op('/', other, self) return _make.Div(other, self)
def __truediv__(self, other): def __truediv__(self, other):
return self.__div__(other) return self.__div__(other)
...@@ -39,15 +40,126 @@ class Expr(NodeBase): ...@@ -39,15 +40,126 @@ class Expr(NodeBase):
def __neg__(self): def __neg__(self):
return self.__mul__(-1) return self.__mul__(-1)
class ConstExpr(Expr):
pass
class BinaryOpExpr(Expr):
pass
class CmpExpr(Expr):
pass
class LogicalExpr(Expr):
pass
@register_node
class FloatImm(ConstExpr):
pass
@register_node
class IntImm(ConstExpr):
pass
@register_node
class UIntImm(ConstExpr):
pass
@register_node
class StringImm(ConstExpr):
pass
@register_node
class Cast(Expr):
pass
@register_node
class Variable(Expr):
pass
@register_node
class Add(BinaryOpExpr):
pass
@register_node
class Sub(BinaryOpExpr):
pass
@register_node
class Mul(BinaryOpExpr):
pass
@register_node
class Div(BinaryOpExpr):
pass
@register_node
class Mod(BinaryOpExpr):
pass
@register_node
class Min(BinaryOpExpr):
pass
@register_node
class Max(BinaryOpExpr):
pass
@register_node
class EQ(CmpExpr):
pass
@register_node
class NE(CmpExpr):
pass
@register_node
class LT(CmpExpr):
pass
@register_node
class LE(CmpExpr):
pass
@register_node
class GT(CmpExpr):
pass
@register_node
class GE(CmpExpr):
pass
@register_node
class And(LogicalExpr):
pass
@register_node
class Or(LogicalExpr):
pass
@register_node
class Not(LogicalExpr):
pass
@register_node
class Select(Expr):
pass
@register_node
class Load(Expr):
pass
@register_node
class Ramp(Expr):
pass
@register_node("IntImm") @register_node
class IntImm(Expr): class Broadcast(Expr):
pass pass
@register_node("UIntImm") @register_node
class UIntImm(Expr): class Call(Expr):
pass pass
@register_node("FloatImm") @register_node
class FloatImm(Expr): class Let(Expr):
pass pass
...@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs ...@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral from numbers import Number as _Number, Integral as _Integral
from ._ctypes._api import _init_function_module from ._ctypes._api import _init_function_module
from .import _function_internal from .import _function_internal
from .import make as _make
int32 = "int32" int32 = "int32"
float32 = "float32" float32 = "float32"
......
"""namespace of IR node builder make function"""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Stmt(NodeBase):
def __repr__(self):
return _func.format_str(self)
@register_node
class LetStmt(Stmt):
pass
@register_node
class AssertStmt(Stmt):
pass
@register_node
class ProducerConsumer(Stmt):
pass
@register_node
class For(Stmt):
pass
@register_node
class Store(Stmt):
pass
@register_node
class Provide(Stmt):
pass
@register_node
class Allocate(Stmt):
pass
@register_node
class Free(Stmt):
pass
@register_node
class Realize(Stmt):
pass
@register_node
class Block(Stmt):
pass
@register_node
class IfThenElse(Stmt):
pass
@register_node
class Evaluate(Stmt):
pass
...@@ -12,19 +12,20 @@ ...@@ -12,19 +12,20 @@
#include <tvm/c_api.h> #include <tvm/c_api.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <exception>
#include "./c_api_registry.h" #include "./c_api_registry.h"
/*! \brief macro to guard beginning and end section of all functions */ /*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try { #define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN(); /*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */ and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*! /*!
* \brief every function starts with API_BEGIN(); * \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR * and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens. * The finally clause contains procedure to cleanup states when an error happens.
*/ */
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg); void TVMAPISetLastError(const char* msg);
...@@ -33,7 +34,7 @@ void TVMAPISetLastError(const char* msg); ...@@ -33,7 +34,7 @@ void TVMAPISetLastError(const char* msg);
* \param e the exception * \param e the exception
* \return the return value of API after exception is handled * \return the return value of API after exception is handled
*/ */
inline int TVMAPIHandleException(const dmlc::Error &e) { inline int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(e.what()); TVMAPISetLastError(e.what());
return -1; return -1;
} }
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0)); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1)); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
Expr a = args.at(0), b = args.at(1); \
match_types(a, b); \
*ret = Node::make(a, b); \
}) \
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add);
REGISTER_MAKE_BINARY_OP(Sub);
REGISTER_MAKE_BINARY_OP(Mul);
REGISTER_MAKE_BINARY_OP(Div);
REGISTER_MAKE_BINARY_OP(Mod);
REGISTER_MAKE_BINARY_OP(Min);
REGISTER_MAKE_BINARY_OP(Max);
REGISTER_MAKE_BINARY_OP(EQ);
REGISTER_MAKE_BINARY_OP(NE);
REGISTER_MAKE_BINARY_OP(LT);
REGISTER_MAKE_BINARY_OP(LE);
REGISTER_MAKE_BINARY_OP(GT);
REGISTER_MAKE_BINARY_OP(GE);
REGISTER_MAKE_BINARY_OP(And);
REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
// TODO(tqchen) For;
REGISTER_MAKE3(Store);
// TODO(tqchen) Provide;
// TODO(tqchen) Allocate;
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace tvm
...@@ -24,7 +24,7 @@ inline std::string Type2String(const Type& t) { ...@@ -24,7 +24,7 @@ inline std::string Type2String(const Type& t) {
inline Type String2Type(std::string s) { inline Type String2Type(std::string s) {
std::istringstream is(s); std::istringstream is(s);
halide_type_code_t code; halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") { if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3); code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") { } else if (s.substr(0, 4) == "uint") {
...@@ -36,7 +36,7 @@ inline Type String2Type(std::string s) { ...@@ -36,7 +36,7 @@ inline Type String2Type(std::string s) {
} else { } else {
LOG(FATAL) << "unknown type " << s; LOG(FATAL) << "unknown type " << s;
} }
int bits, lanes = 1; int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s; LOG(FATAL) << "unknown type " << s;
} }
...@@ -109,12 +109,20 @@ struct APIVariantValue { ...@@ -109,12 +109,20 @@ struct APIVariantValue {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_id, kLong);
return v_union.v_long; return v_union.v_long;
} }
inline operator uint64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
}
inline operator int() const { inline operator int() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_id, kLong);
CHECK_LE(v_union.v_long, CHECK_LE(v_union.v_long,
std::numeric_limits<int>::max()); std::numeric_limits<int>::max());
return v_union.v_long; return v_union.v_long;
} }
inline operator bool() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long != 0;
}
inline operator std::string() const { inline operator std::string() const {
CHECK_EQ(type_id, kStr); CHECK_EQ(type_id, kStr);
return str; return str;
......
...@@ -2,8 +2,24 @@ import tvm ...@@ -2,8 +2,24 @@ import tvm
def test_const(): def test_const():
x = tvm.const(1) x = tvm.const(1)
assert x.type == 'int32' assert x.dtype == 'int32'
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.expr.IntImm)
def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
print tvm.format_str(z)
def test_ir():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
print tvm.format_str(stmt)
if __name__ == "__main__": if __name__ == "__main__":
test_const() test_const()
test_make()
test_ir()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment