Unverified Commit c113712d by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] Enable PlanMemory in the graph runtime. (#2120)

parent 6edb3564
...@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
/*! /*!
* \brief Print node as text format. * \brief Print node as text format.
* \param node The node to be printed. * \param node The node to be printed.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching * \param annotate An optional callback function for attaching
* additional comment block to an expr. * additional comment block to an expr.
* \return The text representation. * \return The text representation.
*/ */
std::string RelayPrint( std::string RelayPrint(
const NodeRef& node, const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None): ...@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
funcs : List[tvm.LoweredFunc] funcs : List[tvm.LoweredFunc]
The list of lowered functions. The list of lowered functions.
target : tvm.Target target : tvm.Target
The target to run the code on. The target to run the code on.
......
...@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system. ...@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import from __future__ import absolute_import
import json import json
import attr import attr
from . import _backend
from . import compile_engine from . import compile_engine
from ..op import Op from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor from ..expr import Function, GlobalVar, ExprFunctor
...@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = [] self.nodes = []
self.var_map = {} self.var_map = {}
self.params = {} self.params = {}
self.storage_map = None
self.compile_engine = compile_engine.get() self.compile_engine = compile_engine.get()
self.lowered_funcs = set() self.lowered_funcs = set()
self._name_map = {} self._name_map = {}
def add_node(self, node, checked_type): def add_node(self, node, expr):
""" """
Add a node to the graph. Add a node to the graph.
...@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
node: Node node: Node
The node to add to the graph. The node to add to the graph.
checked_type: Type expr: tvm.relay.Expr
The type of the node. The corresponding expression.
Returns Returns
------- -------
node_ref: Union[NodeRef, List[NodeRef]] node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node. A reference to the node.
""" """
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
node_id = len(self.nodes) node_id = len(self.nodes)
self.nodes.append(node) self.nodes.append(node)
# Tuple return value, flatten as tuple # Tuple return value, flatten as tuple
...@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
name = "p%d" % index name = "p%d" % index
self.params[name] = op.data self.params[name] = op.data
node = InputNode(name, {}) node = InputNode(name, {})
return self.add_node(node, op.checked_type) return self.add_node(node, op)
def visit_function(self, _): def visit_function(self, _):
raise RuntimeError("function not supported") raise RuntimeError("function not supported")
...@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name = cached_func.func_name op_name = cached_func.func_name
op_node = OpNode(self._get_unique_name(op_name), {}, op_node = OpNode(self._get_unique_name(op_name), {},
op_name, inputs, {}) op_name, inputs, {})
return self.add_node(op_node, call.checked_type) return self.add_node(op_node, call)
def _get_json(self): def _get_json(self):
""" """
...@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
assert node.num_outputs == len(node.attrs["shape"]) assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"] shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"] dltypes += node.attrs["dtype"]
for i in range(node.num_outputs): storage_ids += node.attrs["storage_id"]
storage_ids.append(i + num_entry)
num_entry += node.num_outputs num_entry += node.num_outputs
node_row_ptr.append(num_entry) node_row_ptr.append(num_entry)
...@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
return json.dumps(json_dict, indent=2) return json.dumps(json_dict, indent=2)
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func): def codegen(self, func):
"""Compile a single function into a graph. """Compile a single function into a graph.
...@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
params : Dict[str, tvm.nd.NDArray] params : Dict[str, tvm.nd.NDArray]
Additional constant parameters. Additional constant parameters.
""" """
self.storage_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes. # First we convert all the parameters into input nodes.
for param in func.params: for param in func.params:
node = InputNode(param.name_hint, {}) node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node( self.var_map[param] = self.add_node(
node, param.type_annotation) node, param)
# Then we compile the body into a graph which can depend # Then we compile the body into a graph which can depend
# on input variables. # on input variables.
......
...@@ -23,7 +23,7 @@ def register_relay_node(type_key=None): ...@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase): class RelayNode(NodeBase):
"""Base class of all relay node.""" """Base class of all relay node."""
def astext(self, annotate=None): def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression. """Get the text format of the expression.
Returns Returns
...@@ -31,11 +31,21 @@ class RelayNode(NodeBase): ...@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
text : str text : str
The text format of the expression. The text format of the expression.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str] annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional Optional annotate function to provide additional
information in the comment block. information in the comment block.
Note
----
meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big(constat weights),
so it can be helpful to skip printing the meta data section.
""" """
return _expr.RelayPrint(self, annotate) return _expr.RelayPrint(self, show_meta_data, annotate)
@register_relay_node @register_relay_node
......
/*!
* Copyright (c) 2018 by Contributors
* \file relay/backend/graph_mem_alloca.cc
* \brief Memory index assignment pass for executing
* the program in the graph runtime.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include "../../common/arena.h"
namespace tvm {
namespace relay {
struct StorageToken {
/*! \brief Reference counter */
int ref_counter{0};
/*! \brief number of bytes */
size_t max_bytes{0};
/*! \brief The corresponding tensor type node. */
const TensorTypeNode* ttype{nullptr};
/*! \brief virtual device index */
int device_id{0};
/*! \brief The storage id */
int64_t storage_id{-1};
};
class StorageAllocaBaseVisitor : public ExprVisitor {
public:
// run the visitor on a function.
void Run(const Function& func) {
for (Var param : func->params) {
CreateToken(param.operator->(), false);
}
this->VisitExpr(func->body);
}
void VisitExpr_(const ConstantNode* op) final {
this->CreateToken(op, false);
}
void VisitExpr_(const VarNode* op) final {
// Do nothing.
}
void VisitExpr_(const FunctionNode* op) final {
// do not recursive into sub function.
}
void VisitExpr_(const GlobalVarNode* op) final {
// Do nothing.
}
void VisitExpr_(const OpNode* op) final {
// Do nothing.
}
void VisitExpr_(const TupleNode* op) final {
std::vector<StorageToken*> fields;
for (Expr field : op->fields) {
auto tok = GetToken(field);
CHECK_EQ(tok.size(), 1U);
fields.push_back(tok[0]);
}
token_map_[op] = fields;
}
void VisitExpr_(const TupleGetItemNode* op) final {
const auto& tok = GetToken(op->tuple);
CHECK_LT(static_cast<size_t>(op->index), tok.size());
token_map_[op] = {tok[op->index]};
}
void VisitExpr_(const IfNode* op) final {
LOG(FATAL) << "if is not supported.";
}
void VisitExpr_(const LetNode* op) final {
auto token = GetToken(op->value);
token_map_[op->var.operator->()] = token;
token_map_[op] = GetToken(op->body);
}
protected:
/*! \brief internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > token_map_;
/*!
* \brief Get the necessary token.
* \param expr The expression.
* \return The corresponding token.
*/
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
this->VisitExpr(expr);
auto it = token_map_.find(expr.operator->());
CHECK(it != token_map_.end());
return it->second;
}
/*!
* \brief Populate the token map to set op's tokens
* \param op The node to be processed.
* \param can_realloc Whether we can re-allocate the memory.
*/
virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
};
class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public:
explicit StorageAllocaInit(common::Arena* arena)
: arena_(arena) {}
/*! \return The internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> >
GetInitTokenMap(const Function& func) {
this->Run(func);
return std::move(token_map_);
}
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
tokens.push_back(token);
}
} else {
const auto* ttype = op->checked_type().as<TensorTypeNode>();
CHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
tokens.push_back(token);
}
token_map_[op] = tokens;
}
void VisitExpr_(const CallNode* op) final {
// create token for the call node.
CreateToken(op, true);
// for each input, visit argument token.
for (Expr arg : op->args) {
for (StorageToken* tok : GetToken(arg)) {
tok->ref_counter += 1;
}
}
}
private:
// allocator
common::Arena* arena_;
};
class StorageAllocator : public StorageAllocaBaseVisitor {
public:
/*!
* \return totoal number of bytes allocated
*/
size_t TotalAllocBytes() const {
size_t total = 0;
for (const auto* p : data_) {
total += p->max_bytes;
}
return total;
}
// Run storage allocation for a function.
Map<Expr, Array<Integer> > Plan(const Function& func) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func);
Map<Expr, Array<Integer> > smap;
for (const auto& kv : token_map_) {
Array<Integer> vec;
for (StorageToken* tok : kv.second) {
vec.push_back(tok->storage_id);
}
smap.Set(GetRef<Expr>(kv.first), vec);
}
return smap;
}
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
// override create token by getting token as prototype requirements.
void CreateToken(const ExprNode* op, bool can_realloc) final {
CHECK(!token_map_.count(op));
auto it = prototype_.find(op);
CHECK(it != prototype_.end());
std::vector<StorageToken*> tokens;
for (StorageToken* tok : it->second) {
if (can_realloc) {
tokens.push_back(Request(tok));
} else {
// Allocate a new token,
StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
// ensure it never get de-allocated.
allocated_tok->ref_counter += 1;
tokens.push_back(allocated_tok);
}
}
token_map_[op] = tokens;
}
// The call map
void VisitExpr_(const CallNode* op) final {
std::vector<StorageToken*> args;
// for each input, visit argument token.
for (Expr arg : op->args) {
for (StorageToken* tok : GetToken(arg)) {
args.push_back(tok);
}
}
// create token for the call node.
CreateToken(op, true);
// check if there is orphaned output that can be released immediately.
for (StorageToken* tok : token_map_.at(op)) {
CheckForRelease(tok);
}
for (StorageToken* tok : args) {
tok->ref_counter -= 1;
CheckForRelease(tok);
}
}
/*!
* \brief ceil(size/word_size) to get number of words.
* \param size The original size.
* \param word_size The element size.
*/
static size_t DivRoundUp(size_t size, size_t word_size) {
return (size + word_size - 1) / word_size;
}
/*!
* \brief Get the memory requirement.
* \param prototype The prototype token.
* \return The required memory size.
*/
size_t GetMemorySize(StorageToken* prototype) {
const TensorTypeNode* ttype = prototype->ttype;
CHECK(ttype != nullptr);
size_t size = 1;
for (IndexExpr dim : ttype->shape) {
const int64_t* pval = as_const_int(dim);
CHECK(pval != nullptr)
<< "Cannot allocate memory symbolic tensor shape "
<< ttype->shape;
size *= static_cast<size_t>(pval[0]);
}
size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
return size;
}
/*!
* \brief Request a storage token for a given prototype.
* \param prototype. The prototype storage token.
* \return The result token.
*/
StorageToken* Request(StorageToken* prototype) {
// calculate the size;
size_t size = GetMemorySize(prototype);
// search memory block in [size / match_range_, size * match_range_)
if (match_range_ == 0) {
return this->Alloc(prototype, size);
}
auto begin = free_.lower_bound(size / match_range_);
auto mid = free_.lower_bound(size);
auto end = free_.upper_bound(size * match_range_);
// search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) {
StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes);
tok->ref_counter = prototype->ref_counter;
// find a exact match, erase from map and return
free_.erase(it);
return tok;
}
// then search for memory blocks smaller than requested space
for (auto it = mid; it != begin;) {
--it;
StorageToken *tok = it->second;
if (tok->device_id != prototype->device_id) continue;
CHECK_EQ(tok->ref_counter, 0);
// Use exect matching strategy
tok->max_bytes = std::max(size, tok->max_bytes);
tok->ref_counter = prototype->ref_counter;
// erase from map and return
free_.erase(it);
return tok;
}
// cannot find anything return a new one.
return this->Alloc(prototype, size);
}
/*!
* \brief Allocate a storage token by consuming prototype
* \param prototype The prototype token.
* \param size The size of memory being requested.
*/
StorageToken* Alloc(StorageToken* prototype, size_t size) {
prototype->max_bytes = size;
prototype->storage_id = static_cast<int64_t>(data_.size());
data_.push_back(prototype);
return prototype;
}
/*!
* \brief Check if we can release token.
* \tok The token to be released.
*/
void CheckForRelease(StorageToken* tok) {
CHECK_GE(tok->storage_id, 0);
CHECK_GE(tok->ref_counter, 0);
if (tok->ref_counter == 0) {
free_.insert({tok->max_bytes, tok});
}
}
private:
// allocator
common::Arena arena_;
// scale used for rough match
size_t match_range_{16};
// free list of storage entry
std::multimap<size_t, StorageToken*> free_;
// all the storage resources available
std::vector<StorageToken*> data_;
/*! \brief internal prototype token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
};
Map<Expr, Array<Integer> > GraphPlanMemory(const Function& func) {
return StorageAllocator().Plan(func);
}
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<Integer> >(const Function&)>(GraphPlanMemory);
} // namespace relay
} // namespace tvm
...@@ -113,6 +113,11 @@ class TextMetaDataContext { ...@@ -113,6 +113,11 @@ class TextMetaDataContext {
return SaveJSON(Array<NodeRef>(meta_data_)); return SaveJSON(Array<NodeRef>(meta_data_));
} }
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private: private:
/*! \brief additional metadata stored in TVM json format */ /*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_; std::vector<NodeRef> meta_data_;
...@@ -125,8 +130,9 @@ class TextPrinter : ...@@ -125,8 +130,9 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public: public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate) explicit TextPrinter(bool show_meta_data,
: annotate_(annotate) {} runtime::TypedPackedFunc<std::string(Expr)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate) {}
/*! /*!
* \brief Print a node to string. * \brief Print a node to string.
* \param node. * \param node.
...@@ -144,13 +150,17 @@ class TextPrinter : ...@@ -144,13 +150,17 @@ class TextPrinter :
} else { } else {
stream_ << node; stream_ << node;
} }
if (!meta_.empty()) {
if (show_meta_data_) {
std::string meta_json = meta_.GetMetaSection(); std::string meta_json = meta_.GetMetaSection();
if (meta_json.length() != 0) {
// append meta data in the end. // append meta data in the end.
stream_ << "# meta data\n" stream_ << "# meta data\n"
<< "r\"\"\"\n" << "r\"\"\"\n"
<< meta_json << "\n" << meta_json << "\n"
<< "\"\"\""; << "\"\"\"";
} else {
stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n";
}
} }
return stream_.str(); return stream_.str();
} }
...@@ -227,7 +237,9 @@ class TextPrinter : ...@@ -227,7 +237,9 @@ class TextPrinter :
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op)); stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
this->PrintEndInst("\n"); this->PrintEndInst("");
this->PrintOptionalInfo(GetRef<Expr>(op));
stream_ << '\n';
return id; return id;
} }
...@@ -697,6 +709,8 @@ class TextPrinter : ...@@ -697,6 +709,8 @@ class TextPrinter :
private: private:
class AttrPrinter; class AttrPrinter;
friend class AttrPrinter; friend class AttrPrinter;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */ /*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_; runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */ /*! \brief meta data context */
...@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op, ...@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
} }
std::string RelayPrint(const NodeRef& node, std::string RelayPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node); return TextPrinter(show_meta_data, annotate).Print(node);
} }
TVM_REGISTER_API("relay._expr.RelayPrint") TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string( .set_body_typed<std::string(
const NodeRef&, const NodeRef&, bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint); runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay } // namespace relay
......
...@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator { ...@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
} }
// Debug function, dump the group assignment in text. // Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) { void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string { std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
auto it = gmap_.find(expr.get()); auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return ""; if (it == gmap_.end()) return "";
std::ostringstream os; std::ostringstream os;
......
...@@ -77,7 +77,9 @@ def test_add_op_broadcast(): ...@@ -77,7 +77,9 @@ def test_add_op_broadcast():
def test_with_params(): def test_with_params():
x = relay.var('x', shape=(10, 5)) x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5)) y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y)) z = relay.add(x, y)
z = relay.exp(z)
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data} params = {"y": y_data}
...@@ -87,11 +89,40 @@ def test_with_params(): ...@@ -87,11 +89,40 @@ def test_with_params():
mod.set_input(x=x_data) mod.set_input(x=x_data)
mod.run() mod.run()
res = mod.get_output(0).asnumpy() res = mod.get_output(0).asnumpy()
ref_res = y_data + x_data ref_res = np.exp(y_data + x_data)
tvm.testing.assert_allclose(res, ref_res) tvm.testing.assert_allclose(res, ref_res)
def test_plan_memory():
# it is sufficient to cycle through two memories.
x = relay.var("x", shape=(10,))
y = relay.var("x", shape=(1,))
y2 = relay.exp(y)
z = relay.add(x, y2)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
func = relay.Function([x, y], z)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.fuse_ops(func, opt_level=0)
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set()
for k, v in smap.items():
for x in v:
storage_ids.add(x.value)
# Current rule requires vars have unique storage id
# because we don't do inplace, we will need another
# two alternating temporary space.
assert len(storage_ids) == 4
if __name__ == "__main__": if __name__ == "__main__":
test_plan_memory()
test_with_params() test_with_params()
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment