Commit 7a01476a by Jian Weng Committed by Tianqi Chen

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported;…

[HYBRID FRONTEND] Modify hybrid script to new interface; hybrid op supported; enable compilation_database in CMakeList.txt (#1757)
parent 79735eb2
...@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt") ...@@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt")
# initial variables # initial variables
set(TVM_LINKER_LIBS "") set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "") set(TVM_RUNTIME_LINKER_LIBS "")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Generic compilation options # Generic compilation options
if(MSVC) if(MSVC)
......
...@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun ...@@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun
@tvm.hybrid.script @tvm.hybrid.script
def outer_product(a, b, c): def outer_product(a, b, c):
c = output_tensor((100, 99), 'float32')
for i in range(a.shape[0]): for i in range(a.shape[0]):
for j in range(b.shape[0]): for j in range(b.shape[0]):
c[i, j] = a[i] * b[j] c[i, j] = a[i] * b[j]
a = numpy.random.rand(100) return c
b = numpy.random.rand(99) a = numpy.random.randn(100)
c = numpy.zeros((100, 99)) b = numpy.random.randn(99)
outer_product(a, b, c) c = outer_product(a, b)
This decorator will import `Keywords`_ required spontaneously when software emulation. This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need After software emulation is done, the imported keywords will be cleaned up. Users do not need
...@@ -40,25 +42,25 @@ or ``numpy`` numeric type. ...@@ -40,25 +42,25 @@ or ``numpy`` numeric type.
Backend Compilation Backend Compilation
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
This function is not encouraged to use, users are encouraged to use the second interface.
The current parse interface looks like: The current parse interface looks like:
.. code-block:: python .. code-block:: python
a = tvm.placeholder((100, ), name='a') a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b') b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c') parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.** If we pass these tvm tensors to this function, it returns a op node:
.. code-block:: python .. code-block:: python
a = tvm.placeholder((100, ), name='a') a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b') b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c') c = outer_product(a, b, c) # return the output tensor(s) of the operator
op = outer_product(a, b, c) # return the corresponding op node
**Under construction, we are still deciding what kind of node should be returned.**
Tuning Tuning
~~~~~~ ~~~~~~
......
...@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode { ...@@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
}; };
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
......
...@@ -340,11 +340,6 @@ def lower(sch, ...@@ -340,11 +340,6 @@ def lower(sch,
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt) stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -7,4 +7,5 @@ python semantic emulation. ...@@ -7,4 +7,5 @@ python semantic emulation.
2. Developers can build HalideIR by writing Python code. 2. Developers can build HalideIR by writing Python code.
""" """
from .api import script, parse from .api import script
from .parser import parse_python
"""APIs of lowering the Python subset to HalideIR""" """APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import types
from .._ffi.base import decorate from .._ffi.base import decorate
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor
from .parser import parse_python from .parser import parse_python
from .util import _pruned_source
def script(pyfunc): def script(pyfunc):
...@@ -17,40 +20,26 @@ def script(pyfunc): ...@@ -17,40 +20,26 @@ def script(pyfunc):
hybrid_func : function hybrid_func : function
A decorated hybrid script function. A decorated hybrid script function.
""" """
def wrapped_func(func, *args, **kwargs): def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
return parse(func, args) src = _pruned_source(func)
parser = parse_python(src, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func) intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs) value = func(*args, **kwargs)
_restore_runtime(func, intersect) _restore_runtime(func, intersect)
return value return value
return decorate(pyfunc, wrapped_func)
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters return decorate(pyfunc, wrapped_func)
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
...@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar ...@@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
""" """
return numpy.zeros(shape).astype(dtype) return numpy.zeros(shape).astype(dtype)
output_tensor = allocate #pylint: disable=invalid-name
def popcount(x): def popcount(x):
""" """
...@@ -87,18 +88,19 @@ def sigmoid(x): ...@@ -87,18 +88,19 @@ def sigmoid(x):
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
'unroll' : unroll, 'unroll' : unroll,
'vectorize' : vectorize, 'vectorize' : vectorize,
'parallel' : parallel, 'parallel' : parallel,
'allocate' : allocate, 'allocate' : allocate,
'bind' : bind, 'output_tensor': output_tensor,
'sqrt' : numpy.sqrt, 'bind' : bind,
'log' : numpy.log, 'sqrt' : numpy.sqrt,
'tanh' : numpy.tanh, 'log' : numpy.log,
'power' : numpy.power, 'tanh' : numpy.tanh,
'exp' : numpy.exp, 'power' : numpy.power,
'sigmoid' : sigmoid, 'exp' : numpy.exp,
'popcount' : popcount 'sigmoid' : sigmoid,
'popcount' : popcount
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import ast import ast
import inspect import inspect
import logging
import sys
import numpy import numpy
from .intrin import HYBRID_GLOBALS from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types from .._ffi.base import numeric_types
...@@ -30,10 +32,17 @@ def is_docstring(node): ...@@ -30,10 +32,17 @@ def is_docstring(node):
def _pruned_source(func): def _pruned_source(func):
"""Prune source code's extra leading spaces""" """Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n') try:
leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) lines = inspect.getsource(func).split('\n')
lines = [line[leading_space:] for line in lines] leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
return '\n'.join(lines) lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
except IOError as err:
if sys.version_info[0] == 2 and str(err) == 'could not get source code':
logging.log(logging.CRITICAL, \
'This module is not fully operated under Python2... ' \
'Please move to Python3!')
raise err
def _is_tvm_arg_types(args): def _is_tvm_arg_types(args):
...@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect): ...@@ -70,3 +79,12 @@ def _restore_runtime(func, intersect):
_globals.pop(elem) _globals.pop(elem)
for k, v in intersect: for k, v in intersect:
_globals[k] = v _globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import ast import ast
import sys import sys
from .intrin import HYBRID_GLOBALS from .intrin import HYBRID_GLOBALS
from .util import _internal_assert
class PyVariableUsage(ast.NodeVisitor): class PyVariableUsage(ast.NodeVisitor):
...@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.scope_level.append(node) self.scope_level.append(node)
if len(node.args.args) != len(self.args): _internal_assert(len(node.args.args) == len(self.args), \
raise ValueError('#arguments passed should be the same as #arguments defined') '#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx] self._args[getattr(arg, _attr)] = self.args[idx]
...@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
if not isinstance(node.target, ast.Name): _internal_assert(isinstance(node.target, ast.Name), \
raise ValueError("For's iterator should be an id") "For's iterator should be an id")
self.visit(node.iter) self.visit(node.iter)
self.scope_level.append(node) self.scope_level.append(node)
for i in node.body: for i in node.body:
...@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
#No function pointer supported so far #No function pointer supported so far
if not isinstance(node.func, ast.Name): _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
raise ValueError("Function call should be an id")
func_id = node.func.id func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']: _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
raise ValueError("Function call id not in intrinsics' list") "Function call id not in intrinsics' list")
for elem in node.args: for elem in node.args:
self.visit(elem) self.visit(elem)
...@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id in fors: if node.id in fors:
return return
# The loop variable cannot be overwritten when iteration # The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors: _internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \
raise ValueError("Iter var cannot be overwritten") "Iter var cannot be overwritten")
if node.id not in self.status.keys(): if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store): _internal_assert(isinstance(node.ctx, ast.Store), \
raise ValueError('In Python, "first store" indicates "declaration"') 'Undeclared variable %s' % node.id)
self.status[node.id] = (node, self.scope_level[-1], set()) self.status[node.id] = (node, self.scope_level[-1], set())
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
......
...@@ -180,3 +180,8 @@ class ScanOp(Operation): ...@@ -180,3 +180,8 @@ class ScanOp(Operation):
class ExternOp(Operation): class ExternOp(Operation):
"""Extern operation.""" """Extern operation."""
pass pass
@register_node
class HybridOp(Operation):
"""Hybrid operation."""
pass
...@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp") ...@@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp")
args[6]); args[6]);
}); });
TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = HybridOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput") TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output( *ret = args[0].operator Operation().output(
......
/*!
* Copyright (c) 2018 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include "op_util.h"
namespace tvm {
using namespace ir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) {
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(HybridOpNode);
int HybridOpNode::num_outputs() const {
return static_cast<int>(outputs.size());
}
Array<IterVar> HybridOpNode::root_iter_vars() const {
return {};
}
Type HybridOpNode::output_dtype(size_t i) const {
return outputs[i]->dtype;
}
Array<Expr> HybridOpNode::output_shape(size_t i) const {
return outputs[i]->shape;
}
Operation HybridOpNode::make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
Array<Tensor> outputs,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<HybridOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->body = std::move(body);
Operation res = Operation(n);
return res;
}
Array<Tensor> HybridOpNode::InputTensors() const {
return inputs;
}
Operation HybridOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (body.same_as(n->body) &&
inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
void HybridOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
}
void HybridOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt HybridOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
HalideIR::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
return realize_body;
}
Stmt HybridOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
inputs[i]->shape,
inputs[i]->dtype);
f_push_bind(buffer, inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
}
auto n = make_node<HybridOpNode>(*this);
/*
* These two lines of codes replace tensors' reads & writes.
* This is the simplest way I (@were) can come up with to glue
* hybrid scripts to the structure of TVM op.
* NAMING CONFLICT: In hybrid script all the tensors have their own
* names specified by the users. However, In TVM op, all the output
* tensors' names are the same as the op's name. I cannot change the
* name to the op's name in the function body after the op node is
* formed, because:
* 1. Output tensors all point to the corresponding op node.
* 2. Once OpNode is wrapped up by an Operation node, it can
* no longer be changed.
* This is a chiken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
* Finally, I decided to resolve this issue "lazily". During the
* pipeline of compilation, these tensors will be replaced when
* forming the function body and passing to next stage of compilation.
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
return ret;
}
} // namespace tvm
...@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { ...@@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest; return nest;
} }
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors // replacer to replace tensors
class TensorReplacer : public ir::IRMutator { class TensorReplacer : public ir::IRMutator {
......
...@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage, ...@@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates); std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*! /*!
* \brief Replace the tensor reference in stmt by the replace map. * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed. * \param stmt The statement to be processed.
* \param replace The replacement rule. * \param replace The replacement rule.
*/ */
Stmt ReplaceTensor(Stmt stmt, Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace); const std::unordered_map<Tensor, Tensor>& replace);
/*! /*!
* \brief Replace the tensor reference in expr by the replace map. * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed. * \param expr The expression to be processed.
* \param replace The replacement rule. * \param replace The replacement rule.
*/ */
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment