Commit a2c8a29b by Tianqi Chen Committed by GitHub

[SCHEDULE] Improve bound inference, support reduce codegen. (#30)

parent d4af7ad6
...@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter; ...@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using Halide::Internal::Variable; using Halide::Internal::Variable;
using Halide::Internal::make_const; using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
...@@ -126,25 +129,25 @@ using Halide::abs; ...@@ -126,25 +129,25 @@ using Halide::abs;
using Halide::select; using Halide::select;
/*! /*!
* \brief sum of of source expression over rdom * \brief sum of of source expression over axis
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction. * \param axis List of iteration variables that will be used for reduction.
*/ */
Expr sum(Expr source, Array<IterVar> rdom); Expr sum(Expr source, Array<IterVar> axis);
/*! /*!
* \brief max of of source expression over rdom * \brief max of of source expression over axis
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction. * \param axis List of iteration variables that will be used for reduction.
*/ */
Expr max(Expr source, Array<IterVar> rdom); Expr max(Expr source, Array<IterVar> axis);
/*! /*!
* \brief max of of source expression over rdom * \brief max of of source expression over axis
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction. * \param axis List of iteration variables that will be used for reduction.
*/ */
Expr min(Expr source, Array<IterVar> rdom); Expr min(Expr source, Array<IterVar> axis);
// print functions for expr // print functions for expr
......
...@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std::string op; std::string op;
/*! \brief The source operand */ /*! \brief The source operand */
Expr source; Expr source;
/*! \brief The reduction domains */ /*! \brief The reduction axis */
Array<IterVar> rdom; Array<IterVar> axis;
/*! \brief construct expr from op and rdom */ /*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom); static Expr make(std::string op, Expr src, Array<IterVar> rdom);
...@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("dtype", &type); v->Visit("dtype", &type);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("source", &source); v->Visit("source", &source);
v->Visit("rdom", &rdom); v->Visit("axis", &axis);
} }
static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce"; static constexpr const char* _type_key = "Reduce";
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file ir_pass.h * \file ir_pass.h
* \brief Collection of IR pass functions * \brief Collection of IR pass functions
* *
* All the pass functions in this file are for Stmt, * When the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr * we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/ */
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
...@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) { ...@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
} }
/*! /*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form. * \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For) * That is: each VarExpr is defined and assigned once(in Let/For)
* *
...@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e); ...@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
Stmt ConvertSSA(Stmt stmt); Stmt ConvertSSA(Stmt stmt);
/*! /*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt. * \brief inline all calls of f in stmt.
* *
* \param f The function reference to be inlined * \param f The function reference to be inlined
......
...@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode { ...@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public: public:
/*! \brief IterVar on each axis */ /*! \brief IterVar on each axis */
Array<IterVar> axis; Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */ /*! \brief the compute expression */
Expr body; Expr body;
/*! \brief constructor */ /*! \brief constructor */
...@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode { ...@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("axis", &axis); v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(std::string name, static Operation make(std::string name,
......
...@@ -123,6 +123,8 @@ class Stage : public NodeRef { ...@@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); Expr x_factor, Expr y_factor);
// declare container type
using ContainerType = StageNode;
}; };
/*! /*!
...@@ -153,10 +155,21 @@ class Schedule : public NodeRef { ...@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op); return this->operator[](tensor->op);
} }
/*! /*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
/*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const ScheduleNode* operator->() const; inline const ScheduleNode* operator->() const;
// declare container type
using ContainerType = ScheduleNode;
}; };
/*! /*!
...@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode { ...@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO(FuseNode); TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
}; };
/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class RebaseNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The inner domain */
IterVar rebased;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}
static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
};
// implementations // implementations
inline const StageNode* Stage::operator->() const { inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get()); return static_cast<const StageNode*>(node_.get());
......
...@@ -24,6 +24,15 @@ namespace schedule { ...@@ -24,6 +24,15 @@ namespace schedule {
*/ */
Map<IterVar, Range> InferBound(Schedule sch); Map<IterVar, Range> InferBound(Schedule sch);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_ #endif // TVM_SCHEDULE_PASS_H_
...@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''): ...@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return _api_internal._IterVar(dom, name, thread_tag) return _api_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom): def sum(expr, axis):
"""Create a sum expression over rdom """Create a sum expression over axis
Parameters Parameters
---------- ----------
expr : Expr expr : Expr
The source expression. The source expression.
rdom : RDomain axis : IterVar
The reduction domainx The reduction IterVar axis
""" """
rdom = rdom if isinstance(rdom, list) else [rdom] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, rdom) x = _make.Reduce("Add", expr, axis)
return x return x
def min(expr, rdom): def min(expr, axis):
"""Create a min expression over rdom """Create a min expression over axis
Parameters Parameters
---------- ----------
expr : Expr expr : Expr
The source expression. The source expression.
rdom : RDomain axis : IterVar
The reduction domainx The reduction IterVar axis
""" """
rdom = rdom if isinstance(rdom, list) else [rdom] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, rdom) x = _make.Reduce("Min", expr, axis)
return x return x
def max(expr, rdom): def max(expr, axis):
"""Create a min expression over rdom """Create a min expression over axis
Parameters Parameters
---------- ----------
expr : Expr expr : Expr
The source expression. The source expression.
rdom : RDomain axis : IterVar
The reduction domainx The reduction IterVar axis
""" """
rdom = rdom if isinstance(rdom, list) else [rdom] axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, rdom) x = _make.Reduce("Max", expr, axis)
return x return x
......
...@@ -62,9 +62,10 @@ def build(sch, ...@@ -62,9 +62,10 @@ def build(sch,
# lowering # lowering
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list)) fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi) fsplits = codegen.SplitHostDevice(fapi)
...@@ -73,7 +74,8 @@ def build(sch, ...@@ -73,7 +74,8 @@ def build(sch,
for i, f in enumerate(fsplits): for i, f in enumerate(fsplits):
t = target if i >= 1 else "c" t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t)) record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)
if target == "cuda": if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm") ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl": elif target == "opencl":
......
...@@ -33,6 +33,14 @@ class Schedule(NodeBase): ...@@ -33,6 +33,14 @@ class Schedule(NodeBase):
raise ValueError("Cannot find the operation %s in schedule" % (str(k))) raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k] return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal._ScheduleNormalize(self)
@register_node @register_node
class Stage(NodeBase): class Stage(NodeBase):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
......
...@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile) ...@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
.normalize();
});
} // namespace tvm } // namespace tvm
...@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal) ...@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
} // namespace ir } // namespace ir
......
...@@ -29,6 +29,7 @@ namespace schedule { ...@@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file codegen_c.cc * \file codegen_c.cc
*/ */
#include <iomanip>
#include "./codegen_c.h" #include "./codegen_c.h"
namespace tvm { namespace tvm {
...@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N ...@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch (op->type.bits()) { switch (op->type.bits()) {
case 64: case 32: { case 64: case 32: {
std::ostringstream temp; std::ostringstream temp;
temp << op->value; temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f'; if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str()); p->MarkConst(temp.str());
os << temp.str(); os << temp.str();
...@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N ...@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case 16: { case 16: {
os << '('; os << '(';
p->PrintType(op->type, os); p->PrintType(op->type, os);
os << ')' << op->value << 'f'; os << ')' << std::scientific <<op->value << 'f';
break; break;
} }
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n"; default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
......
...@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op << op->op
<< ", "; << ", ";
p->print(op->source); p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")"; p->stream << ", axis=" << op->axis << ")";
}); });
} // namespace Internal } // namespace Internal
...@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm { namespace tvm {
namespace ir { namespace ir {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) { Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
auto n = std::make_shared<Reduce>(); auto n = std::make_shared<Reduce>();
CHECK(source.defined()); CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
CHECK(rdom[i].defined()); CHECK(axis[i].defined());
} }
n->type = source.type(); n->type = source.type();
n->source = source; n->source = source;
n->op = op; n->op = op;
n->rdom = rdom; n->axis = axis;
return Expr(n); return Expr(n);
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/ir.h>
#include <memory> #include <memory>
namespace tvm { namespace tvm {
...@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) { ...@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
// ComputeOpNode // ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const { Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis; if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
} }
Type ComputeOpNode::output_dtype(size_t i) const { Type ComputeOpNode::output_dtype(size_t i) const {
...@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name, ...@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
n->name = name; n->name = name;
n->axis = axis; n->axis = axis;
n->body = body; n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
}
return Operation(n); return Operation(n);
} }
......
...@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { ...@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
} }
} }
inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) { inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
std::vector<IterVar> new_dom(rdom.size()); std::vector<IterVar> new_dom(rdom.size());
bool changed = false; bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) { for (size_t i = 0; i < rdom.size(); i++) {
...@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) { ...@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) { .set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
Array<IterVar> new_rdom = MutateRDom(op->rdom, m); Array<IterVar> new_axis = MutateIterVarArr(op->axis, m);
Expr new_source = m->Mutate(op->source); Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) && if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source)) { op->source.same_as(new_source)) {
return e; return e;
} else { } else {
return Reduce::make(op->op, new_source, new_rdom); return Reduce::make(op->op, new_source, new_axis);
} }
}); });
......
...@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) { ...@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) { .set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v); VisitRDom(op->axis, v);
v->Visit(op->source); v->Visit(op->source);
}) })
.set_dispatch<IntImm>(NoOp) .set_dispatch<IntImm>(NoOp)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
namespace tvm { namespace tvm {
...@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) { ...@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
v.Visit(e); v.Visit(e);
return v.has_side_effect_; return v.has_side_effect_;
} }
class IRSubstitue : public IRMutator {
public:
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = smap.find(op);
if (it != smap.end()) {
return it->second;
} else {
return e;
}
}
std::unordered_map<const Variable*, Expr> smap;
};
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
}
return m.Mutate(stmt);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -54,6 +54,11 @@ void PassDown(const Stage& s, ...@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
const Range& range_inner = state.at(r->inner); const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent( state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent); 0, range_outer->extent * range_inner->extent);
} else if (rel.as<RebaseNode>()) {
const RebaseNode* r = rel.as<RebaseNode>();
CHECK(state.count(r->parent));
state[r->rebased] = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -85,6 +90,13 @@ void PassUp(const Stage& s, ...@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
&outer, &inner); &outer, &inner);
state[r->outer] = outer; state[r->outer] = outer;
state[r->inner] = inner; state[r->inner] = inner;
} else if (rel.as<RebaseNode>()) {
IntSet parent;
const RebaseNode* r = rel.as<RebaseNode>();
PassUp(r, dom_map,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else { } else {
LOG(FATAL) << "unknown relation type"; LOG(FATAL) << "unknown relation type";
} }
...@@ -109,9 +121,15 @@ void PassToOperation( ...@@ -109,9 +121,15 @@ void PassToOperation(
// Eventually, we need to change the inference to be a Pull style inference // Eventually, we need to change the inference to be a Pull style inference
if (tensor->op.as<ComputeOpNode>()) { if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars(); auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size()); const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
for (size_t i = 0; i < tensor.ndim(); ++i) { CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]); for (size_t i = 0; i < op->axis.size(); ++i) {
(*result)[op->axis[i]].push_back(dim_bounds[i]);
}
// reduction.
for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
(*result)[op->reduce_axis[i]].push_back(
IntSet::range(op->reduce_axis[i]->dom));
} }
} else { } else {
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key(); LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
...@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{"local", 2} {"local", 2}
}; };
static std::unordered_map<std::string, int> thread_tag_rank{ static std::unordered_map<std::string, int> thread_tag_rank{
{"gridIdx.x", 0}, {"blockIdx.x", 0},
{"gridIdx.y", 0}, {"blockIdx.y", 0},
{"gridIdx.z", 0}, {"blockIdx.z", 0},
{"threadIdx.x", 1}, {"threadIdx.x", 1},
{"threadIdx.y", 1}, {"threadIdx.y", 1},
{"threadIdx.z", 1} {"threadIdx.z", 1}
...@@ -194,8 +212,6 @@ void InferBound(const Stage& stage, ...@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
(*rmap)[iv] = iv->dom; (*rmap)[iv] = iv->dom;
} }
} }
// get range of all child iter vars.
PassDown(stage, rmap);
if (stage->attach_type == kScope) { if (stage->attach_type == kScope) {
Stage parent = stage->attach_stage; Stage parent = stage->attach_stage;
...@@ -206,10 +222,18 @@ void InferBound(const Stage& stage, ...@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
bool fix_value = true; bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) { for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
}
if (fix_value && !ScopeRelax(iv, stage->scope)) { if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::make_point(iv->var); up_state[iv] = IntSet::single_point(iv->var);
} else { } else {
up_state[iv] = IntSet::make_range(rmap->at(iv)); up_state[iv] = IntSet::range(vrange);
} }
if (stage->attach_ivar == iv) { if (stage->attach_ivar == iv) {
fix_value = false; fix_value = false;
...@@ -223,12 +247,30 @@ void InferBound(const Stage& stage, ...@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
bp_state[iv] = {up_state.at(iv)}; bp_state[iv] = {up_state.at(iv)};
} }
auto result = BoundProp(post_order, &bp_state); auto result = BoundProp(post_order, &bp_state);
// Set relaxation
Map<IterVar, IntSet> relax_set;
Stage s = stage;
while (s->attach_type == kScope) {
s = s->attach_stage;
for (auto iv : s->leaf_iter_vars) {
if (ScopeRelax(iv, stage->scope)) {
relax_set.Set(iv, IntSet::range(rmap->at(iv)));
}
}
}
for (auto iv : stage->op->root_iter_vars()) { for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv)); CHECK(result.count(iv));
CHECK(!rmap->count(iv)); CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange(); Range r = result.at(iv).cover_range(iv->dom);
if (relax_set.size() != 0) {
r = EvalSet(r, relax_set).cover_range(iv->dom);
}
(*rmap)[iv] = r;
} }
} }
// get range of all child iter vars.
PassDown(stage, rmap);
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace tvm {
namespace schedule {
using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
using Halide::Internal::mul_would_overflow;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \return The result.
*/
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
bool GetConst<uint64_t>(Expr e, uint64_t *out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
} \
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return Halide::Internal::Interval::make_max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
...@@ -22,35 +22,48 @@ class IntSet : public NodeRef { ...@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
public: public:
/*! \brief constructor */ /*! \brief constructor */
IntSet() {} IntSet() {}
// constructor from not deontainer. // constructor from not container.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {} explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return whether the set is empty */
inline bool is_empty() const {
return !defined();
}
/*!
* \return a range that covers the IntSet
*/
Range GetCoverRange() const;
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const IntSetNode* operator->() const; inline const IntSetNode* operator->() const;
/*! /*!
* \param dom The domain to be created. * \brief Find a range that covers the region.
* \return create integer set from existing domain * \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/ */
static IntSet make_range(Range dom); IntSet cover_interval() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*! /*!
* \param point * \brief construct a point set.
* \return create integer set that only contains one point * \param point The point in the set.
* \return construct a single point set
*/ */
static IntSet make_point(Expr point); static IntSet single_point(Expr point);
/*! /*!
* \return create integer set that represents everything * \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/ */
static IntSet make_all_set(); static IntSet range(Range r);
};
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
}; };
/*! /*!
...@@ -63,6 +76,18 @@ class IntSet : public NodeRef { ...@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
*/ */
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
/*! /*!
* \brief Conditional upward message passing. * \brief Conditional upward message passing.
* *
...@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s, ...@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
const IntSet& fused, const IntSet& fused,
IntSet* outer, IntSet* outer,
IntSet* inner); IntSet* inner);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*! /*!
* \brief Create an union set of all sets * \brief Create an union set of all sets
* \param sets The sets to be unioned * \param sets The sets to be unioned
...@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s, ...@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
*/ */
IntSet Union(const Array<IntSet>& sets); IntSet Union(const Array<IntSet>& sets);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
......
...@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) ...@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
} }
} }
CHECK(found) CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars"; << "Cannot find the specified axis in parent stage's leaf_iter_vars";
return *this; return *this;
} }
...@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent, ...@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this; return *this;
} }
Schedule::Schedule(Array<Operation> ops) { Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
n->roots = ops; n->roots = ops;
...@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make( ...@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
return IterVarRelation(n); return IterVarRelation(n);
} }
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
auto n = std::make_shared<RebaseNode>();
n->parent = parent;
n->rebased = rebased;
return IterVarRelation(n);
}
void Schedule::normalize() {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : (*this)->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : (*this)->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindIterVar(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : (*this)->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode); TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm } // namespace tvm
...@@ -18,7 +18,8 @@ def test_add(): ...@@ -18,7 +18,8 @@ def test_add():
# one line to build the function. # one line to build the function.
codes = [] codes = []
fadd = tvm.build(s, args=[A, B, C], fadd = tvm.build(s,
args=[A, B, C],
target="cuda", name="myadd", target="cuda", name="myadd",
record_codes=codes) record_codes=codes)
for c in codes: for c in codes:
......
import tvm
import numpy as np
def test_sum():
# graph
n = tvm.Var('n')
m = tvm.Var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.IterVar((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
# schedule
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
tvm.init_opencl()
codes = []
fsum = tvm.build(s,
args=[A, B],
target="opencl", name="myadd",
record_codes=codes)
for c in codes:
print(c)
num_device = 1
for i in range(num_device):
ctx = tvm.opencl(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1028
m = 129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
if __name__ == "__main__":
test_sum()
...@@ -18,8 +18,7 @@ def test_add_pipeline(): ...@@ -18,8 +18,7 @@ def test_add_pipeline():
# compile to IR # compile to IR
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B') Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C') Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
...@@ -10,12 +10,13 @@ def test_makeapi(): ...@@ -10,12 +10,13 @@ def test_makeapi():
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B') Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C') Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2 num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
......
...@@ -26,7 +26,7 @@ def test_tensor_reduce(): ...@@ -26,7 +26,7 @@ def test_tensor_reduce():
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k") rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv)) C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv))
# json load save # json load save
C_json = tvm.save_json(C) C_json = tvm.save_json(C)
C_loaded = tvm.load_json(C_json) C_loaded = tvm.load_json(C_json)
......
...@@ -12,7 +12,7 @@ def test_flatten2(): ...@@ -12,7 +12,7 @@ def test_flatten2():
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A') Ab = tvm.Buffer(A.shape, A.dtype, name='A')
......
...@@ -11,7 +11,7 @@ def test_schedule0(): ...@@ -11,7 +11,7 @@ def test_schedule0():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_schedule1(): def test_schedule1():
...@@ -24,7 +24,7 @@ def test_schedule1(): ...@@ -24,7 +24,7 @@ def test_schedule1():
xo, xi = s[A1].split(A1.op.axis[0], 8) xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_schedule2(): def test_schedule2():
...@@ -39,7 +39,7 @@ def test_schedule2(): ...@@ -39,7 +39,7 @@ def test_schedule2():
s[A1].compute_at(s[A2], xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment