Commit a2c8a29b by Tianqi Chen Committed by GitHub

[SCHEDULE] Improve bound inference, support reduce codegen. (#30)

parent d4af7ad6
......@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) {
......@@ -126,25 +129,25 @@ using Halide::abs;
using Halide::select;
/*!
* \brief sum of of source expression over rdom
* \brief sum of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr sum(Expr source, Array<IterVar> rdom);
Expr sum(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr max(Expr source, Array<IterVar> rdom);
Expr max(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr min(Expr source, Array<IterVar> rdom);
Expr min(Expr source, Array<IterVar> axis);
// print functions for expr
......
......@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domains */
Array<IterVar> rdom;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
......@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("rdom", &rdom);
v->Visit("axis", &axis);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
......
......@@ -3,8 +3,8 @@
* \file ir_pass.h
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
* When the pass functions in this file are for Stmt,
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
......@@ -38,15 +38,6 @@ inline Stmt Simplify(Stmt a) {
}
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
*
......@@ -70,6 +61,14 @@ bool HasSideEffect(const Expr& e);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param f The function reference to be inlined
......
......@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
......
......@@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
// declare container type
using ContainerType = StageNode;
};
/*!
......@@ -153,10 +155,21 @@ class Schedule : public NodeRef {
return this->operator[](tensor->op);
}
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
// declare container type
using ContainerType = ScheduleNode;
};
/*!
......@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};
/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class RebaseNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The inner domain */
IterVar rebased;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}
static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
};
// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
......
......@@ -24,6 +24,15 @@ namespace schedule {
*/
Map<IterVar, Range> InferBound(Schedule sch);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
......@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return _api_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom):
"""Create a sum expression over rdom
def sum(expr, axis):
"""Create a sum expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
return x
def min(expr, rdom):
"""Create a min expression over rdom
def min(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x
def max(expr, rdom):
"""Create a min expression over rdom
def max(expr, axis):
"""Create a min expression over axis
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x
......
......@@ -62,9 +62,10 @@ def build(sch,
# lowering
bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
......@@ -73,7 +74,8 @@ def build(sch,
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
......
......@@ -33,6 +33,14 @@ class Schedule(NodeBase):
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal._ScheduleNormalize(self)
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
......
......@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
.normalize();
});
} // namespace tvm
......@@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);
} // namespace ir
......
......@@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS2(ScheduleOps);
} // namespace schedule
} // namespace tvm
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include <iomanip>
#include "./codegen_c.h"
namespace tvm {
......@@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << op->value;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
......@@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case 16: {
os << '(';
p->PrintType(op->type, os);
os << ')' << op->value << 'f';
os << ')' << std::scientific <<op->value << 'f';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
......
......@@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op
<< ", ";
p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")";
p->stream << ", axis=" << op->axis << ")";
});
} // namespace Internal
......@@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->source = source;
n->op = op;
n->rdom = rdom;
n->axis = axis;
return Expr(n);
}
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <memory>
namespace tvm {
......@@ -57,7 +58,12 @@ Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
if (reduce_axis.size() == 0) return axis;
Array<IterVar> ret = axis;
for (IterVar iv : reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Type ComputeOpNode::output_dtype(size_t i) const {
......@@ -101,6 +107,9 @@ Operation ComputeOpNode::make(std::string name,
n->name = name;
n->axis = axis;
n->body = body;
if (n->body->is_type<ir::Reduce>()) {
n->reduce_axis = n->body.as<ir::Reduce>()->axis;
}
return Operation(n);
}
......
......@@ -37,7 +37,7 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) {
......@@ -237,13 +237,13 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
Array<IterVar> new_axis = MutateIterVarArr(op->axis, m);
Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source)) {
return e;
} else {
return Reduce::make(op->op, new_source, new_rdom);
return Reduce::make(op->op, new_source, new_axis);
}
});
......
......@@ -120,7 +120,7 @@ void IRVisitor::Visit_(const Call *op) {
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
VisitRDom(op->axis, v);
v->Visit(op->source);
})
.set_dispatch<IntImm>(NoOp)
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
......@@ -32,5 +33,26 @@ bool HasSideEffect(const Expr& e) {
v.Visit(e);
return v.has_side_effect_;
}
class IRSubstitue : public IRMutator {
public:
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = smap.find(op);
if (it != smap.end()) {
return it->second;
} else {
return e;
}
}
std::unordered_map<const Variable*, Expr> smap;
};
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
}
return m.Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -54,6 +54,11 @@ void PassDown(const Stage& s,
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent);
} else if (rel.as<RebaseNode>()) {
const RebaseNode* r = rel.as<RebaseNode>();
CHECK(state.count(r->parent));
state[r->rebased] = Range::make_with_min_extent(
0, state.at(r->parent)->extent);
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -85,6 +90,13 @@ void PassUp(const Stage& s,
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else if (rel.as<RebaseNode>()) {
IntSet parent;
const RebaseNode* r = rel.as<RebaseNode>();
PassUp(r, dom_map,
state.at(r->rebased),
&parent);
state[r->parent] = parent;
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -109,9 +121,15 @@ void PassToOperation(
// Eventually, we need to change the inference to be a Pull style inference
if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size());
for (size_t i = 0; i < tensor.ndim(); ++i) {
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
for (size_t i = 0; i < op->axis.size(); ++i) {
(*result)[op->axis[i]].push_back(dim_bounds[i]);
}
// reduction.
for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
(*result)[op->reduce_axis[i]].push_back(
IntSet::range(op->reduce_axis[i]->dom));
}
} else {
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
......@@ -173,9 +191,9 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
{"local", 2}
};
static std::unordered_map<std::string, int> thread_tag_rank{
{"gridIdx.x", 0},
{"gridIdx.y", 0},
{"gridIdx.z", 0},
{"blockIdx.x", 0},
{"blockIdx.y", 0},
{"blockIdx.z", 0},
{"threadIdx.x", 1},
{"threadIdx.y", 1},
{"threadIdx.z", 1}
......@@ -194,8 +212,6 @@ void InferBound(const Stage& stage,
(*rmap)[iv] = iv->dom;
}
}
// get range of all child iter vars.
PassDown(stage, rmap);
if (stage->attach_type == kScope) {
Stage parent = stage->attach_stage;
......@@ -206,10 +222,18 @@ void InferBound(const Stage& stage,
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
}
if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::make_point(iv->var);
up_state[iv] = IntSet::single_point(iv->var);
} else {
up_state[iv] = IntSet::make_range(rmap->at(iv));
up_state[iv] = IntSet::range(vrange);
}
if (stage->attach_ivar == iv) {
fix_value = false;
......@@ -223,12 +247,30 @@ void InferBound(const Stage& stage,
bp_state[iv] = {up_state.at(iv)};
}
auto result = BoundProp(post_order, &bp_state);
// Set relaxation
Map<IterVar, IntSet> relax_set;
Stage s = stage;
while (s->attach_type == kScope) {
s = s->attach_stage;
for (auto iv : s->leaf_iter_vars) {
if (ScopeRelax(iv, stage->scope)) {
relax_set.Set(iv, IntSet::range(rmap->at(iv)));
}
}
}
for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange();
Range r = result.at(iv).cover_range(iv->dom);
if (relax_set.size() != 0) {
r = EvalSet(r, relax_set).cover_range(iv->dom);
}
(*rmap)[iv] = r;
}
}
// get range of all child iter vars.
PassDown(stage, rmap);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <pass/Interval.h>
namespace tvm {
namespace schedule {
using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
using Halide::Internal::mul_would_overflow;
/*!
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \return The result.
*/
template<typename OP>
inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
template<>
bool GetConst<uint64_t>(Expr e, uint64_t *out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \
LOG(FATAL) << "signed int overflow"; \
} \
return ir::IntImm::make(a.type(), ia OP ib); \
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
} \
template<>
inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) {
if (is_zero(a)) return b;
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(add, +);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Add::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
if (is_one(a)) return b;
if (is_one(b)) return a;
TVM_CONST_PROPAGATION(mul, *);
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Mul::make(a, b);
}
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
return Halide::Internal::Interval::make_max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.cc
* \file int_set_impl.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <pass/Interval.h>
#include "./int_set.h"
#include "./compute_expr.h"
namespace tvm {
namespace schedule {
using Halide::Internal::Interval;
using namespace ir;
/*! \brief Set of continuous interval */
struct IntervalSet : public IntSetNode {
/*! \brief the internal interval*/
Interval i;
static IntSet make(Interval i) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i = i;
return IntSet(n);
}
static IntSet make(Expr min, Expr max) {
std::shared_ptr<IntervalSet> n =
std::make_shared<IntervalSet>();
n->i.min = min;
n->i.max = max;
return IntSet(n);
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet);
};
/*!
* \brief Internal node container of int set.
* \brief set represented by strided integers
* Reserved for cases where strided access is supported.
*/
class IntSetNode : public Node {
public:
/*! \brief The base range scope */
Range base;
/*! \brief additional strided domain */
Array<Range> domain;
/*! \brief The stride of each strided domain */
Array<Expr> stride;
/*!
* \brief The concrete set,
* used when concrete execution is enabled.
*/
std::vector<int32_t> concrete;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &base);
v->Visit("domain", &domain);
v->Visit("stride", &stride);
}
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_NODE_TYPE_INFO(IntSetNode);
struct StrideSet : public IntSetNode {
/*! \brief the base inetrval */
Interval base;
/*! \brief additional extents in positive number */
Array<Expr> extents;
/*! \brief additional strides in positive number */
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet);
};
TVM_REGISTER_NODE_TYPE(IntSetNode);
inline IntSet IntSet::cover_interval() const {
if ((*this).as<IntervalSet>()) return *this;
const StrideSet* s = (*this).as<StrideSet>();
if (s) {
CHECK_NE(s->extents.size(), 0U);
Expr max = s->base.max;
for (size_t i = 0; i < s->extents.size(); ++i) {
max = max + s->extents[i] * s->strides[i] - s->strides[i];
}
return IntervalSet::make(s->base.min, max);
}
LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
return IntSet::everything();
}
Range IntSet::cover_range(Range max_range) const {
IntSet temp;
const IntervalSet* s_int = (*this).as<IntervalSet>();
if (s_int == nullptr) {
temp = this->cover_interval();
s_int = temp.as<IntervalSet>();
}
if (s_int->i.is_bounded()) {
return Range::make_with_min_extent(
s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
}
return max_range;
}
namespace {
bool IntSet::is_everything() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_everything());
}
inline bool Match(const Expr& e, int64_t value) {
const ir::IntImm* v = e.as<ir::IntImm>();
return v != nullptr && v->value;
bool IntSet::is_single_point() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
return (s_int && s_int->i.is_single_point());
}
// whether a exactly matches b.
inline bool Match(const IntSet& a,
const Range& b) {
if (a->base == b &&
a->domain.size() == 0 &&
a->concrete.size() == 0) {
return true;
} else {
return false;
}
IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
// whether a exactly matches b.
inline bool Match(const IntSet& a,
const Expr& b) {
if (a->domain.size() == 0 &&
a->concrete.size() == 0) {
return Match(a->base->extent, 1) && a->base->min.same_as(b);
} else {
return false;
}
IntSet IntSet::single_point(Expr x) {
return IntervalSet::make(Interval::single_point(x));
}
inline bool IsNumber(const IntSet& s) {
if (s->domain.size() != 0) return false;
if (s->concrete.size() != 0) {
return s->concrete.size() == 1;
IntSet IntSet::range(Range r) {
// must make sure it can be matched back by MatchRange.
if (is_one(r->extent)) {
return IntSet::single_point(r->min);
}
if (is_positive_const(r->extent) && is_const(r->min)) {
return IntervalSet::make(
r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
}
return Match(s->base->extent, 1);
return IntervalSet::make(r->min, (r->extent + r->min) - 1);
}
inline Expr AsNumber(const IntSet& s) {
return s->base->min;
// Check if a is created from b.
inline bool MatchRange(const IntSet& a,
const Range& b) {
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
if (!i.min.same_as(b)) return false;
if (is_one(b->extent)) return i.is_single_point();
if (is_positive_const(b->extent) && is_const(b->min)) {
// deep equality
return Equal(
ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1),
a_int->i.max);
}
const Sub* sub = i.max.as<Sub>();
if (!sub) return false;
if (is_one(sub->b)) return false;
const Add* add = sub->a.as<Add>();
return add &&
add->a.same_as(b->min) &&
add->b.same_as(b->extent);
}
// set combination rule by operators
template<typename T>
inline IntSet BinaryCombine(IntSet a, IntSet b) {
LOG(WARNING) << "cannot evaluate binary op " << T::_type_key;
return IntSet::make_all_set();
inline bool MatchPoint(const IntSet& a,
const Expr& b) {
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
return i.is_single_point() && i.min.same_as(b);
}
template<>
inline IntSet BinaryCombine<Add>(IntSet a, IntSet b) {
auto n = std::make_shared<IntSetNode>(*(a.operator->()));
for (size_t i = 0; i < b->domain.size(); ++i) {
n->domain.push_back(b->domain[i]);
n->stride.push_back(b->stride[i]);
}
if (IsNumber(a)) {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
b->base->extent);
} else if (IsNumber(b)) {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
a->base->extent);
} else {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
a->base->extent + b->base->extent - 1);
IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
x.include(set[i].cover_interval().as<IntervalSet>()->i);
}
return IntSet(n);
return IntervalSet::make(x);
}
inline Range Negation(Range a) {
if (Match(a->extent, 1)) {
return Range::make_with_min_extent(-a->min, a->extent);
} else {
return Range::make_with_min_extent(-(a->min + a->extent - 1), a->extent);
// type traits
template<typename OP>
struct is_logical_op {
static const bool value = false;
};
#define TVM_DECLARE_LOGICAL_OP(OP) \
template<> \
struct is_logical_op<ir::OP> { \
static const bool value = true; \
};
// interval related.
template<typename OP>
inline IntSet CombineInterval(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
}
LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
return IntSet::everything();
}
inline IntSet Negation(IntSet a) {
CHECK_EQ(a->concrete.size(), 0U);
auto n = std::make_shared<IntSetNode>();
n->base = Negation(a->base);
for (size_t i = 0; i < a->domain.size(); ++i) {
n->domain.push_back(Negation(a->domain[i]));
n->stride.push_back(a->stride[i]);
template<>
inline IntSet CombineInterval<Add>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
}
Interval r = Interval::everything();
if (a.has_lower_bound() && b.has_lower_bound()) {
r.min = ComputeExpr<Add>(a.min, b.min);
}
return IntSet(a);
if (a.has_upper_bound() && b.has_upper_bound()) {
r.max = ComputeExpr<Add>(a.max, b.max);
}
return IntervalSet::make(r);
}
template<>
inline IntSet BinaryCombine<Sub>(IntSet a, IntSet b) {
return BinaryCombine<Add>(a, Negation(b));
inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
}
Interval r = Interval::everything();
if (a.has_lower_bound() && b.has_upper_bound()) {
r.min = ComputeExpr<Sub>(a.min, b.max);
}
if (a.has_upper_bound() && b.has_lower_bound()) {
r.max = ComputeExpr<Sub>(a.max, b.min);
}
return IntervalSet::make(r);
}
inline IntSet BinaryMul(IntSet a, Expr b) {
// copy construct
if (Match(b, 1)) return a;
if (Match(b, -1)) return Negation(a);
auto n = std::make_shared<IntSetNode>();
n->base = Range::make_with_min_extent(0, 1);
n->domain.push_back(a->base);
n->stride.push_back(b);
for (size_t i = 0; i < a->domain.size(); ++i) {
n->domain.push_back(a->domain[i]);
n->stride.push_back(a->stride[i] * b);
}
return IntSet(a);
template<>
inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
}
if (a.is_single_point() && !b.is_single_point()) {
std::swap(a, b);
}
if (b.is_single_point()) {
if (is_zero(b.min)) return IntSet::single_point(0);
if (is_one(b.min)) return IntervalSet::make(a);
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
// This is relaxiation
// TODO(tqchen): consider convert to StrideSet.
if (is_positive_const(b.min)) {
return IntervalSet::make(e1, e2);
} else if (is_negative_const(b.min)) {
return IntervalSet::make(e2, e1);
} else if (a.is_bounded()) {
Expr cmp = b.min >= make_zero(b.min.type().element_of());
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
}
}
LOG(WARNING) << "Return Everything in CombineInterval Mul";
return IntSet::everything();
}
template<>
inline IntSet BinaryCombine<Mul>(IntSet a, IntSet b) {
if (IsNumber(a)) {
return BinaryMul(a, AsNumber(b));
} else if (IsNumber(b)) {
return BinaryMul(b, AsNumber(a));
} else {
return IntSet::make_all_set();
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
}
return IntervalSet::make(Interval::make_max(a.min, b.min),
Interval::make_max(a.max, b.max));
}
} // namespace
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
template<>
inline IntSet CombineInterval<Min>(Interval a, Interval b) {
if (a.is_single_point() && b.is_single_point()) {
return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
}
return IntervalSet::make(Interval::make_min(a.min, b.min),
Interval::make_min(a.max, b.max));
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntSetNode>([](const IntSetNode *op, IRPrinter *p) {
p->stream << "int-set(base=";
p->print(op->base);
p->stream << ')';
});
template<typename OP>
inline IntSet CombineInterval_(IntSet a, IntSet b) {
return CombineInterval<OP>(
a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
}
IntSet IntSet::make_range(Range dom) {
auto n = std::make_shared<IntSetNode>();
n->base = dom;
// stride related
inline IntSet AsStrideSet(IntSet a) {
if (a.as<StrideSet>()) return a;
const IntervalSet* s = a.as<IntervalSet>();
CHECK(s->i.is_bounded());
std::shared_ptr<StrideSet> n = std::make_shared<StrideSet>();
n->base = s->i;
return IntSet(n);
}
template<typename OP>
inline IntSet CombineSets(IntSet a, IntSet b) {
return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
}
Range IntSet::GetCoverRange() const {
const IntSetNode* s = operator->();
CHECK(s != nullptr) << "empty set";
if (s->domain.size() == 0 && s->concrete.size() == 0) {
return s->base;
template<>
inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
const IntervalSet* a_int = a.as<IntervalSet>();
const IntervalSet* b_int = b.as<IntervalSet>();
if (a_int && is_zero(a_int->i.min)) return b;
if (b_int && is_zero(b_int->i.min)) return a;
a = AsStrideSet(a);
b = AsStrideSet(b);
const StrideSet* a_stride = a.as<StrideSet>();
const StrideSet* b_stride = b.as<StrideSet>();
auto n = std::make_shared<StrideSet>(*a_stride);
for (size_t i = 0; i < b_stride->extents.size(); ++i) {
n->extents.push_back(b_stride->extents[i]);
n->strides.push_back(b_stride->strides[i]);
}
LOG(FATAL) << "not yet implemented";
return Range();
n->base = CombineInterval<Add>(
a_stride->base, b_stride->base).as<IntervalSet>()->i;
return IntSet(n);
}
IntSet IntSet::make_point(Expr point) {
return IntSet::make_range(Range::make_with_min_extent(point, 1));
inline IntSet NegateSet(IntSet a) {
const IntervalSet* a_int = a.as<IntervalSet>();
if (a_int) {
if (a_int->i.is_single_point()) {
return IntSet::single_point(-a_int->i.min);
} else {
Interval r = Interval::everything();
if (a_int->i.has_upper_bound()) {
r.min = -(a_int->i.max);
}
if (a_int->i.has_lower_bound()) {
r.max = -(a_int->i.min);
}
return IntervalSet::make(r);
}
} else {
return NegateSet(a.cover_interval());
}
}
IntSet IntSet::make_all_set() {
LOG(FATAL) << "TODO";
return IntSet();
template<>
inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
return CombineSets<Add>(a, NegateSet(b));
}
IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
LOG(FATAL) << "TODO";
return IntSet();
TVM_DECLARE_LOGICAL_OP(And);
TVM_DECLARE_LOGICAL_OP(Or);
TVM_DECLARE_LOGICAL_OP(EQ);
TVM_DECLARE_LOGICAL_OP(NE);
TVM_DECLARE_LOGICAL_OP(GE);
TVM_DECLARE_LOGICAL_OP(GT);
TVM_DECLARE_LOGICAL_OP(LE);
TVM_DECLARE_LOGICAL_OP(LT);
TVM_DECLARE_LOGICAL_OP(Not);
// generic combine operations of two sets
template<typename OP>
inline IntSet Combine(const IntSet& a, const IntSet &b) {
if (is_logical_op<OP>::value) {
return IntervalSet::make(0, 1);
}
const IntervalSet* a_int = a.as<IntervalSet>();
const IntervalSet* b_int = b.as<IntervalSet>();
if (a_int && a_int->i.is_everything()) return a;
if (b_int && b_int->i.is_everything()) return b;
if (a_int && b_int) {
return CombineInterval<OP>(a_int->i, b_int->i);
}
if (a_int && !(a_int->i.is_bounded())) {
return CombineInterval_<OP>(a, b.cover_interval());
}
if (b_int && !(b_int->i.is_bounded())) {
return CombineInterval_<OP>(a.cover_interval(), b);
}
return CombineSets<OP>(a, b);
}
// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
......@@ -215,33 +358,21 @@ void PassUp(const SplitNode* s,
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
Match(outer, dom_map.at(s->outer)) &&
Match(inner, dom_map.at(s->inner))) {
*parent = IntSet::make_range(dom_map.at(s->parent));
MatchRange(outer, dom_map.at(s->outer)) &&
MatchRange(inner, dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
// copy construct
auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
if (IsNumber(outer)) {
// shift the base offset
n->base = Range::make_with_min_extent(
AsNumber(outer) * factor + inner->base->min,
inner->base->extent);
} else {
// default use all domains in the data.
n->domain.push_back(outer->base);
n->stride.push_back(factor);
for (size_t i = 0; i < outer->domain.size(); ++i) {
n->domain.push_back(outer->domain[i]);
n->stride.push_back(outer->stride[i] * factor);
}
}
*parent = IntSet(n);
*parent = Combine<Add>(
Combine<Add>(
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
IntSet::single_point(parent_min));
}
void PassUp(const FuseNode* s,
......@@ -253,29 +384,51 @@ void PassUp(const FuseNode* s,
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (Match(fused, dom_map.at(s->fused))) {
*outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make_range(dom_map.at(s->inner));
if (MatchRange(fused, dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
if (IsNumber(fused)) {
Expr value = AsNumber(fused);
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
const IntervalSet* fused_int = fused.as<IntervalSet>();
if (fused_int && fused_int->i.is_single_point()) {
Expr value = fused_int->i.min;
Expr factor = dom_map.at(s->inner)->extent;
*outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor);
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make_range(dom_map.at(s->inner));
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (MatchRange(rebased, dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
}
namespace {
// evaluator to evaluate the int set
class IRSetEvaluator {
// Evaluator to evalute the epxression.
class IntSetEvaluator {
public:
inline IntSet Eval(Expr expr) {
static const FType& f = vtable();
......@@ -283,11 +436,11 @@ class IRSetEvaluator {
return f(expr, expr, this);
} else {
LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
return IntSet::make_all_set();
return IntSet::everything();
}
}
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IRSetEvaluator *)>;
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IntSetEvaluator *)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
......@@ -295,76 +448,84 @@ class IRSetEvaluator {
std::unordered_map<const Variable*, IntSet> dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make_point(e);
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) {
return IntSet::single_point(e);
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<IntImm>(ConstOp)
.set_dispatch<UIntImm>(ConstOp)
.set_dispatch<FloatImm>(ConstOp);
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IRSetEvaluator* m) {
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IntSetEvaluator* m) {
auto it = m->dom_map.find(op);
if (it != m->dom_map.end()) {
return it->second;
} else {
return IntSet::make_point(e);
return IntSet::single_point(e);
}
});
// binary operator
template<typename T>
inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) {
IntSet a = m->Eval(op->a);
IntSet b = m->Eval(op->b);
if (IsNumber(a) && IsNumber(b)) {
if (Match(a, op->a) &&
Match(b, op->b)) {
return IntSet::make_point(e);
} else {
return IntSet::make_point(T::make(AsNumber(a), AsNumber(b)));
}
} else {
return BinaryCombine<T>(a, b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e);
}
IntSet r = Combine<T>(a, b);
return r;
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>);
// use simply bound for logical expressions for now.
inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make_range(Range::make_with_min_extent(0, 2));
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<EQ>(Logical)
.set_dispatch<NE>(Logical)
.set_dispatch<LT>(Logical)
.set_dispatch<LE>(Logical)
.set_dispatch<GT>(Logical)
.set_dispatch<GE>(Logical)
.set_dispatch<And>(Logical)
.set_dispatch<Or>(Logical);
} // namespace
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) {
IRSetEvaluator m;
IntSetEvaluator m;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
}
return m.Eval(e);
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
IntSetEvaluator m;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
}
IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i;
if (!ei.has_upper_bound()) return IntSet::everything();
ext_set = IntervalSet::make(0, ComputeExpr<Sub>(ei.max, 1));
return Combine<Add>(min_set, ext_set);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set["
<< "[" << op->i.min << ", "
<< op->i.max << ']';
});
} // namespace schedule
} // namespace tvm
......@@ -22,35 +22,48 @@ class IntSet : public NodeRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not deontainer.
// constructor from not container.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return whether the set is empty */
inline bool is_empty() const {
return !defined();
}
/*!
* \return a range that covers the IntSet
*/
Range GetCoverRange() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IntSetNode* operator->() const;
/*!
* \param dom The domain to be created.
* \return create integer set from existing domain
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
static IntSet make_range(Range dom);
IntSet cover_interval() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*!
* \param point
* \return create integer set that only contains one point
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet make_point(Expr point);
static IntSet single_point(Expr point);
/*!
* \return create integer set that represents everything
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet make_all_set();
static IntSet range(Range r);
};
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
};
/*!
......@@ -63,6 +76,18 @@ class IntSet : public NodeRef {
*/
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Conditional upward message passing.
*
......@@ -99,6 +124,23 @@ void PassUp(const FuseNode* s,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
......@@ -106,6 +148,11 @@ void PassUp(const FuseNode* s,
*/
IntSet Union(const Array<IntSet>& sets);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
} // namespace schedule
} // namespace tvm
......
......@@ -81,7 +81,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
}
}
CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
<< "Cannot find the specified axis in parent stage's leaf_iter_vars";
return *this;
}
......@@ -165,7 +165,6 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
n->roots = ops;
......@@ -203,9 +202,53 @@ IterVarRelation FuseNode::make(
return IterVarRelation(n);
}
IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
auto n = std::make_shared<RebaseNode>();
n->parent = parent;
n->rebased = rebased;
return IterVarRelation(n);
}
void Schedule::normalize() {
std::unordered_map<IterVar, IterVar> rebase_map;
std::unordered_map<const Node*, int> attach_mark;
for (Stage s : (*this)->stages) {
if (s->attach_type == kScope) {
attach_mark[s->attach_stage.get()] = 1;
}
}
for (Stage s : (*this)->stages) {
if (!attach_mark.count(s.get())) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindIterVar(leaf_vars, iv);
if (idx < leaf_vars->data.size()) {
// insert rebase
IterVar rebased(Range(), iv->var->name_hint + ".rb");
s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : (*this)->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
......@@ -8,12 +8,44 @@
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./scope.h"
#include "./ir_util.h"
#include "../schedule/graph.h"
#include "../pass/ir_util.h"
#include "./int_set.h"
#include "./graph.h"
namespace tvm {
namespace ir {
namespace schedule {
using namespace ir;
/*!
* \brief message passing to find if IterVar is related to reduction.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassDownFlag(const Stage& s,
std::unordered_map<IterVar, int>* p_state) {
auto& state = *p_state;
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
int flag = state.at(s->parent);
state[s->outer] = flag;
state[s->inner] = flag;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
int flag_outer = state.at(s->outer);
int flag_inner = state.at(s->inner);
state[s->fused] = flag_outer | flag_inner;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
int flag = state.at(s->parent);
state[s->rebased] = flag;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
......@@ -37,7 +69,7 @@ void PassUpOffset(const Stage& s,
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = parent_min + state[s->parent];
state[s->parent] = state[s->parent] + parent_min;
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
......@@ -49,10 +81,20 @@ void PassUpOffset(const Stage& s,
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = outer_min + state[s->outer];
state[s->outer] = state[s->outer] + outer_min;
}
if (!is_zero(inner_min)) {
state[s->inner] = outer_min + state[s->inner];
state[s->inner] = state[s->inner] + inner_min;
}
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
Expr value = state.at(s->rebased);
Expr parent_min = dom_map.at(s->parent)->min;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = value + parent_min;
} else {
state[s->parent] = value;
}
} else {
LOG(FATAL) << "unknown relation type";
......@@ -60,76 +102,54 @@ void PassUpOffset(const Stage& s,
}
}
/*!
* \brief split the expr by addition.
* \param expr The expression to be splitted.
* \param loop_level The loop level of each Variable
* \param result vector of (level, expr)
* The level gives the mimimum loop level this expression need to be computed.
* The Expr gives the expression content.
*/
void SplitByAdd(Expr expr,
const std::unordered_map<const Variable*, size_t>& loop_level,
std::vector<std::pair<size_t, Expr> > *result) {
const Add* op = expr.as<Add>();
if (op != nullptr) {
SplitByAdd(op->a, loop_level, result);
SplitByAdd(op->b, loop_level, result);
} else {
size_t max_level = 0;
auto fvisit = [&max_level, &loop_level](const NodeRef& n) {
const Variable* op = n.as<Variable>();
if (op != nullptr) {
auto it = loop_level.find(op);
if (it != loop_level.end()) {
max_level = std::max(max_level, it->second);
}
}
};
PostOrderVisit(expr, fvisit);
result->push_back(std::make_pair(max_level, expr));
}
}
/*!
* \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule.
* \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/
std::vector<std::vector<Stmt> > MakeLoopNest(
const Stage& sch,
const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map.
std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& sch,
const Map<IterVar, Range>& dom_map,
size_t begin_loop,
bool reduce_init_loop,
std::unordered_map<IterVar, Expr>* p_value_map,
const std::unordered_map<IterVar, bool>& skip_iter) {
auto leaf_iter_vars = sch->leaf_iter_vars;
std::unordered_map<IterVar, Expr> offset;
std::unordered_map<const Variable*, size_t> loop_level;
Stmt no_op = Evaluate::make(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
if (skip_iter.count(iv) && skip_iter.at(iv)) {
// skip this iteration.
value_map[iv] = iv->var;
continue;
}
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
offset[iv] = iv->var;
loop_level[iv->var.as<Variable>()] = i + 1;
Var var = iv->var;
if (reduce_init_loop) {
var = Var(iv->var->name_hint + ".init", iv->var.type());
}
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
if (is_zero(dom->min)) {
if (is_one(dom->extent)) {
nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
For::make(iv->var, 0, dom->extent,
For::make(var, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(
LetStmt::make(iv->var, dom->min + idx, no_op));
LetStmt::make(var, new_value, no_op));
}
} else {
// Always restrict threaded IterVar to starts from 0.
......@@ -137,69 +157,73 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
value_map[iv] = var;
}
if (!reduce_init_loop) {
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
}
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
}
// message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &offset);
for (IterVar iv : sch->op->root_iter_vars()) {
Expr value = offset.at(iv);
if (!value.same_as(iv->var)) {
using Entry = std::pair<size_t, Expr>;
std::vector<Entry> splits;
SplitByAdd(value, loop_level, &splits);
PassUpOffset(sch, dom_map, &value_map);
return nest;
}
Expr offset = 0;
size_t nsplit_left = splits.size() - 1;
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
size_t hit = 0;
for (const auto& kv : splits) {
if (kv.first == i) {
if (is_zero(offset)) {
offset = kv.second;
} else {
offset = offset + kv.second;
++hit;
}
}
}
nsplit_left -= hit;
if (hit != 0) {
std::ostringstream os;
os << iv->var->name_hint << ".at.l" << i;
Var base_offset(os.str());
if (nsplit_left == 0) {
base_offset = iv->var;
}
nest[i].emplace_back(
LetStmt::make(base_offset, offset, no_op));
offset = base_offset;
}
}
Range dom = dom_map.at(iv);
if (!offset.same_as(iv->var)) {
// define the iv->var
nest.back().emplace_back(
LetStmt::make(iv->var, offset, no_op));
Stmt MakeLoop(const Stage& s,
const Map<IterVar, Range>& dom_map,
Stmt provide,
Stmt init) {
std::unordered_map<IterVar, Expr> value_map;
auto nest = MakeLoopNest(s, dom_map, 0, false, &value_map, {});
provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> reduce_state;
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
for (IterVar iv : compute->reduce_axis) {
reduce_state[iv] = 2;
}
for (IterVar iv : compute->axis) {
reduce_state[iv] = 1;
}
// find which iter var is related to reduction and which is related to axis.
PassDownFlag(s, &reduce_state);
auto leaf_iter_vars = s->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = reduce_state.at(iv);
if ((flag & 2) != 0) {
begin_loop = i; break;
}
Expr condition = (iv->var - dom->min) < dom->extent;
// Boundary condition checking
// Need better boundary condition here.
nest.back().emplace_back(IfThenElse::make(condition, no_op));
init_value_map[iv] = value_map.at(iv);
}
// skip loops that does not relates to axis.
std::unordered_map<IterVar, bool> skip_iter;
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = reduce_state.at(iv);
if ((flag & 1) == 0) skip_iter[iv] = true;
}
auto init_nest = MakeLoopNest(
s, dom_map, begin_loop, true, &init_value_map, skip_iter);
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop, nest.end());
provide = MergeNest(reduce, provide);
return MergeNest(
common, Block::make(init, provide));
} else {
return MergeNest(nest, provide);
}
return nest;
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param tensors The tensors generated by provide.
*/
Stmt MakeProvide(const ComputeOpNode* op,
const std::vector<Tensor>& tensors) {
Tensor t = tensors[0];
......@@ -210,13 +234,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body, args);
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt MakeRealize(const ComputeOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
......@@ -230,6 +247,38 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds, make_const(Bool(1), true), body);
}
void MakeReduction(const ComputeOpNode* op,
const std::vector<Tensor>& tensors,
const Map<IterVar, Range>& dom_map,
Stmt* init,
Stmt* provide) {
Stmt no_op = Evaluate::make(0);
Tensor t = tensors[0];
std::vector<Stmt> nest;
Array<Expr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
const Reduce* reduce = op->body.as<Reduce>();
CHECK(reduce);
Expr init_value, update_value;
if (reduce->op == "Add") {
init_value = make_zero(reduce->type);
update_value = Add::make(t(args), reduce->source);
} else if (reduce->op == "Max") {
init_value = reduce->type.min();
update_value = Max::make(t(args), reduce->source);
} else if (reduce->op == "Min") {
init_value = reduce->type.max();
update_value = Min::make(t(args), reduce->source);
} else {
LOG(FATAL) << "Unsupported reduction " << reduce->op;
}
*init = Provide::make(t->op, t->value_index, init_value, args);
*provide = Provide::make(t->op, t->value_index, update_value, args);
}
Stmt MakePipeline(const Stage& sch,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
......@@ -238,14 +287,20 @@ Stmt MakePipeline(const Stage& sch,
tensors.emplace_back(sch->op.output(i));
}
Stmt provide;
if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
Stmt init, provide;
const ComputeOpNode* compute = sch->op.as<ComputeOpNode>();
if (compute) {
if (compute->reduce_axis.size() == 0) {
provide = MakeProvide(compute, tensors);
} else {
MakeReduction(compute, tensors, dom_map, &init, &provide);
}
} else {
LOG(FATAL) << "not supported op " << sch->op->type_key();
}
std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
Stmt producer = MergeNest(nest, provide);
Stmt producer = MakeLoop(sch, dom_map, provide, init);
producer = ProducerConsumer::make(sch->op, true, producer);
Stmt pipeline = producer;
......@@ -306,7 +361,6 @@ Stmt InjectInline(const Operation op, Stmt body) {
return Inline(body, op, args, compute->body);
}
Stmt ScheduleOps(
Schedule sch, Map<IterVar, Range> dom_map) {
Stmt body = Stmt();
......@@ -330,5 +384,5 @@ Stmt ScheduleOps(
return body;
}
} // namespace ir
} // namespace schedule
} // namespace tvm
......@@ -18,7 +18,8 @@ def test_add():
# one line to build the function.
codes = []
fadd = tvm.build(s, args=[A, B, C],
fadd = tvm.build(s,
args=[A, B, C],
target="cuda", name="myadd",
record_codes=codes)
for c in codes:
......
import tvm
import numpy as np
def test_sum():
# graph
n = tvm.Var('n')
m = tvm.Var('m')
A = tvm.placeholder((n, m), name='A')
k = tvm.IterVar((0, m))
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
# schedule
s = tvm.Schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
tvm.init_opencl()
codes = []
fsum = tvm.build(s,
args=[A, B],
target="opencl", name="myadd",
record_codes=codes)
for c in codes:
print(c)
num_device = 1
for i in range(num_device):
ctx = tvm.opencl(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1028
m = 129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
if __name__ == "__main__":
test_sum()
......@@ -18,8 +18,7 @@ def test_add_pipeline():
# compile to IR
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
......
......@@ -10,12 +10,13 @@ def test_makeapi():
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
......
......@@ -26,7 +26,7 @@ def test_tensor_reduce():
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv))
# json load save
C_json = tvm.save_json(C)
C_loaded = tvm.load_json(C_json)
......
......@@ -12,7 +12,7 @@ def test_flatten2():
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
......
......@@ -11,7 +11,7 @@ def test_schedule0():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1():
......@@ -24,7 +24,7 @@ def test_schedule1():
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2():
......@@ -39,7 +39,7 @@ def test_schedule2():
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment