Commit 78ea652d by Tianqi Chen Committed by Haichen Shen

[PASS] Schedule Ops init working version (#6)

* [PASS] Schedule Ops init working version

* bugfix in PassUp
parent 302c2e64
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683
Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
......@@ -17,6 +17,7 @@ namespace tvm {
using Halide::Type;
using Halide::Float;
using Halide::Bool;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
......@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
using Halide::Internal::make_const;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
......
......@@ -18,6 +18,16 @@
namespace tvm {
namespace ir {
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
......@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir
} // namespace tvm
......
......@@ -13,6 +13,36 @@
namespace tvm {
/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class ComputeOpNode : public OperationNode {
......@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}
size_t num_outputs() const final {
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
......@@ -50,6 +79,16 @@ class ComputeOpNode : public OperationNode {
using FCompute = std::function<Expr (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
......
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
* \file schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_BOUND_H_
#define TVM_SCHEDULE_BOUND_H_
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include "./base.h"
#include "./schedule.h"
namespace tvm {
namespace schedule {
......@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_BOUND_H_
#endif // TVM_SCHEDULE_PASS_H_
......@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public FunctionRef {
class Tensor : public NodeRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
};
/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
class Operation : public FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
......@@ -137,12 +128,10 @@ class Operation : public NodeRef {
};
/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
......@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index);
......@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
......
......@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="TensorObj"):
def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
......@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor
"""
dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, 0)
return _function_internal._Placeholder(
shape, dtype, name)
def compute(shape, fcompute, name="TensorCompute"):
def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
......
......@@ -34,7 +34,9 @@ class Tensor(NodeBase):
else:
raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
def __getitem__(self, indices):
return TensorSlice(self, indices)
......@@ -71,3 +73,7 @@ class Operation(NodeBase):
@register_node
class ComputeOp(Operation):
pass
@register_node
class PlaceholderOp(Operation):
pass
......@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
......@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
......
......@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
......@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor()));
});
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
......
......@@ -7,7 +7,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm {
namespace ir {
......@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
} // namespace ir
} // namespace tvm
......@@ -6,8 +6,8 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"
namespace tvm {
......
......@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm
......@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
......
......@@ -9,11 +9,73 @@
namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")";
.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
......@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n);
}
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "compute(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
......
......@@ -8,33 +8,24 @@
namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
auto n = Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
return n;
}
Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->name = name;
n->dtype = dtype;
n->op = op;
n->value_index = value_index;
......@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
<< ", op.name=" << t->op->name << ')';
});
TVM_REGISTER_NODE_TYPE(TensorNode);
......
......@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call);
} else {
return expr;
......@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
}
} // namespace ir
......
......@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
})
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) {
auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, new_values, new_args);
return Provide::make(op->func, op->value_index, new_value, new_args);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
......@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->types, new_bounds,
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body);
}
})
......@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case);
Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
......
......@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
VisitArray(op->values, v);
v->Visit(op->value);
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
......
......@@ -36,7 +36,7 @@ class Scope {
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0);
CHECK_NE(v.size(), 0U);
v.pop_back();
}
......
......@@ -5,8 +5,8 @@
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./bound.h"
#include "./graph.h"
namespace tvm {
......@@ -113,7 +113,7 @@ void PassToOperation(
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
}
} else {
LOG(FATAL) << "unknown operation mode";
LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
}
}
......@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order,
auto fvisit = [p_state, &result](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
if (t->op.defined()) {
Tensor t = Operation(call->func.node_).output(call->value_index);
if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result));
......
......@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
deps.push_back(t);
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index));
if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(call_op.get());
stack.push_back(call_op);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else {
LOG(FATAL) << "unknown operation mode";
if (!op.as<PlaceholderOpNode>()) {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
}
}
return rmap;
......@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op,
Array<Operation>* post_order) {
visited->insert(op);
for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) {
if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
PostDFSOrder(t->op, g, visited, post_order);
}
}
......
......@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s,
*parent = IntSet::make_range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->outer)->extent;
Expr factor = dom_map.at(s->inner)->extent;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
......@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s,
if (IsNumber(fused)) {
Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent;
Expr factor = dom_map.at(s->inner)->extent;
*outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor);
} else {
......
......@@ -5,8 +5,9 @@
TEST(Tensor, Basic) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
auto C = Compute({m, n}, [&](Var i, Var j) {
return A[i][j];
......@@ -19,8 +20,8 @@ TEST(Tensor, Basic) {
TEST(Tensor, Reduce) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
Tensor A = Placeholder({m, l}, Float(32), "A");
Tensor B = Placeholder({n, l}, Float(32), "B");
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
......
......@@ -10,7 +10,7 @@ def test_tensor():
print(T)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
assert(A.op is None)
assert(isinstance(A.op, tvm.tensor.PlaceholderOp))
assert(A == A)
assert(T.op.output(0) == T)
assert(T.op.output(0).__hash__() == T.__hash__())
......
......@@ -6,7 +6,7 @@ def test_inline():
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
T, [x.var for x in T.op.axis], T.op.body, stmt)
T.op, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
......@@ -14,7 +14,7 @@ def test_inline():
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt)
T.op, [1,2,3], T.op.body, stmt)
assert False
except tvm.TVMError:
pass
......
import tvm
def test_schedule0():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
print(stmt)
if __name__ == "__main__":
test_schedule0()
test_schedule1()
test_schedule2()
......@@ -65,7 +65,7 @@ def test_create_read_graph():
if __name__ == "__main__":
test_create_read_graph()
test_bound3()
test_bound1()
test_bound2()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment