Commit 2548cedc by Tianqi Chen Committed by GitHub

[OP/LANG] Support Extern Call, more regression tests (#69)

* [OP/LANG] Support Extern Call, more regression tests

* [TEST] Include pylintrc
parent b19e01bf
ROOTDIR = $(CURDIR)
ifndef config ifndef config
ifneq ("$(wildcard ./config.mk)","") ifneq ("$(wildcard ./config.mk)","")
config ?= config.mk config ?= config.mk
...@@ -9,7 +11,7 @@ endif ...@@ -9,7 +11,7 @@ endif
include $(config) include $(config)
# specify tensor path # specify tensor path
.PHONY: clean all test doc .PHONY: clean all test doc pylint cpplint lint
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
...@@ -99,8 +101,13 @@ $(LIB_HALIDE_IR): LIBHALIDEIR ...@@ -99,8 +101,13 @@ $(LIB_HALIDE_IR): LIBHALIDEIR
LIBHALIDEIR: LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR) + cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
lint: cpplint:
python2 dmlc-core/scripts/lint.py tvm all include src python python2 dmlc-core/scripts/lint.py tvm cpp include src
pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
lint: cpplint pylint
doc: doc:
doxygen docs/Doxyfile doxygen docs/Doxyfile
......
...@@ -98,6 +98,8 @@ constexpr const char* loop_scope = "loop_scope"; ...@@ -98,6 +98,8 @@ constexpr const char* loop_scope = "loop_scope";
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope"; constexpr const char* scan_init_scope = "scan_init_scope";
/*! \brief extern operator scope */
constexpr const char* extern_op_scope = "extern_op_scope";
// Pipeline related attributes // Pipeline related attributes
/*! \brief channel read scope */ /*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope"; constexpr const char* channel_read_scope = "channel_read_scope";
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "./tensor.h" #include "./tensor.h"
#include "./schedule.h" #include "./schedule.h"
#include "./arithmetic.h" #include "./arithmetic.h"
#include "./buffer.h"
namespace tvm { namespace tvm {
...@@ -307,6 +308,62 @@ class ScanOpNode : public OperationNode { ...@@ -307,6 +308,62 @@ class ScanOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
}; };
/*!
* \brief External computation that cannot be splitted.
*/
class ExternOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representationinputs */
Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders;
/*! \brief the statement that generates the computation. */
Stmt body;
/*! \brief constructor */
ExternOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("inputs", &inputs);
v->Visit("body", &body);
}
static Operation make(std::string name,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body);
static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode);
};
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
......
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, no-member # pylint: disable=invalid-name
""" ctypes library of nnvm and helper functions """ """ ctypes library of nnvm and helper functions """
from __future__ import absolute_import from __future__ import absolute_import
......
# pylint: disable=invalid-name
"""Util to compile with C++ code""" """Util to compile with C++ code"""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys import sys
import subprocess import subprocess
......
# pylint: disable=invalid-name, too-many-locals # pylint: disable=invalid-name
"""Util to compile with NVCC""" """Util to compile with NVCC"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os import os
......
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable, unused-import
"""Functions defined in TVM.""" """Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
...@@ -162,8 +161,8 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -162,8 +161,8 @@ def scan(init, update, state_placeholder, name="scan"):
Returns Returns
------- -------
tensor: tensor.Tensor tensor: Tensor or list of Tensors
The created tensor The created tensor or tuple of tensors it it contains multiple outputs.
Example Example
------- -------
...@@ -187,7 +186,77 @@ def scan(init, update, state_placeholder, name="scan"): ...@@ -187,7 +186,77 @@ def scan(init, update, state_placeholder, name="scan"):
axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3) axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res) return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute,
name="extern", dtype=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: Shape tuple or list of shapes.
The shape of the outputs.
inputs: list of Tensor
The inputs
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
name: str, optional
The name hint of the tensor
dtype: str or list of str, optional
The data types of outputs,
by default dtype will be same as inputs.
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
"""
if isinstance(shape[0], _expr.Expr):
shape = [shape]
input_placeholders = []
output_placeholders = []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
input_placeholders.append(
Buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
if len(types) != 1:
raise ValueError("Cannot infer output type, please provide dtype argument")
infered_type = types.pop()
dtype = [infered_type for _ in shape]
for shp, dt in zip(shape, dtype):
output_placeholders.append(Buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, _expr.Expr):
body = _make.Evaluate(body)
op = _api_internal._ExternOp(
name, inputs, input_placeholders, output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
return res[0] if len(res) == 1 else res
def call_packed(*args):
"""Build expression by call an external packed function
Parameters
----------
args : list
Positional arguments.
"""
args = convert(args)
return _make.Call(
int32, "tvm_call_packed", args, 4, None, 0)
def Buffer(shape, dtype=None, def Buffer(shape, dtype=None,
......
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility""" """Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Eventually some of these pipelines will be moved to C++. Eventually some of these pipelines will be moved to C++.
But the first pipeline will be kept in python for ease of change and evolving. But the first pipeline will be kept in python for ease of change and evolving.
""" """
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
from . import api from . import api
from . import tensor from . import tensor
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
......
# pylint: disable=protected-access, no-member, missing-docstring """Expression class"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import make as _make from . import make as _make
......
"""Runtime module related stuffs""" """Runtime module related stuffs"""
# pylint: disable=unused-import, invalid-name, undefined-variable
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._function import ModuleBase, _init_module_module from ._ctypes._function import ModuleBase, _init_module_module
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This is a simplified runtime API for quick testing and proptyping. This is a simplified runtime API for quick testing and proptyping.
""" """
# pylint: disable=unused-import, invalid-name # pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as _np import numpy as _np
......
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
......
# pylint: disable=protected-access, no-member, missing-docstring """Statement classes"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
......
# pylint: disable=protected-access, no-member, invalid-name
"""Tensor related abstractions""" """Tensor related abstractions"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
...@@ -90,3 +89,8 @@ class ComputeOp(Operation): ...@@ -90,3 +89,8 @@ class ComputeOp(Operation):
class ScanOp(Operation): class ScanOp(Operation):
"""Scan operation.""" """Scan operation."""
pass pass
@register_node
class ExternOp(Operation):
"""Extern operation."""
pass
...@@ -183,6 +183,15 @@ TVM_REGISTER_API(_ScanOp) ...@@ -183,6 +183,15 @@ TVM_REGISTER_API(_ScanOp)
args[4]); args[4]);
}); });
TVM_REGISTER_API(_ExternOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API(_OpGetOutput) TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output( *ret = args[0].operator Operation().output(
......
...@@ -1236,6 +1236,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1236,6 +1236,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo()); buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo());
CHECK(!var_map_.count(op->buffer_var.get())); CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf; var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
......
...@@ -8,9 +8,8 @@ ...@@ -8,9 +8,8 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set> #include <unordered_set>
#include "./make_loop.h" #include "./op_util.h"
namespace tvm { namespace tvm {
...@@ -101,40 +100,12 @@ Array<Tensor> ComputeOpNode::InputTensors() const { ...@@ -101,40 +100,12 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
return ret; return ret;
} }
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Operation ComputeOpNode::ReplaceInputs( Operation ComputeOpNode::ReplaceInputs(
const Operation& self, const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const { const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(self.operator->(), this);
TensorReplacer repl(rmap); Expr new_body = op::ReplaceTensor(this->body, rmap);
Expr new_body = repl.Mutate(this->body); if (!new_body.same_as(this->body)) {
if (repl.found) {
return ComputeOpNode::make(name, axis, new_body); return ComputeOpNode::make(name, axis, new_body);
} else { } else {
return self; return self;
......
/*!
* Copyright (c) 2017 by Contributors
* \brief External computation rule.
* \file extern_op.cc
*/
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <unordered_set>
#include "./op_util.h"
namespace tvm {
using namespace ir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ExternOpNode *op, IRPrinter *p) {
p->stream << "extern(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(ExternOpNode);
int ExternOpNode::num_outputs() const {
return static_cast<int>(output_placeholders.size());
}
Array<IterVar> ExternOpNode::root_iter_vars() const {
return {};
}
Type ExternOpNode::output_dtype(size_t i) const {
return output_placeholders[i]->dtype;
}
Array<Expr> ExternOpNode::output_shape(size_t i) const {
return output_placeholders[i]->shape;
}
Operation ExternOpNode::make(std::string name,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
auto n = std::make_shared<ExternOpNode>();
n->name = name;
CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
}
n->inputs = inputs;
n->input_placeholders = input_placeholders;
n->output_placeholders = output_placeholders;
n->body = body;
return Operation(n);
}
Array<Tensor> ExternOpNode::InputTensors() const {
return inputs;
}
Operation ExternOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
auto n = std::make_shared<ExternOpNode>(*this);
n->body = op::ReplaceTensor(this->body, rmap);
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
n->inputs.Set(i, rmap.at(t));
}
}
if (body.same_as(n->body) &&
inputs.same_as(n->inputs)) {
return self;
} else {
return Operation(n);
}
}
void ExternOpNode::PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom& dom = it->second;
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_with_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i])));
}
}
}
void ExternOpNode::GatherBound(
const Operation& self,
const GraphContext& graph_ctx,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const {
}
Stmt ExternOpNode::BuildRealize(
const Operation& self,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(self.operator->(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = self.output(k);
Halide::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_with_min_extent(
make_const(t->shape[i].type(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
return realize_body;
}
Stmt ExternOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
return AttrStmt::make(
stage->op, ir::attr::extern_op_scope,
StringImm::make(name), body);
}
} // namespace tvm
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \brief Utility to make loop nest. * \brief Utility to make loop nest.
* \file make_loop.cc * \file op_util.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include "./make_loop.h" #include <tvm/ir_mutator.h>
#include "./op_util.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -231,5 +232,45 @@ std::vector<Stmt> MakeBoundCheck( ...@@ -231,5 +232,45 @@ std::vector<Stmt> MakeBoundCheck(
return nest; return nest;
} }
// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
public:
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
Expr Mutate_(const ir::Call* op, const Expr& e) {
if (op->call_type == ir::Call::Halide) {
Tensor t = Operation(op->func.node_).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
op->type, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
}
}
return IRMutator::Mutate_(op, e);
}
// whether it is found.
bool found{false};
private:
const std::unordered_map<Tensor, Tensor>& vmap_;
};
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Stmt ret = repl.Mutate(stmt);
return repl.found ? ret : stmt;
}
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
Expr ret = repl.Mutate(expr);
return repl.found ? ret : expr;
}
} // namespace op } // namespace op
} // namespace tvm } // namespace tvm
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file make_loop.h * \file op_util.h
* \brief Utility to make loop nest from schedule stage info. * \brief Common utility used in operator construction.
*/ */
#ifndef TVM_OP_MAKE_LOOP_H_ #ifndef TVM_OP_OP_UTIL_H_
#define TVM_OP_MAKE_LOOP_H_ #define TVM_OP_OP_UTIL_H_
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
...@@ -50,6 +50,22 @@ MakeBoundCheck(const Stage& stage, ...@@ -50,6 +50,22 @@ MakeBoundCheck(const Stage& stage,
bool skip_ivar_domain, bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map); const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Replace the tensor reference in stmt by the replace map.
* \param stmt The statement to be processed.
* \param replace The replacement rule.
*/
Stmt ReplaceTensor(Stmt stmt,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference in expr by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
Expr ReplaceTensor(Expr expr,
const std::unordered_map<Tensor, Tensor>& replace);
} // namespace op } // namespace op
} // namespace tvm } // namespace tvm
#endif // TVM_OP_MAKE_LOOP_H_ #endif // TVM_OP_OP_UTIL_H_
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./make_loop.h" #include "./op_util.h"
#include "../schedule/graph.h" #include "../schedule/graph.h"
namespace tvm { namespace tvm {
......
...@@ -89,7 +89,7 @@ class AllocateLifter : public IRMutator { ...@@ -89,7 +89,7 @@ class AllocateLifter : public IRMutator {
}; };
Stmt LiftAllocate(Stmt stmt) { Stmt LiftAllocate(Stmt stmt) {
return AllocateLifter().Mutate(stmt); return AllocateLifter().Lift(stmt);
} }
} // namespace ir } // namespace ir
......
...@@ -3,8 +3,11 @@ ...@@ -3,8 +3,11 @@
* \file storage_flatten.cc * \file storage_flatten.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/operation.h>
#include <unordered_map> #include <unordered_map>
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
...@@ -25,6 +28,16 @@ class StorageFlattener : public IRMutator { ...@@ -25,6 +28,16 @@ class StorageFlattener : public IRMutator {
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
} }
} }
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Store::make(it->second, op->value, op->index);
} else {
return stmt;
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == attr::realize_scope) { if (op->type_key == attr::realize_scope) {
...@@ -37,6 +50,8 @@ class StorageFlattener : public IRMutator { ...@@ -37,6 +50,8 @@ class StorageFlattener : public IRMutator {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back(); curr_thread_scope_.pop_back();
return stmt; return stmt;
} else if (op->type_key == attr::extern_op_scope) {
return HandleExternOp(op);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
...@@ -95,6 +110,26 @@ class StorageFlattener : public IRMutator { ...@@ -95,6 +110,26 @@ class StorageFlattener : public IRMutator {
} }
} }
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Load::make(op->type, it->second, op->index);
} else {
return expr;
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = extern_buf_remap_.find(op);
if (it != extern_buf_remap_.end()) {
return it->second;
} else {
return e;
}
}
Expr Mutate_(const Call* op, const Expr& olde) final { Expr Mutate_(const Call* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde); Expr expr = IRMutator::Mutate_(op, olde);
op = expr.as<Call>(); op = expr.as<Call>();
...@@ -113,6 +148,28 @@ class StorageFlattener : public IRMutator { ...@@ -113,6 +148,28 @@ class StorageFlattener : public IRMutator {
} }
private: private:
Stmt HandleExternOp(const AttrStmt* op) {
const ExternOpNode* ext_op = op->node.as<ExternOpNode>();
CHECK(ext_op);
Operation func(op->node.node_);
CHECK_EQ(extern_buf_remap_.size(), 0U);
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
TensorKey key{func, static_cast<int>(i)};
CHECK(buf_map_.count(key));
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
}
for (size_t i = 0; i < ext_op->inputs.size(); ++i) {
TensorKey key{ext_op->inputs[i]->op, ext_op->inputs[i]->value_index};
CHECK(buf_map_.count(key));
extern_buf_remap_[ext_op->input_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
}
Stmt ret = Mutate(op->body);
extern_buf_remap_.clear();
return ret;
}
// The buffer entry in the flatten map // The buffer entry in the flatten map
struct BufferEntry { struct BufferEntry {
// the buffer of storage // the buffer of storage
...@@ -139,6 +196,7 @@ class StorageFlattener : public IRMutator { ...@@ -139,6 +196,7 @@ class StorageFlattener : public IRMutator {
} }
}; };
// The buffer assignment map // The buffer assignment map
std::unordered_map<const Variable*, Var> extern_buf_remap_;
std::unordered_map<TensorKey, BufferEntry> buf_map_; std::unordered_map<TensorKey, BufferEntry> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_; std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
......
import tvm
import numpy as np
def test_add_pipeline():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline"""
i = tvm.Var('i')
stmt = tvm.make.For(
i, 0, n, 0, 0,
tvm.make.Store(outs[0].data,
tvm.make.Load(A.dtype, ins[0].data, i) +
1, i))
return stmt
C = tvm.extern(A.shape, [A], extern_generator, name='C')
s = tvm.Schedule(C.op)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1)
check_llvm()
if __name__ == "__main__":
test_add_pipeline()
...@@ -8,10 +8,7 @@ def test_llvm_add_pipeline(): ...@@ -8,10 +8,7 @@ def test_llvm_add_pipeline():
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
print(s[C])
print("a?")
xo, xi = s[C].split(C.op.axis[0], factor=4) xo, xi = s[C].split(C.op.axis[0], factor=4)
print("a?")
s[C].parallel(xo) s[C].parallel(xo)
s[C].vectorize(xi) s[C].vectorize(xi)
def check_llvm(): def check_llvm():
...@@ -83,12 +80,31 @@ def test_llvm_madd_pipeline(): ...@@ -83,12 +80,31 @@ def test_llvm_madd_pipeline():
check_llvm(4, 0, 1) check_llvm(4, 0, 1)
check_llvm(4, 0, 3) check_llvm(4, 0, 3)
def test_llvm_temp_space():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A(i) + 1, name='B')
C = tvm.compute(A.shape, lambda i: B(i) + 1, name='C')
s = tvm.Schedule(C.op)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1 + 1)
check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
print("a")
test_llvm_add_pipeline() test_llvm_add_pipeline()
print("a")
test_llvm_flip_pipeline() test_llvm_flip_pipeline()
print("a")
test_llvm_madd_pipeline() test_llvm_madd_pipeline()
test_llvm_temp_space()
import tvm import tvm
import numpy as np import numpy as np
def tvm_call_packed(*args):
args = tvm.convert(args)
return tvm.make.Call("int32", "tvm_call_packed", args, 4, None, 0)
def run_jit(fapi, check): def run_jit(fapi, check):
for target in ["llvm", "stackvm"]: for target in ["llvm", "stackvm"]:
if not tvm.codegen.enabled(target): if not tvm.codegen.enabled(target):
...@@ -24,7 +19,7 @@ def test_stack_vm_basic(): ...@@ -24,7 +19,7 @@ def test_stack_vm_basic():
n = tvm.Var('n') n = tvm.Var('n')
Ab = tvm.Buffer((n, ), tvm.float32) Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
...@@ -46,7 +41,7 @@ def test_stack_vm_loop(): ...@@ -46,7 +41,7 @@ def test_stack_vm_loop():
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1), i + 1),
tvm.make.Evaluate(tvm_call_packed("tvm_stack_vm_print", i)))) tvm.make.Evaluate(tvm.call_packed("tvm_stack_vm_print", i))))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
......
...@@ -80,6 +80,30 @@ def test_scan_multi_out(): ...@@ -80,6 +80,30 @@ def test_scan_multi_out():
zz = tvm.load_json(json_str) zz = tvm.load_json(json_str)
assert isinstance(zz, tvm.tensor.ScanOp) assert isinstance(zz, tvm.tensor.ScanOp)
def test_extern():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
def extern_func(ins, outs):
assert(isinstance(ins[0], tvm.schedule.Buffer))
return tvm.call_packed("myadd", ins[0].data, outs[0].data, m)
B = tvm.extern((m,), [A], extern_func)
assert(tuple(B.shape) == (m,))
def test_extern_multi_out():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] * 10)
def extern_func(ins, outs):
assert(isinstance(ins[0], tvm.schedule.Buffer))
return tvm.call_packed(
"myadd", ins[0].data, outs[0].data, outs[1].data, m)
res = tvm.extern([A.shape, A.shape], [A, B], extern_func)
assert(len(res) == 2)
assert(res[1].value_index == 1)
if __name__ == "__main__": if __name__ == "__main__":
test_conv1d() test_conv1d()
...@@ -88,3 +112,5 @@ if __name__ == "__main__": ...@@ -88,3 +112,5 @@ if __name__ == "__main__":
test_tensor_reduce() test_tensor_reduce()
test_tensor_scan() test_tensor_scan()
test_scan_multi_out() test_scan_multi_out()
test_extern()
test_extern_multi_out()
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
make lint || exit -1 echo "Check codestyle of c++ code..."
make cpplint || exit -1
echo "Check codestyle of python code..."
make pylint || exit -1
echo "Check documentations of c++ code..." echo "Check documentations of c++ code..."
make doc 2>log.txt make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt (cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment