Commit a41d644a by tqchen

promote C API lib to root, pass basic_test

parent 34f2adb9
Subproject commit 79a09d0fd60ae7fb6917a647832664212f7cc844
Subproject commit 2a1001108b9112c4e594c456ffd364b57db10b6b
"""Init proptype of the TVM"""
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import Range, RDom, infer_range
from .split import Split
from .buffer import Scope, Buffer
from .schedule import Schedule
from .function import *
from ._ctypes._api import register_node
from . import expr
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import var_name as _name
def enum(*sequential, **named):
enums = dict(zip(sequential, range(len(sequential))), **named)
return type('Enum', (), enums)
"""Scope defines the scope of a buffer
Types
-----
Thread : thread private buffer (registers)
Shared : shared buffer within a thread block (shared memory)
Global : buffer in the global GPU RAM
"""
Scope = enum('Thread', 'Shared', 'Global')
class Buffer(object):
def __init__(self, scope, name=None):
self.scope = scope
buf_name = 'Buffer_'
if name: buf_name += name
self.name = _name.NameManager.current.get(buf_name)
self.shape = []
self.offset_index = []
def reshape(self, domain):
for r in domain:
self.shape.append(r.extent)
self.offset_index.append(r.begin)
def __call__(self, *global_index):
if len(global_index) != len(self.shape):
raise ValueError("Need to provide %d index in buffer slice" % len(self.shape))
stride = [1]
for i in reversed(range(1, len(self.shape))):
stride.insert(0, self.shape[i] * stride[0])
local_index = []
for i in range(0, len(global_index)):
local_index.append(global_index[i] - self.offset_index[i])
index = local_index[0] * stride[0]
for i in range(1, len(local_index)):
index = index + local_index[i] * stride[i]
index = _expr_util.simplify(index)
return _expr.TensorRefExpr(self, [index])
class BufferManager(object):
def __init__(self):
self._buffer_map = {}
self._old_manager = None
def get(self, tensor):
if tensor in self._buffer_map:
return self._buffer_map[tensor]
return None
def bind(self, tensor, buf):
self._buffer_map[tensor] = buf
def __enter__(self):
self._old_manager = BufferManager.current
BufferManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
BufferManager.current = self._old_manager
# initialize the default buffer manager
BufferManager.current = BufferManager()
from __future__ import absolute_import as _abs
from . import buffer as _buffer
from . import expr as _expr
from . import expr_util as _expr_util
def gen_code(expr):
"""change expression to string.
Parameters
----------
expr : Expr
Input expression
Returns
-------
s : str
The string representation of expr
"""
def make_str(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.format_str(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.format_str(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _expr.TensorRefExpr):
buf = _buffer.BufferManager.current.get(e.tensor)
if buf:
return _expr_util.format_str(buf(*e.indices))
return _expr_util.format_str(e.tensor(*e.indices, flatten=True))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_stmt_str(result_children[0])
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return _expr_util.transform(expr, make_str)
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import expr
from . import domain
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node("RangeNode")
class Range(NodeBase):
pass
@register_node("ArrayNode")
class Array(NodeBase):
def __getitem__(self, i):
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from .function import binary_op
class Expr(NodeBase):
def __add__(self, other):
return binary_op('+', self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return binary_op('-', self, other)
def __rsub__(self, other):
return binary_op('-', other, self)
def __mul__(self, other):
return binary_op('*', self, other)
def __rmul__(self, other):
return binary_op('*', other, self)
def __div__(self, other):
return binary_op('/', self, other)
def __rdiv__(self, other):
return binary_op('/', other, self)
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __neg__(self):
return self.__mul__(-1)
@register_node("VarNode")
class Var(Expr):
pass
@register_node("IntNode")
class IntExpr(Expr):
pass
@register_node("FloatNode")
class FloatExpr(Expr):
pass
@register_node("UnaryOpNode")
class UnaryOpExpr(Expr):
pass
@register_node("BinaryOpNode")
class BinaryOpExpr(Expr):
pass
@register_node("ReduceNode")
class ReduceExpr(Expr):
pass
@register_node("TensorReadNode")
class TensorReadExpr(Expr):
pass
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from ._ctypes._api import _init_function_module
from .import _function_internal
int32 = 1
float32 = 2
def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return constant(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
return _function_internal._Array(*value)
else:
return value
def max(lhs, rhs):
"""Max of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
def min(lhs, rhs):
"""Min of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
_init_function_module("tvm.cpp")
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import op as _op
class Range(object):
"""Represent a range in one dimension.
"""
def __init__(self, begin, end=None):
if end is None:
end = begin
begin = _expr.const(0)
begin = _expr_util.simplify(_expr._symbol(begin))
end = _expr_util.simplify(_expr._symbol(end))
self.begin = begin
self.end = end
self.extent = _expr_util.simplify(end - begin)
def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1
def __str__(self):
return "(%s, %s)" % (
_expr_util.format_str(self.begin),
_expr_util.format_str(self.end))
def __repr__(self):
return self.__str__()
class RangeInferError(ValueError):
pass
class RDom(object):
"""Reduction Domain."""
def __init__(self, domain):
if isinstance(domain, Range):
domain = [domain]
self.index = []
self.domain = domain
for i in range(len(domain)):
self.index.append(_expr.Var("rd_index_%d_" % i))
"""Use list of ranges as domain"""
Domain = list
def _combine_range_binary_op(op, lhs, rhs):
if op == _op.add:
return Range(lhs.begin + rhs.begin, lhs.end + rhs.end - 1)
elif op == _op.sub:
return Range(lhs.begin - rhs.end + 1, lhs.end - rhs.begin)
elif op == _op.mul:
v = None
if lhs.is_value():
v = lhs.begin.value
e = rhs
elif rhs.is_value():
v = rhs.begin.value
e = lhs
if v == -1:
return Range(-e.end, -e.begin)
raise InferRangeError("donot know how to infer range for %s" % type(op))
def infer_range(e, range_dict, allow_unbind_var=True):
"""Infer the range of result e given range of variables.
Parameters
----------
expr : Expr
Input expression
range_dict : dict of Var->Range
The variables to be replaced.
allow_unbind_var: bool
Whether allow unbinded variables
"""
def combine_range(e, result_children):
if isinstance(e, _expr.ConstExpr):
return Range(e, e + 1)
elif isinstance(e, _expr.BinaryOpExpr):
return _combine_range_binary_op(e.op, result_children[0], result_children[1])
elif isinstance(e, _expr.Var):
if e in range_dict:
return range_dict[e]
else:
if allow_unbind_var:
return Range(e, e + 1)
else:
raise ValueError("Cannot find var %s in range_dict" % e.name)
else:
raise InferRangeError("cannot infer range for %s" % _expr_util.format_str(e))
return _expr_util.transform(e, combine_range)
def union_range(lhs, rhs):
if lhs is None:
return rhs
if rhs is None:
return lhs
begin = _op.min(lhs.begin, rhs.begin)
end = _op.max(rhs.end, lhs.end)
return Range(begin, end)
"""Base class of symbolic expression"""
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from . import var_name as _name
from ._ctypes._api import NodeBase, register_node
from . import function as _func
__addop__ = None
__subop__ = None
__mulop__ = None
__divop__ = None
class Expr(object):
"""Base class of expression.
Expression object should be in general immutable.
"""
def children(self):
"""get children of this expression.
Returns
-------
children : generator of children
"""
return ()
class Expr(NodeBase):
def __repr__(self):
return _func.format_str(self)
def __add__(self, other):
return BinaryOpExpr(__addop__, self, other)
return binary_op('+', self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return BinaryOpExpr(__subop__, self, other)
return binary_op('-', self, other)
def __rsub__(self, other):
return BinaryOpExpr(__subop__, other, self)
return binary_op('-', other, self)
def __mul__(self, other):
return BinaryOpExpr(__mulop__, self, other)
return binary_op('*', self, other)
def __rmul__(self, other):
return BinaryOpExpr(__mulop__, other, self)
return binary_op('*', other, self)
def __div__(self, other):
return BinaryOpExpr(__divop__, self, other)
return binary_op('/', self, other)
def __rdiv__(self, other):
return BinaryOpExpr(__divop__, other, self)
return binary_op('/', other, self)
def __truediv__(self, other):
return self.__div__(other)
......@@ -57,80 +40,14 @@ class Expr(object):
return self.__mul__(-1)
def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, Expr):
return value
elif isinstance(value, _Number):
return ConstExpr(value)
else:
raise TypeError("type %s not supported" % str(type(other)))
class Var(Expr):
"""Variable, is a symbolic placeholder.
Each variable is uniquely identified by its address
Note that name alone is not able to uniquely identify the var.
Parameters
----------
name : str
optional name to the var.
"""
def __init__(self, name=None):
if name is None: name = 'index'
self.name = _name.NameManager.current.get(name)
class ConstExpr(Expr):
"""Constant expression."""
def __init__(self, value):
assert isinstance(value, _Number)
self.value = value
class BinaryOpExpr(Expr):
"""Binary operator expression."""
def __init__(self, op, lhs, rhs):
self.op = op
self.lhs = _symbol(lhs)
self.rhs = _symbol(rhs)
def children(self):
return (self.lhs, self.rhs)
class UnaryOpExpr(Expr):
"""Unary operator expression."""
def __init__(self, op, src):
self.op = op
self.src = _symbol(src)
def children(self):
return (self.src,)
class ReduceExpr(Expr):
def __init__(self, op, src, rdom):
self.op = op
self.src = src
self.rdom = rdom
def children(self):
return (self.src,)
class TensorRefExpr(Expr):
"""Tensor reference expression, tensor[indices]"""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
@register_node("IntImm")
class IntImm(Expr):
pass
@register_node("UIntImm")
class UIntImm(Expr):
pass
def const(value):
"""Return a constant value"""
return ConstExpr(value)
@register_node("FloatImm")
class FloatImm(Expr):
pass
"""Utilities to manipulate expression"""
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import op as _op
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
A shallow copy of e will happen if children differs from current children
Parameters
----------
e : Expr
The input expression
children : list of Expr
The new children
Returns
-------
new_e : Expr
Expression with the new children
"""
if children:
if isinstance(e, _expr.BinaryOpExpr):
return (e if children[0] == e.lhs and children[1] == e.rhs
else _expr.BinaryOpExpr(e.op, children[0], children[1]))
elif isinstance(e, _expr.UnaryOpExpr):
return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0])
elif isinstance(e, _expr.TensorRefExpr):
return e if children == e.indices else _expr.TensorRefExpr(e.tensor, children)
elif isinstance(e, _expr.ReduceExpr):
return e if children[0] == e.src else _expr.ReduceExpr(e.op, children[0], e.rdom)
else:
raise TypeError("do not know how to handle Expr %s" % type(e))
else:
return e
def transform(e, f):
"""Apply f recursively to e and collect the resulr
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e, ret_children)
ret_children is the result of transform from children
Returns
-------
result : return value of f
The final result of transformation.
"""
if not isinstance(e, _expr.Expr):
raise TypeError("Cannot handle type %s" % type(e))
return f(e , [transform(c, f) for c in e.children()])
def visit(e, f):
"""Apply f to each element of e
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e)
"""
assert isinstance(e, _expr.Expr)
for c in e.children():
visit(c, f)
f(e)
def format_str(expr):
"""change expression to string.
Parameters
----------
expr : Expr
Input expression
Returns
-------
s : str
The string representation of expr
"""
def make_str(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.format_str(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.format_str(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _expr.TensorRefExpr):
return "%s[%s]" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
def simplify(expr):
"""simplify expression
Parameters
----------
expr : Expr
Input expression
Returns
-------
e : Expr
Simplified expression
"""
def canonical(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.canonical(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.canonical(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return {_op.const_canonical_key: e.value}
elif isinstance(e, _expr.Var):
return {e: 1}
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return _op.canonical_to_expr(transform(expr, canonical))
def bind(expr, update_dict):
"""Replace the variable in e by specification from kwarg
Parameters
----------
expr : Expr
Input expression
update_dict : dict of Var->Expr
The variables to be replaced.
Examples
--------
eout = bind(e, update_dict={v1: (x+1)} )
"""
def replace(e, result_children):
if isinstance(e, _expr.Var) and e in update_dict:
return update_dict[e]
else:
return expr_with_new_children(e, result_children)
return transform(expr, replace)
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from ._ctypes._api import _init_function_module
from .import _function_internal
int32 = "int32"
float32 = "float32"
def const(value, dtype=None):
if dtype is None:
if isinstance(value, _Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return const(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
return _function_internal._Array(*value)
else:
return value
_init_function_module("tvm")
......@@ -13,7 +13,7 @@ def find_lib_path():
List of all found path to the libraries
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../../lib/')
api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if os.name == 'nt':
......
from __future__ import absolute_import as _abs
from . import expr as _expr
const_canonical_key = '__constant__'
def canonical_to_expr(c):
elements = []
for k, v in sorted(c.items()):
if k == const_canonical_key and v != 0:
elements.append(_expr.const(v))
elif v == 0:
continue
elif v == 1:
elements.append(k)
else:
elements.append(k * v)
if elements:
expr = elements[0]
for i in range(1, len(elements)):
expr = expr + elements[i]
return expr
else:
return _expr.const(0)
class BinaryOp(object):
"""Base class of binary operator"""
def __call__(self, lhs, rhs):
return _expr.BinaryOpExpr(self, lhs, rhs)
class AddOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs)
def format_reduce_str(self, src, rd):
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
def format_reduce_stmt_str(self, src):
# a temporary hack for now
return "+ %s" % (src)
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
if k in lhs:
lhs[k] += v
else:
lhs[k] = v
return lhs
class SubOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
if k in lhs:
lhs[k] -= v
else:
lhs[k] = -v
return lhs
class MulOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
elhs = canonical_to_expr(lhs)
erhs = canonical_to_expr(rhs)
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
for k, v in lhs.items():
lhs[k] *= erhs.value
return lhs
if isinstance(elhs, _expr.ConstExpr):
rhs = rhs.copy()
for k, v in rhs.items():
rhs[k] *= elhs.value
return rhs
return {elhs * erhs: 1}
class DivOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
erhs = canonical_to_expr(rhs)
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
remove = []
for k, v in lhs.items():
if k == const_canonical_key:
lhs[k] = v / erhs.value
else:
lhs[k / erhs] = 1
remove.append(k)
for k in remove:
del lhs[k]
return lhs
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
class MaxOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'max(%s, %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
diff = SubOp().canonical(lhs, rhs)
ediff = canonical_to_expr(diff)
if isinstance(ediff, _expr.ConstExpr):
return lhs if ediff.value >= 0 else rhs
return {MaxOp()(lhs, rhs): 1}
class MinOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'min(%s, %s)' % (lhs, rhs)
def canonical(self, lhs, rhs):
diff = SubOp().canonical(lhs, rhs)
ediff = canonical_to_expr(diff)
if isinstance(ediff, _expr.ConstExpr):
return rhs if ediff.value >= 0 else lhs
return {MinOp()(lhs, rhs): 1}
add = AddOp()
sub = SubOp()
mul = MulOp()
div = DivOp()
max = MaxOp()
min = MinOp()
_expr.__addop__ = add
_expr.__subop__ = sub
_expr.__mulop__ = mul
_expr.__divop__ = div
def reduce_sum(expr, rdom):
return _expr.ReduceExpr(add, expr, rdom)
def reduce_prod(expr, rdom):
return _expr.ReduceExpr(mul, expr, rdom)
def reduce_min(expr, rdom):
return _expr.ReduceExpr(min, expr, rdom)
def reduce_max(expr, rdom):
return _expr.ReduceExpr(max, expr, rdom)
from __future__ import absolute_import as _abs
from . import domain as _dom
from . import expr as _expr
from . import expr_util as _expr_util
from . import split as _split
from . import buffer as _buffer
from . import codegen as _gen
start_point_key = '__start__'
TAB = ' '
class Schedule(object):
"""SUnit defines the compute schedule of a tensor
Parameters
----------
tensor: tensor
"""
def __init__(self, tensor, buffer=None):
self.tensor = tensor
self.buffer = buffer
self.parent = None
#self.children = []
self.splits = []
self.split_attach = {start_point_key: []}
self.implicit_splits = [_split.Split(i, 1) for i in range(tensor.ndim)]
if isinstance(tensor.expr, _expr.ReduceExpr):
for i in range(len(tensor.expr.rdom.domain)):
self.implicit_splits.append(_split.Split(i, 1, rdom=True))
def add_split(self, split):
self.splits.append(split)
self.split_attach[split] = []
def set_buffer(self, buf):
self.buffer = buf
def attach(self, split, other):
other.parent = self
if split is None:
self.split_attach[start_point_key].append(other)
else:
self.split_attach[split].append(other)
def infer_inner_domain(self, domain):
for split in self.splits:
domain = split.infer_inner_domain(domain)
return domain
def realize(self, domain=None, indent=''):
def realize_attach(lst):
attach_tensors = [sch.tensor for sch in lst]
attach_domains = self.tensor.infer_input_domains(domain, attach_tensors, red_domain=red_domain)
for sch in lst:
body.extend(sch.realize(attach_domains[sch.tensor], indent))
# init domain and red_domain
if domain is None:
domain = self.tensor.domain
red_domain = self.tensor.expr.rdom.domain if isinstance(self.tensor.expr, _expr.ReduceExpr) else None
# init buffer shape
if self.buffer:
if self.buffer.scope == _buffer.Scope.Global:
self.buffer.reshape(self.tensor.domain)
else:
# don't handle shared buffer for now
self.buffer.reshape(domain)
_buffer.BufferManager.current.bind(self.tensor, self.buffer)
body = []
if self.split_attach[start_point_key]:
realize_attach(self.split_attach[start_point_key])
# add loop conditions for splits
for split in self.splits:
if split.rdom:
red_domain = split.generate_loop_condition(red_domain, body, indent)
else:
domain = split.generate_loop_condition(domain, body, indent)
indent += TAB
if self.split_attach[split]:
realize_attach(self.split_attach[split])
# add implicit loop conditions
for split in self.implicit_splits:
if split.rdom:
red_domain = split.generate_loop_condition(red_domain, body, indent)
else:
domain = split.generate_loop_condition(domain, body, indent)
indent += TAB
# add loop body
expr = self.tensor.expr
global_index = [r.begin for r in domain]
global_rdom_index = [r.begin for r in red_domain] if red_domain else []
if expr is None:
if self.buffer:
lhs = self.buffer(*global_index)
rhs = self.tensor(*global_index, flatten=True)
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _expr_util.format_str(rhs)))
else:
if self.buffer:
lhs = self.buffer(*global_index)
else:
lhs = self.tensor(*global_index, flatten=True)
bind_dict = {}
for i in range(self.tensor.ndim):
bind_dict[self.tensor.dim_index[i]] = global_index[i]
if isinstance(expr, _expr.ReduceExpr):
for i in range(len(expr.rdom.domain)):
bind_dict[expr.rdom.index[i]] = global_rdom_index[i]
rhs = _expr_util.bind(expr, bind_dict)
body.append('%s%s = %s;' % (indent, _expr_util.format_str(lhs), _gen.gen_code(rhs)))
# add right brackets
for split in self.implicit_splits:
indent = indent[:-len(TAB)]
body.append('%s}' % indent)
for split in self.splits:
indent = indent[:-len(TAB)]
body.append('%s}' % indent)
return body
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import domain as _dom
from . import tensor as _tensor
class Split(object):
def __init__(self, dim, factor, name=None, rdom=False):
self.dim = dim
self.factor = factor
self.rdom = rdom
if name is None:
name = 'loop_index_%d_' % dim
self.loop_index = _expr.Var(name)
def infer_inner_domain(self, out_domain):
assert self.dim < len(out_domain)
inner_domain = out_domain[:]
dim_out_range = out_domain[self.dim]
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
return inner_domain
def generate_loop_condition(self, out_domain, body, indent):
assert self.dim < len(out_domain)
loop_range = _dom.Range(out_domain[self.dim].extent / self.factor)
stmt = '%sfor (int %s = 0; %s < %s; %s += 1) {' % (
indent,
self.loop_index.name,
self.loop_index.name,
_expr_util.format_str(loop_range.end),
self.loop_index.name)
body.append(stmt)
return self.infer_inner_domain(out_domain)
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
from . import domain as _dom
from . import var_name as _name
class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None, shape=None):
self.ndim = ndim
if fcompute:
arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index)
if shape is None:
raise ValueError("argument shape need to be given for intermediate tensor")
self.shape = shape
else:
self.expr = None
self.dim_index = None
shape_name = '_shape'
if name: shape_name = name + shape_name
self.shape = shape if shape else tuple(
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else _name.NameManager.current.get("TensorObj")
self.inputs = None
def __call__(self, *indices, **option):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
if 'flatten' in option and option['flatten']:
stride = [1]
for i in reversed(range(1, len(indices))):
stride.insert(0, self.shape[i] * stride[0])
index = indices[0] * stride[0]
for i in range(1, len(indices)):
index = index + indices[i] * stride[i]
index = _expr_util.simplify(index)
return _expr.TensorRefExpr(self, [index])
return _expr.TensorRefExpr(self, indices)
@property
def domain(self):
return _dom.Domain([_dom.Range(self.shape[i]) for i in range(self.ndim)])
def input_tensors(self):
"""List of input tensors to this tensor.
Returns
-------
inputs : list of input tensors
"""
if self.inputs is not None:
return self.inputs
inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorRefExpr):
inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
self.inputs = set(inputs)
return self.inputs
def infer_input_domains(self, out_domain, inputs, red_domain=None):
"""Infer the input domains of each domain in given inputs list.
Parameters
----------
out_domain : list of Range
Domain of each dimension.
red_domain : list of Range
Domain of reduction variables, if this tensor
this can only be specified if
self.expr finishes with an ReduceExpr, and we can schedule
over the last reduction that creates this tensor.
Returns
-------
in_domains: dict Tensor->Domain
"""
assert self.expr
assert len(out_domain) == len(self.dim_index)
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
begin_expr = self.expr
if red_domain:
if not isinstance(self.expr, _expr.ReduceExpr):
raise ValueError("red_domain must work with tensor that stores a reduction")
rdom = self.expr.rdom
begin_expr = self.expr.src
assert len(red_domain) == len(rdom.index)
for i in range(len(red_domain)):
index_domains[rdom.index[i]] = red_domain[i]
iset = {}
for t in inputs:
assert t in self.input_tensors()
iset[t] = []
def prepare(e):
if isinstance(e, _expr.ReduceExpr):
rd = e.rdom
for i in range(len(rd.domain)):
index_domains[rd.index[i]] = rd.domain[i]
elif isinstance(e, _expr.TensorRefExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(begin_expr, prepare)
result = {}
for k, v in iset.items():
dm = [None] * len(v[0].indices)
for e in v:
for i, idx in enumerate(e.indices):
dm[i] = _dom.union_range(
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm
return result
@property
def is_rtensor(self):
"""Whether this tensor is a result of reduction.
Returns
-------
is_rtensor : Whether the tensor is RTensor
"""
return self.expr and isinstance(self.expr, _expr.ReduceExpr)
"""Name manager to make sure name is unique."""
from __future__ import absolute_import as _abs
class NameManager(object):
"""NameManager to do automatic naming.
User can also inherit this object to change naming behavior.
"""
current = None
def __init__(self):
self._counter = {}
self._old_manager = None
def get(self, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager
# initialize the default name manager
NameManager.current = NameManager()
......@@ -27,36 +27,6 @@ struct TVMAPIThreadLocalEntry {
inline void SetReturn(ArgVariant* ret_val, int* ret_typeid);
};
namespace tvm {
inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits, lanes = 0;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}
}
using namespace tvm;
/*! \brief Thread local store that can be used to hold return values. */
......@@ -86,7 +56,7 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, Type* value) final {
if (skey == key) *ret = Type2String(value[0]);
if (skey == key) *ret = value[0];
}
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
......
......@@ -4,6 +4,7 @@
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace dmlc {
......@@ -12,7 +13,36 @@ DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg);
namespace tvm {
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
}
*ret = os.str();
})
.add_argument("expr", "Node", "expression to be printed");
} // namespace tvm
......@@ -16,6 +16,33 @@
namespace tvm {
inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}
inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}
/*! \brief Variant container for API calls */
struct APIVariantValue {
/*! \brief the type id */
......@@ -57,6 +84,9 @@ struct APIVariantValue {
this->sptr = ref.node_;
return *this;
}
inline APIVariantValue& operator=(const Type& value) {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
......@@ -89,6 +119,9 @@ struct APIVariantValue {
CHECK_EQ(type_id, kStr);
return str;
}
inline operator Type() const {
return String2Type(operator std::string());
}
};
// common defintiion of API function.
......
import tvm
def test_bind():
x = tvm.Var('x')
y = x + 1
z = tvm.bind(y, {x: tvm.const(10) + 9})
assert tvm.format_str(z) == '((10 + 9) + 1)'
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_simplify():
a = tvm.Var('a')
b = tvm.Var('b')
e1 = a * (2 + 1) + b * 1
e2 = a * (2 + 1) - b * 1
e3 = tvm.max(a * 3.3 + 5, 3 + 3.3 * a)
e4 = a - a
assert tvm.format_str(tvm.simplify(e1)) == '((%s * 3) + %s)' % (a.name, b.name)
assert tvm.format_str(tvm.simplify(e2)) == '((%s * 3) + (%s * -1))' % (a.name, b.name)
assert tvm.format_str(tvm.simplify(e3)) == '((%s * 3.3) + 5)' % (a.name)
assert tvm.format_str(tvm.simplify(e4)) == '0'
def test_const():
x = tvm.const(1)
assert x.type == 'int32'
assert isinstance(x, tvm.expr.IntImm)
if __name__ == "__main__":
test_basic()
test_bind()
test_simplify()
test_const()
import tvm
def test_buffer():
buf = tvm.Buffer(tvm.Scope.Thread)
shape = [32, 16]
domain = [tvm.Range(v) for v in shape]
buf.reshape(domain)
x = tvm.Var('x')
y = tvm.Var('y')
assert tvm.format_str(buf(y, x)) == '%s[(%s + (%s * %s))]' % (buf.name, x.name, y.name, shape[1])
if __name__ == '__main__':
test_buffer()
from tvm import cpp as tvm
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert a == c.lhs
assert c.dtype == tvm.int32
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.Var('a')
x = tvm.function._symbol([1,2,a])
def assert_equal(x, y):
z = tvm.simplify(x - y)
assert isinstance(z, tvm.expr.IntExpr)
assert z.value == 0
def test_simplify():
a = tvm.Var('a')
b = tvm.Var('b')
e1 = a * (2 + 1) + b * 1
e2 = a * (2 + 1) - b * 1
e3 = tvm.max(a * 3 + 5, 3 + 3 * a)
e4 = a - a
assert_equal(e1, a * 3 + b)
assert_equal(e2, a * 3 - b)
assert_equal(e3, a * 3 + 5)
assert_equal(e4, 0)
if __name__ == "__main__":
test_basic()
test_array()
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment