Commit c53dd102 by Jian Weng Committed by Leyuan Wang

[Hybrid script] Backend support (#2477)

* a preliminary version is done?

* we no longer need the redundant hybrid/api.py

* support assert stmt

* cast supported

* intrin -> runtime; util is mainly in charge of compilation time

* assert statement

* fix python lint

* fix cpp lint

* on the way to module

* rollback .cc

* fix typo, no direct expose then

* @vinx13 ceil is added i guess?

* wip...

* temp commit

* fix import

* i preliminary version is done?

* on the way to build hybrid module

* nearly fixed...

* dumped python are equiv as original python

* on the way to bootstrap

* cpu bootstrap done

* bootstrap!

* fix lint

* fix doc

* resolve some review concerns

* support load/save

* fix lint

* thanks to xqdan fixed my typo

* fix build, make dump non-optional

* add vthread

* jesus why i added this
parent 7e2a9fcf
...@@ -190,6 +190,7 @@ include(cmake/modules/contrib/BLAS.cmake) ...@@ -190,6 +190,7 @@ include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS})
......
message(STATUS "Build with contrib.hybriddump")
file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc)
list(APPEND COMPILER_SRCS ${HYBRID_CONTRIB_SRC})
...@@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this: ...@@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this:
a[tx] = b[tx] a[tx] = b[tx]
Assert Statement
~~~~~~~~~~~~~~~~
Assert statement is supported, you can simply use it as it is in standard Python.
.. code-block:: python
assert cond, mesg
.. note::
``Assert`` is NOT a function call. Users are encouraged to use assert in the way
presented above --- condition followed by message. It fits both Python AST and HalideIR.
Keywords Keywords
~~~~~~~~ ~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr`` - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
......
...@@ -292,6 +292,25 @@ def get_binds(args, binds=None): ...@@ -292,6 +292,25 @@ def get_binds(args, binds=None):
return binds, arg_list return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch, def lower(sch,
args, args,
name="default_function", name="default_function",
...@@ -337,11 +356,7 @@ def lower(sch, ...@@ -337,11 +356,7 @@ def lower(sch,
# Phase 0 # Phase 0
if isinstance(sch, schedule.Schedule): if isinstance(sch, schedule.Schedule):
# normalize schedule first stmt = form_body(sch)
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -4,8 +4,77 @@ This package maps a subset of python to HalideIR so that: ...@@ -4,8 +4,77 @@ This package maps a subset of python to HalideIR so that:
1. Users can write some preliminary versions of the computation patterns 1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and have not been supported yet and verify it across the real execution and
python semantic emulation. python semantic emulation.
2. Developers can build HalideIR by writing Python code. 2. So far, it is a text format dedicated to HalideIR Phase 0. Refer tvm.lower
for more details. A larger ambition of this module is to support all levels of
HalideIR.
""" """
from .api import script # TODO(@were): Make this module more complete.
from .parser import parse_python # 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR
from __future__ import absolute_import as _abs
from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body
from .module import HybridModule
from .parser import source_to_op
from .util import _pruned_source
def script(pyfunc):
"""Decorate a python function function as hybrid script.
The hybrid function support emulation mode and parsing to
the internal language IR.
Returns
-------
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)
from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
def build(sch, inputs, outputs, name="hybrid_func"):
"""Dump the corrent schedule to hybrid module
Parameters
----------
sch: Schedule
The schedule to be dumped
inputs: An array of Tensors or Vars
The inputs of the function body
outputs: An array of Tensors
The outputs of the function body
Returns
-------
module: HybridModule
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""
stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)
return HybridModule(src, name)
_init_api("tvm.hybrid")
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python
from .util import _pruned_source
def script(pyfunc):
"""Decorate a python function function as hybrid script.
The hybrid function support emulation mode and parsing to
the internal language IR.
Returns
-------
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
parser = parse_python(src, func.__globals__, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
...@@ -8,6 +8,7 @@ from ..container import Array ...@@ -8,6 +8,7 @@ from ..container import Array
from .. import ir_pass from .. import ir_pass
from ..stmt import For from ..stmt import For
from .util import _internal_assert from .util import _internal_assert
from ..intrin import call_pure_intrin
#pylint: disable=redefined-builtin #pylint: disable=redefined-builtin
...@@ -104,3 +105,29 @@ def len(func_id, args): ...@@ -104,3 +105,29 @@ def len(func_id, args):
except: #pylint: disable=bare-except except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len") _internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0]) return _api.convert(args[0].shape[0])
def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])
float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name
def ceil_div(func_id, args):
_internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 2, "2 arguments expected for division!")
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b
def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)
"""Methods and data structures to support dumping HalideIR to Hybrid Script.
This allows users to do quick hack to generated HalideIR and cast it back to
TVM modules.
To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON.
"""
import ast
import imp
from ..contrib import util
from .util import _internal_assert
from .util import _is_tvm_arg_types
from .parser import source_to_op
class HybridModule(object):
"""The usage of Hybrid Module is very similar to conventional TVM module,
but conventional TVM module requires a function body which is already fully
lowered. This contradicts to the fact that Hybrid Module is originally a text
format for Phase 0 HalideIR. Thus, a totally separated module is defined."""
def __init__(self, src=None, name=None):
"""The constructor of this a hybrid module
Parameters
----------
src : str
The source code of this module
name : str
The name of this module
"""
self.src_ = self.name = self.func_ = self.root_ = None
if src is not None:
temp = util.tempdir()
dst = temp.relpath("script.py")
with open(dst, 'w') as f:
f.write("import tvm\n@tvm.hybrid.script\n%s" % src)
if name is not None:
self.name = name
self.load(dst)
def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return self.func_(*args)
def get_source(self):
return self.src_
def save(self, path):
if not path.endswith('.py'):
path = path + '.py'
with open(path, 'w') as f:
f.write(self.src_)
def load(self, path):
"""Load the module from a python file
Parameters
----------
path : str
Path to the given python file
"""
with open(path, 'r') as f:
self.src_ = f.read()
src = self.src_
class FindFunc(ast.NodeVisitor):
""" Find the function in module to be loaded module. """
#pylint: disable=invalid-name
def __init__(self):
self.name = None
self.root = None
def visit_FunctionDef(self, node):
_internal_assert(self.name is None, "For now, only one function supported!")
self.name = node.name
_internal_assert(self.root is None, "For now, only one function supported!")
self.root = node
root = ast.parse(src)
finder = FindFunc()
finder.visit(root)
_internal_assert(finder.name is not None and finder.root is not None, \
"No function found!")
if self.name is None:
self.name = finder.name
self.root_ = finder.root
py_module = imp.load_source(self.name, path)
self.func_ = getattr(py_module, self.name)
...@@ -17,28 +17,36 @@ from ..api import all as _all ...@@ -17,28 +17,36 @@ from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..container import Array from ..container import Array
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt
from .. import make as _make from .. import make as _make
from .. import api as _api from .. import api as _api
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
def pack_list_to_block(lst): def concat_list_to_block(lst):
if len(lst) == 1: """Concatenate a list of Python IR nodes to HalideIR Block"""
n = len(lst)
if n == 1:
return lst[0] return lst[0]
body = lst[0] body = lst[n - 1]
for i in lst[1:]: for i in range(1, n):
body = _make.Block(body, i) stmt = lst[n - 1 - i]
if isinstance(stmt, _stmt.AssertStmt):
body = _make.AssertStmt(stmt.condition, stmt.message, body)
else:
body = _make.Block(stmt, body)
return body return body
def visit_list_to_block(visit, lst): def visit_list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block""" """Visit and concatenate a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
if not lst: if not lst:
return util.make_nop() return util.make_nop()
return pack_list_to_block(lst) return concat_list_to_block(lst)
class Symbol(Enum): class Symbol(Enum):
...@@ -441,7 +449,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -441,7 +449,7 @@ class HybridParser(ast.NodeVisitor):
body = visit_list_to_block(self.visit, node.body) body = visit_list_to_block(self.visit, node.body)
body = self.wrap_up_realize(node, body) body = self.wrap_up_realize(node, body)
bodies.append(body) bodies.append(body)
return pack_list_to_block(bodies) return concat_list_to_block(bodies)
elif iter_var is None: elif iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
...@@ -496,15 +504,22 @@ class HybridParser(ast.NodeVisitor): ...@@ -496,15 +504,22 @@ class HybridParser(ast.NodeVisitor):
return node.s return node.s
def visit_Assert(self, node):
test = self.visit(node.test)
mesg = _api.convert(self.visit(node.msg))
return _make.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, symbols, args): def parse_python(src, symbols, args):
"""The helper function of calling the AST visitor """The helper function of calling the AST visitor
Parameters Parameters
---------- ----------
src : str src : ast.node or str
The source code of the function to be parsed. If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
src : str symbols : str
The symbol list of the global context of the function. The symbol list of the global context of the function.
args : list of Tensors or Vars args : list of Tensors or Vars
...@@ -517,9 +532,44 @@ def parse_python(src, symbols, args): ...@@ -517,9 +532,44 @@ def parse_python(src, symbols, args):
root : Stmt root : Stmt
The result Halide IR and the parser class instance. The result Halide IR and the parser class instance.
""" """
root = ast.parse(src) root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST)
var_usage = determine_variable_usage(root, args, symbols) var_usage = determine_variable_usage(root, args, symbols)
parser = HybridParser(args, var_usage, symbols) parser = HybridParser(args, var_usage, symbols)
parser.parsed_body = parser.visit(root) parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!') _internal_assert(parser.returned, 'No valid return found in the function body!')
return parser return parser
def source_to_op(src, symbols, args):
"""Another level of wrapper
Parameters
----------
src : ast.node or str
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
Returns
-------
res : list of output tensors
The result of output tensors of the formed OpNode.
"""
parser = parse_python(src, symbols, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
...@@ -73,7 +73,6 @@ def sigmoid(x): ...@@ -73,7 +73,6 @@ def sigmoid(x):
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
'len' : len,
'unroll' : range, 'unroll' : range,
'vectorize' : range, 'vectorize' : range,
'parallel' : range, 'parallel' : range,
...@@ -88,4 +87,37 @@ HYBRID_GLOBALS = { ...@@ -88,4 +87,37 @@ HYBRID_GLOBALS = {
'exp' : numpy.exp, 'exp' : numpy.exp,
'sigmoid' : sigmoid, 'sigmoid' : sigmoid,
'popcount' : popcount, 'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) / b
} }
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
...@@ -5,14 +5,13 @@ import inspect ...@@ -5,14 +5,13 @@ import inspect
import logging import logging
import sys import sys
import numpy import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
from .. import api as _api from .. import api as _api
from .. import make as _make from .. import make as _make
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt from .. import stmt as _stmt
from ..container import Array from .._ffi.base import numeric_types
from ..tensor import Tensor from ..tensor import Tensor
from ..container import Array
#pylint: disable=invalid-name #pylint: disable=invalid-name
...@@ -20,6 +19,7 @@ np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) ...@@ -20,6 +19,7 @@ np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err): def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error""" """Simplify the code segment like if not XXX then raise an error"""
if not cond: if not cond:
...@@ -52,6 +52,23 @@ def _pruned_source(func): ...@@ -52,6 +52,23 @@ def _pruned_source(func):
raise err raise err
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
def _is_tvm_arg_types(args): def _is_tvm_arg_types(args):
"""Determine a list of element is either a list of tvm arguments of a list of numpy arguments. """Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
If neither is true, raise a value error.""" If neither is true, raise a value error."""
...@@ -68,40 +85,3 @@ def _is_tvm_arg_types(args): ...@@ -68,40 +85,3 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \ _internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem))) "Expect a numpy type but %s get!" % str(type(elem)))
return False return False
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import ast import ast
import sys import sys
from .intrin import HYBRID_GLOBALS from .runtime import HYBRID_GLOBALS
from .util import _internal_assert from .util import _internal_assert
...@@ -45,7 +45,7 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -45,7 +45,7 @@ class PyVariableUsage(ast.NodeVisitor):
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id") _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id = node.func.id func_id = node.func.id
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min'] + \ ['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \ list(self.symbols.keys()), \
"Function call id not in intrinsics' list") "Function call id not in intrinsics' list")
for elem in node.args: for elem in node.args:
......
...@@ -103,6 +103,8 @@ Target CreateTarget(const std::string& target_name, ...@@ -103,6 +103,8 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLCPU; t->device_type = kDLCPU;
} else if (target_name == "ext_dev") { } else if (target_name == "ext_dev") {
t->device_type = kDLExtDev; t->device_type = kDLExtDev;
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else { } else {
LOG(ERROR) << "Unknown target name " << target_name; LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm(); return target::stackvm();
......
/*! Copyright (c) 2019 by Contributors
* \file codegen_hybrid.cc
*/
#include <iomanip>
#include <cctype>
#include "codegen_hybrid.h"
namespace tvm {
namespace contrib {
using namespace ir;
std::string dot_to_underscore(std::string s) {
for (auto &ch : s)
if (ch == '.') ch = '_';
return s;
}
std::string CodeGenHybrid::GetUniqueName(std::string prefix) {
prefix = dot_to_underscore(prefix);
auto it = ids_allocated_.find(prefix);
if (it != ids_allocated_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (ids_allocated_.count(name) == 0) {
prefix = name;
break;
}
}
}
ids_allocated_[prefix] = 0;
return prefix;
}
std::string CodeGenHybrid::Finish() {
return stream.str();
}
void CodeGenHybrid::PrintType(Type t, std::ostream &os) {
if (t.is_float()) {
os << "float";
CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else if (t.is_int()) {
os << "int";
CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else {
CHECK(t.is_uint()) << "Unsupported type " << t;
os << "uint";
CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
}
os << t.bits();
}
void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
os << op->value;
}
void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
PrintType(op->type, os);
os << "(" << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
PrintType(op->type, os);
os << "(" << std::setprecision(20) << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
os << "'" << op->value << "'";
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->type.lanes() == 1) << "vec bin op not implemented";
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
} else {
os << '(';
p->PrintExpr(op->a, os);
if (!strcmp(opstr, "&&")) opstr = "and";
if (!strcmp(opstr, "||")) opstr = "or";
os << ' ' << opstr << ' ';
p->PrintExpr(op->b, os);
os << ')';
}
}
inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->type.lanes() == 1) << "vec bin intrin not implemented";
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
os << opstr;
p->PrintExpr(op->args[1], os);
os << ')';
}
void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
if (op->type == op->value.type()) {
PrintExpr(op->value, stream);
} else {
PrintType(op->type, os);
os << "(";
PrintExpr(op->value, os);
os << ")";
}
}
void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
os << "not ";
PrintExpr(op->a, os);
}
void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
if (op->call_type == Call::Halide) {
os << GetTensorID(op->func, op->value_index);
os << "[";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i) os << ", ";
std::stringstream idx;
PrintExpr(op->args[i], idx);
os << idx.str();
}
os << "]";
} else if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, "&", os, this);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, "^", os, this);
} else if (op->is_intrinsic(Call::bitwise_or)) {
PrintBinaryIntrinsitc(op, "|", os, this);
} else if (op->is_intrinsic(Call::shift_left)) {
PrintBinaryIntrinsitc(op, "<<", os, this);
} else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, ">>", os, this);
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
PrintExpr(op->args[0], os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
PrintExpr(op->args[1], os);
os << " if ";
PrintExpr(op->args[0], os);
os << " else ";
PrintExpr(op->args[2], os);
} else {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
}
os << ")";
}
}
void CodeGenHybrid::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Load(s)!";
}
void CodeGenHybrid::VisitStmt_(const Store* op) {
LOG(FATAL) << "Phase 0 has no Store(s)!";
}
void CodeGenHybrid::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Let(s)!";
}
void CodeGenHybrid::VisitStmt_(const Allocate* op) {
LOG(FATAL) << "Phase 0 has no Allocate(s)!";
}
void CodeGenHybrid::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp to be supported yet";
}
void CodeGenHybrid::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
PrintExpr(op->true_value, os);
os << " if ";
PrintExpr(op->condition, os);
os << " else ";
PrintExpr(op->false_value, os);
os << "\n";
}
void CodeGenHybrid::VisitStmt_(const LetStmt* op) {
std::string value = PrintExpr(op->value);
stream << GetVarID(op->var.get()) << " = " << value << ";\n";
PrintStmt(op->body);
}
void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::thread_extent) {
auto iter_var = op->node.as<IterVarNode>();
CHECK(iter_var);
binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
PrintIndent();
stream << "for " << binds_[iter_var->var.get()] << " in bind('"
<< iter_var->var->name_hint << "', ";
PrintExpr(op->value, stream);
stream << "):\n";
indent_ += tab_;
PrintStmt(op->body);
indent_ -= tab_;
} else if (op->attr_key == ir::attr::realize_scope) {
auto v = FunctionRef(op->node.node_);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
PrintStmt(op->body);
} else {
// For now we ignore the unsupported AttrStmt
PrintStmt(op->body);
}
}
void CodeGenHybrid::VisitStmt_(const Realize *op) {
CHECK(alloc_storage_scope_.count(op->func));
if (!alloc_storage_scope_[op->func].empty()) {
PrintIndent();
stream << GetTensorID(op->func, op->value_index) << " = allocate((";
for (size_t i = 0; i < op->bounds.size(); ++i) {
if (i) stream << ", ";
stream << PrintExpr(op->bounds[i]->extent);
}
if (op->bounds.size() == 1) stream << ", ";
stream << "), '";
PrintType(op->type, stream);
stream << "', '";
stream << alloc_storage_scope_[op->func] << "')\n";
}
PrintStmt(op->body);
}
void CodeGenHybrid::VisitStmt_(const AssertStmt* op) {
PrintIndent();
stream << "assert ";
PrintExpr(op->condition, stream);
stream << ", ";
PrintExpr(op->message, stream);
stream << "\n";
PrintStmt(op->body);
}
void CodeGenHybrid::VisitStmt_(const Provide* op) {
PrintIndent();
stream << GetTensorID(op->func, op->value_index);
stream << "[";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i) stream << ", ";
PrintExpr(op->args[i], stream);
}
stream << "] = ";
PrintExpr(op->value, stream);
stream << "\n";
}
void CodeGenHybrid::VisitStmt_(const For* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = GetVarID(op->loop_var.get());
stream << "for " << vid << " in " << "range(" << extent << "):\n";
indent_ += tab_;
PrintStmt(op->body);
indent_ -= tab_;
}
bool is_noop(const Stmt &stmt) {
if (!stmt.defined())
return true;
if (auto eval = stmt.as<Evaluate>())
return is_const(eval->value);
return false;
}
void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if " << cond << ":\n";
indent_ += tab_;
PrintStmt(op->then_case);
indent_ -= tab_;
if (!is_noop(op->else_case)) {
PrintIndent();
stream << "else:\n";
indent_ += tab_;
PrintStmt(op->else_case);
indent_ -= tab_;
}
}
void CodeGenHybrid::VisitStmt_(const Block *op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
}
void CodeGenHybrid::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty())
stream << str << "\n";
}
void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) {
PrintStmt(op->body);
}
void CodeGenHybrid::PrintIndent() {
stream << std::string(indent_, ' ');
}
std::string CodeGenHybrid::GetVarID(const Variable *v) {
if (binds_.count(v))
return binds_[v];
auto key = std::make_pair(v->GetNodePtr().get(), 0);
if (id_map_.count(key)) {
return id_map_[key];
}
return id_map_[key] = GetUniqueName(v->name_hint);
}
std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) {
auto key = std::make_pair(func.get(), value_index);
if (id_map_.count(key)) {
return id_map_[key];
}
std::string name_hint = func->func_name();
if (func->num_outputs() > 1) {
name_hint += "_v" + std::to_string(value_index);
}
return id_map_[key] = GetUniqueName(name_hint);
}
void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("def");
GetUniqueName("for");
GetUniqueName("in");
GetUniqueName("range");
GetUniqueName("unroll");
GetUniqueName("const_range");
GetUniqueName("parallel");
GetUniqueName("vectorize");
GetUniqueName("bind");
GetUniqueName("threadIdx.x");
GetUniqueName("threadIdx.y");
GetUniqueName("threadIdx.z");
GetUniqueName("blockIdx.x");
GetUniqueName("blockIdx.y");
GetUniqueName("blockIdx.z");
GetUniqueName("vthread");
GetUniqueName("allocate");
GetUniqueName("output_tensor");
GetUniqueName("sqrt");
GetUniqueName("log");
GetUniqueName("tanh");
GetUniqueName("power");
GetUniqueName("exp");
GetUniqueName("sigmoid");
GetUniqueName("popcount");
GetUniqueName("likely");
GetUniqueName("int8");
GetUniqueName("int16");
GetUniqueName("int32");
GetUniqueName("int64");
GetUniqueName("uint8");
GetUniqueName("uint16");
GetUniqueName("uint32");
GetUniqueName("uint64");
GetUniqueName("float16");
GetUniqueName("float32");
GetUniqueName("float64");
GetUniqueName("ceil_div");
}
void CodeGenHybrid::DumpStmt(const Stmt &stmt,
const Array<NodeRef> &inputs,
const Array<Tensor> &outputs,
const std::string &name) {
ReserveKeywords();
GetUniqueName(name);
stream << "def " << name << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i) stream << ", ";
if (auto tensor = inputs[i].as<TensorNode>()) {
stream << GetTensorID(tensor->op, tensor->value_index);
} else {
auto var = inputs[i].as<Variable>();
CHECK(var) << "Input should either be a tensor or a variable!";
stream << GetVarID(var);
}
}
stream << "):\n";
indent_ += tab_;
for (size_t i = 0; i < outputs.size(); ++i) {
PrintIndent();
stream << GetTensorID(outputs[i]->op, outputs[i]->value_index)
<< " = output_tensor((";
for (size_t j = 0; j < outputs[i]->shape.size(); ++j) {
if (j) stream << ", ";
PrintExpr(outputs[i]->shape[j], stream);
}
if (outputs[i]->shape.size() == 1)
stream << ", ";
stream << "), '" << outputs[i]->dtype << "')\n";
}
PrintStmt(stmt);
PrintIndent();
stream << "return ";
for (size_t i = 0; i < outputs.size(); ++i) {
if (i) stream << ", ";
stream << GetTensorID(outputs[i]->op, outputs[i]->value_index);
}
stream << "\n";
}
TVM_REGISTER_GLOBAL("hybrid._Dump")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CodeGenHybrid codegen;
if (args.size() == 4)
codegen.DumpStmt(args[0], args[1], args[2], args[3]);
else
codegen.DumpStmt(args[0], args[1], args[2]);
*rv = codegen.Finish();
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file codegen_hybrid.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <tvm/schedule.h>
#include <map>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace tvm {
namespace contrib {
using namespace ir;
/*!
* \brief A base class to generate Hybrid Script.
*
* **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3.
* For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
*/
class CodeGenHybrid :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Dump the given function body to hybrid script.
* \param stmt The function body to be dumped to hybrid script.
* \param inputs Input tensors of this schedule.
* \param outputs Output tensors of this schedule.
* \param name The name of the function.
*/
void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs,
const std::string &name = "hybrid_func");
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
std::string Finish();
/*! \brief Reserve keywords in avoid of name conflict. */
void ReserveKeywords();
/*!
* \brief Print the Stmt n to CodeGenHybrid->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt &n) {
this->VisitStmt(n);
}
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr &n, std::ostream &os) {
this->VisitExpr(n, os);
}
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr(const Expr &n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* \brief Print Type represetnation of type t.
* \param t The type representation.
* \param os The stream to print the ctype into
*/
virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
private:
/*! \brief The current indent of the code dump. */
int indent_{0};
/*! \brief The tab size of code indent. */
const int tab_{4};
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief Keys are ids allocated, and values are the suffix to prevent double-name. */
std::map<std::string, int> ids_allocated_;
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
std::map<std::pair<const Node *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_;
/*!
* \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief The output code string builder. */
std::stringstream stream;
/*!
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
std::string GetVarID(const Variable *v);
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
*/
std::string GetTensorID(const FunctionRef &func, int value_index);
/*! \brief the storage scope of allocation */
std::map<FunctionRef, std::string> alloc_storage_scope_;
};
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
...@@ -173,25 +173,28 @@ Stmt HybridOpNode::BuildProvide( ...@@ -173,25 +173,28 @@ Stmt HybridOpNode::BuildProvide(
rmap[outputs[i]] = stage->op.output(i); rmap[outputs[i]] = stage->op.output(i);
} }
auto n = make_node<HybridOpNode>(*this); auto n = make_node<HybridOpNode>(*this);
/* /* This is a story little bit complicated.
* These two lines of codes replace tensors' reads & writes. * The following two lines of codes replace output tensors' usage.
* This is the simplest way I (@were) can come up with to glue * This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op. * hybrid operation node to TVM op system.
* NAMING CONFLICT: In hybrid script all the tensors have their own * In hybrid script all the tensors, especially the output tensors,
* names specified by the users. However, In TVM op, all the output * have their own names defined by the users. However, In TVM
* tensors' names are the same as the op's name. I cannot change the * conventional ops:
* name to the op's name in the function body after the op node is * 1. Output tensors refer the corresponding op node so that the output
* formed, because: * tensors have the same names as the operation produces them.
* 1. Output tensors all point to the corresponding op node. * 2. Once OpNode is wrapped up by an Operation node, it is finalized.
* 2. Once OpNode is wrapped up by an Operation node, it can * Later access will be from a const OpNode*.
* no longer be changed.
* This is a chiken-egg paradox. It is impossible to put the output * This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The * tensors into the function body without forming the op node. The
* function body is immutable after the node is formed. * function body is immutable after the node is formed.
* *
* Finally, I decided to resolve this issue "lazily". During the * Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when * pipeline of compilation, this stage is a very preliminary stage.
* forming the function body and passing to next stage of compilation. * Technically, it is before Phase 0. The actual tensors will be replaced
* here.
* Thus, the operation body is slightly different from the Phase 0 body.
* This is a major difference that HybridOpNode is NOT the same as
* ExternOpNode.
* */ * */
ret = op::ReplaceTensor(ret, rmap); ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap); ret = op::ReplaceProvideTensor(ret, rmap);
......
import tvm, inspect, sys, traceback, numpy, nose, types import tvm, inspect, sys, traceback, numpy, nose, types, os
from tvm.contrib import util
from tvm.hybrid import script from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS from tvm.hybrid.runtime import HYBRID_GLOBALS
@nose.tools.nottest @nose.tools.nottest
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
...@@ -59,6 +60,11 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): ...@@ -59,6 +60,11 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
for nd, np in zip(out_tensors, ref_data): for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))]
module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
h_module = tvm.hybrid.build(sch, module_args, module_outs)
return h_module, module_args, module_outs
@script @script
def outer_product(n, m, a, b): def outer_product(n, m, a, b):
...@@ -69,6 +75,7 @@ def outer_product(n, m, a, b): ...@@ -69,6 +75,7 @@ def outer_product(n, m, a, b):
c = output_tensor((n, m), a.dtype) c = output_tensor((n, m), a.dtype)
for i in range(n): for i in range(n):
for j in range(m): for j in range(m):
assert i < n and j < m, "index out of range!"
c[i, j] = a[i] * b[j] c[i, j] = a[i] * b[j]
return c return c
...@@ -100,6 +107,10 @@ def test_outer_product(): ...@@ -100,6 +107,10 @@ def test_outer_product():
assert ibody.extent.name == 'm' assert ibody.extent.name == 'm'
#Check loop body #Check loop body
jbody = ibody.body jbody = ibody.body
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!"
jbody = jbody.body
assert isinstance(jbody, tvm.stmt.Provide) assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c' assert jbody.func.name == 'c'
assert len(jbody.args) == 2 assert len(jbody.args) == 2
...@@ -111,8 +122,13 @@ def test_outer_product(): ...@@ -111,8 +122,13 @@ def test_outer_product():
assert mul.a.name == 'a' assert mul.a.name == 'a'
assert mul.b.name == 'b' assert mul.b.name == 'b'
func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) temp = util.tempdir()
path = temp.relpath('%s.py' % func.name)
func.save(path)
func_ = tvm.hybrid.HybridModule()
func_.load(path)
run_and_check(func_, ins, {n: 99, m: 101}, outs=outs)
for key, _ in HYBRID_GLOBALS.items(): for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys() assert key not in globals().keys()
...@@ -197,7 +213,8 @@ def test_fanout(): ...@@ -197,7 +213,8 @@ def test_fanout():
assert len(write.value.args) == 1 assert len(write.value.args) == 1
assert write.value.args[0].value == 0 assert write.value.args[0].value == 0
run_and_check(fanout, [n, a], {n: 10}) func, ins, outs = run_and_check(fanout, [n, a], {n: 10})
run_and_check(func, ins, {n: 10}, outs=outs)
def test_looptype(): def test_looptype():
...@@ -229,7 +246,8 @@ def test_looptype(): ...@@ -229,7 +246,8 @@ def test_looptype():
assert jloop.for_type == tvm.stmt.For.Vectorized assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled assert kloop.for_type == tvm.stmt.For.Unrolled
run_and_check(looptype, [a, b, c]) func, ins, outs = run_and_check(looptype, [a, b, c])
run_and_check(func, ins, outs=outs)
def test_if(): def test_if():
...@@ -248,7 +266,8 @@ def test_if(): ...@@ -248,7 +266,8 @@ def test_if():
a = tvm.placeholder((10, ), dtype='int32', name='a') a = tvm.placeholder((10, ), dtype='int32', name='a')
run_and_check(if_then_else, [a]) func, ins, outs = run_and_check(if_then_else, [a])
run_and_check(func, ins, outs=outs)
@script @script
def if_triple_condition(a): def if_triple_condition(a):
...@@ -260,7 +279,8 @@ def test_if(): ...@@ -260,7 +279,8 @@ def test_if():
b[i] = a[i] + 1 b[i] = a[i] + 1
return b return b
run_and_check(if_triple_condition, [a]) func, ins, outs = run_and_check(if_triple_condition, [a])
run_and_check(func, ins, outs=outs)
@script @script
def if_and(a): def if_and(a):
...@@ -272,7 +292,8 @@ def test_if(): ...@@ -272,7 +292,8 @@ def test_if():
b[i] = a[i] + 1 b[i] = a[i] + 1
return b return b
run_and_check(if_and, [a]) func, ins, outs = run_and_check(if_and, [a])
run_and_check(func, ins, outs=outs)
def test_bind(): def test_bind():
...@@ -288,7 +309,8 @@ def test_bind(): ...@@ -288,7 +309,8 @@ def test_bind():
a = tvm.placeholder((1000, ), dtype='float32', name='a') a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b') b = tvm.placeholder((1000, ), dtype='float32', name='b')
run_and_check(vec_add, [a, b], target='cuda') func, ins, outs = run_and_check(vec_add, [a, b], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
@script @script
def raw(a, b): def raw(a, b):
...@@ -301,7 +323,8 @@ def test_bind(): ...@@ -301,7 +323,8 @@ def test_bind():
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
x = tvm.thread_axis('threadIdx.x') x = tvm.thread_axis('threadIdx.x')
sch[c].bind(c.op.axis[0], x) sch[c].bind(c.op.axis[0], x)
run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
# Test loop binds # Test loop binds
@tvm.hybrid.script @tvm.hybrid.script
...@@ -318,7 +341,8 @@ def test_bind(): ...@@ -318,7 +341,8 @@ def test_bind():
b = [1, 2, 3, 4, 5] b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b)) c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(goo, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
def test_math_intrin(): def test_math_intrin():
@script @script
...@@ -379,7 +403,8 @@ def test_non_zero(): ...@@ -379,7 +403,8 @@ def test_non_zero():
return b return b
a = tvm.placeholder((32, 32), 'float32', 'a') a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur, [a]) func, ins, outs = run_and_check(blur, [a])
run_and_check(func, ins, outs=outs)
@tvm.hybrid.script @tvm.hybrid.script
def triangle(a, b): def triangle(a, b):
...@@ -392,7 +417,8 @@ def test_non_zero(): ...@@ -392,7 +417,8 @@ def test_non_zero():
a = tvm.placeholder((10, ), dtype='float32', name='a') a = tvm.placeholder((10, ), dtype='float32', name='a')
b = tvm.placeholder((10, ), dtype='float32', name='b') b = tvm.placeholder((10, ), dtype='float32', name='b')
run_and_check(triangle, [a, b]) func, ins, outs = run_and_check(triangle, [a, b])
run_and_check(func, ins, outs=outs)
def test_allocate(): def test_allocate():
@tvm.hybrid.script @tvm.hybrid.script
...@@ -408,7 +434,10 @@ def test_allocate(): ...@@ -408,7 +434,10 @@ def test_allocate():
return b return b
a = tvm.placeholder((32, 32), 'float32', 'a') a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur2d, [a]) b = blur2d(a)
sch = tvm.create_schedule(b.op)
func, ins, outs = run_and_check(blur2d, [a])
run_and_check(func, ins, outs=outs)
if tvm.gpu().exist: if tvm.gpu().exist:
@tvm.hybrid.script @tvm.hybrid.script
...@@ -426,7 +455,8 @@ def test_allocate(): ...@@ -426,7 +455,8 @@ def test_allocate():
a = tvm.placeholder((256, ), dtype='float32', name='a') a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b') b = tvm.placeholder((256, ), dtype='float32', name='b')
run_and_check(share_vec_add, [a, b], target='cuda') func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
else: else:
print('[Warning] No GPU found! Skip shared mem test!') print('[Warning] No GPU found! Skip shared mem test!')
...@@ -562,7 +592,8 @@ def test_func_call(): ...@@ -562,7 +592,8 @@ def test_func_call():
a = tvm.placeholder((10, ), name='a') a = tvm.placeholder((10, ), name='a')
b = tvm.placeholder((10, ), name='b') b = tvm.placeholder((10, ), name='b')
run_and_check(foo, [a, b]) func, ins, outs = run_and_check(foo, [a, b])
run_and_check(func, ins, outs=outs)
def test_bool(): def test_bool():
@tvm.hybrid.script @tvm.hybrid.script
...@@ -576,27 +607,29 @@ def test_bool(): ...@@ -576,27 +607,29 @@ def test_bool():
b[i] = 0.0 b[i] = 0.0
return b return b
a = tvm.placeholder((10, ), name='a') a = tvm.placeholder((10, ), name='a')
run_and_check(foo, [a]) func, ins, outs = run_and_check(foo, [a])
run_and_check(func, ins, outs=outs)
def test_const_range(): def test_const_range():
@tvm.hybrid.script @tvm.hybrid.script
def foo(a, b): def foo(a, b):
c = output_tensor(a.shape, a.dtype) c = output_tensor(a.shape, a.dtype)
d = output_tensor(a.shape, a.dtype) d = output_tensor(a.shape, 'int32')
for i in const_range(2): for i in const_range(2):
for j in const_range(5): for j in const_range(5):
c[i, j] = a[i, j] + b[i, j] c[i, j] = float32(int32(a[i, j]) + b[i, j])
for i in const_range(len(b)): for i in const_range(len(b)):
for j in const_range(len(b[0])): for j in const_range(len(b[0])):
d[i, j] = a[i, j] + b[i, j] d[i, j] = int32(a[i, j] + b[i, j])
return c, d return c, d
a = tvm.placeholder((2, 5), name='a', dtype='int32') a = tvm.placeholder((2, 5), name='a', dtype='float32')
b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]] b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
run_and_check(foo, [a, b]) func, ins, outs = run_and_check(foo, [a, b])
run_and_check(func, ins, outs=outs)
@tvm.hybrid.script @tvm.hybrid.script
def goo(a, b): def goo(a, b):
...@@ -612,7 +645,8 @@ def test_const_range(): ...@@ -612,7 +645,8 @@ def test_const_range():
b = [1, 2, 3, 4, 5] b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b)) c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b]) func, ins, outs = run_and_check(goo, [a, b])
run_and_check(func, ins, outs=outs)
@tvm.hybrid.script @tvm.hybrid.script
def hoo(a, b): def hoo(a, b):
...@@ -626,7 +660,8 @@ def test_const_range(): ...@@ -626,7 +660,8 @@ def test_const_range():
return c return c
a = tvm.placeholder((5, ), name='a', dtype='int32') a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5] b = [1, 2, 3, 4, 5]
run_and_check(hoo, [a, b]) func, ins, outs = run_and_check(hoo, [a, b])
run_and_check(func, ins, outs=outs)
def test_schedule(): def test_schedule():
@script @script
...@@ -668,7 +703,8 @@ def test_schedule(): ...@@ -668,7 +703,8 @@ def test_schedule():
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'j.outer.inner' assert ir.loop_var.name == 'j.outer.inner'
ir = ir.body ir = ir.body
run_and_check(outer_product, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test fuse # Test fuse
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
...@@ -680,13 +716,15 @@ def test_schedule(): ...@@ -680,13 +716,15 @@ def test_schedule():
ir = ir.body ir = ir.body
assert isinstance(ir, tvm.stmt.For) assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.j.fused' assert ir.loop_var.name == 'i.j.fused'
run_and_check(outer_product, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test imperfect loop split # Test imperfect loop split
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
sch[c].split(c.op.axis[0], 3) sch[c].split(c.op.axis[0], 3)
ir = tvm.lower(sch, [a, b, c], simple_mode=True) ir = tvm.lower(sch, [a, b, c], simple_mode=True)
run_and_check(outer_product, [a, b], sch=sch, outs=[c]) func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test loop binds # Test loop binds
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment