Commit 7a01476a by Jian Weng Committed by Tianqi Chen

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported;…

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported; enable compilation_database in CMakeList.txt (#1757)
parent 79735eb2
...@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt") ...@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt")
# initial variables # initial variables
set(TVM_LINKER_LIBS "") set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "") set(TVM_RUNTIME_LINKER_LIBS "")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Generic compilation options # Generic compilation options
if(MSVC) if(MSVC)
......
...@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun ...@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun
@tvm.hybrid.script @tvm.hybrid.script
def outer_product(a, b, c): def outer_product(a, b, c):
c = output_tensor((100, 99), 'float32')
for i in range(a.shape[0]): for i in range(a.shape[0]):
for j in range(b.shape[0]): for j in range(b.shape[0]):
c[i, j] = a[i] * b[j] c[i, j] = a[i] * b[j]
a = numpy.random.rand(100) return c
b = numpy.random.rand(99) a = numpy.random.randn(100)
c = numpy.zeros((100, 99)) b = numpy.random.randn(99)
outer_product(a, b, c) c = outer_product(a, b)
This decorator will import `Keywords`_ required spontaneously when software emulation. This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need After software emulation is done, the imported keywords will be cleaned up. Users do not need
...@@ -40,25 +42,25 @@ or ``numpy`` numeric type. ...@@ -40,25 +42,25 @@ or ``numpy`` numeric type.
Backend Compilation Backend Compilation
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
This function is not encouraged to use, users are encouraged to use the second interface.
The current parse interface looks like: The current parse interface looks like:
.. code-block:: python .. code-block:: python
a = tvm.placeholder((100, ), name='a') a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b') b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c') parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.** If we pass these tvm tensors to this function, it returns a op node:
.. code-block:: python .. code-block:: python
a = tvm.placeholder((100, ), name='a') a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b') b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c') c = outer_product(a, b, c) # return the output tensor(s) of the operator
op = outer_product(a, b, c) # return the corresponding op node
**Under construction, we are still deciding what kind of node should be returned.**
Tuning Tuning
~~~~~~ ~~~~~~
......
...@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode { ...@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
}; };
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
......
...@@ -340,11 +340,6 @@ def lower(sch, ...@@ -340,11 +340,6 @@ def lower(sch,
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt) stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -7,4 +7,5 @@ python semantic emulation. ...@@ -7,4 +7,5 @@ python semantic emulation.
2. Developers can build HalideIR by writing Python code. 2. Developers can build HalideIR by writing Python code.
""" """
from .api import script, parse from .api import script
from .parser import parse_python
"""APIs of lowering the Python subset to HalideIR""" """APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import types
from .._ffi.base import decorate from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python from .parser import parse_python
from .util import _pruned_source
def script(pyfunc): def script(pyfunc):
...@@ -17,40 +20,26 @@ def script(pyfunc): ...@@ -17,40 +20,26 @@ def script(pyfunc):
hybrid_func : function hybrid_func : function
A decorated hybrid script function. A decorated hybrid script function.
""" """
def wrapped_func(func, *args, **kwargs): def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
return parse(func, args) src = _pruned_source(func)
parser = parse_python(src, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func) intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs) value = func(*args, **kwargs)
_restore_runtime(func, intersect) _restore_runtime(func, intersect)
return value return value
return decorate(pyfunc, wrapped_func)
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters return decorate(pyfunc, wrapped_func)
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
...@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar ...@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
""" """
return numpy.zeros(shape).astype(dtype) return numpy.zeros(shape).astype(dtype)
output_tensor = allocate #pylint: disable=invalid-name
def popcount(x): def popcount(x):
""" """
...@@ -87,18 +88,19 @@ def sigmoid(x): ...@@ -87,18 +88,19 @@ def sigmoid(x):
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
'unroll' : unroll, 'unroll' : unroll,
'vectorize' : vectorize, 'vectorize' : vectorize,
'parallel' : parallel, 'parallel' : parallel,
'allocate' : allocate, 'allocate' : allocate,
'bind' : bind, 'output_tensor': output_tensor,
'sqrt' : numpy.sqrt, 'bind' : bind,
'log' : numpy.log, 'sqrt' : numpy.sqrt,
'tanh' : numpy.tanh, 'log' : numpy.log,
'power' : numpy.power, 'tanh' : numpy.tanh,
'exp' : numpy.exp, 'power' : numpy.power,
'sigmoid' : sigmoid, 'exp' : numpy.exp,
'popcount' : popcount 'sigmoid' : sigmoid,
'popcount' : popcount
} }
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
import ast import ast
import operator import operator
import logging
import sys import sys
from .util import make_nop, halide_imm_types, is_docstring from .util import make_nop, halide_imm_types, is_docstring, _internal_assert
from .intrin import LOOP_INTRIN, MATH_INTRIN from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage from .var_decl import determine_variable_usage
from ..api import thread_axis from ..api import thread_axis
...@@ -72,15 +73,17 @@ class HybridParser(ast.NodeVisitor): ...@@ -72,15 +73,17 @@ class HybridParser(ast.NodeVisitor):
The name of the function to be lowered; if not provided, The name of the function to be lowered; if not provided,
the compiler will use the name in the AST the compiler will use the name in the AST
""" """
self.args = args[:] self.args = list(args)
self.usage = usage.copy() self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.var_buffers = {} # Buffers formed by mutatble variables
self.alloc_buffers = {} # Buffers formed by allocate instructions self.alloc_buffers = {} # Buffers formed by allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered self.func_name = func_name # The name of the function to be lowered
self.iter_axis = [] self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body
self.returned = False
def wrap_up_realize(self, node, body): def wrap_up_realize(self, node, body):
...@@ -90,9 +93,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -90,9 +93,8 @@ class HybridParser(ast.NodeVisitor):
continue continue
_, level, _ = val _, level, _ = val
if level == node: if level == node:
if key in self.var_buffers.keys(): if key in self._args.keys():
_buf = self.var_buffers[key] continue
_scope = 'global'
else: else:
_buf, _scope = self.alloc_buffers[key] _buf, _scope = self.alloc_buffers[key]
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
...@@ -103,12 +105,13 @@ class HybridParser(ast.NodeVisitor): ...@@ -103,12 +105,13 @@ class HybridParser(ast.NodeVisitor):
return body return body
def _get_buffer_from_id(self, s): def _get_buffer_from_id(self, s, for_provide=False):
if s not in self._args.keys() and s not in self.alloc_buffers.keys(): _internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1,
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) "This %s is expected to be in either \
if s in self._args.keys() and s in self.alloc_buffers.keys(): argument list or allocated buffer!" % s)
raise ValueError("%s, a buffer cannot be both argument and allocated!" % s)
if s in self._args.keys(): if s in self._args.keys():
if for_provide:
self.side_effect.add(self._args[s])
return self._args[s] return self._args[s]
return self.alloc_buffers[s][0] return self.alloc_buffers[s][0]
...@@ -116,15 +119,15 @@ class HybridParser(ast.NodeVisitor): ...@@ -116,15 +119,15 @@ class HybridParser(ast.NodeVisitor):
#pylint: disable=invalid-name, missing-docstring #pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node): def visit_Module(self, node):
if len(node.body) != 1: _internal_assert(len(node.body) == 1, \
raise ValueError("Only one-function source code can be fed to this parser!") "Only one-function source code can be fed to this parser!")
return self.visit(node.body[0]) return self.visit(node.body[0])
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if len(node.args.args) != len(self.args): _internal_assert(len(node.args.args) == len(self.args), \
raise ValueError("The number of arguments passed to the function\ "The number of arguments passed to the \
should be the same as it is defined!") function should be the same as it is defined!")
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx] self._args[getattr(arg, _attr)] = self.args[idx]
...@@ -145,17 +148,17 @@ class HybridParser(ast.NodeVisitor): ...@@ -145,17 +148,17 @@ class HybridParser(ast.NodeVisitor):
return self._args[_id] return self._args[_id]
elif _id in self.loops_above.keys(): elif _id in self.loops_above.keys():
return self.loops_above[_id] return self.loops_above[_id]
if _id in self._args.keys(): _internal_assert(_id not in self._args.keys(), \
raise ValueError("This id %s should be handled in visit_Subscript!" % _id) "This id %s should be handled in visit_Subscript!" % _id)
if _id not in self.usage.keys(): _internal_assert(_id in self.usage.keys(), \
raise ValueError("This id %s is expected to be a defined variable!" % _id) "This id %s is expected to be a defined variable!" % _id)
# Buffer # Buffer
if _id in self.var_buffers.keys(): if _id in self.alloc_buffers.keys():
_buf = self.var_buffers[_id] _buf, _ = self.alloc_buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant # Compilation time constant
if _id not in self.var_consts.keys(): _internal_assert(_id in self.var_consts.keys(),
raise ValueError("This id %s is expected to a compilation time constant!" % _id) "This id %s is expected to a compilation time constant!" % _id)
return self.var_consts[_id] return self.var_consts[_id]
...@@ -164,8 +167,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -164,8 +167,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Assign(self, node): def visit_Assign(self, node):
if len(node.targets) != 1: _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
raise ValueError("So far only one-valued assignment is supported!")
lhs = node.targets[0] lhs = node.targets[0]
rhs = self.visit(node.value) rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr): if isinstance(rhs, _expr.Expr):
...@@ -174,36 +176,40 @@ class HybridParser(ast.NodeVisitor): ...@@ -174,36 +176,40 @@ class HybridParser(ast.NodeVisitor):
#TODO: support defined intermediate buffer later #TODO: support defined intermediate buffer later
lhs_ = lhs lhs_ = lhs
lhs = lhs.id lhs = lhs.id
if lhs in self.loops_above.keys(): _internal_assert(lhs not in self.loops_above.keys(), \
raise ValueError("You CAN NEVER overwrite a loop variable!") "Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs] decl, _, rw = self.usage[lhs]
if decl == lhs_: if decl == lhs_:
if lhs in self.var_consts.keys(): _internal_assert(lhs not in self.var_consts.keys(), \
raise ValueError("BUG: A constant cannot be overwritten!") "A constant cannot be overwritten!")
if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys(): _internal_assert(lhs not in self.alloc_buffers.keys(), \
raise ValueError("BUG: This value should not be defined before this point!") "This value should not be defined before this point!")
if isinstance(rhs, tuple): if isinstance(rhs, tuple):
shape, dtype, scope = rhs shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs) ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope) if scope != 'output':
self.alloc_buffers[lhs] = (ph, scope)
else:
self._args[lhs] = ph
self.outputs.append(lhs)
return make_nop() return make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw: if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs self.var_consts[lhs] = rhs
else: else:
self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, 'global')
if lhs in self.var_consts.keys(): if lhs in self.var_consts.keys():
return make_nop() return make_nop()
else: _internal_assert(lhs in self.alloc_buffers.keys(), \
if lhs not in self.var_buffers.keys(): "This variable should be defined before!")
raise ValueError("BUG: This variable should be defined before!") tgt, _ = self.alloc_buffers[lhs]
tgt = self.var_buffers[lhs] return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else: else:
lhs = self.visit(lhs) lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call): _internal_assert(isinstance(lhs, _expr.Call), \
raise ValueError("An array access's LHS is expected to be a expr.Call!") "An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later #TODO: support slice later
buf = self._get_buffer_from_id(lhs.name) buf = self._get_buffer_from_id(lhs.name, for_provide=True)
return _make.Provide(buf.op, 0, rhs, lhs.args) return _make.Provide(buf.op, 0, rhs, lhs.args)
...@@ -219,21 +225,20 @@ class HybridParser(ast.NodeVisitor): ...@@ -219,21 +225,20 @@ class HybridParser(ast.NodeVisitor):
array = node.value.id array = node.value.id
_buf = self._get_buffer_from_id(array) _buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name): _internal_assert(isinstance(node.value, ast.Attribute), \
raise ValueError("The root of array access is expect to be a id!") "Only variable and attribute's subscript supported so far")
if node.value.attr != "shape": _internal_assert(isinstance(node.value.value, ast.Name), \
raise ValueError("Attribute access so far only 'shape' is supported!") "The root of array access is expect to be a id!")
if len(args) != 1: _internal_assert(node.value.attr == "shape", \
raise ValueError("For 'shape' access the argument should be only one!") "Attribute access so far only 'shape' is supported!")
args = args[0] _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
#TODO: maybe support non-constant value later? args = args[0]
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): #TODO: maybe support non-constant value later?
raise ValueError("So far only constant shape access supported!") _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
buf = self._get_buffer_from_id(node.value.value.id) "So far only constant shape access supported!")
return buf.shape[args.value] buf = self._get_buffer_from_id(node.value.value.id)
else: return buf.shape[args.value]
raise ValueError("Not supported yet!")
def visit_With(self, node): def visit_With(self, node):
...@@ -241,14 +246,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -241,14 +246,11 @@ class HybridParser(ast.NodeVisitor):
context = node.context_expr context = node.context_expr
option = node.optional_vars option = node.optional_vars
else: else:
if len(node.items) != 1: _internal_assert(len(node.items) == 1, "Only one with element is supported so far!")
raise ValueError("Only one with element is supported so far!")
context = node.items[0].context_expr context = node.items[0].context_expr
option = node.items[0].optional_vars option = node.items[0].optional_vars
if not isinstance(context, ast.Call): _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
raise ValueError("The object must be a Python function call!") _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
if not isinstance(option, ast.Name):
raise ValueError("The object after 'as' must be an id!")
self.annotation[option.id] = context.func.id self.annotation[option.id] = context.func.id
return list_to_block(self.visit, node.body) return list_to_block(self.visit, node.body)
...@@ -272,10 +274,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -272,10 +274,8 @@ class HybridParser(ast.NodeVisitor):
def visit_Compare(self, node): def visit_Compare(self, node):
lhs = self.visit(node.left) lhs = self.visit(node.left)
if len(node.ops) != 1: _internal_assert(len(node.ops) == 1, "Only one compare op is supported!")
raise ValueError("Only one compare op is supported!") _internal_assert(len(node.comparators) == 1, "Only one comparator is supported!")
if len(node.comparators) != 1:
raise ValueError("Only one comparator is supported!")
rhs = self.visit(node.comparators[0]) rhs = self.visit(node.comparators[0])
return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs) return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs)
...@@ -293,16 +293,15 @@ class HybridParser(ast.NodeVisitor): ...@@ -293,16 +293,15 @@ class HybridParser(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
# Yet, no function pointer supported # Yet, no function pointer supported
if not isinstance(node.func, ast.Name): _internal_assert(isinstance(node.func, ast.Name), \
raise ValueError("Only id-function function call is supported so far!") "Only id-function function call is supported so far!")
func_id = node.func.id func_id = node.func.id
n = len(node.args) n = len(node.args)
if func_id in LOOP_INTRIN.keys() and func_id != 'bind': if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
if n == 1: if n == 1:
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0])
else: else:
if n != 2: _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
raise ValueError("A loop intrinsic should only have 1 or 2 arguments!")
low, ext = self.visit(node.args[0]), self.visit(node.args[1]) low, ext = self.visit(node.args[0]), self.visit(node.args[1])
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low ext = ext - low
...@@ -310,10 +309,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -310,10 +309,9 @@ class HybridParser(ast.NodeVisitor):
iter_var = None iter_var = None
return iter_var, low, ext, for_type return iter_var, low, ext, for_type
elif func_id == 'bind': elif func_id == 'bind':
if n != 2: _internal_assert(n == 2, "A loop bind should only have 2 arguments!")
raise ValueError("A loop bind should only have 2 arguments!") _internal_assert(isinstance(node.args[0], ast.Str), \
if not isinstance(node.args[0], ast.Str): "A loop bind's first argument should be a string!")
raise ValueError("A loop bind's first argument should be a string!")
_vn = node.args[0].s _vn = node.args[0].s
iter_var = thread_axis(node.args[0].s) iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
...@@ -321,29 +319,39 @@ class HybridParser(ast.NodeVisitor): ...@@ -321,29 +319,39 @@ class HybridParser(ast.NodeVisitor):
return iter_var, low, ext, for_type return iter_var, low, ext, for_type
elif func_id in MATH_INTRIN: elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate': elif func_id in ['allocate', 'output_tensor']:
if not isinstance(node.args[0], ast.Tuple): _internal_assert(isinstance(node.args[0], ast.Tuple), \
raise ValueError("allocate's first argument should be a tuple of shape!") "allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts) shape = tuple(self.visit(i) for i in node.args[0].elts)
if func_id == 'output_tensor':
_internal_assert(not self.loops_above, \
"Are you sure to allocate a output buffer multiple times?")
for i in shape: for i in shape:
if not isinstance(i, _expr.Expr): _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
raise ValueError("The shape should be an expression")
if n > 1: if n > 1:
if not isinstance(node.args[1], ast.Str): if isinstance(node.args[1], ast.Str):
raise ValueError("The data type should be an string") dtype = node.args[1].s
dtype = node.args[1].s else:
_internal_assert(isinstance(node.args[1], ast.Attribute), \
"Unable to evaluate to get data type")
to_eval = node.args[1]
_internal_assert(isinstance(to_eval.value, ast.Name), \
"Unable to evaluate the attribute to get data type")
_internal_assert(to_eval.attr == 'dtype', \
"Only dtype attribute is supported so far")
dtype = self._get_buffer_from_id(to_eval.value.id).dtype
else: else:
dtype = 'float32' dtype = 'float32'
if n > 2: if n > 2:
if not isinstance(node.args[2], ast.Str): _internal_assert(isinstance(node.args[2], ast.Str), \
raise ValueError("The data type should be an string") "The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = node.args[2].s scope = node.args[2].s
else: else:
scope = 'global' scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope) return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min': elif func_id == 'max' or func_id == 'min':
if n != 2: _internal_assert(n == 2, "Max/Min function should have 2 elements")
raise ValueError("Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1]) a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b) return getattr(_make, func_id.title())(a, b)
else: else:
...@@ -352,19 +360,17 @@ class HybridParser(ast.NodeVisitor): ...@@ -352,19 +360,17 @@ class HybridParser(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter) iter_var, low, ext, for_type = self.visit(node.iter)
if not isinstance(node.target, ast.Name): _internal_assert(isinstance(node.target, ast.Name), \
raise ValueError("The loop iterator should be a variable!") "The loop iterator should be a variable!")
_name = node.target.id _name = node.target.id
if iter_var is None: if iter_var is None:
if for_type is None: _internal_assert(for_type is not None, "The loop bind function parse error!")
raise ValueError("The loop bind function parse error!")
offset = iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
offset = iter_var + low offset = iter_var + low
self.loops_above[_name] = offset self.loops_above[_name] = offset
else: else:
if for_type is not None: _internal_assert(for_type is None, "The loop iterating function parse error!")
raise ValueError("The loop iterating function parse error!")
self.loops_above[_name] = iter_var.var self.loops_above[_name] = iter_var.var
_body = list_to_block(self.visit, node.body) _body = list_to_block(self.visit, node.body)
_body = self.wrap_up_realize(node, _body) _body = self.wrap_up_realize(node, _body)
...@@ -376,10 +382,46 @@ class HybridParser(ast.NodeVisitor): ...@@ -376,10 +382,46 @@ class HybridParser(ast.NodeVisitor):
return res return res
def visit_Return(self, node):
_internal_assert(not self.loops_above, "Return should not be in a loop body!")
ids = []
if isinstance(node.value, ast.Name):
ids.append(node.value.id)
else:
_internal_assert(isinstance(node.value, ast.Tuple), \
"You should return either a single tensor or a tuple")
for i in node.value.elts:
_internal_assert(isinstance(i, ast.Name), "What do you return?")
ids.append(i.id)
_internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
if len(ids) != len(self.outputs):
logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
self.outputs = [self._args[i] for i in ids]
self.returned = True
return make_nop()
def parse_python(src, args): def parse_python(src, args):
"""The helper function of calling the AST visitor""" """The helper function of calling the AST visitor
Parameters
----------
src : str
The source code of the function to be parsed.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
root = ast.parse(src) root = ast.parse(src)
var_usage = determine_variable_usage(root, args) var_usage = determine_variable_usage(root, args)
parser = HybridParser(args, var_usage) parser = HybridParser(args, var_usage)
halide_ir = parser.visit(root) parser.parsed_body = parser.visit(root)
return halide_ir _internal_assert(parser.returned, 'No valid return found in the function body!')
return parser
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import ast import ast
import inspect import inspect
import logging
import sys
import numpy import numpy
from .intrin import HYBRID_GLOBALS from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types from .._ffi.base import numeric_types
...@@ -30,10 +32,17 @@ def is_docstring(node): ...@@ -30,10 +32,17 @@ def is_docstring(node):
def _pruned_source(func): def _pruned_source(func):
"""Prune source code's extra leading spaces""" """Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n') try:
leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) lines = inspect.getsource(func).split('\n')
lines = [line[leading_space:] for line in lines] leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
return '\n'.join(lines) lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
except IOError as err:
if sys.version_info[0] == 2 and str(err) == 'could not get source code':
logging.log(logging.CRITICAL, \
'This module is not fully operated under Python2... ' \
'Please move to Python3!')
raise err
def _is_tvm_arg_types(args): def _is_tvm_arg_types(args):
...@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect): ...@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect):
_globals.pop(elem) _globals.pop(elem)
for k, v in intersect: for k, v in intersect:
_globals[k] = v _globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import ast import ast
import sys import sys
from .intrin import HYBRID_GLOBALS from .intrin import HYBRID_GLOBALS
from .util import _internal_assert
class PyVariableUsage(ast.NodeVisitor): class PyVariableUsage(ast.NodeVisitor):
...@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.scope_level.append(node) self.scope_level.append(node)
if len(node.args.args) != len(self.args): _internal_assert(len(node.args.args) == len(self.args), \
raise ValueError('#arguments passed should be the same as #arguments defined') '#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx] self._args[getattr(arg, _attr)] = self.args[idx]
...@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
if not isinstance(node.target, ast.Name): _internal_assert(isinstance(node.target, ast.Name), \
raise ValueError("For's iterator should be an id") "For's iterator should be an id")
self.visit(node.iter) self.visit(node.iter)
self.scope_level.append(node) self.scope_level.append(node)
for i in node.body: for i in node.body:
...@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
#No function pointer supported so far #No function pointer supported so far
if not isinstance(node.func, ast.Name): _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
raise ValueError("Function call should be an id")
func_id = node.func.id func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']: _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
raise ValueError("Function call id not in intrinsics' list") "Function call id not in intrinsics' list")
for elem in node.args: for elem in node.args:
self.visit(elem) self.visit(elem)
...@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id in fors: if node.id in fors:
return return
# The loop variable cannot be overwritten when iteration # The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors: _internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \
raise ValueError("Iter var cannot be overwritten") "Iter var cannot be overwritten")
if node.id not in self.status.keys(): if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store): _internal_assert(isinstance(node.ctx, ast.Store), \
raise ValueError('In Python, "first store" indicates "declaration"') 'Undeclared variable %s' % node.id)
self.status[node.id] = (node, self.scope_level[-1], set()) self.status[node.id] = (node, self.scope_level[-1], set())
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
......
...@@ -180,3 +180,8 @@ class ScanOp(Operation): ...@@ -180,3 +180,8 @@ class ScanOp(Operation):
class ExternOp(Operation): class ExternOp(Operation):
"""Extern operation.""" """Extern operation."""
pass pass
@register_node
class HybridOp(Operation):
"""Hybrid operation."""
pass
...@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp") ...@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp")
args[6]); args[6]);
}); });
TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = HybridOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput") TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output( *ret = args[0].operator Operation().output(
......
/*!
* Copyright (c) 2018 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "op_util.h"
namespace tvm {
using namespace ir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) {
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(HybridOpNode);
int HybridOpNode::num_outputs() const {
return static_cast<int>(outputs.size());
}
Array<IterVar> HybridOpNode::root_iter_vars() const {
return {};
}
Type HybridOpNode::output_dtype(size_t i) const {
return outputs[i]->dtype;
}
Array<Expr> HybridOpNode::output_shape(size_t i) const {
return outputs[i]->shape;
}
Operation HybridOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<HybridOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->body = std::move(body);
Operation res = Operation(n);
return res;
}
Array<Tensor> HybridOpNode::InputTensors() const {
return inputs;
}
Operation HybridOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (body.same_as(n->body) &&
inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
void HybridOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
}
void HybridOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt HybridOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
return realize_body;
}
Stmt HybridOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
inputs[i]->shape,
inputs[i]->dtype);
f_push_bind(buffer, inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
}
auto n = make_node<HybridOpNode>(*this);
/*
* These two lines of codes replace tensors' reads & writes.
* This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op.
* NAMING CONFLICT: In hybrid script all the tensors have their own
* names specified by the users. However, In TVM op, all the output
* tensors' names are the same as the op's name. I cannot change the
* name to the op's name in the function body after the op node is
* formed, because:
* 1. Output tensors all point to the corresponding op node.
* 2. Once OpNode is wrapped up by an Operation node, it can
* no longer be changed.
* This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
* Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when
* forming the function body and passing to next stage of compilation.
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
return ret;
}
} // namespace tvm
...@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { ...@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest; return nest;
} }
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors // replacer to replace tensors
class TensorReplacer : public ir::IRMutator { class TensorReplacer : public ir::IRMutator {
......
...@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage, ...@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates); std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*! /*!
* \brief Replace the tensor reference in stmt by the replace map. * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed. * \param stmt The statement to be processed.
* \param replace The replacement rule. * \param replace The replacement rule.
*/ */
Stmt ReplaceTensor(Stmt stmt, Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace); const std::unordered_map<Tensor, Tensor>& replace);
/*! /*!
* \brief Replace the tensor reference in expr by the replace map. * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed. * \param expr The expression to be processed.
* \param replace The replacement rule. * \param replace The replacement rule.
*/ */
......
...@@ -3,7 +3,7 @@ from tvm.hybrid import script ...@@ -3,7 +3,7 @@ from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS from tvm.hybrid.intrin import HYBRID_GLOBALS
@nose.tools.nottest @nose.tools.nottest
def run_and_check(func, args, outs, var_dict={}, target='llvm'): def run_and_check(func, args, var_dict={}, target='llvm'):
def tvm_val_2_py_val(val): def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val) val = tvm.ir_pass.Simplify(val)
...@@ -14,39 +14,50 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'): ...@@ -14,39 +14,50 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
emu_args = [] emu_args = []
nd_args = [] nd_args = []
to_check = []
for i in args: for i in args:
if isinstance(i, tvm.tensor.Tensor): if isinstance(i, tvm.tensor.Tensor):
shape = [tvm_val_2_py_val(j) for j in i.shape] shape = [tvm_val_2_py_val(j) for j in i.shape]
if i in outs: emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
emu_args.append(numpy.zeros(shape).astype(i.dtype)) nd_args.append(tvm.nd.array(emu_args[-1], ctx))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
to_check.append((nd_args[-1], emu_args[-1]))
else:
emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
else: else:
assert isinstance(i, tvm.expr.Var) assert isinstance(i, tvm.expr.Var)
emu_args.append(tvm_val_2_py_val(i)) emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1]) nd_args.append(emu_args[-1])
func(*emu_args) outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
lowerd_func = tvm.lower(func(*args), args) sch = tvm.create_schedule(op)
module = tvm.build(lowerd_func, target=target) module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
assert module assert module
out_tensors = []
for i in range(op.num_outputs):
output = op.output(i)
shape = [tvm_val_2_py_val(j) for j in output.shape]
nd_args.append(tvm.nd.array(numpy.zeros(shape).astype(output.dtype), ctx))
out_tensors.append(nd_args[-1])
ref_data = func(*emu_args)
if isinstance(ref_data, numpy.ndarray):
ref_data = [ref_data]
module(*nd_args) module(*nd_args)
for nd, np in to_check: for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
@script @script
def outer_product(n, m, a, b, c): def outer_product(n, m, a, b):
"""This is a simple outer product""" """This is a simple outer product.
Actually this function is not required to be documented.
I write this docstring to test skipping docstring functionality.
"""
c = output_tensor((n, m), a.dtype)
for i in range(n): for i in range(n):
for j in range(m): for j in range(m):
c[i, j] = a[i] * b[j] c[i, j] = a[i] * b[j]
return c
#Test global function #Test global function
#Test bridge between frontend and backend #Test bridge between frontend and backend
...@@ -55,8 +66,14 @@ def test_outer_product(): ...@@ -55,8 +66,14 @@ def test_outer_product():
m = tvm.var('m') m = tvm.var('m')
a = tvm.placeholder((n, ), name='a') a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((m, ), name='b') b = tvm.placeholder((m, ), name='b')
c = tvm.placeholder((n, m), name='c')
ir = outer_product(n, m, a, b, c) try:
c = outer_product(n, m, a, b)
ir = c.op.body
except IOError as err:
assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
return
#Check for i in (0, n) #Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i' assert ir.loop_var.name == 'i'
...@@ -81,10 +98,8 @@ def test_outer_product(): ...@@ -81,10 +98,8 @@ def test_outer_product():
assert mul.a.name == 'a' assert mul.a.name == 'a'
assert mul.b.name == 'b' assert mul.b.name == 'b'
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001}) run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
for key, _ in HYBRID_GLOBALS.items(): for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys() assert key not in globals().keys()
...@@ -94,19 +109,25 @@ def test_outer_product(): ...@@ -94,19 +109,25 @@ def test_outer_product():
#Test allocation of local variable #Test allocation of local variable
def test_fanout(): def test_fanout():
@script @script
def fanout(n, a, b): def fanout(n, a):
three = 3.0 three = 3.0
b = output_tensor((a.shape[0] - 3, ), a.dtype)
for i in range(a.shape[0] - 3): for i in range(a.shape[0] - 3):
sigma = 0.0 sigma = 0.0
for j in range(3): for j in range(3):
sigma = sigma + a[i + j] sigma = sigma + a[i + j]
sigma = sigma / three sigma = sigma / three
b[i] = sigma b[i] = sigma
return b
n = tvm.var('n') n = tvm.var('n')
a = tvm.placeholder((n, ), 'float32', name='a') a = tvm.placeholder((n, ), 'float32', name='a')
b = tvm.placeholder((n-3, ), 'float32', name='b') try:
ir = fanout(n, a, b) b = fanout(n, a)
ir = b.op.body
except IOError as err:
assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
return
#Check for i in (0, n-3) #Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.stmt.For)
...@@ -163,38 +184,31 @@ def test_fanout(): ...@@ -163,38 +184,31 @@ def test_fanout():
assert len(write.value.args) == 1 assert len(write.value.args) == 1
assert write.value.args[0].value == 0 assert write.value.args[0].value == 0
run_and_check(fanout, [n, a, b], [b], {n: 10}) run_and_check(fanout, [n, a], {n: 10})
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure():
try:
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
except Exception as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'
def test_looptype(): def test_looptype():
@script @script
def looptype(a, b, c): def looptype(a, b, c):
d = output_tensor((8, ), 'int32')
e = output_tensor((8, ), 'int32')
f = output_tensor((8, ), 'int32')
for i in parallel(8): for i in parallel(8):
a[i] = i d[i] = a[i]
for j in vectorize(8): for j in vectorize(8):
b[j] = j e[j] = b[j]
for k in unroll(8): for k in unroll(8):
c[k] = k f[k] = c[k]
return d, e, f
a = tvm.placeholder((8, ), name='a', dtype='int32') a = tvm.placeholder((8, ), name='a', dtype='int32')
b = tvm.placeholder((8, ), name='b', dtype='int32') b = tvm.placeholder((8, ), name='b', dtype='int32')
c = tvm.placeholder((8, ), name='c', dtype='int32') c = tvm.placeholder((8, ), name='c', dtype='int32')
ir = looptype(a, b, c) try:
d, e, f = looptype(a, b, c)
ir = d.op.body
except:
return
iloop = ir.first iloop = ir.first
jloop = ir.rest.first jloop = ir.rest.first
kloop = ir.rest.rest kloop = ir.rest.rest
...@@ -202,24 +216,26 @@ def test_looptype(): ...@@ -202,24 +216,26 @@ def test_looptype():
assert jloop.for_type == tvm.stmt.For.Vectorized assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled assert kloop.for_type == tvm.stmt.For.Unrolled
run_and_check(looptype, [a, b, c], [a, b, c]) run_and_check(looptype, [a, b, c])
def test_if(): def test_if():
@script @script
def if_then_else(a, b): def if_then_else(a):
b = output_tensor((10, ), 'int32')
c = output_tensor((10, ), 'int32')
for i in range(10): for i in range(10):
if i % 2 == 0: if i % 2 == 0:
a[i] = -1 c[i] = a[i]
else: else:
a[i] = 1 c[i] = b[i]
for i in unroll(10): for i in unroll(10):
b[i] = -1 if i % 2 == 0 else 1 b[i] = -1 if i % 2 == 0 else 1
return b, c
a = tvm.placeholder((10, ), dtype='int32', name='a') a = tvm.placeholder((10, ), dtype='int32', name='a')
b = tvm.placeholder((10, ), dtype='int32', name='b')
run_and_check(if_then_else, [a, b], [a, b]) run_and_check(if_then_else, [a])
def test_bind(): def test_bind():
...@@ -227,55 +243,66 @@ def test_bind(): ...@@ -227,55 +243,66 @@ def test_bind():
print('[Warning] No GPU found! Skip bind test!') print('[Warning] No GPU found! Skip bind test!')
return return
@script @script
def vec_add(a, b, c): def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32')
for tx in bind('threadIdx.x', 1000): for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx] c[tx] = b[tx] + c[tx]
return c
a = tvm.placeholder((1000, ), dtype='float32', name='a') a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b') b = tvm.placeholder((1000, ), dtype='float32', name='b')
c = tvm.placeholder((1000, ), dtype='float32', name='c')
run_and_check(vec_add, [a, b, c], [c], target='cuda') run_and_check(vec_add, [a, b], target='cuda')
def test_math_intrin(): def test_math_intrin():
@script @script
def intrin_real(a): def intrin_real(a):
a[0] = sqrt(a[0]) b = output_tensor((8, ), 'float32')
a[1] = log(a[1]) b[0] = sqrt(a[0])
a[2] = exp(a[2]) b[1] = log(a[1])
a[3] = sigmoid(a[3]) b[2] = exp(a[2])
a[4] = power(a[4], a[5]) b[3] = sigmoid(a[3])
a[5] = tanh(a[5]) b[4] = power(a[4], a[5])
a[6] = min(a[4], a[5]) b[5] = tanh(a[5])
a[7] = max(a[5], a[6]) b[6] = min(a[4], a[5])
b[7] = max(a[5], a[6])
return b
a8 = tvm.placeholder((8, ), dtype='float32', name='a') a8 = tvm.placeholder((8, ), dtype='float32', name='a')
ir = intrin_real(a8) b8 = intrin_real(a8)
func = tvm.build(tvm.lower(ir, [a8])) sch = tvm.create_schedule(b8.op)
func = tvm.build(sch, [a8, b8])
assert func assert func
a = numpy.arange(2, 10).astype('float32') a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a) tvm_a = tvm.ndarray.array(a)
func(tvm_a) tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32'))
intrin_real(a) b = intrin_real(a)
tvm.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5) func(tvm_a, tvm_b)
tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
@script @script
def intrin_int(a): def intrin_int(a):
a[0] = popcount(a[0]) b = output_tensor((1, ), 'int32')
b[0] = popcount(a[0])
return b
a1 = tvm.placeholder((1, ), dtype='int32') a1 = tvm.placeholder((1, ), dtype='int32')
ir = intrin_int(a1) b1 = intrin_int(a1)
func = tvm.build(tvm.lower(ir, [a1])) sch = tvm.create_schedule(b1.op)
func = tvm.build(sch, [a1, b1])
assert func assert func
a = numpy.array([1234567890]).astype('int32') a = numpy.array([114514]).astype('int32')
tvm_a = tvm.ndarray.array(a) tvm_a = tvm.ndarray.array(a)
intrin_int(a) tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32'))
func(tvm_a) b = intrin_int(a)
assert tvm_a.asnumpy()[0] == a[0] func(tvm_a, tvm_b)
assert tvm_b.asnumpy()[0] == b[0]
# test non caconical loops
def test_non_zero(): def test_non_zero():
@tvm.hybrid.script @tvm.hybrid.script
def blur(a, b): def blur(a):
b = output_tensor((30, 30), 'float32')
for i in range(2, 32): for i in range(2, 32):
for j in range(2, 32): for j in range(2, 32):
s = 0.0 s = 0.0
...@@ -283,29 +310,28 @@ def test_non_zero(): ...@@ -283,29 +310,28 @@ def test_non_zero():
for dj in range(3): for dj in range(3):
s = s + a[i-di, j-dj] s = s + a[i-di, j-dj]
b[i-2, j-2] = s / 9.0 b[i-2, j-2] = s / 9.0
try: return b
a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b') a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur, [a, b], [b]) run_and_check(blur, [a])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_non_zero is skipped by Python2 because "%s"' % str(err))
@tvm.hybrid.script @tvm.hybrid.script
def triangle(a, b, c): def triangle(a, b):
c = output_tensor((10, 10), dtype='float32')
for i in range(10): for i in range(10):
for j in range(i, 10): for j in range(i, 10):
c[i, j] = a[i] * b[j] c[i, j] = a[i] * b[j]
return c
a = tvm.placeholder((10, ), dtype='float32', name='a') a = tvm.placeholder((10, ), dtype='float32', name='a')
b = tvm.placeholder((10, ), dtype='float32', name='b') b = tvm.placeholder((10, ), dtype='float32', name='b')
c = tvm.placeholder((10, 10), dtype='float32', name='c')
run_and_check(triangle, [a, b, c], [c]) run_and_check(triangle, [a, b])
def test_allocate(): def test_allocate():
@tvm.hybrid.script @tvm.hybrid.script
def blur2d(a, b): def blur2d(a):
b = output_tensor((30, 30), 'float32')
for i in range(30): for i in range(30):
ha = allocate((3, 30), 'float32') ha = allocate((3, 30), 'float32')
for j in range(3): for j in range(3):
...@@ -313,15 +339,15 @@ def test_allocate(): ...@@ -313,15 +339,15 @@ def test_allocate():
ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2] ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2]
for j in range(30): for j in range(30):
b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0 b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0
return b
a = tvm.placeholder((32, 32), 'float32', 'a') a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b') run_and_check(blur2d, [a])
run_and_check(blur2d, [a, b], [b])
if tvm.gpu().exist: if tvm.gpu().exist:
@tvm.hybrid.script @tvm.hybrid.script
def share_vec_add(a, b, c): def share_vec_add(a, b):
c = output_tensor((256, ), 'float32')
shared = allocate((256, ), 'float32', 'shared') shared = allocate((256, ), 'float32', 'shared')
for i in bind("threadIdx.x", 256): for i in bind("threadIdx.x", 256):
shared[i] = a[i] shared[i] = a[i]
...@@ -330,23 +356,81 @@ def test_allocate(): ...@@ -330,23 +356,81 @@ def test_allocate():
local[i] = b[i] local[i] = b[i]
for i in bind("threadIdx.x", 256): for i in bind("threadIdx.x", 256):
c[i] = shared[i] + local[i] c[i] = shared[i] + local[i]
return c
a = tvm.placeholder((256, ), dtype='float32', name='a') a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b') b = tvm.placeholder((256, ), dtype='float32', name='b')
c = tvm.placeholder((256, ), dtype='float32', name='c') run_and_check(share_vec_add, [a, b], target='cuda')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
else: else:
print('[Warning] No GPU found! Skip shared mem test!') print('[Warning] No GPU found! Skip shared mem test!')
def test_upstream():
@tvm.hybrid.script
def upstream(a):
b = output_tensor((20, ), 'float32')
for i in range(20):
b[i] = a[i] * i
return b
a = tvm.placeholder((20, ), 'float32')
b = tvm.placeholder((20, ), 'float32')
c = tvm.compute((20, ), lambda x: a[x] + b[x])
d = upstream(c)
sch = tvm.create_schedule([c.op, d.op])
ir = tvm.lower(sch, [a, b, d], simple_mode=True)
func = tvm.build(sch, [a, b, d])
assert(func)
a = numpy.random.randn(20).astype('float32')
b = numpy.random.randn(20).astype('float32')
ref = numpy.zeros((20, ), 'float32')
for i in range(20):
ref[i] = (a[i] + b[i]) * i
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(b)
tvm_d = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
func(tvm_a, tvm_b, tvm_d)
tvm.testing.assert_allclose(tvm_d.asnumpy(), ref, 1e-5, 1e-5)
def test_downstream():
@tvm.hybrid.script
def downstream(a):
b = output_tensor((20, ), 'float32')
for i in range(20):
b[i] = a[i] * i
return b
a = tvm.placeholder((20, ), 'float32')
b = downstream(a)
c = tvm.compute((20, ), lambda x: b[x] + 1.0)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c])
assert module
a = numpy.random.randn(20).astype('float32')
ref = numpy.zeros((20, )).astype('float32')
for i in range(20):
ref[i] = (a[i] * i) + 1.0
tvm_a = tvm.nd.array(a)
tvm_c = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
module(tvm_a, tvm_c)
tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
test_fanout() test_fanout()
test_failure()
test_looptype() test_looptype()
test_if() test_if()
test_bind() test_bind()
test_math_intrin() test_math_intrin()
test_non_zero() test_non_zero()
test_allocate() test_allocate()
#test_inplace()
test_upstream()
test_downstream()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment