Commit 78ea652d by Tianqi Chen Committed by Haichen Shen

[PASS] Schedule Ops init working version (#6)

* [PASS] Schedule Ops init working version

* bugfix in PassUp
parent 302c2e64
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683 Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
using Halide::Type; using Halide::Type;
using Halide::Float; using Halide::Float;
using Halide::Bool;
using Halide::Int; using Halide::Int;
using Halide::UInt; using Halide::UInt;
using Halide::Handle; using Halide::Handle;
...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt; ...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter; using Halide::Internal::IRPrinter;
using Halide::Internal::Variable; using Halide::Internal::Variable;
using Halide::Internal::make_const;
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
public: public:
......
...@@ -18,6 +18,16 @@ ...@@ -18,6 +18,16 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*! /*!
* \brief verifies whether the IR stmt or Expr is in SSA form. * \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For) * That is: each VarExpr is defined and assigned once(in Let/For)
...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f, ...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body, Expr body,
Stmt stmt); Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -13,6 +13,36 @@ ...@@ -13,6 +13,36 @@
namespace tvm { namespace tvm {
/*! /*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain. * \brief A Compute op that compute a tensor on certain domain.
*/ */
class ComputeOpNode : public OperationNode { class ComputeOpNode : public OperationNode {
...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode { ...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
size_t num_outputs() const final { int num_outputs() const final {
return 1; return 1;
} }
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
...@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode { ...@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode {
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
/*! /*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis) * using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file bound.h * \file schedule_pass.h
* \brief The bound inference logics on the schedule. * \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/ */
#ifndef TVM_SCHEDULE_BOUND_H_ #ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_BOUND_H_ #define TVM_SCHEDULE_PASS_H_
#include <tvm/expr.h> #include "./base.h"
#include <tvm/schedule.h> #include "./schedule.h"
#include <unordered_map>
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch); ...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_SCHEDULE_BOUND_H_
...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef; ...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input, * \brief Tensor structure representing a possible input,
* or intermediate computation result. * or intermediate computation result.
*/ */
class Tensor : public FunctionRef { class Tensor : public NodeRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {} explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef { ...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
}; };
/*! \brief Operation that produces tensors */ /*! \brief Operation that produces tensors */
class Operation : public NodeRef { class Operation : public FunctionRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Operation() {} Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -137,12 +128,10 @@ class Operation : public NodeRef { ...@@ -137,12 +128,10 @@ class Operation : public NodeRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode { class TensorNode : public Node {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
/*! \brief the source operation, can be None */ /*! \brief the source operation, can be None */
...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode { ...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
} }
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index); int value_index);
...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode { ...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*! /*!
* \brief base class of operation node. * \brief base class of operation node.
*/ */
class OperationNode : public Node { class OperationNode : public FunctionBaseNode {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */ /*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */ /*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0; virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */ /*! \return shape of i-th output */
......
...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32): ...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype) return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="TensorObj"): def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object. """Construct an empty tensor object.
Parameters Parameters
...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"): ...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor The created tensor
""" """
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
return _function_internal._Tensor( return _function_internal._Placeholder(
shape, name, dtype, None, 0) shape, dtype, name)
def compute(shape, fcompute, name="TensorCompute"): def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain. """Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis) The compute rule is result[axis] = fcompute(axis)
......
...@@ -34,7 +34,9 @@ class Tensor(NodeBase): ...@@ -34,7 +34,9 @@ class Tensor(NodeBase):
else: else:
raise ValueError("The indices must be expression") raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0) return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
def __getitem__(self, indices): def __getitem__(self, indices):
return TensorSlice(self, indices) return TensorSlice(self, indices)
...@@ -71,3 +73,7 @@ class Operation(NodeBase): ...@@ -71,3 +73,7 @@ class Operation(NodeBase):
@register_node @register_node
class ComputeOp(Operation): class ComputeOp(Operation):
pass pass
@register_node
class PlaceholderOp(Operation):
pass
...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For) ...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5)); args.at(5));
}); });
TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Call) TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0), *ret = Call::make(args.at(0),
...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt); ...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store); REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse); REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate); REGISTER_MAKE1(Evaluate);
......
...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range) ...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0), *ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2), args.at(2),
args.at(3), args.at(3),
args.at(4)); args.at(4));
...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash) ...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor())); std::hash<Tensor>()(args.at(0).operator Tensor()));
}); });
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp) TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0), *ret = ComputeOpNode::make(args.at(0),
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue; ...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h" #include "../schedule/graph.h"
namespace tvm { namespace tvm {
......
...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IterVarNode); TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) { .set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = "; p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value); p->print(op->value);
p->stream << '\n'; p->stream << '\n';
p->print(op->body); p->print(op->body);
......
...@@ -9,11 +9,73 @@ ...@@ -9,11 +9,73 @@
namespace tvm { namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) { .set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")"; p->stream << "placeholder(" << op->name << ", " << op << ")";
}); });
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name, ...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n); return Operation(n);
} }
Tensor Operation::output(size_t i) const { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
auto node = std::make_shared<TensorNode>(); .set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
node->op = *this; p->stream << "compute(" << op->name << ", " << op << ")";
node->value_index = 0; });
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode); TVM_REGISTER_NODE_TYPE(ComputeOpNode);
......
...@@ -8,33 +8,24 @@ ...@@ -8,33 +8,24 @@
namespace tvm { namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const { Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call; using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size()) CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read" << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "ndim = " << ndim() << ", indices.size=" << indices.size();
auto n = Call::make( auto n = Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this); (*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
return n; return n;
} }
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = std::make_shared<TensorNode>();
n->shape = shape; n->shape = shape;
n->name = name;
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
n->value_index = value_index; n->value_index = value_index;
...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) { .set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')'; << ", op.name=" << t->op->name << ')';
}); });
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
......
...@@ -22,6 +22,7 @@ class IRInline : public IRMutator { ...@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
expr = IRMutator::Mutate(expr); expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>(); const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) { if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call); return InlineCall(call);
} else { } else {
return expr; return expr;
...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f, ...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
Array<Var> args, Array<Var> args,
Expr body, Expr body,
Stmt stmt) { Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}) })
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) { .set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m); auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m); auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) { if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s; return s;
} else { } else {
return Provide::make(op->func, new_values, new_args); return Provide::make(op->func, op->value_index, new_value, new_args);
} }
}) })
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
condition.same_as(op->condition)) { condition.same_as(op->condition)) {
return s; return s;
} else { } else {
return Realize::make(op->func, op->types, new_bounds, return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body); condition, body);
} }
}) })
...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) { .set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition); Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case); Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case); Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
......
...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
}) })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) { .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
VisitArray(op->values, v); v->Visit(op->value);
}) })
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) { .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
......
...@@ -36,7 +36,7 @@ class Scope { ...@@ -36,7 +36,7 @@ class Scope {
*/ */
inline void Pop(const K& key) { inline void Pop(const K& key) {
auto& v = data_[key]; auto& v = data_[key];
CHECK_NE(v.size(), 0); CHECK_NE(v.size(), 0U);
v.pop_back(); v.pop_back();
} }
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h" #include "./int_set.h"
#include "./bound.h"
#include "./graph.h" #include "./graph.h"
namespace tvm { namespace tvm {
...@@ -113,7 +113,7 @@ void PassToOperation( ...@@ -113,7 +113,7 @@ void PassToOperation(
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]); (*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
} }
} else { } else {
LOG(FATAL) << "unknown operation mode"; LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
} }
} }
...@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order, ...@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order,
auto fvisit = [p_state, &result](const NodeRef& n) { auto fvisit = [p_state, &result](const NodeRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_); Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined()) { if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
std::vector<IntSet> arg_bounds; std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) { for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result)); arg_bounds.push_back(EvalSet(call->args[i], result));
......
...@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) { ...@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) { auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>(); auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_); Operation call_op(call->func.node_);
deps.push_back(t); deps.push_back(call_op.output(call->value_index));
if (t->op.defined() && visited.count(t->op.get()) == 0) { if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(t->op.get()); visited.insert(call_op.get());
stack.push_back(t->op); stack.push_back(call_op);
} }
} }
}; };
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps); rmap.Set(op, deps);
} else { } else {
LOG(FATAL) << "unknown operation mode"; if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
} }
} }
return rmap; return rmap;
...@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op, ...@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op,
Array<Operation>* post_order) { Array<Operation>* post_order) {
visited->insert(op); visited->insert(op);
for (const auto& t : g.at(op)) { for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) { if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
PostDFSOrder(t->op, g, visited, post_order); PostDFSOrder(t->op, g, visited, post_order);
} }
} }
......
...@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s, ...@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent)); *parent = IntSet::make_range(dom_map.at(s->parent));
return; return;
} }
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
CHECK(outer.defined()); CHECK(outer.defined());
CHECK(inner.defined()); CHECK(inner.defined());
CHECK(factor.defined()); CHECK(factor.defined());
...@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s, ...@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s,
if (IsNumber(fused)) { if (IsNumber(fused)) {
Expr value = AsNumber(fused); Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
*outer = IntSet::make_point(value / factor); *outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor); *inner = IntSet::make_point(value % factor);
} else { } else {
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
TEST(Tensor, Basic) { TEST(Tensor, Basic) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B"); Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = Compute({m, n}, [&](Var i, Var j) {
return A[i][j]; return A[i][j];
...@@ -19,8 +20,8 @@ TEST(Tensor, Basic) { ...@@ -19,8 +20,8 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) { TEST(Tensor, Reduce) {
using namespace tvm; using namespace tvm;
Var m("m"), n("n"), l("l"); Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A"); Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B({n, l}, "B"); Tensor B = Placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k"); IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) { auto C = Compute({m, n}, [&](Var i, Var j) {
......
...@@ -10,7 +10,7 @@ def test_tensor(): ...@@ -10,7 +10,7 @@ def test_tensor():
print(T) print(T)
print(T.op.body) print(T.op.body)
assert(tuple(T.shape) == (m, n, l)) assert(tuple(T.shape) == (m, n, l))
assert(A.op is None) assert(isinstance(A.op, tvm.tensor.PlaceholderOp))
assert(A == A) assert(A == A)
assert(T.op.output(0) == T) assert(T.op.output(0) == T)
assert(T.op.output(0).__hash__() == T.__hash__()) assert(T.op.output(0).__hash__() == T.__hash__())
......
...@@ -6,7 +6,7 @@ def test_inline(): ...@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, [x.var for x in T.op.axis], T.op.body, stmt) T.op, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
...@@ -14,7 +14,7 @@ def test_inline(): ...@@ -14,7 +14,7 @@ def test_inline():
# pass in int array(wrong argument type) # pass in int array(wrong argument type)
# must raise an error # must raise an error
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt) T.op, [1,2,3], T.op.body, stmt)
assert False assert False
except tvm.TVMError: except tvm.TVMError:
pass pass
......
import tvm
def test_schedule0():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
print(stmt)
if __name__ == "__main__":
test_schedule0()
test_schedule1()
test_schedule2()
...@@ -65,7 +65,7 @@ def test_create_read_graph(): ...@@ -65,7 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__": if __name__ == "__main__":
test_create_read_graph()
test_bound3() test_bound3()
test_bound1() test_bound1()
test_bound2() test_bound2()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment