Commit c53dd102 by Jian Weng Committed by Leyuan Wang

[Hybrid script] Backend support (#2477)

* a preliminary version is done?

* we no longer need the redundant hybrid/

* support assert stmt

* cast supported

* intrin -> runtime; util is mainly in charge of compilation time

* assert statement

* fix python lint

* fix cpp lint

* on the way to module

* rollback .cc

* fix typo, no direct expose then

* @vinx13 ceil is added i guess?

* wip...

* temp commit

* fix import

* i preliminary version is done?

* on the way to build hybrid module

* nearly fixed...

* dumped python are equiv as original python

* on the way to bootstrap

* cpu bootstrap done

* bootstrap!

* fix lint

* fix doc

* resolve some review concerns

* support load/save

* fix lint

* thanks to xqdan fixed my typo

* fix build, make dump non-optional

* add vthread

* jesus why i added this
parent 7e2a9fcf
......@@ -190,6 +190,7 @@ include(cmake/modules/contrib/BLAS.cmake)
add_library(tvm_topi SHARED ${TOPI_SRCS})
message(STATUS "Build with contrib.hybriddump")
file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc)
......@@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this:
a[tx] = b[tx]
Assert Statement
Assert statement is supported, you can simply use it as it is in standard Python.
.. code-block:: python
assert cond, mesg
.. note::
``Assert`` is NOT a function call. Users are encouraged to use assert in the way
presented above --- condition followed by message. It fits both Python AST and HalideIR.
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
......@@ -292,6 +292,25 @@ def get_binds(args, binds=None):
return binds, arg_list
def form_body(sch):
"""According to the given schedule, form the raw body
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
The body formed according to the given schedule
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt
def lower(sch,
......@@ -337,11 +356,7 @@ def lower(sch,
# Phase 0
if isinstance(sch, schedule.Schedule):
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
stmt = form_body(sch)
for f in lower_phase0:
stmt = f(stmt)
......@@ -4,8 +4,77 @@ This package maps a subset of python to HalideIR so that:
1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and
python semantic emulation.
2. Developers can build HalideIR by writing Python code.
2. So far, it is a text format dedicated to HalideIR Phase 0. Refer tvm.lower
for more details. A larger ambition of this module is to support all levels of
from .api import script
from .parser import parse_python
# TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR
from __future__ import absolute_import as _abs
from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body
from .module import HybridModule
from .parser import source_to_op
from .util import _pruned_source
def script(pyfunc):
"""Decorate a python function function as hybrid script.
The hybrid function support emulation mode and parsing to
the internal language IR.
hybrid_func : function
A decorated hybrid script function.
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)
from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
def build(sch, inputs, outputs, name="hybrid_func"):
"""Dump the corrent schedule to hybrid module
sch: Schedule
The schedule to be dumped
inputs: An array of Tensors or Vars
The inputs of the function body
outputs: An array of Tensors
The outputs of the function body
module: HybridModule
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)
return HybridModule(src, name)
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python
from .util import _pruned_source
def script(pyfunc):
"""Decorate a python function function as hybrid script.
The hybrid function support emulation mode and parsing to
the internal language IR.
hybrid_func : function
A decorated hybrid script function.
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
parser = parse_python(src, func.__globals__, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
......@@ -8,6 +8,7 @@ from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
from ..intrin import call_pure_intrin
#pylint: disable=redefined-builtin
......@@ -104,3 +105,29 @@ def len(func_id, args):
except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0])
def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])
float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name
def ceil_div(func_id, args):
_internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 2, "2 arguments expected for division!")
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b
def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)
"""Methods and data structures to support dumping HalideIR to Hybrid Script.
This allows users to do quick hack to generated HalideIR and cast it back to
TVM modules.
To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON.
import ast
import imp
from ..contrib import util
from .util import _internal_assert
from .util import _is_tvm_arg_types
from .parser import source_to_op
class HybridModule(object):
"""The usage of Hybrid Module is very similar to conventional TVM module,
but conventional TVM module requires a function body which is already fully
lowered. This contradicts to the fact that Hybrid Module is originally a text
format for Phase 0 HalideIR. Thus, a totally separated module is defined."""
def __init__(self, src=None, name=None):
"""The constructor of this a hybrid module
src : str
The source code of this module
name : str
The name of this module
self.src_ = = self.func_ = self.root_ = None
if src is not None:
temp = util.tempdir()
dst = temp.relpath("")
with open(dst, 'w') as f:
f.write("import tvm\n@tvm.hybrid.script\n%s" % src)
if name is not None: = name
def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return self.func_(*args)
def get_source(self):
return self.src_
def save(self, path):
if not path.endswith('.py'):
path = path + '.py'
with open(path, 'w') as f:
def load(self, path):
"""Load the module from a python file
path : str
Path to the given python file
with open(path, 'r') as f:
self.src_ =
src = self.src_
class FindFunc(ast.NodeVisitor):
""" Find the function in module to be loaded module. """
#pylint: disable=invalid-name
def __init__(self): = None
self.root = None
def visit_FunctionDef(self, node):
_internal_assert( is None, "For now, only one function supported!") =
_internal_assert(self.root is None, "For now, only one function supported!")
self.root = node
root = ast.parse(src)
finder = FindFunc()
_internal_assert( is not None and finder.root is not None, \
"No function found!")
if is None: =
self.root_ = finder.root
py_module = imp.load_source(, path)
self.func_ = getattr(py_module,
......@@ -17,28 +17,36 @@ from ..api import all as _all
from ..api import any as _any
from ..container import Array
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import stmt as _stmt
from .. import make as _make
from .. import api as _api
from .. import ir_pass as _ir_pass
def pack_list_to_block(lst):
if len(lst) == 1:
def concat_list_to_block(lst):
"""Concatenate a list of Python IR nodes to HalideIR Block"""
n = len(lst)
if n == 1:
return lst[0]
body = lst[0]
for i in lst[1:]:
body = _make.Block(body, i)
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
if isinstance(stmt, _stmt.AssertStmt):
body = _make.AssertStmt(stmt.condition, stmt.message, body)
body = _make.Block(stmt, body)
return body
def visit_list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
"""Visit and concatenate a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
return pack_list_to_block(lst)
return concat_list_to_block(lst)
class Symbol(Enum):
......@@ -441,7 +449,7 @@ class HybridParser(ast.NodeVisitor):
body = visit_list_to_block(self.visit, node.body)
body = self.wrap_up_realize(node, body)
return pack_list_to_block(bodies)
return concat_list_to_block(bodies)
elif iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!")
......@@ -496,15 +504,22 @@ class HybridParser(ast.NodeVisitor):
return node.s
def visit_Assert(self, node):
test = self.visit(node.test)
mesg = _api.convert(self.visit(node.msg))
return _make.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, symbols, args):
"""The helper function of calling the AST visitor
src : str
The source code of the function to be parsed.
src : ast.node or str
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
src : str
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
......@@ -517,9 +532,44 @@ def parse_python(src, symbols, args):
root : Stmt
The result Halide IR and the parser class instance.
root = ast.parse(src)
root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST)
var_usage = determine_variable_usage(root, args, symbols)
parser = HybridParser(args, var_usage, symbols)
parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!')
return parser
def source_to_op(src, symbols, args):
"""Another level of wrapper
src : ast.node or str
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
res : list of output tensors
The result of output tensors of the formed OpNode.
parser = parse_python(src, symbols, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
......@@ -73,7 +73,6 @@ def sigmoid(x):
'len' : len,
'unroll' : range,
'vectorize' : range,
'parallel' : range,
......@@ -88,4 +87,37 @@ HYBRID_GLOBALS = {
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) / b
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
for k, v in intersect:
_globals[k] = v
......@@ -5,14 +5,13 @@ import inspect
import logging
import sys
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from ..container import Array
from .._ffi.base import numeric_types
from ..tensor import Tensor
from ..container import Array
#pylint: disable=invalid-name
......@@ -20,6 +19,7 @@ np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
......@@ -52,6 +52,23 @@ def _pruned_source(func):
raise err
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype,, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
def _is_tvm_arg_types(args):
"""Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
If neither is true, raise a value error."""
......@@ -68,40 +85,3 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
for k, v in intersect:
_globals[k] = v
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype,, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
......@@ -2,7 +2,7 @@
import ast
import sys
from .intrin import HYBRID_GLOBALS
from .runtime import HYBRID_GLOBALS
from .util import _internal_assert
......@@ -45,7 +45,7 @@ class PyVariableUsage(ast.NodeVisitor):
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id =
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min'] + \
['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list")
for elem in node.args:
......@@ -103,6 +103,8 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
t->device_type = kDLExtDev;
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm();
/*! Copyright (c) 2019 by Contributors
* \file
#include <iomanip>
#include <cctype>
#include "codegen_hybrid.h"
namespace tvm {
namespace contrib {
using namespace ir;
std::string dot_to_underscore(std::string s) {
for (auto &ch : s)
if (ch == '.') ch = '_';
return s;
std::string CodeGenHybrid::GetUniqueName(std::string prefix) {
prefix = dot_to_underscore(prefix);
auto it = ids_allocated_.find(prefix);
if (it != ids_allocated_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (ids_allocated_.count(name) == 0) {
prefix = name;
ids_allocated_[prefix] = 0;
return prefix;
std::string CodeGenHybrid::Finish() {
return stream.str();
void CodeGenHybrid::PrintType(Type t, std::ostream &os) {
if (t.is_float()) {
os << "float";
CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else if (t.is_int()) {
os << "int";
CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else {
CHECK(t.is_uint()) << "Unsupported type " << t;
os << "uint";
CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
os << t.bits();
void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*)
os << op->value;
void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
PrintType(op->type, os);
os << "(" << op->value << ")";
void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
PrintType(op->type, os);
os << "(" << std::setprecision(20) << op->value << ")";
void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
os << "'" << op->value << "'";
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->type.lanes() == 1) << "vec bin op not implemented";
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
} else {
os << '(';
p->PrintExpr(op->a, os);
if (!strcmp(opstr, "&&")) opstr = "and";
if (!strcmp(opstr, "||")) opstr = "or";
os << ' ' << opstr << ' ';
p->PrintExpr(op->b, os);
os << ')';
inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
CHECK(op->type.lanes() == 1) << "vec bin intrin not implemented";
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
os << opstr;
p->PrintExpr(op->args[1], os);
os << ')';
void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
if (op->type == op->value.type()) {
PrintExpr(op->value, stream);
} else {
PrintType(op->type, os);
os << "(";
PrintExpr(op->value, os);
os << ")";
void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
if (op->type.is_int())
PrintBinaryExpr(op, "//", os, this);
PrintBinaryExpr(op, "/", os, this);
void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
os << "not ";
PrintExpr(op->a, os);
void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
if (op->call_type == Call::Halide) {
os << GetTensorID(op->func, op->value_index);
os << "[";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i) os << ", ";
std::stringstream idx;
PrintExpr(op->args[i], idx);
os << idx.str();
os << "]";
} else if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, "&", os, this);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, "^", os, this);
} else if (op->is_intrinsic(Call::bitwise_or)) {
PrintBinaryIntrinsitc(op, "|", os, this);
} else if (op->is_intrinsic(Call::shift_left)) {
PrintBinaryIntrinsitc(op, "<<", os, this);
} else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, ">>", os, this);
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
PrintExpr(op->args[0], os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
PrintExpr(op->args[1], os);
os << " if ";
PrintExpr(op->args[0], os);
os << " else ";
PrintExpr(op->args[2], os);
} else {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
os << ")";
void CodeGenHybrid::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Load(s)!";
void CodeGenHybrid::VisitStmt_(const Store* op) {
LOG(FATAL) << "Phase 0 has no Store(s)!";
void CodeGenHybrid::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Let(s)!";
void CodeGenHybrid::VisitStmt_(const Allocate* op) {
LOG(FATAL) << "Phase 0 has no Allocate(s)!";
void CodeGenHybrid::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp to be supported yet";
void CodeGenHybrid::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
PrintExpr(op->true_value, os);
os << " if ";
PrintExpr(op->condition, os);
os << " else ";
PrintExpr(op->false_value, os);
os << "\n";
void CodeGenHybrid::VisitStmt_(const LetStmt* op) {
std::string value = PrintExpr(op->value);
stream << GetVarID(op->var.get()) << " = " << value << ";\n";
void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::thread_extent) {
auto iter_var = op-><IterVarNode>();
binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint);
stream << "for " << binds_[iter_var->var.get()] << " in bind('"
<< iter_var->var->name_hint << "', ";
PrintExpr(op->value, stream);
stream << "):\n";
indent_ += tab_;
indent_ -= tab_;
} else if (op->attr_key == ir::attr::realize_scope) {
auto v = FunctionRef(op->node.node_);
alloc_storage_scope_[v] = op-><StringImm>()->value;
} else {
// For now we ignore the unsupported AttrStmt
void CodeGenHybrid::VisitStmt_(const Realize *op) {
if (!alloc_storage_scope_[op->func].empty()) {
stream << GetTensorID(op->func, op->value_index) << " = allocate((";
for (size_t i = 0; i < op->bounds.size(); ++i) {
if (i) stream << ", ";
stream << PrintExpr(op->bounds[i]->extent);
if (op->bounds.size() == 1) stream << ", ";
stream << "), '";
PrintType(op->type, stream);
stream << "', '";
stream << alloc_storage_scope_[op->func] << "')\n";
void CodeGenHybrid::VisitStmt_(const AssertStmt* op) {
stream << "assert ";
PrintExpr(op->condition, stream);
stream << ", ";
PrintExpr(op->message, stream);
stream << "\n";
void CodeGenHybrid::VisitStmt_(const Provide* op) {
stream << GetTensorID(op->func, op->value_index);
stream << "[";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i) stream << ", ";
PrintExpr(op->args[i], stream);
stream << "] = ";
PrintExpr(op->value, stream);
stream << "\n";
void CodeGenHybrid::VisitStmt_(const For* op) {
std::string extent = PrintExpr(op->extent);
std::string vid = GetVarID(op->loop_var.get());
stream << "for " << vid << " in " << "range(" << extent << "):\n";
indent_ += tab_;
indent_ -= tab_;
bool is_noop(const Stmt &stmt) {
if (!stmt.defined())
return true;
if (auto eval =<Evaluate>())
return is_const(eval->value);
return false;
void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition);
stream << "if " << cond << ":\n";
indent_ += tab_;
indent_ -= tab_;
if (!is_noop(op->else_case)) {
stream << "else:\n";
indent_ += tab_;
indent_ -= tab_;
void CodeGenHybrid::VisitStmt_(const Block *op) {
if (op->rest.defined()) PrintStmt(op->rest);
void CodeGenHybrid::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty())
stream << str << "\n";
void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) {
void CodeGenHybrid::PrintIndent() {
stream << std::string(indent_, ' ');
std::string CodeGenHybrid::GetVarID(const Variable *v) {
if (binds_.count(v))
return binds_[v];
auto key = std::make_pair(v->GetNodePtr().get(), 0);
if (id_map_.count(key)) {
return id_map_[key];
return id_map_[key] = GetUniqueName(v->name_hint);
std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) {
auto key = std::make_pair(func.get(), value_index);
if (id_map_.count(key)) {
return id_map_[key];
std::string name_hint = func->func_name();
if (func->num_outputs() > 1) {
name_hint += "_v" + std::to_string(value_index);
return id_map_[key] = GetUniqueName(name_hint);
void CodeGenHybrid::ReserveKeywords() {
void CodeGenHybrid::DumpStmt(const Stmt &stmt,
const Array<NodeRef> &inputs,
const Array<Tensor> &outputs,
const std::string &name) {
stream << "def " << name << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i) stream << ", ";
if (auto tensor = inputs[i].as<TensorNode>()) {
stream << GetTensorID(tensor->op, tensor->value_index);
} else {
auto var = inputs[i].as<Variable>();
CHECK(var) << "Input should either be a tensor or a variable!";
stream << GetVarID(var);
stream << "):\n";
indent_ += tab_;
for (size_t i = 0; i < outputs.size(); ++i) {
stream << GetTensorID(outputs[i]->op, outputs[i]->value_index)
<< " = output_tensor((";
for (size_t j = 0; j < outputs[i]->shape.size(); ++j) {
if (j) stream << ", ";
PrintExpr(outputs[i]->shape[j], stream);
if (outputs[i]->shape.size() == 1)
stream << ", ";
stream << "), '" << outputs[i]->dtype << "')\n";
stream << "return ";
for (size_t i = 0; i < outputs.size(); ++i) {
if (i) stream << ", ";
stream << GetTensorID(outputs[i]->op, outputs[i]->value_index);
stream << "\n";
.set_body([](TVMArgs args, TVMRetValue* rv) {
CodeGenHybrid codegen;
if (args.size() == 4)
codegen.DumpStmt(args[0], args[1], args[2], args[3]);
codegen.DumpStmt(args[0], args[1], args[2]);
*rv = codegen.Finish();
} // namespace contrib
} // namespace tvm
* Copyright (c) 2019 by Contributors
* \file codegen_hybrid.h
* \brief Common utilities to generated C style code.
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <tvm/schedule.h>
#include <map>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace tvm {
namespace contrib {
using namespace ir;
* \brief A base class to generate Hybrid Script.
* **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3.
* For runtime support, please refer the decorator in ``tvm/python/hybrid/``.
class CodeGenHybrid :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
* \brief Dump the given function body to hybrid script.
* \param stmt The function body to be dumped to hybrid script.
* \param inputs Input tensors of this schedule.
* \param outputs Output tensors of this schedule.
* \param name The name of the function.
void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs,
const std::string &name = "hybrid_func");
* \brief Finalize the compilation and return the code.
* \return The code.
std::string Finish();
/*! \brief Reserve keywords in avoid of name conflict. */
void ReserveKeywords();
* \brief Print the Stmt n to CodeGenHybrid->stream
* \param n The statement to be printed.
void PrintStmt(const Stmt &n) {
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
void PrintExpr(const Expr &n, std::ostream &os) {
this->VisitExpr(n, os);
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
std::string PrintExpr(const Expr &n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
* \brief Print Type represetnation of type t.
* \param t The type representation.
* \param os The stream to print the ctype into
virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
/*! \brief The current indent of the code dump. */
int indent_{0};
/*! \brief The tab size of code indent. */
const int tab_{4};
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief Keys are ids allocated, and values are the suffix to prevent double-name. */
std::map<std::string, int> ids_allocated_;
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
std::map<std::pair<const Node *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
std::map<const Variable *, std::string> binds_;
* \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix.
std::string GetUniqueName(std::string prefix);
/*! \brief The output code string builder. */
std::stringstream stream;
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
std::string GetVarID(const Variable *v);
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
* \param value_index The value index of the given tensor.
std::string GetTensorID(const FunctionRef &func, int value_index);
/*! \brief the storage scope of allocation */
std::map<FunctionRef, std::string> alloc_storage_scope_;
} // namespace contrib
} // namespace tvm
......@@ -173,25 +173,28 @@ Stmt HybridOpNode::BuildProvide(
rmap[outputs[i]] = stage->op.output(i);
auto n = make_node<HybridOpNode>(*this);
* These two lines of codes replace tensors' reads & writes.
/* This is a story little bit complicated.
* The following two lines of codes replace output tensors' usage.
* This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op.
* NAMING CONFLICT: In hybrid script all the tensors have their own
* names specified by the users. However, In TVM op, all the output
* tensors' names are the same as the op's name. I cannot change the
* name to the op's name in the function body after the op node is
* formed, because:
* 1. Output tensors all point to the corresponding op node.
* 2. Once OpNode is wrapped up by an Operation node, it can
* no longer be changed.
* hybrid operation node to TVM op system.
* In hybrid script all the tensors, especially the output tensors,
* have their own names defined by the users. However, In TVM
* conventional ops:
* 1. Output tensors refer the corresponding op node so that the output
* tensors have the same names as the operation produces them.
* 2. Once OpNode is wrapped up by an Operation node, it is finalized.
* Later access will be from a const OpNode*.
* This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
* Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when
* forming the function body and passing to next stage of compilation.
* pipeline of compilation, this stage is a very preliminary stage.
* Technically, it is before Phase 0. The actual tensors will be replaced
* here.
* Thus, the operation body is slightly different from the Phase 0 body.
* This is a major difference that HybridOpNode is NOT the same as
* ExternOpNode.
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
import tvm, inspect, sys, traceback, numpy, nose, types
import tvm, inspect, sys, traceback, numpy, nose, types, os
from tvm.contrib import util
from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
from tvm.hybrid.runtime import HYBRID_GLOBALS
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
......@@ -59,6 +60,11 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
for nd, np in zip(out_tensors, ref_data):
tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))]
module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
h_module =, module_args, module_outs)
return h_module, module_args, module_outs
def outer_product(n, m, a, b):
......@@ -69,6 +75,7 @@ def outer_product(n, m, a, b):
c = output_tensor((n, m), a.dtype)
for i in range(n):
for j in range(m):
assert i < n and j < m, "index out of range!"
c[i, j] = a[i] * b[j]
return c
......@@ -100,6 +107,10 @@ def test_outer_product():
assert == 'm'
#Check loop body
jbody = ibody.body
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!"
jbody = jbody.body
assert isinstance(jbody, tvm.stmt.Provide)
assert == 'c'
assert len(jbody.args) == 2
......@@ -111,8 +122,13 @@ def test_outer_product():
assert == 'a'
assert == 'b'
run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101})
temp = util.tempdir()
path = temp.relpath('' %
func_ = tvm.hybrid.HybridModule()
run_and_check(func_, ins, {n: 99, m: 101}, outs=outs)
for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
......@@ -197,7 +213,8 @@ def test_fanout():
assert len(write.value.args) == 1
assert write.value.args[0].value == 0
run_and_check(fanout, [n, a], {n: 10})
func, ins, outs = run_and_check(fanout, [n, a], {n: 10})
run_and_check(func, ins, {n: 10}, outs=outs)
def test_looptype():
......@@ -229,7 +246,8 @@ def test_looptype():
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
run_and_check(looptype, [a, b, c])
func, ins, outs = run_and_check(looptype, [a, b, c])
run_and_check(func, ins, outs=outs)
def test_if():
......@@ -248,7 +266,8 @@ def test_if():
a = tvm.placeholder((10, ), dtype='int32', name='a')
run_and_check(if_then_else, [a])
func, ins, outs = run_and_check(if_then_else, [a])
run_and_check(func, ins, outs=outs)
def if_triple_condition(a):
......@@ -260,7 +279,8 @@ def test_if():
b[i] = a[i] + 1
return b
run_and_check(if_triple_condition, [a])
func, ins, outs = run_and_check(if_triple_condition, [a])
run_and_check(func, ins, outs=outs)
def if_and(a):
......@@ -272,7 +292,8 @@ def test_if():
b[i] = a[i] + 1
return b
run_and_check(if_and, [a])
func, ins, outs = run_and_check(if_and, [a])
run_and_check(func, ins, outs=outs)
def test_bind():
......@@ -288,7 +309,8 @@ def test_bind():
a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b')
run_and_check(vec_add, [a, b], target='cuda')
func, ins, outs = run_and_check(vec_add, [a, b], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
def raw(a, b):
......@@ -301,7 +323,8 @@ def test_bind():
sch = tvm.create_schedule(c.op)
x = tvm.thread_axis('threadIdx.x')
sch[c].bind(c.op.axis[0], x)
run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
# Test loop binds
......@@ -318,7 +341,8 @@ def test_bind():
b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b], sch=sch, outs=[c])
func, ins, outs = run_and_check(goo, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
def test_math_intrin():
......@@ -379,7 +403,8 @@ def test_non_zero():
return b
a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur, [a])
func, ins, outs = run_and_check(blur, [a])
run_and_check(func, ins, outs=outs)
def triangle(a, b):
......@@ -392,7 +417,8 @@ def test_non_zero():
a = tvm.placeholder((10, ), dtype='float32', name='a')
b = tvm.placeholder((10, ), dtype='float32', name='b')
run_and_check(triangle, [a, b])
func, ins, outs = run_and_check(triangle, [a, b])
run_and_check(func, ins, outs=outs)
def test_allocate():
......@@ -408,7 +434,10 @@ def test_allocate():
return b
a = tvm.placeholder((32, 32), 'float32', 'a')
run_and_check(blur2d, [a])
b = blur2d(a)
sch = tvm.create_schedule(b.op)
func, ins, outs = run_and_check(blur2d, [a])
run_and_check(func, ins, outs=outs)
if tvm.gpu().exist:
......@@ -426,7 +455,8 @@ def test_allocate():
a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b')
run_and_check(share_vec_add, [a, b], target='cuda')
func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
print('[Warning] No GPU found! Skip shared mem test!')
......@@ -562,7 +592,8 @@ def test_func_call():
a = tvm.placeholder((10, ), name='a')
b = tvm.placeholder((10, ), name='b')
run_and_check(foo, [a, b])
func, ins, outs = run_and_check(foo, [a, b])
run_and_check(func, ins, outs=outs)
def test_bool():
......@@ -576,27 +607,29 @@ def test_bool():
b[i] = 0.0
return b
a = tvm.placeholder((10, ), name='a')
run_and_check(foo, [a])
func, ins, outs = run_and_check(foo, [a])
run_and_check(func, ins, outs=outs)
def test_const_range():
def foo(a, b):
c = output_tensor(a.shape, a.dtype)
d = output_tensor(a.shape, a.dtype)
d = output_tensor(a.shape, 'int32')
for i in const_range(2):
for j in const_range(5):
c[i, j] = a[i, j] + b[i, j]
c[i, j] = float32(int32(a[i, j]) + b[i, j])
for i in const_range(len(b)):
for j in const_range(len(b[0])):
d[i, j] = a[i, j] + b[i, j]
d[i, j] = int32(a[i, j] + b[i, j])
return c, d
a = tvm.placeholder((2, 5), name='a', dtype='int32')
a = tvm.placeholder((2, 5), name='a', dtype='float32')
b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
run_and_check(foo, [a, b])
func, ins, outs = run_and_check(foo, [a, b])
run_and_check(func, ins, outs=outs)
def goo(a, b):
......@@ -612,7 +645,8 @@ def test_const_range():
b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b])
func, ins, outs = run_and_check(goo, [a, b])
run_and_check(func, ins, outs=outs)
def hoo(a, b):
......@@ -626,7 +660,8 @@ def test_const_range():
return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5]
run_and_check(hoo, [a, b])
func, ins, outs = run_and_check(hoo, [a, b])
run_and_check(func, ins, outs=outs)
def test_schedule():
......@@ -668,7 +703,8 @@ def test_schedule():
assert isinstance(ir, tvm.stmt.For)
assert == 'j.outer.inner'
ir = ir.body
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test fuse
sch = tvm.create_schedule(c.op)
......@@ -680,13 +716,15 @@ def test_schedule():
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert == 'i.j.fused'
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test imperfect loop split
sch = tvm.create_schedule(c.op)
sch[c].split(c.op.axis[0], 3)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c])
run_and_check(func, ins, outs=outs)
# Test loop binds
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment