Commit ac54577f by Jian Weng Committed by Lianmin Zheng

[Hybrid Script] Supporting scheduling hybrid script (#2416)

* on the way to enable hybrid schedule

* I think I am done with imperfect loop split?

* copyright watermark

* loop annotation

* fix lint

* fix lint 1

* shit!

* loop reorder supported

* support bind to add some tests

* fused tested

* imperfect loop testcase

* fix lint

* add bind testcase

* fix doc

* fix online edit typo

* resolve @mercymercy review

* fix indent

* i should convince myself it is not flaky test first

* fix test hybrid

* how many flaky test are you expecting; i ball ball u to let me pass

* rebase halide...
parent b9604671
......@@ -68,17 +68,23 @@ to LLVM module.
Tuning
~~~~~~
**Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
.. code-block:: python
i, j = c.op.axis
sch = tvm.create_schedule(op)
jo, ji = sch.split(j, 4)
sch.vectorize(ji)
``split``, ``reorder``, and loop_annotation will be supported!
For now, you can use loop annotations (``unroll``, ``parallel``, ``vectorize``, and ``bind``),
loop manipulation (``split`` and ``fuse``), and ``reorder``.
.. note::
This is a preliminary function, so users should be in charge of the correctness
of the functionality after tuning. Specifically, users should be careful when
fusing and reorderding imperfect loops.
Loops
~~~~~
......
......@@ -459,6 +459,8 @@ class HybridOpNode : public OperationNode {
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
......@@ -500,6 +502,7 @@ class HybridOpNode : public OperationNode {
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
EXPORT static Operation make(std::string name,
......
......@@ -152,7 +152,7 @@ class ComputeOp(Operation):
"""Compute operation."""
@property
def axis(self):
"""Represent axis of IterVar, only defined when it is a ComputeOp"""
"""Represent axis of IterVar, defined when it is a ComputeOp"""
return self.__getattr__("axis")
@property
......@@ -184,4 +184,7 @@ class ExternOp(Operation):
@register_node
class HybridOp(Operation):
"""Hybrid operation."""
pass
@property
def axis(self):
"""Represent axis of IterVar, also defined when it is a HybridOp"""
return self.__getattr__("axis")
......@@ -212,6 +212,7 @@ void ComputeOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
......
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
......@@ -7,8 +7,13 @@
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h>
#include <ir/Expr.h>
#include <unordered_set>
#include <string>
#include "op_util.h"
#include "hybrid_op.h"
namespace tvm {
using namespace ir;
......@@ -25,7 +30,7 @@ int HybridOpNode::num_outputs() const {
}
Array<IterVar> HybridOpNode::root_iter_vars() const {
return {};
return this->axis;
}
Type HybridOpNode::output_dtype(size_t i) const {
......@@ -52,6 +57,7 @@ Operation HybridOpNode::make(std::string name,
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->axis = op::GatherLoopVars(body);
n->body = std::move(body);
Operation res = Operation(n);
return res;
......@@ -62,8 +68,8 @@ Array<Tensor> HybridOpNode::InputTensors() const {
}
Operation HybridOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
const Operation &self,
const std::unordered_map<Tensor, Tensor> &rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = make_node<HybridOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
......@@ -83,13 +89,13 @@ Operation HybridOpNode::ReplaceInputs(
}
void HybridOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const Operation &self,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
TensorDom &dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
......@@ -99,15 +105,20 @@ void HybridOpNode::PropBoundToInputs(
}
void HybridOpNode::GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
const Operation &self,
const std::unordered_map<Tensor, TensorDom> &tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
for (auto iter_var : axis) {
CHECK(!out_dom_map->count(iter_var));
out_dom_map->operator[](iter_var) = iter_var->dom;
}
}
Stmt HybridOpNode::BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
const Stage &stage,
const std::unordered_map<IterVar, Range> &realize_map,
const Stmt &body) const {
// TODO(@were): Add attribute inject here and remove it from hybrid parser.
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
......@@ -126,8 +137,8 @@ Stmt HybridOpNode::BuildRealize(
}
Stmt HybridOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
......@@ -184,6 +195,302 @@ Stmt HybridOpNode::BuildProvide(
* */
ret = op::ReplaceTensor(ret, rmap);
ret = op::ReplaceProvideTensor(ret, rmap);
ret = op::ApplySchedule(stage, dom_map, ret);
return ret;
}
namespace op {
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public IRMutator {
Expr factor;
const Variable *parent;
IterVar inner, outer;
public:
bool splitted;
LoopSpliter(const SplitNode *split,
const std::unordered_map<IterVar, Range> &dom_map) :
factor(split->factor), splitted(false) {
parent = split->parent->var.get();
auto &inner_ = split->inner;
CHECK(dom_map.count(inner_));
auto &inner_dom = dom_map.find(inner_)->second;
CHECK(is_const_int(inner_dom->min, 0));
auto &outer_ = split->outer;
CHECK(dom_map.count(outer_));
auto &outer_dom = dom_map.find(outer_)->second;
CHECK(is_const_int(outer_dom->min, 0));
inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type);
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
}
Stmt Mutate_(const For *op, const Stmt &stmt) {
if (op->loop_var.get() == parent) {
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
Stmt ret = ir::Substitute(op->body, rmap);
Expr cond = likely(outer * factor < (op->extent - inner));
ret = IfThenElse::make(cond, ret);
ret = For::make(inner->var, Expr(0), inner->dom->extent,
IterVarTypeToForType(inner->iter_type), op->device_api, ret);
ret = For::make(outer->var, Expr(0), outer->dom->extent,
IterVarTypeToForType(outer->iter_type), op->device_api, ret);
splitted = true;
return ret;
}
return IRMutator::Mutate_(op, stmt);
}
};
class LoopFuser : public IRMutator {
const IterVar &parent;
const Variable *inner;
const Variable *outer;
bool under_outer;
Expr extent;
public:
bool fused;
explicit LoopFuser(const FuseNode *fuse_)
: parent(fuse_->fused), inner(fuse_->inner->var.get()),
outer(fuse_->outer->var.get()), under_outer(false),
extent(0), fused(false) {}
// TODO(@were): Handle imperfect loops
Stmt Mutate_(const For *op, const Stmt &stmt) {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent % op->extent;
extent = op->extent;
fused = true;
return ir::Substitute(op->body, rmap);
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent;
body = ir::Substitute(body, rmap);
under_outer = false;
return For::make(parent->var, Expr(0), extent * op->extent,
op->for_type, op->device_api, body);
} else if (under_outer) {
Stmt body = IRMutator::Mutate(op->body);
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = parent / extent % op->extent;
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
return body;
}
return IRMutator::Mutate(stmt);
}
};
for (auto &rel : stage->relations) {
if (const SplitNode *split = rel.as<SplitNode>()) {
LoopSpliter Spliter(split, dom_map);
stmt = Spliter.Mutate(stmt);
CHECK(Spliter.splitted);
} else if (const FuseNode *fuse = rel.as<FuseNode>()) {
LoopFuser Fuser(fuse);
stmt = Fuser.Mutate(stmt);
CHECK(Fuser.fused);
}
}
return stmt;
}
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
class LoopAnnotator : public IRMutator {
const Variable *var;
const IterVarAttr &attr;
public:
LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread;
if (iter_var->dom.defined()) {
CHECK(is_const_int(iter_var->dom->min, 0));
CHECK(Equal(iter_var->dom->extent, op->extent))
<< "Thread extent and loop extent mismatch!\n";
}
std::unordered_map<const Variable *, Expr> rmap;
rmap[op->loop_var.get()] = iter_var;
Stmt body = ir::Substitute(op->body, rmap);
return AttrStmt::make(iter_var, "thread_extent", op->extent, body);
} else {
return For::make(op->loop_var, op->min, op->extent,
IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
}
}
return IRMutator::Mutate_(op, stmt);
}
};
for (auto &iter_var : stage->leaf_iter_vars) {
bool need_change = false;
int found = 0;
const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
const Variable *var = actual->var.get();
ForType expected = IterVarTypeToForType(iter_var->iter_type);
IterVarAttr attr;
if (stage->iter_var_attrs.count(iter_var)) {
attr = stage->iter_var_attrs[iter_var];
expected = IterVarTypeToForType(attr->iter_type);
}
PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) {
if (const For *op = node.as<For>()) {
if (op->loop_var.get() == var) {
++found;
need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
}
}
});
CHECK_EQ(found, 1) << " iter var should be found exactly once!";
if (need_change) {
stmt = LoopAnnotator(var, attr).Mutate(stmt);
}
}
return stmt;
}
Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
std::vector<const Variable*> current_order;
PostOrderVisit(stmt, [&current_order](const NodeRef &node) {
if (const For *op = node.as<For>())
current_order.push_back(op->loop_var.get());
});
std::reverse(current_order.begin(), current_order.end());
auto &required_ord = stage->leaf_iter_vars;
CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
std::unordered_map<const Variable *, IterVar> reorder;
bool need_reorder = false;
for (size_t i = 0; i < current_order.size(); ++i) {
auto &current = current_order[i];
const IterVar &iter_var = required_ord[i];
const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n";
reorder[current] = required;
if (current != required->var.get()) {
need_reorder = true;
}
}
class LoopReorder : public IRMutator {
const Stage &stage;
const std::unordered_map<IterVar, Range> &dom_map;
const std::unordered_map<const Variable *, IterVar> &reorder;
public:
LoopReorder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<const Variable*, IterVar> &reorder)
: stage(stage), dom_map(dom_map), reorder(reorder) {}
Stmt Mutate_(const For *op, const Stmt &stmt) {
// Reorder from in to out
Stmt body_ = IRMutator::Mutate(op->body);
CHECK(reorder.count(op->loop_var.get()));
auto target = reorder.find(op->loop_var.get())->second;
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return stmt;
const Stmt &body = op->body.same_as(body_) ? op->body : body_;
ForType for_type = IterVarTypeToForType(target->iter_type);
if (stage->iter_var_attrs.count(target)) {
for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
}
const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For::make(target->var, range->min, range->extent,
for_type, HalideIR::DeviceAPI::None, body);
}
};
if (need_reorder)
return LoopReorder(stage, dom_map, reorder).Mutate(stmt);
return stmt;
}
Stmt ApplySchedule(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
// TODO(@were): Eliminate loop rebase in script parser and move the burden here
// Gather rebased variables
std::unordered_map<IterVar, IterVar> rebased;
for (auto rel : stage->relations) {
if (auto rebase = rel.as<RebaseNode>()) {
rebased[rebase->rebased] = rebase->parent;
CHECK(rebase->parent->dom.defined());
CHECK(dom_map.count(rebase->rebased));
}
}
stmt = ApplyLoopShapes(stage, dom_map, stmt);
stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt);
stmt = ApplyLoopAnnotations(stage, rebased, stmt);
return stmt;
}
std::vector<IterVar> GatherLoopVars(Stmt stmt) {
// TODO(@were): Write a comprehensive pass to analyze iter var types
std::vector<IterVar> res_;
PostOrderVisit(stmt, [&res_](const NodeRef &node) {
if (const For *op = node.as<For>()) {
Var loop_var(op->loop_var);
Range dom = Range::make_by_min_extent(op->min, op->extent);
res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
}
});
std::reverse(res_.begin(), res_.end());
return res_;
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt &s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor> &vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor> &replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
} // namespace op
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \brief Helper utilities to implement hybrid_op.
* \file hybrid_op.h
*/
#ifndef TVM_OP_HYBRID_OP_H_
#define TVM_OP_HYBRID_OP_H_
#include <tvm/expr.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
#include "../schedule/message_passing.h"
namespace tvm {
namespace op {
/*!
* \brief Find all the iteration variables in the given statement body.
* \param stmt The body to be inspected.
*/
std::vector<IterVar> GatherLoopVars(Stmt stmt);
/*!
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Apply the schedule manipulation on the function body.
* \param stmt The statement to be processed.
* \param dom_map The extents of the iterative variables may be used.
* \param stage The schedule information to be applied.
*/
Stmt ApplySchedule(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
/*!
* \brief Apply loop splits and fuses in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param dom_map The extents of the iterative variables may be used.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range>& dom_map, Stmt stmt);
/*!
* \brief Apply loop annotation in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt);
/*!
* \brief Apply loop order in the schedule on the function body.
* \param stage The schedule information to be applied.
* \param dom_map The extents of the iterative variables may be used.
* \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables.
* \param stmt The statement to be processed.
*/
Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt);
} // namespace op
} // namespace tvm
#endif // TVM_OP_HYBRID_OP_H_
......@@ -164,38 +164,6 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
return nest;
}
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
public:
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Stmt ret = ir::Provide::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
}
return IRMutator::Mutate_(op, s);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
ProviderReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
......@@ -247,5 +215,35 @@ Stmt Substitute(Stmt s,
return ir::Substitute(s, init);
}
IterVarType ForTypeToIterVarType(ir::ForType for_type) {
switch (for_type) {
case ForType::Serial:
return kDataPar;
case ForType::Parallel:
return kParallelized;
case ForType::Vectorized:
return kVectorized;
case ForType::Unrolled:
return kUnrolled;
default:
return kDataPar;
}
}
ir::ForType IterVarTypeToForType(IterVarType iter_type) {
switch (iter_type) {
case kDataPar:
return ForType::Serial;
case kParallelized:
return ForType::Parallel;
case kVectorized:
return ForType::Vectorized;
case kUnrolled:
return ForType::Unrolled;
default:
return ForType::Serial;
}
}
} // namespace op
} // namespace tvm
......@@ -49,14 +49,6 @@ MakeLoopNest(const Stage& stage,
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
/*!
* \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceProvideTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
......@@ -80,6 +72,18 @@ Expr ReplaceTensor(Expr expr,
Stmt Substitute(Stmt stmt,
const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Converts Halide ForType to its corresponding IterVarType
* \param for_type The ForType to be converted
*/
IterVarType ForTypeToIterVarType(ir::ForType for_type);
/*!
* \brief Converts IterVarType to its corresponding Halide ForType
* \param iter_type The IterVarType to be converted
*/
ir::ForType IterVarTypeToForType(IterVarType iter_type);
} // namespace op
} // namespace tvm
#endif // TVM_OP_OP_UTIL_H_
......@@ -3,7 +3,7 @@ from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
@nose.tools.nottest
def run_and_check(func, args, var_dict={}, target='llvm'):
def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val)
......@@ -13,8 +13,14 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx = tvm.context(target, 0)
op = None
if sch is None:
outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args))
op = outs[0].op if isinstance(outs, list) else outs.op
sch = tvm.create_schedule(op)
else:
assert outs is not None
assert isinstance(outs, list)
op = outs[0].op
emu_args = []
nd_args = []
......@@ -30,10 +36,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
assert isinstance(i, list)
emu_args.append(numpy.array(i))
sch = tvm.create_schedule(op)
compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
(outs if isinstance(outs, list) else [outs])
module = tvm.build(sch,
[i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
(outs if isinstance(outs, list) else [outs]),
compile_args,
target=target)
assert module
......@@ -282,9 +288,38 @@ def test_bind():
a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b')
run_and_check(vec_add, [a, b], target='cuda')
@script
def raw(a, b):
c = output_tensor((1000, ), 'float32')
for i in range(1000):
c[i] = a[i] + b[i]
return c
c = raw(a, b)
sch = tvm.create_schedule(c.op)
x = tvm.thread_axis('threadIdx.x')
sch[c].bind(c.op.axis[0], x)
run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
# Test loop binds
@tvm.hybrid.script
def goo(a, b):
c = output_tensor(a.shape, a.dtype)
len_b = len(b)
for i in const_range(len_b * 2):
if i < len_b:
c[i] = a[i] + b[i]
else:
c[i - len_b] = a[i - len_b] + b[i - len_b]
return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b], sch=sch, outs=[c])
def test_math_intrin():
@script
def intrin_real(a):
......@@ -593,6 +628,68 @@ def test_const_range():
b = [1, 2, 3, 4, 5]
run_and_check(hoo, [a, b])
def test_schedule():
@script
def outer_product(a, b):
c = output_tensor((64, 64), a.dtype)
for i in range(64):
for j in range(64):
c[i, j] = a[i] * b[j]
return c
a = tvm.placeholder((64,), name='a', dtype='float32')
b = tvm.placeholder((64,), name='b', dtype='float32')
c = outer_product(a, b)
# Test perfect loop split
# Test loop reorder
# Test loop annotation
sch = tvm.create_schedule(c.op)
i, j = c.op.axis
io, ii = sch[c].split(i, 4)
sch[c].parallel(ii)
jo, ji = sch[c].split(j, 4)
joo, joi = sch[c].split(jo, 4)
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.inner'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'j.outer.outer'
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'j.outer.inner'
ir = ir.body
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test fuse
sch = tvm.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
assert isinstance(ir, tvm.stmt.ProducerConsumer)
ir = ir.body
assert isinstance(ir, tvm.stmt.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i.j.fused'
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test imperfect loop split
sch = tvm.create_schedule(c.op)
sch[c].split(c.op.axis[0], 3)
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
run_and_check(outer_product, [a, b], sch=sch, outs=[c])
# Test loop binds
if __name__ == "__main__":
test_outer_product()
......@@ -610,5 +707,6 @@ if __name__ == "__main__":
test_func_call()
test_bool()
test_const_range()
test_schedule()
# TODO:
# test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment