Unverified Commit 2c1ca60e by masahi Committed by GitHub

add memoized expr translator for use by backend codegen (#5325)

parent 0ab18036
...@@ -21,29 +21,31 @@ ...@@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc * \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine. * \brief Internal compialtion engine.
*/ */
#include "compile_engine.h"
#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h> #include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <topi/tags.h> #include <functional>
#include <utility>
#include <limits> #include <limits>
#include <mutex> #include <mutex>
#include <functional>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector>
#include "compile_engine.h" #include "utils.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { ...@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// The getter to get schedule from compile engine. // The getter to get schedule from compile engine.
// Get schedule from functor. // Get schedule from functor.
class ScheduleGetter : class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public ExprFunctor<Array<te::Tensor>(const Expr&)> {
public: public:
explicit ScheduleGetter(Target target) explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {} : target_(target), device_copy_op_(Op::Get("device_copy")) {}
...@@ -179,17 +180,6 @@ class ScheduleGetter : ...@@ -179,17 +180,6 @@ class ScheduleGetter :
return CachedFunc(cache_node); return CachedFunc(cache_node);
} }
Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}
Array<te::Tensor> VisitExpr_(const VarNode* op) final { Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint(); LOG(FATAL) << "Free variable " << op->name_hint();
return {}; return {};
...@@ -327,7 +317,6 @@ class ScheduleGetter : ...@@ -327,7 +317,6 @@ class ScheduleGetter :
int master_op_pattern_{0}; int master_op_pattern_{0};
OpImplementation master_implementation_; OpImplementation master_implementation_;
std::ostringstream readable_name_stream_; std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_; Array<te::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup // Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules. // overhead for each invocation of call node when retrieving schedules.
...@@ -335,7 +324,7 @@ class ScheduleGetter : ...@@ -335,7 +324,7 @@ class ScheduleGetter :
}; };
// Creates shape function from functor. // Creates shape function from functor.
class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> { class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public: public:
MakeShapeFunc() {} MakeShapeFunc() {}
...@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> { ...@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return std::make_pair(schedule, cfunc); return std::make_pair(schedule, cfunc);
} }
Array<te::Tensor> VisitExpr(const Expr& expr) { Array<te::Tensor> VisitExpr(const Expr& expr) final {
auto it = memo_.find(expr); if (expr.as<VarNode>()) {
if (it != memo_.end()) { // Do not memoize vars because shape functions could use either the data
return it->second; // or the shape of a var each time.
} else { return ExprFunctor::VisitExpr(expr);
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_[expr] = res;
}
return res;
} }
// For other case, do memoized visit
return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
} }
Array<te::Tensor> VisitExpr_(const VarNode* var_node) final { Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
...@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> { ...@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_; std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */ /*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_; std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
/*! \brief Memoized visit result */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */ /*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_; std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */ /*! \brief Scalars used in the shape function */
......
...@@ -40,18 +40,10 @@ using namespace backend; ...@@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users * purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators. * may need to extend them to cover more operators.
*/ */
class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>, class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public CodegenCBase {
public: public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final { std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {}; return {};
...@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>, ...@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> func_decl_; std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */ /*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
}; };
class CSourceCodegen : public CSourceModuleCodegenBase { class CSourceCodegen : public CSourceModuleCodegenBase {
......
...@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) { ...@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement. // all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>, class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public CodegenCBase {
public: public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}
std::vector<Output> VisitExprDefault_(const Object* op) final { std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey(); LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {}; return {};
...@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>, ...@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> ext_func_body; std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */ /*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_; std::vector<std::string> buf_decl_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
}; };
/*! /*!
......
...@@ -28,13 +28,12 @@ ...@@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <list> #include <list>
#include <string> #include <string>
#include <vector> #include <vector>
#include "utils.h"
#include "compile_engine.h" #include "compile_engine.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode { ...@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
}; };
/*! \brief Code generator for graph runtime */ /*! \brief Code generator for graph runtime */
class GraphRuntimeCodegen class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
public: public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
: mod_(mod) {
compile_engine_ = CompileEngine::Global(); compile_engine_ = CompileEngine::Global();
targets_ = targets; targets_ = targets;
} }
...@@ -313,47 +310,6 @@ class GraphRuntimeCodegen ...@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return {GraphNodeRef(node_id, 0)}; return {GraphNodeRef(node_id, 0)};
} }
/*! \brief Visitors */
std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;
std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
std::vector<GraphNodeRef> res;
if (expr.as<ConstantNode>()) {
res = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<TupleNode>()) {
res = VisitExpr_(expr.as<TupleNode>());
} else if (expr.as<VarNode>()) {
res = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<GlobalVarNode>()) {
res = VisitExpr_(expr.as<GlobalVarNode>());
} else if (expr.as<FunctionNode>()) {
res = VisitExpr_(expr.as<FunctionNode>());
} else if (expr.as<CallNode>()) {
res = VisitExpr_(expr.as<CallNode>());
} else if (expr.as<LetNode>()) {
res = VisitExpr_(expr.as<LetNode>());
} else if (expr.as<IfNode>()) {
res = VisitExpr_(expr.as<IfNode>());
} else if (expr.as<OpNode>()) {
res = VisitExpr_(expr.as<OpNode>());
} else if (expr.as<TupleGetItemNode>()) {
res = VisitExpr_(expr.as<TupleGetItemNode>());
} else if (expr.as<RefCreateNode>()) {
res = VisitExpr_(expr.as<RefCreateNode>());
} else if (expr.as<RefReadNode>()) {
res = VisitExpr_(expr.as<RefReadNode>());
} else if (expr.as<RefWriteNode>()) {
res = VisitExpr_(expr.as<RefWriteNode>());
} else if (expr.as<ConstructorNode>()) {
res = VisitExpr_(expr.as<ConstructorNode>());
} else if (expr.as<MatchNode>()) {
res = VisitExpr_(expr.as<MatchNode>());
}
visitor_cache_[expr] = res;
return res;
}
std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override { std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op); Expr expr = GetRef<Expr>(op);
return var_map_[expr.get()]; return var_map_[expr.get()];
......
...@@ -244,11 +244,6 @@ class Interpreter : ...@@ -244,11 +244,6 @@ class Interpreter :
return VisitExpr(expr); return VisitExpr(expr);
} }
ObjectRef VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
return ret;
}
ObjectRef VisitExpr_(const VarNode* var_node) final { ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node)); return Lookup(GetRef<Var>(var_node));
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <dmlc/json.h> #include <dmlc/json.h>
#include <tvm/driver/driver_api.h> #include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
...@@ -42,6 +43,40 @@ ...@@ -42,6 +43,40 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace backend { namespace backend {
/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template <typename OutputType>
class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;
public:
/*! \brief virtual destructor */
virtual ~MemoizedExprTranslator() {}
/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual OutputType VisitExpr(const Expr& n) {
CHECK(n.defined());
auto it = memo_.find(n);
if (it != memo_.end()) {
return it->second;
}
auto res = BaseFunctor::VisitExpr(n);
memo_[n] = res;
return res;
}
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
};
/*! /*!
* \brief Get the Packed Func * \brief Get the Packed Func
* *
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment