Commit ac54577f by Jian Weng Committed by Lianmin Zheng

[Hybrid Script] Supporting scheduling hybrid script (#2416)

* on the way to enable hybrid schedule

* I think I am done with imperfect loop split?

* copyright watermark

* loop annotation

* fix lint

* fix lint 1

* shit!

* loop reorder supported

* support bind to add some tests

* fused tested

* imperfect loop testcase

* fix lint

* add bind testcase

* fix doc

* fix online edit typo

* resolve @mercymercy review

* fix indent

* i should convince myself it is not flaky test first

* fix test hybrid

* how many flaky test are you expecting; i ball ball u to let me pass

* rebase halide...
parent b9604671
...@@ -68,17 +68,23 @@ to LLVM module. ...@@ -68,17 +68,23 @@ to LLVM module.
Tuning Tuning
~~~~~~ ~~~~~~
**Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code: Follow up the example above, you can use some tvm like interfaces to tune the code:
.. code-block:: python .. code-block:: python
i, j = c.op.axis
sch = tvm.create_schedule(op) sch = tvm.create_schedule(op)
jo, ji = sch.split(j, 4) jo, ji = sch.split(j, 4)
sch.vectorize(ji) sch.vectorize(ji)
``split``, ``reorder``, and loop_annotation will be supported! For now, you can use loop annotations (``unroll``, ``parallel``, ``vectorize``, and ``bind``),
loop manipulation (``split`` and ``fuse``), and ``reorder``.
.. note::
This is a preliminary function, so users should be in charge of the correctness
of the functionality after tuning. Specifically, users should be careful when
fusing and reorderding imperfect loops.
Loops Loops
~~~~~ ~~~~~
......
...@@ -459,6 +459,8 @@ class HybridOpNode : public OperationNode { ...@@ -459,6 +459,8 @@ class HybridOpNode : public OperationNode {
Array<Tensor> inputs; Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */ /*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs; Array<Tensor> outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is /*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output * slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script. * tensors keep its own name specified by users in the script.
...@@ -500,6 +502,7 @@ class HybridOpNode : public OperationNode { ...@@ -500,6 +502,7 @@ class HybridOpNode : public OperationNode {
v->Visit("attrs", &attrs); v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
v->Visit("axis", &axis);
v->Visit("body", &body); v->Visit("body", &body);
} }
EXPORT static Operation make(std::string name, EXPORT static Operation make(std::string name,
......
...@@ -152,7 +152,7 @@ class ComputeOp(Operation): ...@@ -152,7 +152,7 @@ class ComputeOp(Operation):
"""Compute operation.""" """Compute operation."""
@property @property
def axis(self): def axis(self):
"""Represent axis of IterVar, only defined when it is a ComputeOp""" """Represent axis of IterVar, defined when it is a ComputeOp"""
return self.__getattr__("axis") return self.__getattr__("axis")
@property @property
...@@ -184,4 +184,7 @@ class ExternOp(Operation): ...@@ -184,4 +184,7 @@ class ExternOp(Operation):
@register_node @register_node
class HybridOp(Operation): class HybridOp(Operation):
"""Hybrid operation.""" """Hybrid operation."""
pass @property
def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis")
...@@ -212,6 +212,7 @@ void ComputeOpNode::GatherBound( ...@@ -212,6 +212,7 @@ void ComputeOpNode::GatherBound(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const { std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
const TensorDom& tdom = tensor_dom.at(self.output(0)); const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) { for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
......
/*!
* Copyright (c) 2019 by Contributors
* \brief Helper utilities to implement hybrid_op.
* \file hybrid_op.h
*/
#ifndef TVM_OP_HYBRID_OP_H_
#define TVM_OP_HYBRID_OP_H_
#include <tvm/expr.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
#include "../schedule/message_passing.h"
namespace tvm {
namespace op {
/*!
* \brief Find all the iteration variables in the given statement body.
* \param stmt The body to be inspected.
*/
std::vector<IterVar> GatherLoopVars(Stmt stmt);
/*!
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Apply the schedule manipulation on the function body.
* \param stmt The statement to be processed.
* \param dom_map The extents of the iterative variables may be used.
* \param stage The schedule information to be applied.
*/
Stmt ApplySchedule(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
/*!
* \brief Apply loop splits and fuses in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param dom_map The extents of the iterative variables may be used.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
/*!
* \brief Apply loop annotation in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
/*!
* \brief Apply loop order in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param dom_map The extents of the iterative variables may be used.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
} // namespace op
} // namespace tvm
#endif // TVM_OP_HYBRID_OP_H_
...@@ -164,38 +164,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { ...@@ -164,38 +164,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest; return nest;
} }
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors // replacer to replace tensors
class TensorReplacer : public ir::IRMutator { class TensorReplacer : public ir::IRMutator {
public: public:
...@@ -247,5 +215,35 @@ Stmt Substitute(Stmt s, ...@@ -247,5 +215,35 @@ Stmt Substitute(Stmt s,
return ir::Substitute(s, init); return ir::Substitute(s, init);
} }
IterVarType ForTypeToIterVarType(ir::ForType for_type) {
switch (for_type) {
case ForType::Serial:
return kDataPar;
case ForType::Parallel:
return kParallelized;
case ForType::Vectorized:
return kVectorized;
case ForType::Unrolled:
return kUnrolled;
default:
return kDataPar;
}
}
ir::ForType IterVarTypeToForType(IterVarType iter_type) {
switch (iter_type) {
case kDataPar:
return ForType::Serial;
case kParallelized:
return ForType::Parallel;
case kVectorized:
return ForType::Vectorized;
case kUnrolled:
return ForType::Unrolled;
default:
return ForType::Serial;
}
}
} // namespace op } // namespace op
} // namespace tvm } // namespace tvm
...@@ -49,14 +49,6 @@ MakeLoopNest(const Stage& stage, ...@@ -49,14 +49,6 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates); std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*! /*!
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed. * \param stmt The statement to be processed.
* \param replace The replacement rule. * \param replace The replacement rule.
...@@ -80,6 +72,18 @@ Expr ReplaceTensor(Expr expr, ...@@ -80,6 +72,18 @@ Expr ReplaceTensor(Expr expr,
Stmt Substitute(Stmt stmt, Stmt Substitute(Stmt stmt,
const std::unordered_map<IterVar, Expr>& value_map); const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Converts Halide ForType to its corresponding IterVarType
* \param for_type The ForType to be converted
*/
IterVarType ForTypeToIterVarType(ir::ForType for_type);
/*!
* \brief Converts IterVarType to its corresponding Halide ForType
* \param iter_type The IterVarType to be converted
*/
ir::ForType IterVarTypeToForType(IterVarType iter_type);
} // namespace op } // namespace op
} // namespace tvm } // namespace tvm
#endif // TVM_OP_OP_UTIL_H_ #endif // TVM_OP_OP_UTIL_H_
...@@ -3,7 +3,7 @@ from tvm.hybrid import script ...@@ -3,7 +3,7 @@ from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS from tvm.hybrid.intrin import HYBRID_GLOBALS
@nose.tools.nottest @nose.tools.nottest
def run_and_check(func, args, var_dict={}, target='llvm'): def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val): def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val) val = tvm.ir_pass.Simplify(val)
...@@ -13,8 +13,14 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -13,8 +13,14 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
op = None op = None
outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args)) if sch is None:
op = outs[0].op if isinstance(outs, list) else outs.op outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args))
op = outs[0].op if isinstance(outs, list) else outs.op
sch = tvm.create_schedule(op)
else:
assert outs is not None
assert isinstance(outs, list)
op = outs[0].op
emu_args = [] emu_args = []
nd_args = [] nd_args = []
...@@ -30,13 +36,13 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -30,13 +36,13 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
assert isinstance(i, list) assert isinstance(i, list)
emu_args.append(numpy.array(i)) emu_args.append(numpy.array(i))
sch = tvm.create_schedule(op) compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
(outs if isinstance(outs, list) else [outs])
module = tvm.build(sch, module = tvm.build(sch,
[i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ compile_args,
(outs if isinstance(outs, list) else [outs]),
target=target) target=target)
assert module assert module
out_tensors = [] out_tensors = []
for i in range(op.num_outputs): for i in range(op.num_outputs):
output = op.output(i) output = op.output(i)
...@@ -47,7 +53,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -47,7 +53,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ref_data = func(*emu_args) ref_data = func(*emu_args)
if isinstance(ref_data, numpy.ndarray): if isinstance(ref_data, numpy.ndarray):
ref_data = [ref_data] ref_data = [ref_data]
module(*nd_args) module(*nd_args)
for nd, np in zip(out_tensors, ref_data): for nd, np in zip(out_tensors, ref_data):
...@@ -282,9 +288,38 @@ def test_bind(): ...@@ -282,9 +288,38 @@ def test_bind():
a = tvm.placeholder((1000, ), dtype='float32', name='a') a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b') b = tvm.placeholder((1000, ), dtype='float32', name='b')
run_and_check(vec_add, [a, b], target='cuda') run_and_check(vec_add, [a, b], target='cuda')
@script
def raw(a, b):
c = output_tensor((1000, ), 'float32')
for i in range(1000):
c[i] = a[i] + b[i]
return c
c = raw(a, b)
sch = tvm.create_schedule(c.op)
x = tvm.thread_axis('threadIdx.x')
sch[c].bind(c.op.axis[0], x)
run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
# Test loop binds
@tvm.hybrid.script
def goo(a, b):
c = output_tensor(a.shape, a.dtype)
len_b = len(b)
for i in const_range(len_b * 2):
if i < len_b:
c[i] = a[i] + b[i]
else:
c[i - len_b] = a[i - len_b] + b[i - len_b]
return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b], sch=sch, outs=[c])
def test_math_intrin(): def test_math_intrin():
@script @script
def intrin_real(a): def intrin_real(a):
...@@ -593,6 +628,68 @@ def test_const_range(): ...@@ -593,6 +628,68 @@ def test_const_range():
b = [1, 2, 3, 4, 5] b = [1, 2, 3, 4, 5]
run_and_check(hoo, [a, b]) run_and_check(hoo, [a, b])
def test_schedule():
@script
def outer_product(a, b):
c = output_tensor((64, 64), a.dtype)
for i in range(64):
for j in range(64):
c[i, j] = a[i] * b[j]
return c
a = tvm.placeholder((64,), name='a', dtype='float32')
b = tvm.placeholder((64,), name='b', dtype='float32')
c = outer_product(a, b)
# Test perfect loop split
# Test loop reorder
# Test loop annotation
sch = tvm.create_schedule(c.op)
i, j = c.op.axis
io, ii = sch[c].split(i, 4)
sch[c].parallel(ii)
jo, ji = sch[c].split(j, 4)
joo, joi = sch[c].split(jo, 4)
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.inner'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'j.outer.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'j.outer.inner'
ir = ir.body
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test fuse
sch = tvm.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.j.fused'
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test imperfect loop split
sch = tvm.create_schedule(c.op)
sch[c].split(c.op.axis[0], 3)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test loop binds
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -610,5 +707,6 @@ if __name__ == "__main__": ...@@ -610,5 +707,6 @@ if __name__ == "__main__":
test_func_call() test_func_call()
test_bool() test_bool()
test_const_range() test_const_range()
test_schedule()
# TODO: # TODO:
# test_inplace() # test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment