Commit 622cee7a by tqchen

Add most of IR constructors

parent a41d644a
Subproject commit 2a1001108b9112c4e594c456ffd364b57db10b6b
Subproject commit 872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
......@@ -8,12 +8,12 @@
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node.h>
#include <string>
#include <memory>
#include <functional>
#include <typeinfo>
#include <type_traits>
#include <tvm/node.h>
namespace tvm {
......
......@@ -6,8 +6,8 @@
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <type_traits>
#include <ir/Expr.h>
#include <type_traits>
#include "./base.h"
namespace tvm {
......@@ -15,5 +15,5 @@ namespace tvm {
using Halide::Type;
using Halide::Expr;
} // namespace std
} // namespace tvm
#endif // TVM_EXPR_H_
......@@ -4,3 +4,5 @@ from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import expr
from . import stmt
from . import make
......@@ -162,19 +162,23 @@ def _make_function(handle, name):
return func
def register_node(type_key):
def register_node(type_key=None):
"""register node type
Parameters
----------
type_key : str
type_key : str or cls
The type key of the node
"""
if isinstance(type_key, str):
def register(cls):
NODE_TYPE[type_key] = cls
return cls
return register
else:
cls = type_key
NODE_TYPE[cls.__name__] = cls
return cls
def _init_function_module(root_namespace):
"""List and add all the functions to current module."""
......@@ -189,11 +193,21 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
module_make = sys.modules["%s.make" % root_namespace]
for name in op_names:
hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
function = _make_function(hdl, name)
if function.__name__.startswith('_'):
if name.startswith("_make_"):
fname = name[6:]
else:
fname = name
function = _make_function(hdl, fname)
if name.startswith("_make_"):
setattr(module_make, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Expr(NodeBase):
def __repr__(self):
return _func.format_str(self)
def __add__(self, other):
return binary_op('+', self, other)
return _make.Add(self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return binary_op('-', self, other)
return _make.Sub(self, other)
def __rsub__(self, other):
return binary_op('-', other, self)
return _make.Sub(other, self)
def __mul__(self, other):
return binary_op('*', self, other)
return _make.Mul(self, other)
def __rmul__(self, other):
return binary_op('*', other, self)
return _make.Mul(other, self)
def __div__(self, other):
return binary_op('/', self, other)
return _make.Div(self, other)
def __rdiv__(self, other):
return binary_op('/', other, self)
return _make.Div(other, self)
def __truediv__(self, other):
return self.__div__(other)
......@@ -39,15 +40,126 @@ class Expr(NodeBase):
def __neg__(self):
return self.__mul__(-1)
class ConstExpr(Expr):
pass
class BinaryOpExpr(Expr):
pass
class CmpExpr(Expr):
pass
class LogicalExpr(Expr):
pass
@register_node
class FloatImm(ConstExpr):
pass
@register_node
class IntImm(ConstExpr):
pass
@register_node
class UIntImm(ConstExpr):
pass
@register_node
class StringImm(ConstExpr):
pass
@register_node
class Cast(Expr):
pass
@register_node
class Variable(Expr):
pass
@register_node
class Add(BinaryOpExpr):
pass
@register_node
class Sub(BinaryOpExpr):
pass
@register_node
class Mul(BinaryOpExpr):
pass
@register_node
class Div(BinaryOpExpr):
pass
@register_node
class Mod(BinaryOpExpr):
pass
@register_node
class Min(BinaryOpExpr):
pass
@register_node
class Max(BinaryOpExpr):
pass
@register_node
class EQ(CmpExpr):
pass
@register_node
class NE(CmpExpr):
pass
@register_node
class LT(CmpExpr):
pass
@register_node
class LE(CmpExpr):
pass
@register_node
class GT(CmpExpr):
pass
@register_node
class GE(CmpExpr):
pass
@register_node
class And(LogicalExpr):
pass
@register_node
class Or(LogicalExpr):
pass
@register_node
class Not(LogicalExpr):
pass
@register_node
class Select(Expr):
pass
@register_node
class Load(Expr):
pass
@register_node
class Ramp(Expr):
pass
@register_node("IntImm")
class IntImm(Expr):
@register_node
class Broadcast(Expr):
pass
@register_node("UIntImm")
class UIntImm(Expr):
@register_node
class Call(Expr):
pass
@register_node("FloatImm")
class FloatImm(Expr):
@register_node
class Let(Expr):
pass
......@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from ._ctypes._api import _init_function_module
from .import _function_internal
from .import make as _make
int32 = "int32"
float32 = "float32"
......
"""namespace of IR node builder make function"""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Stmt(NodeBase):
def __repr__(self):
return _func.format_str(self)
@register_node
class LetStmt(Stmt):
pass
@register_node
class AssertStmt(Stmt):
pass
@register_node
class ProducerConsumer(Stmt):
pass
@register_node
class For(Stmt):
pass
@register_node
class Store(Stmt):
pass
@register_node
class Provide(Stmt):
pass
@register_node
class Allocate(Stmt):
pass
@register_node
class Free(Stmt):
pass
@register_node
class Realize(Stmt):
pass
@register_node
class Block(Stmt):
pass
@register_node
class IfThenElse(Stmt):
pass
@register_node
class Evaluate(Stmt):
pass
......@@ -12,19 +12,20 @@
#include <tvm/c_api.h>
#include <vector>
#include <string>
#include <exception>
#include "./c_api_registry.h"
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
......@@ -33,7 +34,7 @@ void TVMAPISetLastError(const char* msg);
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int TVMAPIHandleException(const dmlc::Error &e) {
inline int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(e.what());
return -1;
}
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0)); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1)); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
Expr a = args.at(0), b = args.at(1); \
match_types(a, b); \
*ret = Node::make(a, b); \
}) \
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add);
REGISTER_MAKE_BINARY_OP(Sub);
REGISTER_MAKE_BINARY_OP(Mul);
REGISTER_MAKE_BINARY_OP(Div);
REGISTER_MAKE_BINARY_OP(Mod);
REGISTER_MAKE_BINARY_OP(Min);
REGISTER_MAKE_BINARY_OP(Max);
REGISTER_MAKE_BINARY_OP(EQ);
REGISTER_MAKE_BINARY_OP(NE);
REGISTER_MAKE_BINARY_OP(LT);
REGISTER_MAKE_BINARY_OP(LE);
REGISTER_MAKE_BINARY_OP(GT);
REGISTER_MAKE_BINARY_OP(GE);
REGISTER_MAKE_BINARY_OP(And);
REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
// TODO(tqchen) For;
REGISTER_MAKE3(Store);
// TODO(tqchen) Provide;
// TODO(tqchen) Allocate;
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace tvm
......@@ -24,7 +24,7 @@ inline std::string Type2String(const Type& t) {
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code;
halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
......@@ -36,7 +36,7 @@ inline Type String2Type(std::string s) {
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits, lanes = 1;
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
......@@ -109,12 +109,20 @@ struct APIVariantValue {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
}
inline operator uint64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
}
inline operator int() const {
CHECK_EQ(type_id, kLong);
CHECK_LE(v_union.v_long,
std::numeric_limits<int>::max());
return v_union.v_long;
}
inline operator bool() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long != 0;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr);
return str;
......
......@@ -2,8 +2,24 @@ import tvm
def test_const():
x = tvm.const(1)
assert x.type == 'int32'
assert x.dtype == 'int32'
assert isinstance(x, tvm.expr.IntImm)
def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
print tvm.format_str(z)
def test_ir():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
print tvm.format_str(stmt)
if __name__ == "__main__":
test_const()
test_make()
test_ir()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment