Commit 7a01476a by Jian Weng Committed by Tianqi Chen

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported;…

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported; enable compilation_database in CMakeList.txt (#1757)
parent 79735eb2
......@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt")
# initial variables
set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Generic compilation options
if(MSVC)
......
......@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun
@tvm.hybrid.script
def outer_product(a, b, c):
c = output_tensor((100, 99), 'float32')
for i in range(a.shape[0]):
for j in range(b.shape[0]):
c[i, j] = a[i] * b[j]
a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b, c)
return c
a = numpy.random.randn(100)
b = numpy.random.randn(99)
c = outer_product(a, b)
This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need
......@@ -40,25 +42,25 @@ or ``numpy`` numeric type.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
This function is not encouraged to use, users are encouraged to use the second interface.
The current parse interface looks like:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.**
If we pass these tvm tensors to this function, it returns a op node:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
op = outer_product(a, b, c) # return the corresponding op node
c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.**
Tuning
~~~~~~
......
......@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
};
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
......
......@@ -340,11 +340,6 @@ def lower(sch,
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0:
stmt = f(stmt)
......
......@@ -7,4 +7,5 @@ python semantic emulation.
2. Developers can build HalideIR by writing Python code.
"""
from .api import script, parse
from .api import script
from .parser import parse_python
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
import types
from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python
from .util import _pruned_source
def script(pyfunc):
......@@ -17,40 +20,26 @@ def script(pyfunc):
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs):
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
return parse(func, args)
src = _pruned_source(func)
parser = parse_python(src, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value
return decorate(pyfunc, wrapped_func)
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
return decorate(pyfunc, wrapped_func)
......@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
"""
return numpy.zeros(shape).astype(dtype)
output_tensor = allocate #pylint: disable=invalid-name
def popcount(x):
"""
......@@ -87,18 +88,19 @@ def sigmoid(x):
HYBRID_GLOBALS = {
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'output_tensor': output_tensor,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
}
......
......@@ -2,6 +2,8 @@
import ast
import inspect
import logging
import sys
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
......@@ -30,10 +32,17 @@ def is_docstring(node):
def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
try:
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
except IOError as err:
if sys.version_info[0] == 2 and str(err) == 'could not get source code':
logging.log(logging.CRITICAL, \
'This module is not fully operated under Python2... ' \
'Please move to Python3!')
raise err
def _is_tvm_arg_types(args):
......@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
......@@ -3,6 +3,7 @@
import ast
import sys
from .intrin import HYBRID_GLOBALS
from .util import _internal_assert
class PyVariableUsage(ast.NodeVisitor):
......@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_FunctionDef(self, node):
self.scope_level.append(node)
if len(node.args.args) != len(self.args):
raise ValueError('#arguments passed should be the same as #arguments defined')
_internal_assert(len(node.args.args) == len(self.args), \
'#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
......@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_For(self, node):
if not isinstance(node.target, ast.Name):
raise ValueError("For's iterator should be an id")
_internal_assert(isinstance(node.target, ast.Name), \
"For's iterator should be an id")
self.visit(node.iter)
self.scope_level.append(node)
for i in node.body:
......@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']:
raise ValueError("Function call id not in intrinsics' list")
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
"Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
......@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id in fors:
return
# The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors:
raise ValueError("Iter var cannot be overwritten")
_internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \
"Iter var cannot be overwritten")
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
......
......@@ -180,3 +180,8 @@ class ScanOp(Operation):
class ExternOp(Operation):
"""Extern operation."""
pass
@register_node
class HybridOp(Operation):
"""Hybrid operation."""
pass
......@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp")
args[6]);
});
TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = HybridOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
......
/*!
* Copyright (c) 2018 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "op_util.h"
namespace tvm {
using namespace ir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) {
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(HybridOpNode);
int HybridOpNode::num_outputs() const {
return static_cast<int>(outputs.size());
}
Array<IterVar> HybridOpNode::root_iter_vars() const {
return {};
}
Type HybridOpNode::output_dtype(size_t i) const {
return outputs[i]->dtype;
}
Array<Expr> HybridOpNode::output_shape(size_t i) const {
return outputs[i]->shape;
}
Operation HybridOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<HybridOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->body = std::move(body);
Operation res = Operation(n);
return res;
}
Array<Tensor> HybridOpNode::InputTensors() const {
return inputs;
}
Operation HybridOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (body.same_as(n->body) &&
inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
void HybridOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
}
void HybridOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt HybridOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
return realize_body;
}
Stmt HybridOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
inputs[i]->shape,
inputs[i]->dtype);
f_push_bind(buffer, inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
}
auto n = make_node<HybridOpNode>(*this);
/*
* These two lines of codes replace tensors' reads & writes.
* This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op.
* NAMING CONFLICT: In hybrid script all the tensors have their own
* names specified by the users. However, In TVM op, all the output
* tensors' names are the same as the op's name. I cannot change the
* name to the op's name in the function body after the op node is
* formed, because:
* 1. Output tensors all point to the corresponding op node.
* 2. Once OpNode is wrapped up by an Operation node, it can
* no longer be changed.
* This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
* Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when
* forming the function body and passing to next stage of compilation.
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
return ret;
}
} // namespace tvm
......@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest;
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
......
......@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*!
* \brief Replace the tensor reference in stmt by the replace map.
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference in expr by the replace map.
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment