Unverified Commit a2edd01b by Zhi Committed by GitHub

relay::StructuralHash to tvm::StructuralHash (#5166)

parent 919ae889
......@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};
} // namespace relay
} // namespace tvm
......
......@@ -20,11 +20,10 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from tvm.ir import RelayExpr, IRModule
from tvm.ir import IRModule
from . import _ffi_api
from .feature import Feature
from ..ty import Type
def post_order_visit(expr, fvisit):
......@@ -314,29 +313,6 @@ def detect_feature(a, b=None):
return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
def structural_hash(value):
"""Hash a Relay expression structurally.
Parameters
----------
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result : int
The hash value
"""
if isinstance(value, RelayExpr):
return int(_ffi_api._expr_hash(value))
elif isinstance(value, Type):
return int(_ffi_api._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions.
......
......@@ -27,7 +27,7 @@ import tvm
from tvm.ir import IRModule
from tvm.relay.prelude import Prelude
from tvm.relay.analysis import structural_hash as s_hash
from tvm.ir import structural_hash as s_hash
from .. import analysis
from .. import expr as _expr
......
......@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor):
# compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt)
func_hash = relay.analysis.structural_hash(op)
func_hash = tvm.ir.structural_hash(op)
op_name = '_lowered_op_{}'.format(func_hash)
if not tvm.get_global_func(op_name, allow_missing=True):
jitted = self.engine.jit(cc_key, self.tgt)
......
......@@ -21,6 +21,7 @@
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
......@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = StructuralHash()(func);
size_t hash_ = tvm::StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func);
}
......
......@@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
......@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func);
hash_ = tvm::StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
......
......@@ -23,6 +23,7 @@
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
......@@ -39,7 +40,7 @@ namespace relay {
namespace vm {
inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func);
size_t hash = tvm::StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
}
......
......@@ -31,7 +31,8 @@ def alpha_equal(x, y):
"""
x = x['main']
y = y['main']
return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
return tvm.ir.structural_equal(x, y) and \
tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment