Commit 06108bed by 雾雨魔理沙 Committed by Tianqi Chen

[RELAY] First pass at pretty printer (#1749)

parent 32af4d28
...@@ -158,6 +158,8 @@ class RelayNode : public Node { ...@@ -158,6 +158,8 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
}; };
struct Environment;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -376,6 +376,14 @@ class IfNode : public ExprNode { ...@@ -376,6 +376,14 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr); RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
* \param os the stream
* \returns A reference to the stream.
*/
std::ostream& DebugPrint(const Environment& env, const Expr& e, std::ostream& os);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./op.h" #include "./op.h"
#include "./error.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -89,7 +90,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -89,7 +90,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op, virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Node* op, Args...) {
throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); throw Error(std::string("Do not have a default for ") + op->type_key());
} }
private: private:
......
...@@ -365,6 +365,14 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -365,6 +365,14 @@ class TypeRelationNode : public TypeConstraintNode {
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*! \brief Print a debug representation of the type to the stream.
* \param env The environment.
* \param t The type
* \param os the stream
* \returns A reference to the stream.
*/
std::ostream& DebugPrint(const Environment& env, const Type& t, std::ostream& os);
// The following fields contains advanced typing // The following fields contains advanced typing
// Only keep the class name and reserved for future usage. // Only keep the class name and reserved for future usage.
class GenericTensorType; class GenericTensorType;
......
...@@ -33,4 +33,3 @@ Function = expr.Function ...@@ -33,4 +33,3 @@ Function = expr.Function
Call = expr.Call Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
Var = Var
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._expr", __name__)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node from .base import NodeBase, register_relay_node
from . import _expr
from . import _make from . import _make
from .. import convert from .. import convert
...@@ -115,3 +116,5 @@ class If(Expr): ...@@ -115,3 +116,5 @@ class If(Expr):
def __init__(self, cond, true_value, false_value): def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value) _make.If, cond, true_value, false_value)
debug_print = _expr._debug_print
...@@ -111,4 +111,4 @@ class If(Expr): ...@@ -111,4 +111,4 @@ class If(Expr):
def __init__(self, cond, true_value, false_value): def __init__(self, cond, true_value, false_value):
# type: (Expr, Expr, Expr) -> None # type: (Expr, Expr, Expr) -> None
... ...
\ No newline at end of file
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); });
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const ParamNode* p) final {
return TypeAnnotation(Docify(p->var), p->type);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
...@@ -190,7 +190,7 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { ...@@ -190,7 +190,7 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) { .set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< node->false_branch << ")"; << ", " << node->false_branch << ")";
}); });
} // namespace relay } // namespace relay
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -68,7 +70,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -68,7 +70,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitTypeDefault_(const Node* op, Args...) { virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key(); LOG(FATAL) << "Do not have a default for " << op->type_key();
return R(); throw; // unreachable, written to stop compiler warning
} }
private: private:
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
# print(r) # uncomment this line to debug
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
show(lv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
ret_type = None
body = params[0].var
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), "float32")
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment