Unverified Commit 24e6fcb6 by Tianqi Chen Committed by GitHub

[REFACTOR][TYPE] Remove un-necessary var sub-field in GlobalTypeVar and TypeVar (#4615)

Currently, we use a tvm::Var to represent a placeholder for shapes in generic types.
This is not necessary for GlobalTypeVar(as we never parameterize by shape var),
and is a bit twisted for TypeVar.

As we move to a unified type system, we want to break the dependency
from the base TypeVar(which is shared across the languages) from the expression.
Note that it is fine for TensorType to depend on Expr.

One alternative solution to embed the Var would be to introduce a TypeVarExpr,
which can wrap a TypeVar as Expr. However, this new alternative won't be
natural until we migrate the type to the global scope.

Lucikly, we have not yet start to depend on the shape parameterization heavily yet.

This PR removes the tvm::Var from the typevars. We will follow up with another
PR to migrate the types to a base location. After that, we should be able to
use the more elegant approach via TypeVarExpr.
parent 9c638f06
...@@ -157,16 +157,13 @@ class TypeVar; ...@@ -157,16 +157,13 @@ class TypeVar;
/*! \brief TypeVar container node */ /*! \brief TypeVar container node */
class TypeVarNode : public TypeNode { class TypeVarNode : public TypeNode {
public: public:
/*! /*! \brief Name of the variable, it only acts as a hint. */
* \brief The variable itself is only meaningful when std::string name_hint;
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
} }
...@@ -189,16 +186,13 @@ class GlobalTypeVar; ...@@ -189,16 +186,13 @@ class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */ /*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode { class GlobalTypeVarNode : public TypeNode {
public: public:
/*! /*! \brief Name of the variable, it only acts as a hint. */
* \brief The variable itself is only meaningful when std::string name_hint;
* kind is ShapeVar; otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */ /*! \brief The kind of type parameter */
Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) { void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var); v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
} }
......
...@@ -272,7 +272,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -272,7 +272,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def _type_expr_name(self, e): def _type_expr_name(self, e):
if isinstance(e, adt.Constructor): if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.var.name) return "`{0}` ADT constructor".format(e.belong_to.name_hint)
elif isinstance(e, ty.GlobalTypeVar): elif isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle: if e.kind == ty.Kind.AdtHandle:
return "ADT definition" return "ADT definition"
......
...@@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor): ...@@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor):
and reconstructs the AST. and reconstructs the AST.
""" """
def visit_type_var(self, tv): def visit_type_var(self, tv):
return TypeVar(tv.var.name, tv.kind) return TypeVar(tv.name_hint, tv.kind)
def visit_incomplete_type(self, it): def visit_incomplete_type(self, it):
return IncompleteType(it.kind) return IncompleteType(it.kind)
...@@ -180,7 +180,7 @@ class TypeMutator(TypeFunctor): ...@@ -180,7 +180,7 @@ class TypeMutator(TypeFunctor):
return RefType(self.visit(rt.value)) return RefType(self.visit(rt.value))
def visit_global_type_var(self, gtv): def visit_global_type_var(self, gtv):
return GlobalTypeVar(gtv.var.name, gtv.kind) return GlobalTypeVar(gtv.name_hint, gtv.kind)
def visit_type_call(self, tc): def visit_type_call(self, tc):
return TypeCall( return TypeCall(
......
...@@ -69,8 +69,8 @@ class AlphaEqualHandler: ...@@ -69,8 +69,8 @@ class AlphaEqualHandler:
} }
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) { for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) || if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { !Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
return false; return false;
} }
} }
...@@ -233,11 +233,6 @@ class AlphaEqualHandler: ...@@ -233,11 +233,6 @@ class AlphaEqualHandler:
return false; return false;
} }
equal_map_[lhs->type_params[i]] = rhs->type_params[i]; equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
} }
for (size_t i = 0; i < lhs->arg_types.size(); i++) { for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false; if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
......
...@@ -228,11 +228,11 @@ class RelayHashHandler: ...@@ -228,11 +228,11 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(var_node->type_annotation)); hash = Combine(hash, TypeHash(var_node->type_annotation));
} }
hash_map_[var] = hash; hash_map_[var] = hash;
// TODO(tqchen) Introduce TypeVarExpr
const auto* ty_param = var.as<TypeVarNode>(); // const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == Kind::kShapeVar) { // if (ty_param && ty_param->kind == Kind::kShapeVar) {
hash_map_[ty_param->var] = hash; // hash_map_[ty_param->var] = hash;
} // }
return hash; return hash;
} }
......
...@@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
for (const auto& kv : n->type_definitions) { for (const auto& kv : n->type_definitions) {
// set global typevar map // set global typevar map
CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0) CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->var->name_hint; << "Duplicate global type definition name " << kv.first->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second); n->RegisterConstructors(kv.first, kv.second);
} }
...@@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& ...@@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
// We hash the global type var name to use as a globally unique prefix for tags. // We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of // The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes // the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->var->name_hint); size_t hash = std::hash<std::string>()(var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24; int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) { for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i); type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
...@@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, ...@@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
this->type_definitions.Set(var, type); this->type_definitions.Set(var, type);
if (!update) { if (!update) {
// set global type var map // set global type var map
CHECK(global_type_var_map_.count(var->var->name_hint) == 0) CHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->var->name_hint; << "Duplicate global type definition name " << var->name_hint;
} }
global_type_var_map_.Set(var->var->name_hint, var); global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type); RegisterConstructors(var, type);
} }
...@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const { ...@@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var); auto it = type_definitions.find(var);
CHECK(it != type_definitions.end()) CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint; << "There is no definition of " << var->name_hint;
return (*it).second; return (*it).second;
} }
......
...@@ -312,7 +312,7 @@ class PrettyPrinter : ...@@ -312,7 +312,7 @@ class PrettyPrinter :
val << "-malformed-ir"; val << "-malformed-ir";
return val; return val;
} }
std::string name = var->var->name_hint; std::string name = var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name; name = "t" + name;
} }
...@@ -493,7 +493,7 @@ class PrettyPrinter : ...@@ -493,7 +493,7 @@ class PrettyPrinter :
doc << "["; doc << "[";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) { for (const TypeVar& tv : fn->type_params) {
type_params.push_back(Doc(tv->var->name_hint)); type_params.push_back(Doc(tv->name_hint));
} }
doc << PrintSep(type_params); doc << PrintSep(type_params);
doc << "]"; doc << "]";
...@@ -701,11 +701,11 @@ class PrettyPrinter : ...@@ -701,11 +701,11 @@ class PrettyPrinter :
} }
Doc VisitType_(const TypeVarNode* node) final { Doc VisitType_(const TypeVarNode* node) final {
return Doc(node->var->name_hint); return Doc(node->name_hint);
} }
Doc VisitType_(const GlobalTypeVarNode* node) final { Doc VisitType_(const GlobalTypeVarNode* node) final {
return Doc(node->var->name_hint); return Doc(node->name_hint);
} }
Doc VisitType_(const TypeCallNode* node) final { Doc VisitType_(const TypeCallNode* node) final {
......
...@@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TypeVar TypeVarNode::make(std::string name, Kind kind) { TypeVar TypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>(); ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->var = tvm::Var(name); n->name_hint = std::move(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return TypeVar(n); return TypeVar(n);
} }
...@@ -75,18 +75,18 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode); ...@@ -75,18 +75,18 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar") TVM_REGISTER_API("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) { .set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<Kind>(kind)); return TypeVarNode::make(name, static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) { .set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get()); auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->var->name_hint << ", " p->stream << "TypeVarNode(" << node->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>(); ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->var = tvm::Var(name); n->name_hint = std::move(name);
n->kind = std::move(kind); n->kind = std::move(kind);
return GlobalTypeVar(n); return GlobalTypeVar(n);
} }
...@@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar") ...@@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) { .set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get()); auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
......
...@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) { ...@@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
public PatternMutator { public PatternMutator {
public: public:
TypeVar Fresh(const TypeVar& tv) { TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind);
type_rename_[tv] = ret; type_rename_[tv] = ret;
return ret; return ret;
} }
......
...@@ -334,7 +334,7 @@ Function UnCPS(const Function& f) { ...@@ -334,7 +334,7 @@ Function UnCPS(const Function& f) {
auto new_ret_type = Type(cont_type->arg_types[0]); auto new_ret_type = Type(cont_type->arg_types[0]);
std::vector<TypeVar> new_type_params; std::vector<TypeVar> new_type_params;
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind)); new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind));
} }
auto answer_type = new_type_params.back(); auto answer_type = new_type_params.back();
new_type_params.pop_back(); new_type_params.pop_back();
......
...@@ -534,15 +534,16 @@ def test_fused_ops(): ...@@ -534,15 +534,16 @@ def test_fused_ops():
tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2) tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
def test_arange_with_dynamic_shape(): def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32') m, n, k = relay.Any(), relay.Any(), relay.Any()
x = relay.var('x', shape=(m, n, k), dtype='float32')
y0 = relay.shape_of(x) y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32')) y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1, dtype="int32") y2 = relay.op.arange(y1, dtype="int32")
y3 = y2 + relay.const(1, dtype="int32") y3 = y2 + relay.const(1, dtype="int32")
data = np.random.rand(10, 5, 3).astype('float32') data = np.random.rand(10, 5, 3).astype('float32')
mod = relay.module.Module() mod = relay.module.Module()
mod["main"] = relay.Function([x], y3, type_params=[m, n, k]) mod["main"] = relay.Function([x], y3)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data) result = ex.evaluate()(data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment