Commit 78ea652d by Tianqi Chen Committed by Haichen Shen

[PASS] Schedule Ops init working version (#6)

* [PASS] Schedule Ops init working version

* bugfix in PassUp
parent 302c2e64
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683 Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
using Halide::Type; using Halide::Type;
using Halide::Float; using Halide::Float;
using Halide::Bool;
using Halide::Int; using Halide::Int;
using Halide::UInt; using Halide::UInt;
using Halide::Handle; using Halide::Handle;
...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt; ...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter; using Halide::Internal::IRPrinter;
using Halide::Internal::Variable; using Halide::Internal::Variable;
using Halide::Internal::make_const;
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
public: public:
......
...@@ -18,6 +18,16 @@ ...@@ -18,6 +18,16 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*! /*!
* \brief verifies whether the IR stmt or Expr is in SSA form. * \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For) * That is: each VarExpr is defined and assigned once(in Let/For)
...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f, ...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body, Expr body,
Stmt stmt); Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -13,6 +13,36 @@ ...@@ -13,6 +13,36 @@
namespace tvm { namespace tvm {
/*! /*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain. * \brief A Compute op that compute a tensor on certain domain.
*/ */
class ComputeOpNode : public OperationNode { class ComputeOpNode : public OperationNode {
...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode { ...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
size_t num_outputs() const final { int num_outputs() const final {
return 1; return 1;
} }
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
...@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode { ...@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode {
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
/*! /*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis) * using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file bound.h * \file schedule_pass.h
* \brief The bound inference logics on the schedule. * \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/ */
#ifndef TVM_SCHEDULE_BOUND_H_ #ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_BOUND_H_ #define TVM_SCHEDULE_PASS_H_
#include <tvm/expr.h> #include "./base.h"
#include <tvm/schedule.h> #include "./schedule.h"
#include <unordered_map>
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch); ...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_SCHEDULE_BOUND_H_
...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef; ...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input, * \brief Tensor structure representing a possible input,
* or intermediate computation result. * or intermediate computation result.
*/ */
class Tensor : public FunctionRef { class Tensor : public NodeRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {} explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef { ...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
}; };
/*! \brief Operation that produces tensors */ /*! \brief Operation that produces tensors */
class Operation : public NodeRef { class Operation : public FunctionRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Operation() {} Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -137,12 +128,10 @@ class Operation : public NodeRef { ...@@ -137,12 +128,10 @@ class Operation : public NodeRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode { class TensorNode : public Node {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
/*! \brief the source operation, can be None */ /*! \brief the source operation, can be None */
...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode { ...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
} }
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index); int value_index);
...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode { ...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*! /*!
* \brief base class of operation node. * \brief base class of operation node.
*/ */
class OperationNode : public Node { class OperationNode : public FunctionBaseNode {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */ /*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */ /*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0; virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */ /*! \return shape of i-th output */
......
...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32): ...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype) return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="TensorObj"): def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object. """Construct an empty tensor object.
Parameters Parameters
...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"): ...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor The created tensor
""" """
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
return _function_internal._Tensor( return _function_internal._Placeholder(
shape, name, dtype, None, 0) shape, dtype, name)
def compute(shape, fcompute, name="TensorCompute"): def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain. """Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis) The compute rule is result[axis] = fcompute(axis)
......
...@@ -34,7 +34,9 @@ class Tensor(NodeBase): ...@@ -34,7 +34,9 @@ class Tensor(NodeBase):
else: else:
raise ValueError("The indices must be expression") raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0) return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
def __getitem__(self, indices): def __getitem__(self, indices):
return TensorSlice(self, indices) return TensorSlice(self, indices)
...@@ -71,3 +73,7 @@ class Operation(NodeBase): ...@@ -71,3 +73,7 @@ class Operation(NodeBase):
@register_node @register_node
class ComputeOp(Operation): class ComputeOp(Operation):
pass pass
@register_node
class PlaceholderOp(Operation):
pass
...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For) ...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5)); args.at(5));
}); });
TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Call) TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0), *ret = Call::make(args.at(0),
...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt); ...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store); REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse); REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate); REGISTER_MAKE1(Evaluate);
......
...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range) ...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0), *ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2), args.at(2),
args.at(3), args.at(3),
args.at(4)); args.at(4));
...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash) ...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor())); std::hash<Tensor>()(args.at(0).operator Tensor()));
}); });
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp) TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0), *ret = ComputeOpNode::make(args.at(0),
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue; ...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h" #include "../schedule/graph.h"
namespace tvm { namespace tvm {
......
...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IterVarNode); TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) { .set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = "; p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value); p->print(op->value);
p->stream << '\n'; p->stream << '\n';
p->print(op->body); p->print(op->body);
......
...@@ -9,11 +9,73 @@ ...@@ -9,11 +9,73 @@
namespace tvm { namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) { .set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")"; p->stream << "placeholder(" << op->name << ", " << op << ")";
}); });
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name, ...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n); return Operation(n);
} }
Tensor Operation::output(size_t i) const { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
auto node = std::make_shared<TensorNode>(); .set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
node->op = *this; p->stream << "compute(" << op->name << ", " << op << ")";
node->value_index = 0; });
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode); TVM_REGISTER_NODE_TYPE(ComputeOpNode);
......
...@@ -8,33 +8,24 @@ ...@@ -8,33 +8,24 @@
namespace tvm { namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const { Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call; using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size()) CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read" << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "ndim = " << ndim() << ", indices.size=" << indices.size();
auto n = Call::make( auto n = Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this); (*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
return n; return n;
} }
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = std::make_shared<TensorNode>();
n->shape = shape; n->shape = shape;
n->name = name;
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
n->value_index = value_index; n->value_index = value_index;
...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) { .set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')'; << ", op.name=" << t->op->name << ')';
}); });
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
......
...@@ -22,6 +22,7 @@ class IRInline : public IRMutator { ...@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
expr = IRMutator::Mutate(expr); expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>(); const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) { if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call); return InlineCall(call);
} else { } else {
return expr; return expr;
...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f, ...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
Array<Var> args, Array<Var> args,
Expr body, Expr body,
Stmt stmt) { Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}) })
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) { .set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m); auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m); auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) { if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s; return s;
} else { } else {
return Provide::make(op->func, new_values, new_args); return Provide::make(op->func, op->value_index, new_value, new_args);
} }
}) })
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
condition.same_as(op->condition)) { condition.same_as(op->condition)) {
return s; return s;
} else { } else {
return Realize::make(op->func, op->types, new_bounds, return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body); condition, body);
} }
}) })
...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) { .set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition); Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case); Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case); Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
......
...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
}) })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) { .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
VisitArray(op->values, v); v->Visit(op->value);
}) })
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) { .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
......
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./scope.h" #include "./scope.h"
#include "../schedule/graph.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -20,7 +23,7 @@ namespace { ...@@ -20,7 +23,7 @@ namespace {
* IterVar->The assignment. * IterVar->The assignment.
*/ */
void PassUpOffset(const Schedule& s, void PassUpOffset(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) { std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state; auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) { for (size_t i = s->relations.size(); i != 0; --i) {
...@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s, ...@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s,
if (rel.as<SplitNode>()) { if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>(); const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer); Expr outer = state.at(s->outer);
Expr inner = state.at(s->outer); Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr offset = inner + outer * factor; Expr offset = inner + outer * factor;
Expr outer_min = dom_map.at(s->parent)->min; Expr outer_min = dom_map.at(s->parent)->min;
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
...@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s, ...@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s,
} else if (rel.as<FuseNode>()) { } else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>(); const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused); Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
state[s->outer] = value / factor; state[s->outer] = value / factor;
state[s->inner] = value % factor; state[s->inner] = value % factor;
} else { } else {
...@@ -84,24 +87,35 @@ void SplitByAdd(Expr expr, ...@@ -84,24 +87,35 @@ void SplitByAdd(Expr expr,
* \param nest A list of For and LetStmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined.
* \param body body * \param body body
*/ */
Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) { Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
while (!nest.empty()) { // use reverse iteration
Stmt s = std::move(nest.back()); for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
nest.pop_back(); for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
if (s.as<For>()) { Stmt s = *rj;
auto n = std::make_shared<For>(*s.as<For>()); if (s.as<For>()) {
n->body = body; auto n = std::make_shared<For>(*s.as<For>());
body = Stmt(n); CHECK(is_no_op(n->body));
} else if (s.as<LetStmt>()) { n->body = body;
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>()); body = Stmt(n);
n->body = body; } else if (s.as<LetStmt>()) {
body = Stmt(n); auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
} else if (s.as<AttrStmt>()) { CHECK(is_no_op(n->body));
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>()); n->body = body;
n->body = body; body = Stmt(n);
body = Stmt(n); } else if (s.as<AttrStmt>()) {
} else { auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
LOG(FATAL) << "not supported nest type"; CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
} }
} }
return body; return body;
...@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) { ...@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
* \brief Make the loop nest of the correspondings schedule. * \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule. * \param sch The schedule.
* \param dom_map The domain map. * \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/ */
std::vector<Stmt> MakeLoopNest( std::vector<std::vector<Stmt> > MakeLoopNest(
const Schedule& sch, const Schedule& sch,
const std::unordered_map<IterVar, Range>& dom_map) { const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map. // optional, use let to define some CSE in dom_map.
auto leaf_iter_vars = sch->leaf_iter_vars; auto leaf_iter_vars = sch->leaf_iter_vars;
std::unordered_map<IterVar, Expr> offset; std::unordered_map<IterVar, Expr> offset;
std::unordered_map<const Variable*, size_t> loop_level; std::unordered_map<const Variable*, size_t> loop_level;
Stmt no_op = Evaluate::make(0);
// create the loop nest // create the loop nest
std::vector<Stmt> nest; std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1, Stmt()); nest.resize(leaf_iter_vars.size() + 1);
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i]; auto iv = leaf_iter_vars[i];
// initialize the offset and loop_level // initialize the offset and loop_level
offset[iv] = iv->var; offset[iv] = iv->var;
loop_level[iv->var.as<Variable>()] = i + 1; loop_level[iv->var.as<Variable>()] = i + 1;
// Mark the iter var in the IR, to remember the point
nest[i] = AttrStmt::make(iv->var, "scope", iv, Stmt());
if (iv->thread_tag.length() == 0) { if (iv->thread_tag.length() == 0) {
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
nest[i] = For::make(iv->var, dom->min, dom->extent, nest[i + 1].emplace_back(
ForType::Serial, DeviceAPI::None, nest[i]); For::make(iv->var, dom->min, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
} }
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
} }
// message passing to get offset of root iter vars. // message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &offset); PassUpOffset(sch, dom_map, &offset);
for (IterVar iv : sch->op->root_iter_vars()) { for (IterVar iv : sch->op->root_iter_vars()) {
Expr value = offset.at(iv); Expr value = offset.at(iv);
if (value.same_as(iv->var)) continue; if (!value.same_as(iv->var)) {
using Entry = std::pair<size_t, Expr>; using Entry = std::pair<size_t, Expr>;
std::vector<Entry> splits; std::vector<Entry> splits;
SplitByAdd(value, loop_level, &splits); SplitByAdd(value, loop_level, &splits);
Expr offset = 0; Expr offset = 0;
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) { size_t nsplit_left = splits.size() - 1;
auto iv = leaf_iter_vars[i]; for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
for (const auto& kv : splits) { size_t hit = 0;
if (kv.first == i) { for (const auto& kv : splits) {
offset = offset + splits[i].second; if (kv.first == i) {
if (is_zero(offset)) {
offset = kv.second;
} else {
offset = offset + kv.second;
++hit;
}
}
} }
nsplit_left -= hit;
if (hit != 0) {
std::ostringstream os;
os << iv->var->name_hint << ".at.l" << i;
Var base_offset(os.str());
if (nsplit_left == 0) {
base_offset = iv->var;
}
nest[i].emplace_back(
LetStmt::make(base_offset, offset, no_op));
offset = base_offset;
}
}
Range dom = dom_map.at(iv);
if (!offset.same_as(iv->var)) {
// define the iv->var
nest.back().emplace_back(
LetStmt::make(iv->var, offset, no_op));
} }
std::ostringstream os; Expr condition = (iv->var - dom->min) < dom->extent;
os << iv->var->name_hint << ".at.l" << i; // Boundary condition checking
Var base_offset(os.str()); // Need better boundary condition here.
nest[i] = LetStmt::make(base_offset, offset, nest[i]); nest.back().emplace_back(IfThenElse::make(condition, no_op));
offset = base_offset;
} }
nest.back() = LetStmt::make(iv->var, offset, nest.back());
} }
return nest; return nest;
} }
/*! /*!
* \brief Make the loop nest of the correspondings schedule. * \brief Make pipeline specifically for compute op node.
* \param op The operation. * \param op The compute node
* \param tensors The tensors generated by provide.
*/ */
Stmt MakeBody(const Operation& op) { Stmt MakeProvide(const ComputeOpNode* op,
Stmt body; const std::vector<Tensor>& tensors) {
if (op.as<ComputeOpNode>()) { Tensor t = tensors[0];
const ComputeOpNode* compute = op.as<ComputeOpNode>(); Array<Expr> args;
// Note: Tensor's address cannot uniquely for (IterVar iv : op->axis) {
Tensor t = op.output(0); args.push_back(iv->var);
Array<Expr> args; }
for (IterVar iv : compute->axis) { return Provide::make(t->op, t->value_index, op->body, args);
args.push_back(iv->var); }
}
body = Provide::make(t, {compute->body}, args); /*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt MakeRealize(const ComputeOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Tensor t = tensors[0];
Halide::Internal::Region bounds;
for (IterVar iv : op->axis) {
bounds.push_back(dom_map.at(iv));
}
return Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
Stmt MakePipeline(const Schedule& sch,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
std::vector<Tensor> tensors;
for (int i = 0; i < sch->op->num_outputs(); ++i) {
tensors.emplace_back(sch->op.output(i));
}
Stmt provide;
if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
} else { } else {
LOG(FATAL) << "not supported op"; LOG(FATAL) << "not supported op";
} }
return body; std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
} Stmt producer = MergeNest(nest, provide);
producer = ProducerConsumer::make(sch->op, true, producer);
Stmt MakePipeline(const Schedule& sch, Stmt body) { Stmt pipeline = producer;
return body; if (consumer.defined()) {
consumer = ProducerConsumer::make(sch->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
if (sch->op.as<ComputeOpNode>()) {
return MakeRealize(sch->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
return Stmt();
}
} }
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
class InjectRealize : public IRMutator { class InjectRealize : public IRMutator {
public: public:
explicit InjectRealize(Schedule sch) InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map)
: sch_(sch) {} : schedule(schedule), dom_map(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr && if (op != nullptr &&
op->type_key == "scope" && op->type_key == "scope") {
op->node == sch_->attach_parent) { if (op->node == schedule->attach_parent) {
return AttrStmt::make( CHECK(!found_attach);
op->node, op->type_key, op->value, found_attach = true;
MakePipeline(sch_, op->body)); stmt = AttrStmt::make(
} else { op->node, op->type_key, op->value,
return stmt; MakePipeline(schedule, dom_map,
IRMutator::Mutate(op->body)));
}
} }
return stmt;
} }
private:
// the operations to be carried // the operations to be carried
Schedule sch_; Schedule schedule;
Scope<AttrKey, Expr> attr_scope_; // domain map
Map<IterVar, Range> dom_map;
// whether attach point is found
bool found_attach{false};
}; };
void GetOpToScheduleMap(
Schedule s,
std::unordered_map<Operation, Schedule>* ret) {
CHECK(!ret->count(s->op))
<< "Duplicated schedule for op";
(*ret)[s->op] = s;
for (Schedule c : s->children) {
GetOpToScheduleMap(c, ret);
}
}
// order schedule by DFS calling order of ops
std::vector<Schedule> OrderSchedule(Schedule s) {
auto g = schedule::CreateReadGraph(s->op);
auto post_order = schedule::PostDFSOrder(s->op, g);
std::unordered_map<Operation, Schedule> op2sch;
GetOpToScheduleMap(s, &op2sch);
std::vector<Schedule> sorder;
// reverse iteration.
for (size_t i = post_order.size(); i != 0; --i) {
sorder.push_back(op2sch.at(post_order[i - 1]));
}
return sorder;
}
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute != nullptr)
<< "can only inline compute op";
Array<Var> args;
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
return Inline(op, args, compute->body, body);
}
} // namespace } // namespace
Stmt ScheduleOps(
Schedule s, Map<IterVar, Range> dom_map) {
std::vector<Schedule> svec = OrderSchedule(s);
Stmt body = Stmt();
for (Schedule s : svec) {
if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
CHECK(body.defined());
InjectRealize mutator(s, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point";
}
}
return body;
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,7 @@ class Scope { ...@@ -36,7 +36,7 @@ class Scope {
*/ */
inline void Pop(const K& key) { inline void Pop(const K& key) {
auto& v = data_[key]; auto& v = data_[key];
CHECK_NE(v.size(), 0); CHECK_NE(v.size(), 0U);
v.pop_back(); v.pop_back();
} }
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h" #include "./int_set.h"
#include "./bound.h"
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
...@@ -113,7 +113,7 @@ void PassToOperation( ...@@ -113,7 +113,7 @@ void PassToOperation(
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]); (*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
} }
} else { } else {
LOG(FATAL) << "unknown operation mode"; LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
} }
} }
...@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order, ...@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order,
auto fvisit = [p_state, &result](const NodeRef& n) { auto fvisit = [p_state, &result](const NodeRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_); Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined()) { if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
std::vector<IntSet> arg_bounds; std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) { for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result)); arg_bounds.push_back(EvalSet(call->args[i], result));
......
...@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) { ...@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) { auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_); Operation call_op(call->func.node_);
deps.push_back(t); deps.push_back(call_op.output(call->value_index));
if (t->op.defined() && visited.count(t->op.get()) == 0) { if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(t->op.get()); visited.insert(call_op.get());
stack.push_back(t->op); stack.push_back(call_op);
} }
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps); rmap.Set(op, deps);
} else { } else {
LOG(FATAL) << "unknown operation mode"; if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
} }
} }
return rmap; return rmap;
...@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op, ...@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op,
Array<Operation>* post_order) { Array<Operation>* post_order) {
visited->insert(op); visited->insert(op);
for (const auto& t : g.at(op)) { for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) { if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
PostDFSOrder(t->op, g, visited, post_order); PostDFSOrder(t->op, g, visited, post_order);
} }
} }
......
...@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s, ...@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent)); *parent = IntSet::make_range(dom_map.at(s->parent));
return; return;
} }
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
CHECK(outer.defined()); CHECK(outer.defined());
CHECK(inner.defined()); CHECK(inner.defined());
CHECK(factor.defined()); CHECK(factor.defined());
...@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s, ...@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s,
if (IsNumber(fused)) { if (IsNumber(fused)) {
Expr value = AsNumber(fused); Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
*outer = IntSet::make_point(value / factor); *outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor); *inner = IntSet::make_point(value % factor);
} else { } else {
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
TEST(Tensor, Basic) { TEST(Tensor, Basic) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B"); Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = Compute({m, n}, [&](Var i, Var j) {
return A[i][j]; return A[i][j];
...@@ -19,8 +20,8 @@ TEST(Tensor, Basic) { ...@@ -19,8 +20,8 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) { TEST(Tensor, Reduce) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A"); Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B({n, l}, "B"); Tensor B = Placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k"); IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = Compute({m, n}, [&](Var i, Var j) {
......
...@@ -10,7 +10,7 @@ def test_tensor(): ...@@ -10,7 +10,7 @@ def test_tensor():
print(T) print(T)
print(T.op.body) print(T.op.body)
assert(tuple(T.shape) == (m, n, l)) assert(tuple(T.shape) == (m, n, l))
assert(A.op is None) assert(isinstance(A.op, tvm.tensor.PlaceholderOp))
assert(A == A) assert(A == A)
assert(T.op.output(0) == T) assert(T.op.output(0) == T)
assert(T.op.output(0).__hash__() == T.__hash__()) assert(T.op.output(0).__hash__() == T.__hash__())
......
...@@ -6,7 +6,7 @@ def test_inline(): ...@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, [x.var for x in T.op.axis], T.op.body, stmt) T.op, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
...@@ -14,7 +14,7 @@ def test_inline(): ...@@ -14,7 +14,7 @@ def test_inline():
# pass in int array(wrong argument type) # pass in int array(wrong argument type)
# must raise an error # must raise an error
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt) T.op, [1,2,3], T.op.body, stmt)
assert False assert False
except tvm.TVMError: except tvm.TVMError:
pass pass
......
import tvm
def test_schedule0():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
print(stmt)
if __name__ == "__main__":
test_schedule0()
test_schedule1()
test_schedule2()
...@@ -65,7 +65,7 @@ def test_create_read_graph(): ...@@ -65,7 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__": if __name__ == "__main__":
test_create_read_graph()
test_bound3() test_bound3()
test_bound1() test_bound1()
test_bound2() test_bound2()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment