Commit 3d2a0560 by Steven S. Lyubomirsky Committed by Jared Roesch

[Relay][Module] Make tags for ADT constructors and ConstructorValues more robust (#3369)

* Use hash of ADT name and constructor idx to generate tag, add reverse mapping to module and use where appropriate

* Lint and build fixes

* Add round-tripping test for getting constructors by tag

* Use int64_t everywhere for tags

* Add additional identity check

* Bring out _arg_to_ast again

* Use 8-bit hash of GTV name as MSB of tag, index as LSB for more readable tags

* Use int32 instead of int64 for tag
parent 1efc84a4
......@@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode {
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
mutable int32_t tag = -1;
ConstructorNode() {}
......
......@@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
class ConstructorValue;
struct ConstructorValueNode : ValueNode {
int tag;
int32_t tag;
tvm::Array<Value> fields;
......@@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode {
v->Visit("constructor", &constructor);
}
TVM_DLL static ConstructorValue make(int tag,
TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
Constructor construtor = {});
......
......@@ -32,6 +32,7 @@
#include <tvm/relay/type.h>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
namespace relay {
......@@ -133,34 +134,41 @@ class ModuleNode : public RelayNode {
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
/*!
* \brief Lookup a global function by its variable.
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;
/*!
* \brief Lookup a global function by its string name
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;
/*!
* \brief Lookup a global type definition by its variable.
* \brief Look up a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
/*!
* \brief Lookup a global type definition by its name.
* \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
*/
TVM_DLL Constructor LookupTag(const int32_t tag);
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
......@@ -185,6 +193,9 @@ class ModuleNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
private:
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
......@@ -194,6 +205,11 @@ class ModuleNode : public RelayNode {
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
/*! \brief A map from constructor tags to constructor objects
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;
};
struct Module : public NodeRef {
......
......@@ -114,17 +114,18 @@ class RefValue(Value):
_make.RefValue, value)
def _arg_to_ast(arg):
def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(field) for field in arg])
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
return RefCreate(_arg_to_ast(mod, arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
return Call(mod.get_constructor(arg.tag),
[_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
......@@ -231,7 +232,7 @@ class Executor(object):
if binds:
scope_builder = ScopeBuilder()
for key, value in binds.items():
scope_builder.let(key, _arg_to_ast(value))
scope_builder.let(key, _arg_to_ast(self.mod, value))
scope_builder.ret(expr)
expr = scope_builder.get()
......@@ -294,7 +295,7 @@ class Interpreter(Executor):
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
relay_args.append(_arg_to_ast(self.mod, arg))
# Set the entry function for the module.
if expr is None:
......
......@@ -156,6 +156,25 @@ class Module(RelayNode):
"""
return _module.Module_GetGlobalTypeVar(self, name)
def get_constructor(self, tag):
"""Look up an ADT constructor by tag.
Parameters
----------
tag: int
The tag for a constructor.
Returns
-------
constructor: Constructor
The constructor associated with the given tag,
Raises
------
tvm.TVMError if the corresponding constructor cannot be found.
"""
return _module.Module_LookupTag(self, tag)
@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
......@@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")";
});
ConstructorValue ConstructorValueNode::make(int tag,
ConstructorValue ConstructorValueNode::make(int32_t tag,
tvm::Array<Value> fields,
Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
......
......@@ -53,6 +53,7 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}
return Module(n);
......@@ -108,15 +109,25 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func);
}
void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
}
}
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = i;
}
RegisterConstructors(var, type);
// need to kind check at the end because the check can look up
// a definition potentially
......@@ -159,6 +170,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
<< "There is no constructor with the tag " << tag;
return (*it).second;
}
void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
......@@ -236,6 +254,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str")
return mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
return mod->LookupTag(tag);
});
TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) {
return ModuleNode::FromExpr(e);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for module functionality."""
import tvm
from tvm import relay
from tvm.relay import Module
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions
def constructor_list(p):
return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]
def adt_list(p):
return [p.nat, p.l, p.optional, p.tree]
def test_constructor_tag_round_trip():
mod1 = Module()
p1 = Prelude(mod1)
add_nat_definitions(p1)
mod2 = Module()
p2 = Prelude(mod2)
add_nat_definitions(p2)
# ensure hashes match across modules
ctors1 = constructor_list(p1)
ctors2 = constructor_list(p2)
for i in range(len(ctors1)):
tag = ctors1[i].tag
ctor = mod2.get_constructor(tag)
assert ctor == ctors2[i]
assert ctor.name_hint == ctors1[i].name_hint
def test_constructor_tag_differences():
# ensure that if we have the type data for a given ADT, the tags
# for the constructors of the *same ADT* are simple offsets from
# each other
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
adts = adt_list(p)
for adt in adts:
data = mod[adt]
for i in range(len(data.constructors) - 1):
ctor1 = data.constructors[i]
ctor2 = data.constructors[i + 1]
assert ctor2.tag - ctor1.tag == 1
# make sure there is something present at the MSB
assert ctor1.tag - i != 0
assert ctor2.tag - (i + 1) != 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment