Unverified Commit 24e6fcb6 by Tianqi Chen Committed by GitHub

[REFACTOR][TYPE] Remove un-necessary var sub-field in GlobalTypeVar and TypeVar (#4615)

Currently, we use a tvm::Var to represent a placeholder for shapes in generic types.
This is not necessary for GlobalTypeVar(as we never parameterize by shape var),
and is a bit twisted for TypeVar.

As we move to a unified type system, we want to break the dependency
from the base TypeVar(which is shared across the languages) from the expression.
Note that it is fine for TensorType to depend on Expr.

One alternative solution to embed the Var would be to introduce a TypeVarExpr,
which can wrap a TypeVar as Expr. However, this new alternative won't be
natural until we migrate the type to the global scope.

Lucikly, we have not yet start to depend on the shape parameterization heavily yet.

This PR removes the tvm::Var from the typevars. We will follow up with another
PR to migrate the types to a base location. After that, we should be able to
use the more elegant approach via TypeVarExpr.
parent 9c638f06
......@@ -157,16 +157,13 @@ class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
......@@ -189,16 +186,13 @@ class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar; otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
......
......@@ -272,7 +272,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def _type_expr_name(self, e):
if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.var.name)
return "`{0}` ADT constructor".format(e.belong_to.name_hint)
elif isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle:
return "ADT definition"
......
......@@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor):
and reconstructs the AST.
"""
def visit_type_var(self, tv):
return TypeVar(tv.var.name, tv.kind)
return TypeVar(tv.name_hint, tv.kind)
def visit_incomplete_type(self, it):
return IncompleteType(it.kind)
......@@ -180,7 +180,7 @@ class TypeMutator(TypeFunctor):
return RefType(self.visit(rt.value))
def visit_global_type_var(self, gtv):
return GlobalTypeVar(gtv.var.name, gtv.kind)
return GlobalTypeVar(gtv.name_hint, gtv.kind)
def visit_type_call(self, tc):
return TypeCall(
......
......@@ -69,8 +69,8 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
return false;
}
}
......@@ -233,11 +233,6 @@ class AlphaEqualHandler:
return false;
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
}
for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
......
......@@ -228,11 +228,11 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;
const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == Kind::kShapeVar) {
hash_map_[ty_param->var] = hash;
}
// TODO(tqchen) Introduce TypeVarExpr
// const auto* ty_param = var.as<TypeVarNode>();
// if (ty_param && ty_param->kind == Kind::kShapeVar) {
// hash_map_[ty_param->var] = hash;
// }
return hash;
}
......
......@@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->name_hint;
n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}
......@@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->var->name_hint);
size_t hash = std::hash<std::string>()(var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
......@@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
CHECK(global_type_var_map_.count(var->var->name_hint) == 0)
<< "Duplicate global type definition name " << var->var->name_hint;
CHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
}
global_type_var_map_.Set(var->var->name_hint, var);
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
}
......@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint;
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
......
......@@ -312,7 +312,7 @@ class PrettyPrinter :
val << "-malformed-ir";
return val;
}
std::string name = var->var->name_hint;
std::string name = var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
......@@ -493,7 +493,7 @@ class PrettyPrinter :
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(Doc(tv->var->name_hint));
type_params.push_back(Doc(tv->name_hint));
}
doc << PrintSep(type_params);
doc << "]";
......@@ -701,11 +701,11 @@ class PrettyPrinter :
}
Doc VisitType_(const TypeVarNode* node) final {
return Doc(node->var->name_hint);
return Doc(node->name_hint);
}
Doc VisitType_(const GlobalTypeVarNode* node) final {
return Doc(node->var->name_hint);
return Doc(node->name_hint);
}
Doc VisitType_(const TypeCallNode* node) final {
......
......@@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TypeVar TypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->var = tvm::Var(name);
n->name_hint = std::move(name);
n->kind = std::move(kind);
return TypeVar(n);
}
......@@ -74,19 +74,19 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<Kind>(kind));
});
return TypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});
GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->var = tvm::Var(name);
n->name_hint = std::move(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
......@@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});
......
......@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
......
......@@ -334,7 +334,7 @@ Function UnCPS(const Function& f) {
auto new_ret_type = Type(cont_type->arg_types[0]);
std::vector<TypeVar> new_type_params;
for (const auto& tp : f->type_params) {
new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind));
new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind));
}
auto answer_type = new_type_params.back();
new_type_params.pop_back();
......
......@@ -534,15 +534,16 @@ def test_fused_ops():
tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
# m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
m, n, k = relay.Any(), relay.Any(), relay.Any()
x = relay.var('x', shape=(m, n, k), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1, dtype="int32")
y3 = y2 + relay.const(1, dtype="int32")
data = np.random.rand(10, 5, 3).astype('float32')
mod = relay.module.Module()
mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
mod["main"] = relay.Function([x], y3)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment