Commit 820a8597 by Tianqi Chen Committed by GitHub

[LANG] Introduce Scan, Bugfix Canonical (#43)

parent f8f02829
......@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
constexpr const char* scope = "scope";
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
// The above attr does not pass to ir stage.
/*!
* \brief Mark launching extent of thread, used by device API.
*/
......@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_IR_H_
......@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
};
/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
};
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;
......@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
std::function<Expr(Var)> f,
......
......@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import collections as _collections
int32 = "int32"
......@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
shape: Tuple of Expr
The shape of the tensor
fcompute: lambda function of *indices-> value
Specifies the input source expression
......@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return _api_internal._Tensor(
shape, body.dtype, op_node, 0)
return op_node.output(0)
def scan(axis, init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.
Parameters
----------
axis: IterVar
The scanning axis.
init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps
update: Tensor or list of Tensor
The update rule of the scan given by symbolic tensor.
state_placeholder: Tensor or list of Tensor
The placeholder variables used by update.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
Example
-------
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
if isinstance(update, _tensor.Tensor):
update = [update]
if isinstance(state_placeholder, _tensor.Tensor):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
def Buffer(shape, dtype=None,
......
......@@ -75,11 +75,16 @@ class Operation(NodeBase):
return _api_internal._OpGetOutput(self, index)
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
pass
@register_node
class ComputeOp(Operation):
"""Compute operation."""
pass
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
class ScanOp(Operation):
"""Scan operation."""
pass
......@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
args[2]);
});
TVM_REGISTER_API(_ScanOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ScanOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
......
......@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
const ComExpr& sumb,
int bscale) {
std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
n->base = suma->base + sumb->base;
n->base = suma->base + sumb->base * bscale;
// merge of suma and sumb;
size_t i = 0, j = 0;
while (i < suma->elem.size() && j < sumb->elem.size()) {
......@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
// convert sum to expr
Expr Sum2Expr(const ComExpr& com, Type t) {
Expr vsum;
if (com->base != 0) {
if (com->base > 0) {
vsum = make_const(t, com->base);
}
for (const ComExprEntry& e : com->elem) {
......@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
}
}
}
if (com->base < 0) {
if (vsum.defined()) {
vsum = Sub::make(vsum, make_const(t, -com->base));
} else {
vsum = make_const(t, com->base);
}
}
for (const ComExprEntry& e : com->elem) {
if (e.scale < 0) {
Expr v = e.value;
......
......@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
LOG(INFO) << code;
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
......
......@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
int max_auto_unroll_{1025};
};
} // namespace codegen
......
......@@ -5,6 +5,7 @@
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <memory>
namespace tvm {
......@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
// Scan
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
int ScanOpNode::num_outputs() const {
return update.size();
}
Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis};
}
Type ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
Array<Expr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Operation ScanOpNode::make(std::string name,
IterVar axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder) {
auto n = std::make_shared<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
<< "shate_placeholder.shape[0] need to match"
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}
n->name = name;
n->scan_axis = axis;
n->init = init;
n->update = update;
n->state_placeholder = state_placeholder;
return Operation(n);
}
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
for (int i = 0; i < op->num_outputs(); ++i) {
res.push_back(op.output(i));
}
return res;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
p->stream << "scan(" << op->name << ", " << op << ")";
});
} // namespace tvm
......@@ -191,20 +191,16 @@ class VTInjector : public IRMutator {
}
// Attribute
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::scope) {
return Mutate(op->body);
Expr value = Mutate(op->value);
if (visit_touched_var_) {
return InjectVTLoop(s, true);
} else {
Expr value = Mutate(op->value);
if (visit_touched_var_) {
return InjectVTLoop(s, true);
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
Stmt body = Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
}
return AttrStmt::make(op->node, op->type_key, value, body);
}
}
}
......
......@@ -11,40 +11,6 @@
namespace tvm {
namespace ir {
// key of function buffer
struct TensorKey {
FunctionRef f;
int value_index;
inline bool operator==(const TensorKey& other) const {
return f == other.f && value_index == other.value_index;
}
inline std::string GetName() const {
if (f->num_outputs() == 1) return f->func_name();
std::ostringstream os;
os << f->func_name() << ".v" << value_index;
return os.str();
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::TensorKey> {
std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
size_t lhs = k.f.hash();
size_t rhs = static_cast<size_t>(k.value_index);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
namespace tvm {
namespace ir {
using Halide::Internal::Region;
using runtime::StorageScope;
using runtime::ThreadScope;
......
......@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return ir::Simplify((a + b - 1) / b);
}
inline bool prove_equal(Expr lhs, Expr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
// Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
......@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
} else {
CHECK(!state.count(r->outer));
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
if (!state.count(r->outer)) {
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
} else {
Expr outer_ext = DivCeil(range_parent->extent, r->factor);
Range outer_rng = state.at(r->outer);
bool match = is_zero(outer_rng->min);
if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
CHECK(match)
<< "IterVar is used in two places as outer scope,"
<< " cannot prove their extents are the same";
}
}
} else {
CHECK(r->outer->dom.defined());
......@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
}
}
// All the itervars that are needed to output bound of op.
// For most op, it is root_iter_vars
// For Scan, it also contains the additional spatial axis.
Array<IterVar> OutputRelatedIterVars(const Operation& op) {
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<IterVar> ret{scan->scan_axis};
for (IterVar iv : scan->spatial_axis_) {
ret.push_back(iv);
}
return ret;
} else {
return op->root_iter_vars();
}
}
/*! \brief temporary data structure to store Tensor domain */
struct TensorDom {
......@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
size_t sp_idx = 0;
for (size_t i = 0; i < scan->init.size(); ++i) {
TensorDom* init_dom = nullptr;
TensorDom* update_dom = nullptr;
if (out->count(scan->init[i])) {
init_dom = &out->at(scan->init[i]);
}
if (out->count(scan->update[i])) {
update_dom = &out->at(scan->update[i]);
}
// first dimension, always needed.
if (init_dom) {
init_dom->data[0].push_back(IntSet::range(
Range::make_with_min_extent(0, scan->init[i]->shape[0])));
}
// The update dimensions
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
if (init_dom) {
init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
}
if (update_dom) {
update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
}
}
}
} else if (op.as<PlaceholderOpNode>()) {
// do nothing
} else {
......@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
}
}
void InferOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
// Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
std::vector<Tensor> output(op->num_outputs());
for (size_t i = 0; i < output.size(); ++i) {
output[i] = op.output(i);
}
// Update for time axis.
std::vector<IntSet> time_dom;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
}
LOG(INFO) << time_dom.size();
CHECK(!rmap->count(scan->scan_axis));
Range sdom = scan->scan_axis->dom;
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
void GatherOpBound(const Operation& op,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
auto root_iter_vars = op->root_iter_vars();
const ComputeOpNode* compute = op.as<ComputeOpNode>();
const TensorDom& tdom = tmap.at(op.output(0));
for (size_t i = 0; i < compute->axis.size(); ++i) {
Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
CHECK(!rmap->count(compute->axis[i]));
......@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
CHECK(!rmap->count(compute->reduce_axis[i]));
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
......@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
auto root_iter_vars = stage->op->root_iter_vars();
for (auto iv : root_iter_vars) {
for (auto iv : OutputRelatedIterVars(stage->op)) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
......@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
PassUp(parent, *rmap, &up_state);
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : parent->op->root_iter_vars()) {
Range r = up_state.at(iv).cover_range(iv->dom);
for (auto iv : OutputRelatedIterVars(parent->op)) {
Range r;
if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom);
} else {
r = iv->dom;
}
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
} else {
......@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
CHECK(found)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : op->root_iter_vars()) {
for (auto iv : OutputRelatedIterVars(op)) {
Range r = rmap->at(iv);
dom_map[iv->var.get()] = EvalSet(r, relax_set);
}
BoundProp(op, dom_map, &tmap);
}
InferOpBound(stage->op, tmap, rmap);
GatherOpBound(stage->op, tmap, rmap);
}
FeedGraph CreateFeedGraph(const Schedule& sch) {
......
......@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
if (call != nullptr && call->func.defined()) {
Operation call_op(call->func.node_);
deps.push_back(call_op.output(call->value_index));
if (call_op.defined() && visited.count(call_op.get()) == 0) {
visited.insert(call_op.get());
stack.push_back(call_op);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
for (Tensor t : scan->init) {
deps.push_back(t);
}
for (Tensor t : scan->update) {
deps.push_back(t);
}
} else if (op.as<PlaceholderOpNode>()) {
// empty set of deps
rmap.Set(op, deps);
} else {
LOG(FATAL) << "unknown Operation" << op->type_key();
}
rmap.Set(op, deps);
for (Tensor t : deps) {
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
}
}
}
return rmap;
}
......
......@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
StageNode* self = operator->();
CHECK(!self->op.as<ScanOpNode>())
<< "Cannot reorder axis of scan";
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos;
......
import tvm
import numpy as np
def test_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
# schedule
s = tvm.Schedule(res.op)
num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
_, x = s[s_init].split(x, outer=thread_x)
_, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[s_update].split(x, outer=thread_x)
# one line to build the function.
def check_device(target):
codes = []
fscan = tvm.build(s, [X, res],
target, record_codes=codes,
name="myscan")
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
return
for c in codes[1:]:
print(c)
# launch the kernel.
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
fscan(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0))
tvm.init_opencl()
check_device("cuda")
if __name__ == "__main__":
test_scan()
......@@ -34,6 +34,20 @@ def test_tensor_reduce():
assert(str(C_loaded) == str(C))
def test_tensor_scan():
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), "t")
x = tvm.placeholder((m, n))
s = tvm.placeholder((m, n))
res = tvm.scan(t,
tvm.compute((1, n), lambda _, i: x[0, i]),
tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
s)
assert tuple(res.shape) == (m, n)
if __name__ == "__main__":
test_tensor()
test_tensor_reduce()
test_tensor_scan()
......@@ -18,9 +18,15 @@ def test_simplify():
tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i),
None)))
print(stmt)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
print(stmt)
def test_basic():
m = tvm.Var('m')
ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
assert str(ret.value) == "(m - 1)"
if __name__ == "__main__":
test_basic()
test_simplify()
......@@ -6,13 +6,11 @@ def test_schedule0():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1():
m = tvm.Var('m')
......@@ -25,7 +23,7 @@ def test_schedule1():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2():
m = tvm.Var('m')
......@@ -40,25 +38,45 @@ def test_schedule2():
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_scan():
m = tvm.Var("m")
n = tvm.Var("n")
l = tvm.Var("l")
t = tvm.IterVar((1, m), name="t")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
assert tuple(res.shape) == (m, n)
s = tvm.Schedule(res.op)
s.normalize()
bounds = tvm.schedule.InferBound(s)
assert(bounds[res.op.scan_axis].min.value == 1)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
def test_auto_inline():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
if __name__ == "__main__":
test_schedule_scan()
test_schedule0()
test_schedule1()
test_schedule2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment