Commit a2c8a29b by Tianqi Chen Committed by GitHub

[SCHEDULE] Improve bound inference, support reduce codegen. (#30)

parent d4af7ad6
......@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) {
......@@ -126,25 +129,25 @@ using Halide::abs;
using Halide::select;
/*!
* \brief sum of of source expression over rdom
* \brief sum of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr sum(Expr source, Array<IterVar> rdom);
Expr sum(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr max(Expr source, Array<IterVar> rdom);
Expr max(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr min(Expr source, Array<IterVar> rdom);
Expr min(Expr source, Array<IterVar> axis);
// print functions for expr
......
......@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domains */
Array<IterVar> rdom;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
......@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("rdom", &rdom);
v->Visit("axis", &axis);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
......
......@@ -3,8 +3,8 @@
* \file ir_pass.h
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
* When the pass functions in this file are for Stmt,
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
......@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
......
......@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
......
......@@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
// declare container type
using ContainerType = StageNode;
};
/*!
......@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op);
}
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
// declare container type
using ContainerType = ScheduleNode;
};
/*!
......@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};
/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class RebaseNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The inner domain */
IterVar rebased;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}
static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
};
// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
......
......@@ -24,6 +24,15 @@ namespace schedule {
*/
Map<IterVar, Range> InferBound(Schedule sch);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
......@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return _api_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom):
"""Create a sum expression over rdom
def sum(expr, axis):
"""Create a sum expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
return x
def min(expr, rdom):
"""Create a min expression over rdom
def min(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x
def max(expr, rdom):
"""Create a min expression over rdom
def max(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x
......
......@@ -62,9 +62,10 @@ def build(sch,
# lowering
bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
......@@ -73,7 +74,8 @@ def build(sch,
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
......
......@@ -33,6 +33,14 @@ class Schedule(NodeBase):
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal._ScheduleNormalize(self)
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
......
......@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
.normalize();
});
} // namespace tvm
......@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);
} // namespace ir
......
......@@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
} // namespace tvm
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include <iomanip>
#include "./codegen_c.h"
namespace tvm {
......@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << op->value;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
......@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case 16: {
os << '(';
p->PrintType(op->type, os);
os << ')' << op->value << 'f';
os << ')' << std::scientific <<op->value << 'f';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
......
......@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op
<< ", ";
p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")";
p->stream << ", axis=" << op->axis << ")";
});
} // namespace Internal
......@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->source = source;
n->op = op;
n->rdom = rdom;
n->axis = axis;
return Expr(n);
}
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <memory>
namespace tvm {
......@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type ComputeOpNode::output_dtype(size_t i) const {
......@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
n->name = name;
n->axis = axis;
n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
}
return Operation(n);
}
......
......@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) {
......@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
Array<IterVar> new_axis = MutateIterVarArr(op->axis, m);
Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source)) {
return e;
} else {
return Reduce::make(op->op, new_source, new_rdom);
return Reduce::make(op->op, new_source, new_axis);
}
});
......
......@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
VisitRDom(op->axis, v);
v->Visit(op->source);
})
.set_dispatch<IntImm>(NoOp)
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
......@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
v.Visit(e);
return v.has_side_effect_;
}
class IRSubstitue : public IRMutator {
public:
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = smap.find(op);
if (it != smap.end()) {
return it->second;
} else {
return e;
}
}
std::unordered_map<const Variable*, Expr> smap;
};
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
}
return m.Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent);
} else if (rel.as<RebaseNode>()) {
const RebaseNode* r = rel.as<RebaseNode>();
CHECK(state.count(r->parent));
state[r->rebased] = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else if (rel.as<RebaseNode>()) {
IntSet parent;
const RebaseNode* r = rel.as<RebaseNode>();
PassUp(r, dom_map,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -109,9 +121,15 @@ void PassToOperation(
// Eventually, we need to change the inference to be a Pull style inference
if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size());
for (size_t i = 0; i < tensor.ndim(); ++i) {
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
for (size_t i = 0; i < op->axis.size(); ++i) {
(*result)[op->axis[i]].push_back(dim_bounds[i]);
}
// reduction.
for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
(*result)[op->reduce_axis[i]].push_back(
IntSet::range(op->reduce_axis[i]->dom));
}
} else {
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
......@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{"local", 2}
};
static std::unordered_map<std::string, int> thread_tag_rank{
{"gridIdx.x", 0},
{"gridIdx.y", 0},
{"gridIdx.z", 0},
{"blockIdx.x", 0},
{"blockIdx.y", 0},
{"blockIdx.z", 0},
{"threadIdx.x", 1},
{"threadIdx.y", 1},
{"threadIdx.z", 1}
......@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
(*rmap)[iv] = iv->dom;
}
}
// get range of all child iter vars.
PassDown(stage, rmap);
if (stage->attach_type == kScope) {
Stage parent = stage->attach_stage;
......@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
}
if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::make_point(iv->var);
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::make_range(rmap->at(iv));
up_state[iv] = IntSet::range(vrange);
}
if (stage->attach_ivar == iv) {
fix_value = false;
......@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
bp_state[iv] = {up_state.at(iv)};
}
auto result = BoundProp(post_order, &bp_state);
// Set relaxation
Map<IterVar, IntSet> relax_set;
Stage s = stage;
while (s->attach_type == kScope) {
s = s->attach_stage;
for (auto iv : s->leaf_iter_vars) {
if (ScopeRelax(iv, stage->scope)) {
relax_set.Set(iv, IntSet::range(rmap->at(iv)));
}
}
}
for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange();
Range r = result.at(iv).cover_range(iv->dom);
if (relax_set.size() != 0) {
r = EvalSet(r, relax_set).cover_range(iv->dom);
}
(*rmap)[iv] = r;
}
}
// get range of all child iter vars.
PassDown(stage, rmap);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace tvm {
namespace schedule {
using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
using Halide::Internal::mul_would_overflow;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \return The result.
*/
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
bool GetConst<uint64_t>(Expr e, uint64_t *out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
} \
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return Halide::Internal::Interval::make_max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
......@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not deontainer.
// constructor from not container.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return whether the set is empty */
inline bool is_empty() const {
return !defined();
}
/*!
* \return a range that covers the IntSet
*/
Range GetCoverRange() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IntSetNode* operator->() const;
/*!
* \param dom The domain to be created.
* \return create integer set from existing domain
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
static IntSet make_range(Range dom);
IntSet cover_interval() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*!
* \param point
* \return create integer set that only contains one point
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet make_point(Expr point);
static IntSet single_point(Expr point);
/*!
* \return create integer set that represents everything
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet make_all_set();
static IntSet range(Range r);
};
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
};
/*!
......@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
*/
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Conditional upward message passing.
*
......@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
......@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
*/
IntSet Union(const Array<IntSet>& sets);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace schedule
} // namespace tvm
......
......@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
<< "Cannot find the specified axis in parent stage's leaf_iter_vars";
return *this;
}
......@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
n->roots = ops;
......@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
return IterVarRelation(n);
}
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
auto n = std::make_shared<RebaseNode>();
n->parent = parent;
n->rebased = rebased;
return IterVarRelation(n);
}
void Schedule::normalize() {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : (*this)->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : (*this)->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindIterVar(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : (*this)->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
......@@ -18,7 +18,8 @@ def test_add():
# one line to build the function.
codes = []
fadd = tvm.build(s, args=[A, B, C],
fadd = tvm.build(s,
args=[A, B, C],
target="cuda", name="myadd",
record_codes=codes)
for c in codes:
......
import tvm
import numpy as np
def test_sum():
# graph
n = tvm.Var('n')
m = tvm.Var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.IterVar((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
# schedule
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
tvm.init_opencl()
codes = []
fsum = tvm.build(s,
args=[A, B],
target="opencl", name="myadd",
record_codes=codes)
for c in codes:
print(c)
num_device = 1
for i in range(num_device):
ctx = tvm.opencl(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1028
m = 129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
if __name__ == "__main__":
test_sum()
......@@ -18,8 +18,7 @@ def test_add_pipeline():
# compile to IR
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
......@@ -10,12 +10,13 @@ def test_makeapi():
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
......
......@@ -26,7 +26,7 @@ def test_tensor_reduce():
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv))
# json load save
C_json = tvm.save_json(C)
C_loaded = tvm.load_json(C_json)
......
......@@ -12,7 +12,7 @@ def test_flatten2():
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
......
......@@ -11,7 +11,7 @@ def test_schedule0():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1():
......@@ -24,7 +24,7 @@ def test_schedule1():
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2():
......@@ -39,7 +39,7 @@ def test_schedule2():
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment