Commit 7a01476a by Jian Weng Committed by Tianqi Chen

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported;…

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported; enable compilation_database in CMakeList.txt (#1757)
parent 79735eb2
......@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt")
# initial variables
set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Generic compilation options
if(MSVC)
......
......@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun
@tvm.hybrid.script
def outer_product(a, b, c):
c = output_tensor((100, 99), 'float32')
for i in range(a.shape[0]):
for j in range(b.shape[0]):
c[i, j] = a[i] * b[j]
a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b, c)
return c
a = numpy.random.randn(100)
b = numpy.random.randn(99)
c = outer_product(a, b)
This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need
......@@ -40,25 +42,25 @@ or ``numpy`` numeric type.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
This function is not encouraged to use, users are encouraged to use the second interface.
The current parse interface looks like:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.**
If we pass these tvm tensors to this function, it returns a op node:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
op = outer_product(a, b, c) # return the corresponding op node
c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.**
Tuning
~~~~~~
......
......@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
};
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
......
......@@ -340,11 +340,6 @@ def lower(sch,
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0:
stmt = f(stmt)
......
......@@ -7,4 +7,5 @@ python semantic emulation.
2. Developers can build HalideIR by writing Python code.
"""
from .api import script, parse
from .api import script
from .parser import parse_python
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
import types
from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python
from .util import _pruned_source
def script(pyfunc):
......@@ -17,40 +20,26 @@ def script(pyfunc):
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs):
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
return parse(func, args)
src = _pruned_source(func)
parser = parse_python(src, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
return decorate(pyfunc, wrapped_func)
......@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
"""
return numpy.zeros(shape).astype(dtype)
output_tensor = allocate #pylint: disable=invalid-name
def popcount(x):
"""
......@@ -87,18 +88,19 @@ def sigmoid(x):
HYBRID_GLOBALS = {
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'output_tensor': output_tensor,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
}
......
......@@ -2,8 +2,9 @@
import ast
import operator
import logging
import sys
from .util import make_nop, halide_imm_types, is_docstring
from .util import make_nop, halide_imm_types, is_docstring, _internal_assert
from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage
from ..api import thread_axis
......@@ -72,15 +73,17 @@ class HybridParser(ast.NodeVisitor):
The name of the function to be lowered; if not provided,
the compiler will use the name in the AST
"""
self.args = args[:]
self.args = list(args)
self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.var_buffers = {} # Buffers formed by mutatble variables
self.alloc_buffers = {} # Buffers formed by allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered
self.iter_axis = []
self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body
self.returned = False
def wrap_up_realize(self, node, body):
......@@ -90,9 +93,8 @@ class HybridParser(ast.NodeVisitor):
continue
_, level, _ = val
if level == node:
if key in self.var_buffers.keys():
_buf = self.var_buffers[key]
_scope = 'global'
if key in self._args.keys():
continue
else:
_buf, _scope = self.alloc_buffers[key]
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
......@@ -103,12 +105,13 @@ class HybridParser(ast.NodeVisitor):
return body
def _get_buffer_from_id(self, s):
if s not in self._args.keys() and s not in self.alloc_buffers.keys():
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s)
if s in self._args.keys() and s in self.alloc_buffers.keys():
raise ValueError("%s, a buffer cannot be both argument and allocated!" % s)
def _get_buffer_from_id(self, s, for_provide=False):
_internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1,
"This %s is expected to be in either \
argument list or allocated buffer!" % s)
if s in self._args.keys():
if for_provide:
self.side_effect.add(self._args[s])
return self._args[s]
return self.alloc_buffers[s][0]
......@@ -116,15 +119,15 @@ class HybridParser(ast.NodeVisitor):
#pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node):
if len(node.body) != 1:
raise ValueError("Only one-function source code can be fed to this parser!")
_internal_assert(len(node.body) == 1, \
"Only one-function source code can be fed to this parser!")
return self.visit(node.body[0])
def visit_FunctionDef(self, node):
if len(node.args.args) != len(self.args):
raise ValueError("The number of arguments passed to the function\
should be the same as it is defined!")
_internal_assert(len(node.args.args) == len(self.args), \
"The number of arguments passed to the \
function should be the same as it is defined!")
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
......@@ -145,17 +148,17 @@ class HybridParser(ast.NodeVisitor):
return self._args[_id]
elif _id in self.loops_above.keys():
return self.loops_above[_id]
if _id in self._args.keys():
raise ValueError("This id %s should be handled in visit_Subscript!" % _id)
if _id not in self.usage.keys():
raise ValueError("This id %s is expected to be a defined variable!" % _id)
_internal_assert(_id not in self._args.keys(), \
"This id %s should be handled in visit_Subscript!" % _id)
_internal_assert(_id in self.usage.keys(), \
"This id %s is expected to be a defined variable!" % _id)
# Buffer
if _id in self.var_buffers.keys():
_buf = self.var_buffers[_id]
if _id in self.alloc_buffers.keys():
_buf, _ = self.alloc_buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant
if _id not in self.var_consts.keys():
raise ValueError("This id %s is expected to a compilation time constant!" % _id)
_internal_assert(_id in self.var_consts.keys(),
"This id %s is expected to a compilation time constant!" % _id)
return self.var_consts[_id]
......@@ -164,8 +167,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Assign(self, node):
if len(node.targets) != 1:
raise ValueError("So far only one-valued assignment is supported!")
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr):
......@@ -174,36 +176,40 @@ class HybridParser(ast.NodeVisitor):
#TODO: support defined intermediate buffer later
lhs_ = lhs
lhs = lhs.id
if lhs in self.loops_above.keys():
raise ValueError("You CAN NEVER overwrite a loop variable!")
_internal_assert(lhs not in self.loops_above.keys(), \
"Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs]
if decl == lhs_:
if lhs in self.var_consts.keys():
raise ValueError("BUG: A constant cannot be overwritten!")
if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys():
raise ValueError("BUG: This value should not be defined before this point!")
_internal_assert(lhs not in self.var_consts.keys(), \
"A constant cannot be overwritten!")
_internal_assert(lhs not in self.alloc_buffers.keys(), \
"This value should not be defined before this point!")
if isinstance(rhs, tuple):
shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope)
if scope != 'output':
self.alloc_buffers[lhs] = (ph, scope)
else:
self._args[lhs] = ph
self.outputs.append(lhs)
return make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs
else:
self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, 'global')
if lhs in self.var_consts.keys():
return make_nop()
else:
if lhs not in self.var_buffers.keys():
raise ValueError("BUG: This variable should be defined before!")
tgt = self.var_buffers[lhs]
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
_internal_assert(lhs in self.alloc_buffers.keys(), \
"This variable should be defined before!")
tgt, _ = self.alloc_buffers[lhs]
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else:
lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("An array access's LHS is expected to be a expr.Call!")
_internal_assert(isinstance(lhs, _expr.Call), \
"An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later
buf = self._get_buffer_from_id(lhs.name)
buf = self._get_buffer_from_id(lhs.name, for_provide=True)
return _make.Provide(buf.op, 0, rhs, lhs.args)
......@@ -219,21 +225,20 @@ class HybridParser(ast.NodeVisitor):
array = node.value.id
_buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name):
raise ValueError("The root of array access is expect to be a id!")
if node.value.attr != "shape":
raise ValueError("Attribute access so far only 'shape' is supported!")
if len(args) != 1:
raise ValueError("For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)):
raise ValueError("So far only constant shape access supported!")
buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value]
else:
raise ValueError("Not supported yet!")
_internal_assert(isinstance(node.value, ast.Attribute), \
"Only variable and attribute's subscript supported so far")
_internal_assert(isinstance(node.value.value, ast.Name), \
"The root of array access is expect to be a id!")
_internal_assert(node.value.attr == "shape", \
"Attribute access so far only 'shape' is supported!")
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!")
buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value]
def visit_With(self, node):
......@@ -241,14 +246,11 @@ class HybridParser(ast.NodeVisitor):
context = node.context_expr
option = node.optional_vars
else:
if len(node.items) != 1:
raise ValueError("Only one with element is supported so far!")
_internal_assert(len(node.items) == 1, "Only one with element is supported so far!")
context = node.items[0].context_expr
option = node.items[0].optional_vars
if not isinstance(context, ast.Call):
raise ValueError("The object must be a Python function call!")
if not isinstance(option, ast.Name):
raise ValueError("The object after 'as' must be an id!")
_internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
_internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
self.annotation[option.id] = context.func.id
return list_to_block(self.visit, node.body)
......@@ -272,10 +274,8 @@ class HybridParser(ast.NodeVisitor):
def visit_Compare(self, node):
lhs = self.visit(node.left)
if len(node.ops) != 1:
raise ValueError("Only one compare op is supported!")
if len(node.comparators) != 1:
raise ValueError("Only one comparator is supported!")
_internal_assert(len(node.ops) == 1, "Only one compare op is supported!")
_internal_assert(len(node.comparators) == 1, "Only one comparator is supported!")
rhs = self.visit(node.comparators[0])
return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs)
......@@ -293,16 +293,15 @@ class HybridParser(ast.NodeVisitor):
def visit_Call(self, node):
# Yet, no function pointer supported
if not isinstance(node.func, ast.Name):
raise ValueError("Only id-function function call is supported so far!")
_internal_assert(isinstance(node.func, ast.Name), \
"Only id-function function call is supported so far!")
func_id = node.func.id
n = len(node.args)
if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
if n == 1:
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0])
else:
if n != 2:
raise ValueError("A loop intrinsic should only have 1 or 2 arguments!")
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = self.visit(node.args[0]), self.visit(node.args[1])
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
......@@ -310,10 +309,9 @@ class HybridParser(ast.NodeVisitor):
iter_var = None
return iter_var, low, ext, for_type
elif func_id == 'bind':
if n != 2:
raise ValueError("A loop bind should only have 2 arguments!")
if not isinstance(node.args[0], ast.Str):
raise ValueError("A loop bind's first argument should be a string!")
_internal_assert(n == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(node.args[0], ast.Str), \
"A loop bind's first argument should be a string!")
_vn = node.args[0].s
iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
......@@ -321,29 +319,39 @@ class HybridParser(ast.NodeVisitor):
return iter_var, low, ext, for_type
elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate':
if not isinstance(node.args[0], ast.Tuple):
raise ValueError("allocate's first argument should be a tuple of shape!")
elif func_id in ['allocate', 'output_tensor']:
_internal_assert(isinstance(node.args[0], ast.Tuple), \
"allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
if func_id == 'output_tensor':
_internal_assert(not self.loops_above, \
"Are you sure to allocate a output buffer multiple times?")
for i in shape:
if not isinstance(i, _expr.Expr):
raise ValueError("The shape should be an expression")
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
if not isinstance(node.args[1], ast.Str):
raise ValueError("The data type should be an string")
dtype = node.args[1].s
if isinstance(node.args[1], ast.Str):
dtype = node.args[1].s
else:
_internal_assert(isinstance(node.args[1], ast.Attribute), \
"Unable to evaluate to get data type")
to_eval = node.args[1]
_internal_assert(isinstance(to_eval.value, ast.Name), \
"Unable to evaluate the attribute to get data type")
_internal_assert(to_eval.attr == 'dtype', \
"Only dtype attribute is supported so far")
dtype = self._get_buffer_from_id(to_eval.value.id).dtype
else:
dtype = 'float32'
if n > 2:
if not isinstance(node.args[2], ast.Str):
raise ValueError("The data type should be an string")
_internal_assert(isinstance(node.args[2], ast.Str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = node.args[2].s
else:
scope = 'global'
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min':
if n != 2:
raise ValueError("Max/Min function should have 2 elements")
_internal_assert(n == 2, "Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b)
else:
......@@ -352,19 +360,17 @@ class HybridParser(ast.NodeVisitor):
def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter)
if not isinstance(node.target, ast.Name):
raise ValueError("The loop iterator should be a variable!")
_internal_assert(isinstance(node.target, ast.Name), \
"The loop iterator should be a variable!")
_name = node.target.id
if iter_var is None:
if for_type is None:
raise ValueError("The loop bind function parse error!")
_internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
offset = iter_var + low
self.loops_above[_name] = offset
else:
if for_type is not None:
raise ValueError("The loop iterating function parse error!")
_internal_assert(for_type is None, "The loop iterating function parse error!")
self.loops_above[_name] = iter_var.var
_body = list_to_block(self.visit, node.body)
_body = self.wrap_up_realize(node, _body)
......@@ -376,10 +382,46 @@ class HybridParser(ast.NodeVisitor):
return res
def visit_Return(self, node):
_internal_assert(not self.loops_above, "Return should not be in a loop body!")
ids = []
if isinstance(node.value, ast.Name):
ids.append(node.value.id)
else:
_internal_assert(isinstance(node.value, ast.Tuple), \
"You should return either a single tensor or a tuple")
for i in node.value.elts:
_internal_assert(isinstance(i, ast.Name), "What do you return?")
ids.append(i.id)
_internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
if len(ids) != len(self.outputs):
logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
self.outputs = [self._args[i] for i in ids]
self.returned = True
return make_nop()
def parse_python(src, args):
"""The helper function of calling the AST visitor"""
"""The helper function of calling the AST visitor
Parameters
----------
src : str
The source code of the function to be parsed.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
root = ast.parse(src)
var_usage = determine_variable_usage(root, args)
parser = HybridParser(args, var_usage)
halide_ir = parser.visit(root)
return halide_ir
parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!')
return parser
......@@ -2,6 +2,8 @@
import ast
import inspect
import logging
import sys
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
......@@ -30,10 +32,17 @@ def is_docstring(node):
def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
try:
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
except IOError as err:
if sys.version_info[0] == 2 and str(err) == 'could not get source code':
logging.log(logging.CRITICAL, \
'This module is not fully operated under Python2... ' \
'Please move to Python3!')
raise err
def _is_tvm_arg_types(args):
......@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
......@@ -3,6 +3,7 @@
import ast
import sys
from .intrin import HYBRID_GLOBALS
from .util import _internal_assert
class PyVariableUsage(ast.NodeVisitor):
......@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_FunctionDef(self, node):
self.scope_level.append(node)
if len(node.args.args) != len(self.args):
raise ValueError('#arguments passed should be the same as #arguments defined')
_internal_assert(len(node.args.args) == len(self.args), \
'#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
......@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_For(self, node):
if not isinstance(node.target, ast.Name):
raise ValueError("For's iterator should be an id")
_internal_assert(isinstance(node.target, ast.Name), \
"For's iterator should be an id")
self.visit(node.iter)
self.scope_level.append(node)
for i in node.body:
......@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']:
raise ValueError("Function call id not in intrinsics' list")
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
"Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
......@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id in fors:
return
# The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors:
raise ValueError("Iter var cannot be overwritten")
_internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \
"Iter var cannot be overwritten")
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
......
......@@ -180,3 +180,8 @@ class ScanOp(Operation):
class ExternOp(Operation):
"""Extern operation."""
pass
@register_node
class HybridOp(Operation):
"""Hybrid operation."""
pass
......@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp")
args[6]);
});
TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = HybridOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
......
/*!
* Copyright (c) 2018 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "op_util.h"
namespace tvm {
using namespace ir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) {
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(HybridOpNode);
int HybridOpNode::num_outputs() const {
return static_cast<int>(outputs.size());
}
Array<IterVar> HybridOpNode::root_iter_vars() const {
return {};
}
Type HybridOpNode::output_dtype(size_t i) const {
return outputs[i]->dtype;
}
Array<Expr> HybridOpNode::output_shape(size_t i) const {
return outputs[i]->shape;
}
Operation HybridOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<HybridOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->body = std::move(body);
Operation res = Operation(n);
return res;
}
Array<Tensor> HybridOpNode::InputTensors() const {
return inputs;
}
Operation HybridOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (body.same_as(n->body) &&
inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
void HybridOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
}
void HybridOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt HybridOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
return realize_body;
}
Stmt HybridOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
inputs[i]->shape,
inputs[i]->dtype);
f_push_bind(buffer, inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
}
auto n = make_node<HybridOpNode>(*this);
/*
* These two lines of codes replace tensors' reads & writes.
* This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op.
* NAMING CONFLICT: In hybrid script all the tensors have their own
* names specified by the users. However, In TVM op, all the output
* tensors' names are the same as the op's name. I cannot change the
* name to the op's name in the function body after the op node is
* formed, because:
* 1. Output tensors all point to the corresponding op node.
* 2. Once OpNode is wrapped up by an Operation node, it can
* no longer be changed.
* This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
* Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when
* forming the function body and passing to next stage of compilation.
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
return ret;
}
} // namespace tvm
......@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest;
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
......
......@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*!
* \brief Replace the tensor reference in stmt by the replace map.
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference in expr by the replace map.
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
......
......@@ -3,7 +3,7 @@ from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
@nose.tools.nottest
def run_and_check(func, args, outs, var_dict={}, target='llvm'):
def run_and_check(func, args, var_dict={}, target='llvm'):
def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val)
......@@ -14,39 +14,50 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
emu_args = []
nd_args = []
to_check = []
for i in args:
if isinstance(i, tvm.tensor.Tensor):
shape = [tvm_val_2_py_val(j) for j in i.shape]
if i in outs:
emu_args.append(numpy.zeros(shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
to_check.append((nd_args[-1], emu_args[-1]))
else:
emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx))
else:
assert isinstance(i, tvm.expr.Var)
emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1])
func(*emu_args)
lowerd_func = tvm.lower(func(*args), args)
module = tvm.build(lowerd_func, target=target)
outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
sch = tvm.create_schedule(op)
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
assert module
out_tensors = []
for i in range(op.num_outputs):
output = op.output(i)
shape = [tvm_val_2_py_val(j) for j in output.shape]
nd_args.append(tvm.nd.array(numpy.zeros(shape).astype(output.dtype), ctx))
out_tensors.append(nd_args[-1])
ref_data = func(*emu_args)
if isinstance(ref_data, numpy.ndarray):
ref_data = [ref_data]
module(*nd_args)
for nd, np in to_check:
for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
@script
def outer_product(n, m, a, b, c):
"""This is a simple outer product"""
def outer_product(n, m, a, b):
"""This is a simple outer product.
Actually this function is not required to be documented.
I write this docstring to test skipping docstring functionality.
"""
c = output_tensor((n, m), a.dtype)
for i in range(n):
for j in range(m):
c[i, j] = a[i] * b[j]
return c
#Test global function
#Test bridge between frontend and backend
......@@ -55,8 +66,14 @@ def test_outer_product():
m = tvm.var('m')
a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((m, ), name='b')
c = tvm.placeholder((n, m), name='c')
ir = outer_product(n, m, a, b, c)
try:
c = outer_product(n, m, a, b)
ir = c.op.body
except IOError as err:
assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
return
#Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i'
......@@ -81,10 +98,8 @@ def test_outer_product():
assert mul.a.name == 'a'
assert mul.b.name == 'b'
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
......@@ -94,19 +109,25 @@ def test_outer_product():
#Test allocation of local variable
def test_fanout():
@script
def fanout(n, a, b):
def fanout(n, a):
three = 3.0
b = output_tensor((a.shape[0] - 3, ), a.dtype)
for i in range(a.shape[0] - 3):
sigma = 0.0
for j in range(3):
sigma = sigma + a[i + j]
sigma = sigma / three
b[i] = sigma
return b
n = tvm.var('n')
a = tvm.placeholder((n, ), 'float32', name='a')
b = tvm.placeholder((n-3, ), 'float32', name='b')
ir = fanout(n, a, b)
try:
b = fanout(n, a)
ir = b.op.body
except IOError as err:
assert sys.version_info[0] == 2 and str(err) == 'could not get source code'
return
#Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For)
......@@ -163,38 +184,31 @@ def test_fanout():
assert len(write.value.args) == 1
assert write.value.args[0].value == 0
run_and_check(fanout, [n, a, b], [b], {n: 10})
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure():
try:
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
except Exception as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'
run_and_check(fanout, [n, a], {n: 10})
def test_looptype():
@script
def looptype(a, b, c):
d = output_tensor((8, ), 'int32')
e = output_tensor((8, ), 'int32')
f = output_tensor((8, ), 'int32')
for i in parallel(8):
a[i] = i
d[i] = a[i]
for j in vectorize(8):
b[j] = j
e[j] = b[j]
for k in unroll(8):
c[k] = k
f[k] = c[k]
return d, e, f
a = tvm.placeholder((8, ), name='a', dtype='int32')
b = tvm.placeholder((8, ), name='b', dtype='int32')
c = tvm.placeholder((8, ), name='c', dtype='int32')
ir = looptype(a, b, c)
try:
d, e, f = looptype(a, b, c)
ir = d.op.body
except:
return
iloop = ir.first
jloop = ir.rest.first
kloop = ir.rest.rest
......@@ -202,24 +216,26 @@ def test_looptype():
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
run_and_check(looptype, [a, b, c], [a, b, c])
run_and_check(looptype, [a, b, c])
def test_if():
@script
def if_then_else(a, b):
def if_then_else(a):
b = output_tensor((10, ), 'int32')
c = output_tensor((10, ), 'int32')
for i in range(10):
if i % 2 == 0:
a[i] = -1
c[i] = a[i]
else:
a[i] = 1
c[i] = b[i]
for i in unroll(10):
b[i] = -1 if i % 2 == 0 else 1
return b, c
a = tvm.placeholder((10, ), dtype='int32', name='a')
b = tvm.placeholder((10, ), dtype='int32', name='b')
run_and_check(if_then_else, [a, b], [a, b])
run_and_check(if_then_else, [a])
def test_bind():
......@@ -227,55 +243,66 @@ def test_bind():
print('[Warning] No GPU found! Skip bind test!')
return
@script
def vec_add(a, b, c):
def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32')
for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx]
return c
a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b')
c = tvm.placeholder((1000, ), dtype='float32', name='c')
run_and_check(vec_add, [a, b, c], [c], target='cuda')
run_and_check(vec_add, [a, b], target='cuda')
def test_math_intrin():
@script
def intrin_real(a):
a[0] = sqrt(a[0])
a[1] = log(a[1])
a[2] = exp(a[2])
a[3] = sigmoid(a[3])
a[4] = power(a[4], a[5])
a[5] = tanh(a[5])
a[6] = min(a[4], a[5])
a[7] = max(a[5], a[6])
b = output_tensor((8, ), 'float32')
b[0] = sqrt(a[0])
b[1] = log(a[1])
b[2] = exp(a[2])
b[3] = sigmoid(a[3])
b[4] = power(a[4], a[5])
b[5] = tanh(a[5])
b[6] = min(a[4], a[5])
b[7] = max(a[5], a[6])
return b
a8 = tvm.placeholder((8, ), dtype='float32', name='a')
ir = intrin_real(a8)
func = tvm.build(tvm.lower(ir, [a8]))
b8 = intrin_real(a8)
sch = tvm.create_schedule(b8.op)
func = tvm.build(sch, [a8, b8])
assert func
a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a)
func(tvm_a)
intrin_real(a)
tvm.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5)
tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32'))
b = intrin_real(a)
func(tvm_a, tvm_b)
tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5)
@script
def intrin_int(a):
a[0] = popcount(a[0])
b = output_tensor((1, ), 'int32')
b[0] = popcount(a[0])
return b
a1 = tvm.placeholder((1, ), dtype='int32')
ir = intrin_int(a1)
func = tvm.build(tvm.lower(ir, [a1]))
b1 = intrin_int(a1)
sch = tvm.create_schedule(b1.op)
func = tvm.build(sch, [a1, b1])
assert func
a = numpy.array([1234567890]).astype('int32')
a = numpy.array([114514]).astype('int32')
tvm_a = tvm.ndarray.array(a)
intrin_int(a)
func(tvm_a)
assert tvm_a.asnumpy()[0] == a[0]
tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32'))
b = intrin_int(a)
func(tvm_a, tvm_b)
assert tvm_b.asnumpy()[0] == b[0]
# test non caconical loops
def test_non_zero():
@tvm.hybrid.script
def blur(a, b):
def blur(a):
b = output_tensor((30, 30), 'float32')
for i in range(2, 32):
for j in range(2, 32):
s = 0.0
......@@ -283,29 +310,28 @@ def test_non_zero():
for dj in range(3):
s = s + a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
try:
a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')
run_and_check(blur, [a, b], [b])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_non_zero is skipped by Python2 because "%s"' % str(err))
return b
a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur, [a])
@tvm.hybrid.script
def triangle(a, b, c):
def triangle(a, b):
c = output_tensor((10, 10), dtype='float32')
for i in range(10):
for j in range(i, 10):
c[i, j] = a[i] * b[j]
return c
a = tvm.placeholder((10, ), dtype='float32', name='a')
b = tvm.placeholder((10, ), dtype='float32', name='b')
c = tvm.placeholder((10, 10), dtype='float32', name='c')
run_and_check(triangle, [a, b, c], [c])
run_and_check(triangle, [a, b])
def test_allocate():
@tvm.hybrid.script
def blur2d(a, b):
def blur2d(a):
b = output_tensor((30, 30), 'float32')
for i in range(30):
ha = allocate((3, 30), 'float32')
for j in range(3):
......@@ -313,15 +339,15 @@ def test_allocate():
ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2]
for j in range(30):
b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0
return b
a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')
run_and_check(blur2d, [a, b], [b])
run_and_check(blur2d, [a])
if tvm.gpu().exist:
@tvm.hybrid.script
def share_vec_add(a, b, c):
def share_vec_add(a, b):
c = output_tensor((256, ), 'float32')
shared = allocate((256, ), 'float32', 'shared')
for i in bind("threadIdx.x", 256):
shared[i] = a[i]
......@@ -330,23 +356,81 @@ def test_allocate():
local[i] = b[i]
for i in bind("threadIdx.x", 256):
c[i] = shared[i] + local[i]
return c
a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b')
c = tvm.placeholder((256, ), dtype='float32', name='c')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
run_and_check(share_vec_add, [a, b], target='cuda')
else:
print('[Warning] No GPU found! Skip shared mem test!')
def test_upstream():
@tvm.hybrid.script
def upstream(a):
b = output_tensor((20, ), 'float32')
for i in range(20):
b[i] = a[i] * i
return b
a = tvm.placeholder((20, ), 'float32')
b = tvm.placeholder((20, ), 'float32')
c = tvm.compute((20, ), lambda x: a[x] + b[x])
d = upstream(c)
sch = tvm.create_schedule([c.op, d.op])
ir = tvm.lower(sch, [a, b, d], simple_mode=True)
func = tvm.build(sch, [a, b, d])
assert(func)
a = numpy.random.randn(20).astype('float32')
b = numpy.random.randn(20).astype('float32')
ref = numpy.zeros((20, ), 'float32')
for i in range(20):
ref[i] = (a[i] + b[i]) * i
tvm_a = tvm.nd.array(a)
tvm_b = tvm.nd.array(b)
tvm_d = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
func(tvm_a, tvm_b, tvm_d)
tvm.testing.assert_allclose(tvm_d.asnumpy(), ref, 1e-5, 1e-5)
def test_downstream():
@tvm.hybrid.script
def downstream(a):
b = output_tensor((20, ), 'float32')
for i in range(20):
b[i] = a[i] * i
return b
a = tvm.placeholder((20, ), 'float32')
b = downstream(a)
c = tvm.compute((20, ), lambda x: b[x] + 1.0)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c])
assert module
a = numpy.random.randn(20).astype('float32')
ref = numpy.zeros((20, )).astype('float32')
for i in range(20):
ref[i] = (a[i] * i) + 1.0
tvm_a = tvm.nd.array(a)
tvm_c = tvm.nd.array(numpy.zeros((20, )).astype('float32'))
module(tvm_a, tvm_c)
tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
if __name__ == "__main__":
test_outer_product()
test_fanout()
test_failure()
test_looptype()
test_if()
test_bind()
test_math_intrin()
test_non_zero()
test_allocate()
#test_inplace()
test_upstream()
test_downstream()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment