Commit 06108bed by 雾雨魔理沙 Committed by Tianqi Chen

[RELAY] First pass at pretty printer (#1749)

parent 32af4d28
...@@ -158,6 +158,8 @@ class RelayNode : public Node { ...@@ -158,6 +158,8 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
}; };
struct Environment;
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -376,6 +376,14 @@ class IfNode : public ExprNode { ...@@ -376,6 +376,14 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr); RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
* \param os the stream
* \returns A reference to the stream.
*/
std::ostream& DebugPrint(const Environment& env, const Expr& e, std::ostream& os);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./op.h" #include "./op.h"
#include "./error.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -89,7 +90,7 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -89,7 +90,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op, virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Node* op, Args...) {
throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); throw Error(std::string("Do not have a default for ") + op->type_key());
} }
private: private:
......
...@@ -365,6 +365,14 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -365,6 +365,14 @@ class TypeRelationNode : public TypeConstraintNode {
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*! \brief Print a debug representation of the type to the stream.
* \param env The environment.
* \param t The type
* \param os the stream
* \returns A reference to the stream.
*/
std::ostream& DebugPrint(const Environment& env, const Type& t, std::ostream& os);
// The following fields contains advanced typing // The following fields contains advanced typing
// Only keep the class name and reserved for future usage. // Only keep the class name and reserved for future usage.
class GenericTensorType; class GenericTensorType;
......
...@@ -33,4 +33,3 @@ Function = expr.Function ...@@ -33,4 +33,3 @@ Function = expr.Function
Call = expr.Call Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
Var = Var
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._expr", __name__)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node from .base import NodeBase, register_relay_node
from . import _expr
from . import _make from . import _make
from .. import convert from .. import convert
...@@ -115,3 +116,5 @@ class If(Expr): ...@@ -115,3 +116,5 @@ class If(Expr):
def __init__(self, cond, true_value, false_value): def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value) _make.If, cond, true_value, false_value)
debug_print = _expr._debug_print
...@@ -111,4 +111,4 @@ class If(Expr): ...@@ -111,4 +111,4 @@ class If(Expr):
def __init__(self, cond, true_value, false_value): def __init__(self, cond, true_value, false_value):
# type: (Expr, Expr, Expr) -> None # type: (Expr, Expr, Expr) -> None
... ...
\ No newline at end of file
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); });
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const ParamNode* p) final {
return TypeAnnotation(Docify(p->var), p->type);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file doc.h
* \brief A pretty printer DSL for constructing (Doc) and formatting (RDoc) documents.
* It is based heavily on Philip Wadler's "A prettier printer."
* See https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf
* for more details.
*
* Since the original paper uses call by value for efficiency, everything doc function is maximally lazy.
* You can probably yank speed by doing strict analysis and removing some Lazy (if this is bottleneck).
*/
#ifndef TVM_RELAY_IR_DOC_H_
#define TVM_RELAY_IR_DOC_H_
#include <unordered_map>
#include <utility>
#include <string>
#include <functional>
#include <vector>
#include <memory>
#include <ostream>
#include <map>
#include "error.h"
namespace tvm {
namespace relay {
/*! \brief A Document represent structured text.
* beside having unstructured string, it capture different ways to compose them -
* line break, space, indentation, representation choice.
*/
struct Doc;
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, choice, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc;
//! \brief Empty document
inline Doc Nil();
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r);
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc);
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text);
//! \brief New line
inline Doc Endl();
//! \brief Remove all line break from the Document.
inline Doc Flatten(const Doc& d);
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r);
//! \brief Use a single line if possible
inline Doc Group(const Doc& d);
//! \brief print an RDoc
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc);
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep);
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close);
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close);
//! \brief Either a space or a new line
inline Doc Sep();
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width = 80);
// end of API, start of implementation
template<typename T>
struct LazyNode {
mutable std::function<T()> thunk;
explicit LazyNode(const std::function<T()>& thunk) : thunk(thunk) { }
};
//! \brief denote a value that will be computed (at most once) on need.
template<typename T>
struct Lazy {
std::shared_ptr<LazyNode<T> > lazy_node;
explicit Lazy(const std::function<T()>& thunk) :
lazy_node(std::make_shared<LazyNode<T>>(thunk)) { }
explicit Lazy(const T& value) : Lazy([=]() { return value; }) { }
explicit Lazy(const Lazy<Lazy<T>>& thunk) : Lazy([=]() { return thunk.get().get(); }) { }
// calculate the result.
// memoize it by replacing the thunk with a constant function which immediate return.
T get() const {
T res = lazy_node->thunk();
lazy_node->thunk = [=]() { return res; };
return res;
}
template<typename R>
Lazy<R> map(const std::function<R(const T&)>& func) const {
Lazy<T> self(*this);
return Lazy<R>([=]() -> R { return func(self.get()); });
}
};
struct NilNode;
struct AppNode;
struct NestNode;
struct TextNode;
struct LineNode;
struct ChoiceNode;
/*! \brief The inner representation of Doc.
* a doc represent structured text,
* and can be rendered onto screen while keeping the structure.
*/
struct DocNode {
/* a docnode is a union of the below node.
* exactly one of them will be non null.
* their meaning is denoted by the construction function of the same name.
* so for example, the meaning of AppNode is exactly a node construct by App.
*/
std::shared_ptr<NilNode> nil;
std::shared_ptr<AppNode> app;
std::shared_ptr<NestNode> nest;
std::shared_ptr<TextNode> text; // construct by DocOfStr
std::shared_ptr<LineNode> line;
std::shared_ptr<ChoiceNode> choice;
DocNode(std::shared_ptr<NilNode> nil,
std::shared_ptr<AppNode> app,
std::shared_ptr<NestNode> nest,
std::shared_ptr<TextNode> text,
std::shared_ptr<LineNode> line,
std::shared_ptr<ChoiceNode> choice) :
nil(nil),
app(app),
nest(nest),
text(text),
line(line),
choice(choice) { }
};
struct Doc {
Lazy<DocNode> doc;
explicit Doc(const DocNode& ed) : doc(ed) { }
explicit Doc(const Lazy<Doc>& ldoc) :
doc(ldoc.map<Lazy<DocNode> >([](const Doc& d){ return d.doc; })) { }
Doc operator+(const Doc& r) const {
return App(*this, r);
}
template<typename T>
Lazy<T> Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const;
};
struct NilNode { };
struct AppNode {
Doc left, right;
AppNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
struct NestNode {
size_t space;
Doc doc;
NestNode(size_t space, const Doc& doc) : space(space), doc(doc) { }
};
struct TextNode {
std::string text;
explicit TextNode(const std::string& text) : text(text) { }
};
struct LineNode { };
struct ChoiceNode {
Doc left, right;
ChoiceNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
template<typename T>
Lazy<T> Doc::Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const {
return doc.map<T>([=](const DocNode& d) {
if (d.nil) {
return nilf();
} else if (d.app) {
return appf(d.app->left, d.app->right);
} else if (d.nest) {
return nestf(d.nest->space, d.nest->doc);
} else if (d.text) {
return textf(d.text->text);
} else if (d.line) {
return linef();
} else {
return choicef(d.choice->left, d.choice->right);
}
});
}
//! \brief Empty document
inline Doc Nil() {
return Doc(DocNode(std::make_shared<NilNode>(), nullptr, nullptr, nullptr, nullptr, nullptr));
}
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
std::make_shared<AppNode>(l, r),
nullptr,
nullptr,
nullptr,
nullptr));
}
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc) {
auto x = std::make_shared<NestNode>(width, doc);
return Doc(DocNode(
nullptr,
nullptr,
std::make_shared<NestNode>(width, doc),
nullptr,
nullptr,
nullptr));
}
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text) {
return Doc(DocNode(nullptr, nullptr, nullptr,
std::make_shared<TextNode>(text), nullptr, nullptr));
}
//! \brief New line
inline Doc Endl() {
return Doc(DocNode(nullptr, nullptr, nullptr, nullptr, std::make_shared<LineNode>(), nullptr));
}
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
std::make_shared<ChoiceNode>(l, r)));
}
//! \brief Remove new line from the whole document.
inline Doc Flatten(const Doc& d) {
return Doc(d.Match<Doc>(
[]() { return Nil(); },
[](const Doc& l, const Doc& r) { return Flatten(l) + Flatten(r); },
[](size_t space, const Doc& doc) { return Flatten(doc); },
[](const std::string& str) { return DocOfStr(str); },
[]() { return DocOfStr(" "); },
[](const Doc& l, const Doc& r) { return Flatten(l); }));
}
//! \brief Use a single line if possible
inline Doc Group(const Doc& d) {
return Choose(Flatten(d), d);
}
struct RNilNode;
struct RTextNode;
struct RLineNode;
struct RDocNode {
std::shared_ptr<RNilNode> rnil;
std::shared_ptr<RTextNode> rtext;
std::shared_ptr<RLineNode> rline;
RDocNode(const std::shared_ptr<RNilNode>& rnil,
const std::shared_ptr<RTextNode>& rtext,
const std::shared_ptr<RLineNode>& rline) :
rnil(rnil), rtext(rtext), rline(rline) { }
};
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, alternative, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc {
Lazy<RDocNode> doc;
explicit RDoc(const RDocNode& d) : doc(d) { }
explicit RDoc(const Lazy<RDoc>& ldoc) :
doc(ldoc.map<Lazy<RDocNode>>([](const RDoc& d){ return d.doc; })) { }
template<typename T>
Lazy<T> Match(
const std::function<T()> &rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const;
};
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc) {
return *rdoc.Match<std::ostream*>(
[&]() { return & os; },
[&](const std::string& text, const RDoc& r) {
return & (os << text << r);
},
[&](size_t space, const RDoc& r) {
return & (os << std::endl << std::string(space, ' ') << r);
}).get();
}
struct RNilNode { };
struct RTextNode {
std::string text;
RDoc rest;
RTextNode(const std::string& text, const RDoc& rest) : text(text), rest(rest) { }
};
struct RLineNode {
size_t space;
RDoc rest;
RLineNode(size_t space, const RDoc& rest) : space(space), rest(rest) { }
};
//! \brief Empty RDoc
inline RDoc RNil() { return RDoc(RDocNode(std::make_shared<RNilNode>(), nullptr, nullptr)); }
//! \brief RDoc that begin with std::string
inline RDoc RText(const std::string& text, const RDoc& rest) {
return RDoc(RDocNode(nullptr, std::make_shared<RTextNode>(text, rest), nullptr));
}
//! \brief RDoc that begin with a new line, followed by space
inline RDoc RLine(size_t space, const RDoc& rest) {
return RDoc(RDocNode(nullptr, nullptr, std::make_shared<RLineNode>(space, rest)));
}
template<typename T>
Lazy<T> RDoc::Match(
const std::function<T()>& rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const {
return doc.map<T>([=](const RDocNode& rdoc) {
if (rdoc.rnil) {
return rnilf();
} else if (rdoc.rtext) {
return rtextf(rdoc.rtext->text, rdoc.rtext->rest);
} else {
return rlinef(rdoc.rline->space, rdoc.rline->rest);
}
});
}
template<typename T>
struct List;
template<typename T>
struct EagerList {
const std::shared_ptr<std::pair<T, List<T>>> cons;
};
//! \brief lazy list
template<typename T>
struct List {
Lazy<EagerList<T> > l;
List() : l([]() { return EagerList<T>({nullptr}); }) { }
List(const T& t, const List<T>& l) :
l([=]() { return EagerList<T>({std::make_shared<std::pair<T, List<T>>>(t, l)}); }) { }
template<typename R>
Lazy<R> Match(const std::function<R()>& nullf,
const std::function<R(const T&, const List<T>&)>& consf) const {
return l.template map<R>([=](const EagerList<T>& l) {
if (l.cons) {
return consf(l.cons->first, l.cons->second);
} else {
return nullf();
}
});
}
};
//! \brief Does x fit into line of size w?
inline bool Fits(int w, const RDoc& x) {
return (w >= 0) && x.Match<bool>(
[]() { return true; },
[=](const std::string& s, const RDoc& x) { return Fits(w - s.size(), x); },
[](size_t space, const RDoc& x) { return true; }).get();
}
//! \brief Choose the one that fits best.
inline RDoc Better(size_t w, size_t k, const RDoc& x, const RDoc& y) {
return Fits(w-k, x) ? x : y;
}
typedef std::pair<size_t/*indent size*/, Doc> best_arg;
inline RDoc Best(size_t w/*wrap width*/, size_t k/*space used*/,
const List<best_arg>& l/*to be rendered*/) {
return RDoc(l.Match<RDoc>(
[]() { return RNil(); },
[=](const best_arg& p, const List<best_arg>& z) {
return RDoc(p.second.Match<RDoc>(
[=]() { return Best(w, k, z); },
[=](const Doc& x, const Doc& y) {
return Best(
w,
k,
List<best_arg>(best_arg(p.first, x), List<best_arg>(best_arg(p.first, y), z))); },
[=](size_t j, const Doc& x) {
return Best(w, k, List<best_arg>(best_arg(p.first + j, x), z)); },
[=](const std::string& text) { return RText(text, Best(w, k + text.size(), z)); },
[=]() { return RLine(p.first, Best(w, p.first, z)); },
[=](const Doc& x, const Doc& y) {
return Better(
w,
k,
Best(w, k, List<best_arg>(best_arg(p.first, x), z)),
Best(w, k, List<best_arg>(best_arg(p.first, y), z))); }));
}));
}
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep) {
// https://www.safaribooksonline.com/library/view/c-cookbook/0596007612/ch04s09.html
Doc output = Nil();
for (auto p = vec.begin(); p != vec.end(); ++p) {
output = output + *p;
if (p != vec.end() - 1) {
output = output + sep;
}
}
return output;
}
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close) {
return DocOfStr(open) + Nest(indent, Endl() + body) + Endl() + DocOfStr(close);
}
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close) {
return Group(DocOfStr(open) +
Nest(open.size(), Join(body, DocOfStr(",") + Endl())) +
DocOfStr(close));
}
//! \brief Either a space or a new line
inline Doc Sep() {
return Choose(DocOfStr(" "), Endl());
}
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width) {
return Best(width, 0, List<best_arg>(best_arg(0, d), List<best_arg>()));
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_DOC_H_
...@@ -190,7 +190,7 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { ...@@ -190,7 +190,7 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) { .set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< node->false_branch << ")"; << ", " << node->false_branch << ")";
}); });
} // namespace relay } // namespace relay
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -68,7 +70,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -68,7 +70,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitTypeDefault_(const Node* op, Args...) { virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key(); LOG(FATAL) << "Do not have a default for " << op->type_key();
return R(); throw; // unreachable, written to stop compiler warning
} }
private: private:
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
# print(r) # uncomment this line to debug
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
show(lv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
ret_type = None
body = params[0].var
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), "float32")
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment