Unverified Commit 2c1ca60e by masahi Committed by GitHub

add memoized expr translator for use by backend codegen (#5325)

parent 0ab18036
......@@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <topi/tags.h>
#include <utility>
#include <functional>
#include <limits>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include <utility>
#include <vector>
#include "compile_engine.h"
#include "utils.h"
namespace tvm {
namespace relay {
......@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
......@@ -179,17 +180,6 @@ class ScheduleGetter :
return CachedFunc(cache_node);
}
Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
......@@ -327,7 +317,6 @@ class ScheduleGetter :
int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
......@@ -335,7 +324,7 @@ class ScheduleGetter :
};
// Creates shape function from functor.
class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
MakeShapeFunc() {}
......@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return std::make_pair(schedule, cfunc);
}
Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_[expr] = res;
}
return res;
Array<te::Tensor> VisitExpr(const Expr& expr) final {
if (expr.as<VarNode>()) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
return ExprFunctor::VisitExpr(expr);
}
// For other case, do memoized visit
return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
}
Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
......@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
/*! \brief Memoized visit result */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
......
......@@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {};
......@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};
class CSourceCodegen : public CSourceModuleCodegenBase {
......
......@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {};
......@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};
/*!
......
......@@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <list>
#include <string>
#include <vector>
#include "utils.h"
#include "compile_engine.h"
#include "utils.h"
namespace tvm {
namespace relay {
......@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};
/*! \brief Code generator for graph runtime */
class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
: mod_(mod) {
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
......@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return {GraphNodeRef(node_id, 0)};
}
/*! \brief Visitors */
std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;
std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
std::vector<GraphNodeRef> res;
if (expr.as<ConstantNode>()) {
res = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<TupleNode>()) {
res = VisitExpr_(expr.as<TupleNode>());
} else if (expr.as<VarNode>()) {
res = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<GlobalVarNode>()) {
res = VisitExpr_(expr.as<GlobalVarNode>());
} else if (expr.as<FunctionNode>()) {
res = VisitExpr_(expr.as<FunctionNode>());
} else if (expr.as<CallNode>()) {
res = VisitExpr_(expr.as<CallNode>());
} else if (expr.as<LetNode>()) {
res = VisitExpr_(expr.as<LetNode>());
} else if (expr.as<IfNode>()) {
res = VisitExpr_(expr.as<IfNode>());
} else if (expr.as<OpNode>()) {
res = VisitExpr_(expr.as<OpNode>());
} else if (expr.as<TupleGetItemNode>()) {
res = VisitExpr_(expr.as<TupleGetItemNode>());
} else if (expr.as<RefCreateNode>()) {
res = VisitExpr_(expr.as<RefCreateNode>());
} else if (expr.as<RefReadNode>()) {
res = VisitExpr_(expr.as<RefReadNode>());
} else if (expr.as<RefWriteNode>()) {
res = VisitExpr_(expr.as<RefWriteNode>());
} else if (expr.as<ConstructorNode>()) {
res = VisitExpr_(expr.as<ConstructorNode>());
} else if (expr.as<MatchNode>()) {
res = VisitExpr_(expr.as<MatchNode>());
}
visitor_cache_[expr] = res;
return res;
}
std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);
return var_map_[expr.get()];
......
......@@ -244,11 +244,6 @@ class Interpreter :
return VisitExpr(expr);
}
ObjectRef VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
return ret;
}
ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
......
......@@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
......@@ -42,6 +43,40 @@
namespace tvm {
namespace relay {
namespace backend {
/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template <typename OutputType>
class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;
public:
/*! \brief virtual destructor */
virtual ~MemoizedExprTranslator() {}
/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual OutputType VisitExpr(const Expr& n) {
CHECK(n.defined());
auto it = memo_.find(n);
if (it != memo_.end()) {
return it->second;
}
auto res = BaseFunctor::VisitExpr(n);
memo_[n] = res;
return res;
}
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
};
/*!
* \brief Get the Packed Func
*
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment