Commit dc97e527 by Josh Pollock Committed by Tianqi Chen

[Relay][Text Format] Pretty Printer Smart Inlining (#2881)

parent fefbb006
......@@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
"""
THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT!
USE `.astext()` INSTEAD!
A version of the pretty printer intended for debugging passes. Contains
advanced printing options.
Parameters
----------
ast : Union[relay.Expr, relay.Module, relay.Type]
The relay fragment to be turned into text.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
gnf : bool
Whether to print in GNF. If it is disabled, pointers are left implicit.
Returns
-------
text : str
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)
def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
......
......@@ -22,12 +22,22 @@
* \file pretty_printer.cc
* \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"
namespace tvm {
......@@ -135,10 +145,8 @@ class PrettyPrinter :
public TypeFunctor<Doc(const Type&)>,
public AttrFunctor<Doc(const NodeRef&)> {
public:
explicit PrettyPrinter(bool GNF,
bool show_meta_data,
explicit PrettyPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) :
GNF_(GNF),
show_meta_data_(show_meta_data),
annotate_(annotate) {}
......@@ -150,10 +158,9 @@ class PrettyPrinter :
Doc doc;
// additional information in comment.
if (annotate_ != nullptr) {
return doc << " // " << annotate_(expr);
return doc << " /* " << annotate_(expr) << " */";
} else if (expr->checked_type_.defined()) {
doc << " // ty=";
return doc << Print(expr->checked_type());
return doc << " /* ty=" << Print(expr->checked_type()) << " */";
} else {
return doc;
}
......@@ -176,13 +183,18 @@ class PrettyPrinter :
// print in a new scope
doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node);
Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
return doc;
}
Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
......@@ -200,9 +212,9 @@ class PrettyPrinter :
Doc PrintAttrs(const Attrs& attrs, const Expr& op);
Doc Print(const NodeRef& node, bool meta = false) {
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta);
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
......@@ -308,7 +320,12 @@ class PrettyPrinter :
return val;
}
inline bool IsAtomicExpr(const Expr& expr) {
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
}
bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
}
......@@ -316,17 +333,25 @@ class PrettyPrinter :
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Doc PrintExpr(const Expr& expr, bool meta) {
Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.
// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}
auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;
Doc printed_expr;
if (meta) {
printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
} else if (GNF_ && expr.as<LetNode>()) {
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
......@@ -335,28 +360,26 @@ class PrettyPrinter :
} else {
printed_expr = VisitExpr(expr);
}
// we choose to inline atomic exprs
if (GNF_ && !IsAtomicExpr(expr)) {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr;
if (expr.as<CallNode>()) {
doc_stack_.back() << PrintOptionalInfo(expr);
printed_expr << PrintOptionalInfo(expr);
}
doc_stack_.back() << "\n";
return temp_var;
} else if (expr.as<VarNode>()) {
// add expr to doc
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar.
return memo_[expr];
} else {
} else if (inline_expr) {
memo_[expr] = printed_expr;
if (GNF_ && expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
return temp_var;
}
}
......@@ -420,8 +443,9 @@ class PrettyPrinter :
Doc VisitExpr_(const LetNode* op) final {
Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n";
doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
return doc;
}
......@@ -456,6 +480,8 @@ class PrettyPrinter :
Doc doc;
int counter = 0;
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
......@@ -664,8 +690,6 @@ class PrettyPrinter :
}
private:
/*! \brief Whether to use GNF. */
bool GNF_;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
......@@ -682,6 +706,10 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
DependencyGraph dg_;
class AttrPrinter;
friend class AttrPrinter;
};
......@@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.1" << "\n"
<< PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node);
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
std::string AsText(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate, true);
}
std::string PassDebugPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
return PrettyPrint_(node, show_meta_data, annotate, gnf);
return PrettyPrint_(node, show_meta_data, annotate);
}
TVM_REGISTER_API("relay._expr.AsText")
......@@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText")
bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
TVM_REGISTER_API("relay._ir_pass.pass_debug_print")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>,
bool)>(PassDebugPrint);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/pass/dependency_graph.cc
* \brief
*/
#include "dependency_graph.h"
#include <tvm/relay/expr_functor.h>
#include <unordered_set>
#include <utility>
namespace tvm {
namespace relay {
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(DependencyGraph::Node* parent, const Expr& child) {
VisitExpr(child);
CHECK_NE(graph_.expr_node.count(child), 0);
Depend(parent, graph_.expr_node[child]);
}
void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = parent;
child->parents.Push(parent_link);
auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
child_link->value = child;
parent->children.Push(child_link);
}
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
DependencyGraph::Node* NewNode(bool new_scope) {
auto* ret = arena_->make<DependencyGraph::Node>();
ret->new_scope = new_scope;
return ret;
}
void VisitExpr(const Expr& e) final {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
void VisitExpr_(const CallNode* c) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
Depend(n, c->op);
for (const auto& a : c->args) {
Depend(n, a);
}
}
void VisitExpr_(const TupleNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
for (const auto& a : t->fields) {
Depend(n, a);
}
}
void VisitExpr_(const TupleGetItemNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
Depend(n, t->tuple);
}
void VisitExpr_(const RefCreateNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->value);
}
void VisitExpr_(const RefReadNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
}
void VisitExpr_(const RefWriteNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
Depend(n, r->value);
}
void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true);
DependencyGraph::Node* f = NewNode(true);
Depend(n, i->cond);
Depend(n, t);
Depend(n, f);
Depend(t, i->true_branch);
Depend(f, i->false_branch);
graph_.post_dfs_order.push_back(f);
graph_.post_dfs_order.push_back(t);
}
void VisitExpr_(const FunctionNode* f) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const LetNode* l) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const MatchNode* m) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
Depend(n, m->data);
std::vector<DependencyGraph::Node*> v;
for (const Clause& c : m->clauses) {
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, c->rhs);
v.push_back(b);
}
for (auto it = v.rbegin(); it != v.rend(); ++it) {
graph_.post_dfs_order.push_back(*it);
}
}
void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { }
void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { }
void VisitExpr_(const ConstructorNode* c) final { }
};
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Create(body);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors.
* \file tvm/relay/pass/dependency_graph.h
* \brief
*/
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#include <tvm/relay/expr.h>
#include <unordered_map>
#include <vector>
#include "let_list.h"
#include "../../common/arena.h"
namespace tvm {
namespace relay {
using common::LinkNode;
using common::LinkedList;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class DependencyGraph {
public:
/*! \brief A node in the graph. */
struct Node {
// Determine scope boundaries. Used for calculating scopes, not for
// constructing dependency graph.
bool new_scope = false;
// incoming edges
LinkedList<Node*> children;
// outgoing edges
LinkedList<Node*> parents;
};
/*! \brief Maps a Relay Expr to its node in the dependency graph. */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> expr_node;
/*! \brief The dependency graph in post DFS order. */
std::vector<Node*> post_dfs_order;
/*!
* \brief Create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static DependencyGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
......@@ -29,193 +29,11 @@
#include "let_list.h"
#include "../../common/arena.h"
#include "pass_util.h"
#include "dependency_graph.h"
namespace tvm {
namespace relay {
using common::LinkNode;
using common::LinkedList;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class DependencyGraph {
public:
/*! \brief A node in the graph. */
struct Node {
bool new_scope = false;
LinkedList<Node*> input;
LinkedList<Node*> output;
};
/*! \brief The node map that maps node to graph */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> expr_node;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static DependencyGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(DependencyGraph::Node* parent, const Expr& child) {
VisitExpr(child);
CHECK_NE(graph_.expr_node.count(child), 0);
Depend(parent, graph_.expr_node[child]);
}
void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = parent;
child->output.Push(parent_link);
auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
child_link->value = child;
parent->input.Push(child_link);
}
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
DependencyGraph::Node* NewNode(bool new_scope) {
auto* ret = arena_->make<DependencyGraph::Node>();
ret->new_scope = new_scope;
return ret;
}
void VisitExpr(const Expr& e) final {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
void VisitExpr_(const CallNode* c) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
Depend(n, c->op);
for (const auto& a : c->args) {
Depend(n, a);
}
}
void VisitExpr_(const TupleNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
for (const auto& a : t->fields) {
Depend(n, a);
}
}
void VisitExpr_(const TupleGetItemNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
Depend(n, t->tuple);
}
void VisitExpr_(const RefCreateNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->value);
}
void VisitExpr_(const RefReadNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
}
void VisitExpr_(const RefWriteNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
Depend(n, r->value);
}
void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true);
DependencyGraph::Node* f = NewNode(true);
Depend(n, i->cond);
Depend(n, t);
Depend(n, f);
Depend(t, i->true_branch);
Depend(f, i->false_branch);
graph_.post_dfs_order.push_back(f);
graph_.post_dfs_order.push_back(t);
}
void VisitExpr_(const FunctionNode* f) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const LetNode* l) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const MatchNode* m) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
Depend(n, m->data);
std::vector<DependencyGraph::Node*> v;
for (const Clause& c : m->clauses) {
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, c->rhs);
v.push_back(b);
}
for (auto it = v.rbegin(); it != v.rend(); ++it) {
graph_.post_dfs_order.push_back(*it);
}
}
void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { }
void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { }
void VisitExpr_(const ConstructorNode* c) final { }
};
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Create(body);
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
struct ScopeNode;
......@@ -256,7 +74,7 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
Scope global_scope = std::make_shared<ScopeNode>();
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->output.head;
auto iit = n->parents.head;
Scope s;
if (iit == nullptr) {
s = global_scope;
......@@ -313,7 +131,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Scope GetSubScope(const Expr& e, size_t i) {
DependencyGraph::Node* n = dg_.expr_node.at(e);
auto h = n->input.head;
auto h = n->children.head;
while (i != 0) {
CHECK(h);
--i;
......
......@@ -50,8 +50,8 @@ def test_env():
text = env.astext()
assert "def @myf" in text
assert "def @myf" in str(env)
assert "%1 = add(%0, %0) // ty=float32" in text
assert "%1 = add(%0, %0) // ty=float32" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text
assert "add(%0, %0) /* ty=float32 */" in str(env)
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text)
......@@ -112,7 +112,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 6
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
......@@ -180,8 +180,19 @@ def test_call_node_order():
"%2 = fn (%x) {\n"
" %x\n"
"}\n"
"%3 = %2(%1)\n"
"%3")
"%2(%1)")
def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
("%0 = (0, 0)\n"
"let %x = %0\n"
"%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \
("let %x = (0, 0)\n"
"%x")
if __name__ == "__main__":
do_print[0] = True
......@@ -201,3 +212,4 @@ if __name__ == "__main__":
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_let_inlining()
......@@ -38,7 +38,7 @@ def test_unary_op():
x = relay.var("x", tp)
y = opfunc(x)
# test printer
assert ("%0 = {}(%x)".format(y.op.name)) in y.astext()
assert ("{}(%x)".format(y.op.name)) in y.astext()
# test type inference
assert relay.ir_pass.infer_type(y).checked_type == tp
......@@ -78,7 +78,7 @@ def test_binary_op():
y = relay.var("y", t2)
z = opfunc(x, y)
# test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1
if ref is not None:
......
......@@ -29,7 +29,7 @@ def test_binary_op():
y = relay.var("y", t2)
z = opfunc(x, y)
# test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext()
assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1
if ref is not None:
......
......@@ -44,7 +44,7 @@ def initialize_box_adt(mod):
def test_monomorphic_let():
"Program: let x = 1; x"
"Program: let %x = 1; %x"
sb = relay.ScopeBuilder()
x = sb.let('x', relay.const(1.0, "float64"))
sb.ret(x)
......@@ -53,7 +53,7 @@ def test_monomorphic_let():
def test_single_op():
"Program: fn (x : float32) { let t1 = f(x); t1 }"
"Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
x = relay.var('x', shape=[])
func = relay.Function([x], op.log(x))
ttype = relay.TensorType([], dtype='float32')
......@@ -63,8 +63,9 @@ def test_single_op():
def test_add_broadcast_op():
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
x + y
fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
-> Tensor[(5, 10, 4), float32] {
%x + %y
}
"""
x = relay.var('x', shape=(10, 4))
......@@ -80,10 +81,10 @@ def test_add_broadcast_op():
def test_dual_op():
"""Program:
fn (x : Tensor[f32, (10, 10)]) {
let t1 = log(x);
let t2 = add(t1, x);
t1
fn (%x : Tensor[(10, 10), float32]) {
let %t1 = log(x);
let %t2 = add(%t1, %x);
%t1
}
"""
tp = relay.TensorType((10, 10), "float32")
......@@ -99,8 +100,8 @@ def test_dual_op():
def test_decl():
"""Program:
def f(x : Tensor[(10, 10), f32]) {
log(x)
def @f(%x : Tensor[(10, 10), float32]) {
log(%x)
}
"""
tp = relay.TensorType((10, 10))
......@@ -113,11 +114,11 @@ def test_decl():
def test_recursion():
"""
Program:
def f(n: i32, data: f32) -> f32 {
if (n == 0) {
data
def @f(%n: int32, %data: float32) -> float32 {
if (%n == 0) {
%data
} else {
f(n - 1, log(data))
@f(%n - 1, log(%data))
}
}
"""
......@@ -134,7 +135,7 @@ def test_recursion():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
mod = relay.Module()
mod[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in mod.astext()
assert "@f(%1, %2) /* ty=float32 */" in mod.astext()
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment