Commit 78ea652d by Tianqi Chen Committed by Haichen Shen

[PASS] Schedule Ops init working version (#6)

* [PASS] Schedule Ops init working version

* bugfix in PassUp
parent 302c2e64
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683
Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
......@@ -17,6 +17,7 @@ namespace tvm {
using Halide::Type;
using Halide::Float;
using Halide::Bool;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
......@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
using Halide::Internal::make_const;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
......
......@@ -18,6 +18,16 @@
namespace tvm {
namespace ir {
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
......@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir
} // namespace tvm
......
......@@ -13,6 +13,36 @@
namespace tvm {
/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class ComputeOpNode : public OperationNode {
......@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}
size_t num_outputs() const final {
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
......@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode {
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
......
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
* \file schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_BOUND_H_
#define TVM_SCHEDULE_BOUND_H_
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include "./base.h"
#include "./schedule.h"
namespace tvm {
namespace schedule {
......@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_BOUND_H_
#endif // TVM_SCHEDULE_PASS_H_
......@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public FunctionRef {
class Tensor : public NodeRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
};
/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
class Operation : public FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -137,12 +128,10 @@ class Operation : public NodeRef {
};
/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
......@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index);
......@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
......
......@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="TensorObj"):
def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
......@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor
"""
dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, 0)
return _function_internal._Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="TensorCompute"):
def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
......
......@@ -34,7 +34,9 @@ class Tensor(NodeBase):
else:
raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
def __getitem__(self, indices):
return TensorSlice(self, indices)
......@@ -71,3 +73,7 @@ class Operation(NodeBase):
@register_node
class ComputeOp(Operation):
pass
@register_node
class PlaceholderOp(Operation):
pass
......@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
......@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
......
......@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
......@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor()));
});
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
......
......@@ -7,7 +7,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm {
namespace ir {
......@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
} // namespace ir
} // namespace tvm
......@@ -6,8 +6,8 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"
namespace tvm {
......
......@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm
......@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
......
......@@ -9,11 +9,73 @@
namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")";
.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
......@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n);
}
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
......
......@@ -8,33 +8,24 @@
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
auto n = Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
return n;
}
Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->name = name;
n->dtype = dtype;
n->op = op;
n->value_index = value_index;
......@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
<< ", op.name=" << t->op->name << ')';
});
TVM_REGISTER_NODE_TYPE(TensorNode);
......
......@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call);
} else {
return expr;
......@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
}
} // namespace ir
......
......@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
})
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) {
auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, new_values, new_args);
return Provide::make(op->func, op->value_index, new_value, new_args);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
......@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->types, new_bounds,
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body);
}
})
......@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case);
Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
......
......@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
VisitArray(op->values, v);
v->Visit(op->value);
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
......
......@@ -6,7 +6,10 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./scope.h"
#include "../schedule/graph.h"
namespace tvm {
namespace ir {
......@@ -20,7 +23,7 @@ namespace {
* IterVar->The assignment.
*/
void PassUpOffset(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) {
......@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s,
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer);
Expr inner = state.at(s->outer);
Expr factor = dom_map.at(s->outer)->extent;
Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr offset = inner + outer * factor;
Expr outer_min = dom_map.at(s->parent)->min;
if (!is_zero(outer_min)) {
......@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s,
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->outer)->extent;
Expr factor = dom_map.at(s->inner)->extent;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
} else {
......@@ -84,26 +87,37 @@ void SplitByAdd(Expr expr,
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
*/
Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back());
nest.pop_back();
Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
Stmt s = *rj;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
}
return body;
}
......@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
* \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule.
* \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/
std::vector<Stmt> MakeLoopNest(
std::vector<std::vector<Stmt> > MakeLoopNest(
const Schedule& sch,
const std::unordered_map<IterVar, Range>& dom_map) {
const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map.
auto leaf_iter_vars = sch->leaf_iter_vars;
std::unordered_map<IterVar, Expr> offset;
std::unordered_map<const Variable*, size_t> loop_level;
Stmt no_op = Evaluate::make(0);
// create the loop nest
std::vector<Stmt> nest;
nest.resize(leaf_iter_vars.size() + 1, Stmt());
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
// initialize the offset and loop_level
offset[iv] = iv->var;
loop_level[iv->var.as<Variable>()] = i + 1;
nest[i] = AttrStmt::make(iv->var, "scope", iv, Stmt());
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
Range dom = dom_map.at(iv);
nest[i] = For::make(iv->var, dom->min, dom->extent,
ForType::Serial, DeviceAPI::None, nest[i]);
nest[i + 1].emplace_back(
For::make(iv->var, dom->min, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
}
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
}
// message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &offset);
for (IterVar iv : sch->op->root_iter_vars()) {
Expr value = offset.at(iv);
if (value.same_as(iv->var)) continue;
if (!value.same_as(iv->var)) {
using Entry = std::pair<size_t, Expr>;
std::vector<Entry> splits;
SplitByAdd(value, loop_level, &splits);
Expr offset = 0;
size_t nsplit_left = splits.size() - 1;
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
size_t hit = 0;
for (const auto& kv : splits) {
if (kv.first == i) {
offset = offset + splits[i].second;
if (is_zero(offset)) {
offset = kv.second;
} else {
offset = offset + kv.second;
++hit;
}
}
}
nsplit_left -= hit;
if (hit != 0) {
std::ostringstream os;
os << iv->var->name_hint << ".at.l" << i;
Var base_offset(os.str());
nest[i] = LetStmt::make(base_offset, offset, nest[i]);
if (nsplit_left == 0) {
base_offset = iv->var;
}
nest[i].emplace_back(
LetStmt::make(base_offset, offset, no_op));
offset = base_offset;
}
nest.back() = LetStmt::make(iv->var, offset, nest.back());
}
Range dom = dom_map.at(iv);
if (!offset.same_as(iv->var)) {
// define the iv->var
nest.back().emplace_back(
LetStmt::make(iv->var, offset, no_op));
}
Expr condition = (iv->var - dom->min) < dom->extent;
// Boundary condition checking
// Need better boundary condition here.
nest.back().emplace_back(IfThenElse::make(condition, no_op));
}
}
return nest;
}
/*!
* \brief Make the loop nest of the correspondings schedule.
* \param op The operation.
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param tensors The tensors generated by provide.
*/
Stmt MakeBody(const Operation& op) {
Stmt body;
if (op.as<ComputeOpNode>()) {
const ComputeOpNode* compute = op.as<ComputeOpNode>();
// Note: Tensor's address cannot uniquely
Tensor t = op.output(0);
Stmt MakeProvide(const ComputeOpNode* op,
const std::vector<Tensor>& tensors) {
Tensor t = tensors[0];
Array<Expr> args;
for (IterVar iv : compute->axis) {
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
body = Provide::make(t, {compute->body}, args);
return Provide::make(t->op, t->value_index, op->body, args);
}
/*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt MakeRealize(const ComputeOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Tensor t = tensors[0];
Halide::Internal::Region bounds;
for (IterVar iv : op->axis) {
bounds.push_back(dom_map.at(iv));
}
return Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
Stmt MakePipeline(const Schedule& sch,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
std::vector<Tensor> tensors;
for (int i = 0; i < sch->op->num_outputs(); ++i) {
tensors.emplace_back(sch->op.output(i));
}
Stmt provide;
if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
} else {
LOG(FATAL) << "not supported op";
}
return body;
}
std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
Stmt producer = MergeNest(nest, provide);
producer = ProducerConsumer::make(sch->op, true, producer);
Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body;
Stmt pipeline = producer;
if (consumer.defined()) {
consumer = ProducerConsumer::make(sch->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
if (sch->op.as<ComputeOpNode>()) {
return MakeRealize(sch->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
return Stmt();
}
}
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(Schedule sch)
: sch_(sch) {}
InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map)
: schedule(schedule), dom_map(dom_map) {}
Stmt Mutate(Stmt stmt) final {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
}
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == "scope" &&
op->node == sch_->attach_parent) {
return AttrStmt::make(
op->type_key == "scope") {
if (op->node == schedule->attach_parent) {
CHECK(!found_attach);
found_attach = true;
stmt = AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(sch_, op->body));
} else {
return stmt;
MakePipeline(schedule, dom_map,
IRMutator::Mutate(op->body)));
}
}
private:
return stmt;
}
// the operations to be carried
Schedule sch_;
Scope<AttrKey, Expr> attr_scope_;
Schedule schedule;
// domain map
Map<IterVar, Range> dom_map;
// whether attach point is found
bool found_attach{false};
};
void GetOpToScheduleMap(
Schedule s,
std::unordered_map<Operation, Schedule>* ret) {
CHECK(!ret->count(s->op))
<< "Duplicated schedule for op";
(*ret)[s->op] = s;
for (Schedule c : s->children) {
GetOpToScheduleMap(c, ret);
}
}
// order schedule by DFS calling order of ops
std::vector<Schedule> OrderSchedule(Schedule s) {
auto g = schedule::CreateReadGraph(s->op);
auto post_order = schedule::PostDFSOrder(s->op, g);
std::unordered_map<Operation, Schedule> op2sch;
GetOpToScheduleMap(s, &op2sch);
std::vector<Schedule> sorder;
// reverse iteration.
for (size_t i = post_order.size(); i != 0; --i) {
sorder.push_back(op2sch.at(post_order[i - 1]));
}
return sorder;
}
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute != nullptr)
<< "can only inline compute op";
Array<Var> args;
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
return Inline(op, args, compute->body, body);
}
} // namespace
Stmt ScheduleOps(
Schedule s, Map<IterVar, Range> dom_map) {
std::vector<Schedule> svec = OrderSchedule(s);
Stmt body = Stmt();
for (Schedule s : svec) {
if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
CHECK(body.defined());
InjectRealize mutator(s, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point";
}
}
return body;
}
} // namespace ir
} // namespace tvm
......@@ -36,7 +36,7 @@ class Scope {
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0);
CHECK_NE(v.size(), 0U);
v.pop_back();
}
......
......@@ -5,8 +5,8 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./bound.h"
#include "./graph.h"
namespace tvm {
......@@ -113,7 +113,7 @@ void PassToOperation(
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
}
} else {
LOG(FATAL) << "unknown operation mode";
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
}
}
......@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order,
auto fvisit = [p_state, &result](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
if (t->op.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result));
......
......@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
deps.push_back(t);
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index));
if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(call_op.get());
stack.push_back(call_op);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else {
LOG(FATAL) << "unknown operation mode";
if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
}
}
return rmap;
......@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op,
Array<Operation>* post_order) {
visited->insert(op);
for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) {
if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
PostDFSOrder(t->op, g, visited, post_order);
}
}
......
......@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->outer)->extent;
Expr factor = dom_map.at(s->inner)->extent;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
......@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s,
if (IsNumber(fused)) {
Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent;
Expr factor = dom_map.at(s->inner)->extent;
*outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor);
} else {
......
......@@ -5,8 +5,9 @@
TEST(Tensor, Basic) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) {
return A[i][j];
......@@ -19,8 +20,8 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
......
......@@ -10,7 +10,7 @@ def test_tensor():
print(T)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
assert(A.op is None)
assert(isinstance(A.op, tvm.tensor.PlaceholderOp))
assert(A == A)
assert(T.op.output(0) == T)
assert(T.op.output(0).__hash__() == T.__hash__())
......
......@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
T, [x.var for x in T.op.axis], T.op.body, stmt)
T.op, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......@@ -14,7 +14,7 @@ def test_inline():
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt)
T.op, [1,2,3], T.op.body, stmt)
assert False
except tvm.TVMError:
pass
......
import tvm
def test_schedule0():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
print(stmt)
if __name__ == "__main__":
test_schedule0()
test_schedule1()
test_schedule2()
......@@ -65,7 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__":
test_create_read_graph()
test_bound3()
test_bound1()
test_bound2()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment