Commit 09efdc9d by 雾雨魔理沙 Committed by Yizhi Liu

[Relay] Fix format (#1957)

* save

* fix format
parent 390acc52
...@@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) { ...@@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
} }
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) { .set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")"; p->stream << "SourceName(" << node->name << ", " << node << ")";
}); });
...@@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { ...@@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span") TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpanNode::make(args[0], args[1], args[2]); *ret = SpanNode::make(args[0], args[1], args[2]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) { .set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")"; << node->col_offset << ")";
}); });
......
...@@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) { ...@@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
return (*it).second; return (*it).second;
} }
Function EnvironmentNode::Lookup(const std::string &name) { Function EnvironmentNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name); GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id); return this->Lookup(id);
} }
void EnvironmentNode::Update(const Environment &env) { void EnvironmentNode::Update(const Environment& env) {
for (auto pair : env->functions) { for (auto pair : env->functions) {
this->Update(pair.first, pair.second); this->Update(pair.first, pair.second);
} }
......
...@@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) { ...@@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstantNode::make(args[0]); *ret = ConstantNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) { .set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)"; p->stream << "Constant(TODO)";
}); });
...@@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { ...@@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleNode::make(args[0]); *ret = TupleNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) { .set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
p->stream << "Tuple(" << node->fields << ")"; p->stream << "Tuple(" << node->fields << ")";
}); });
...@@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { ...@@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]); *ret = VarNode::make(args[0], args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) { .set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint; p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) { if (node->type_annotation.defined()) {
p->stream << ", ty="; p->stream << ", ty=";
...@@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { ...@@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GlobalVarNode::make(args[0]); *ret = GlobalVarNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) { .set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
p->stream << "GlobalVar(" << node->name_hint << ")"; p->stream << "GlobalVar(" << node->name_hint << ")";
}); });
...@@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]); *ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode *node, .set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter *p) { tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")"; << ", " << node->body << ", " << node->type_params << ")";
}); });
...@@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, ...@@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call") TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]); *ret = CallNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode *node, tvm::IRPrinter *p) { .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", " p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")"; << node->attrs << ", " << node->type_args << ")";
}); });
...@@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) { ...@@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LetNode::make(args[0], args[1], args[2]); *ret = LetNode::make(args[0], args[1], args[2]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) { .set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
p->stream << "LetNode(" << node->var << ", " << node->value p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")"; << ", " << node->body << ")";
}); });
...@@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { ...@@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IfNode::make(args[0], args[1], args[2]); *ret = IfNode::make(args[0], args[1], args[2]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) { .set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")"; << ", " << node->false_branch << ")";
}); });
......
...@@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { ...@@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Array<IndexExpr> shape = args[0]; Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]); *ret = TensorTypeNode::make(shape, args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node, .set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
tvm::IRPrinter *p) { tvm::IRPrinter* p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
...@@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { ...@@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar") TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1]; int kind = args[1];
*ret = *ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind)); TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode *node, .set_dispatch<TypeVarNode>([](const TypeVarNode* node,
tvm::IRPrinter *p) { tvm::IRPrinter* p) {
p->stream << "TypeVarNode(" << node->var->name_hint << ", " p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")"; << node->kind << ")";
}); });
...@@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, ...@@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType") TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node, .set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
tvm::IRPrinter *p) { tvm::IRPrinter* p) {
p->stream << "FuncTypeNode(" << node->type_params << ", " p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", " << node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")"; << node->type_constraints << ")";
...@@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, ...@@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]); *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) { .set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
p->stream << "TypeRelationNode(" p->stream << "TypeRelationNode("
<< node->func->name << node->func->name
<< ", " << node->args << ")"; << ", " << node->args << ")";
...@@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) { ...@@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType") TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleTypeNode::make(args[0]); *ret = TupleTypeNode::make(args[0]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node, .set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
tvm::IRPrinter *p) { tvm::IRPrinter* p) {
p->stream << "TupleTypeNode(" << node->fields << ")"; p->stream << "TupleTypeNode(" << node->fields << ")";
}); });
......
...@@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { ...@@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
}; };
bool AlphaEqual(const Type& t1, const Type& t2) { bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined()) if (t1.defined() != t2.defined()) {
return false; return false;
}
if (!t1.defined()) if (!t1.defined()) {
return true; return true;
}
TypeAlphaEq aeq; TypeAlphaEq aeq;
aeq.VisitType(t1, t2); aeq.VisitType(t1, t2);
...@@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { ...@@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
for (size_t i = 0; i < func1->params.size(); ++i) { for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]); MergeVarDecl(func1->params[i], func2->params[i]);
} }
if (!equal) return;
if (!equal) {
return;
}
for (size_t i = 0U; i < func1->type_params.size(); i++) { for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
......
...@@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> { ...@@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {
// checks if t is an incomplete node of kind k or a type param of kind k // checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) { bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) { if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k; return tv->kind == k;
} }
if (const TypeVarNode *tp = t.as<TypeVarNode>()) { if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
return tp->kind == k; return tp->kind == k;
} }
...@@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> { ...@@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
} }
} }
bool Check(const Type &t) { bool Check(const Type& t) {
this->VisitType(t); this->VisitType(t);
return valid; return valid;
} }
......
...@@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return new_e; return new_e;
} }
Type VisitType(const Type &t) final { Type VisitType(const Type& t) final {
return solver_->Resolve(t); return solver_->Resolve(t);
} }
......
...@@ -14,10 +14,10 @@ namespace relay { ...@@ -14,10 +14,10 @@ namespace relay {
class FreeVar; class FreeVar;
class FreeTypeVar : private TypeVisitor<> { class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars; std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars; std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars, FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) : std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { } free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeVarNode* tp) final { void VisitType_(const TypeVarNode* tp) final {
...@@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> { ...@@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
}; };
class FreeVar : public ExprVisitor { class FreeVar : public ExprVisitor {
void VisitExpr_(const VarNode *v) final { void VisitExpr_(const VarNode* v) final {
auto var = GetRef<Var>(v); auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) { if (bound_vars.count(var) == 0) {
free_vars.insert(var); free_vars.insert(var);
...@@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor { ...@@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
} }
} }
void VisitExpr_(const FunctionNode *f) final { void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) { for (const auto& tp : f->type_params) {
bound_types.insert(tp); bound_types.insert(tp);
} }
...@@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor { ...@@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
VisitType(f->ret_type); VisitType(f->ret_type);
} }
void VisitExpr_(const LetNode *l) final { void VisitExpr_(const LetNode* l) final {
bound_vars.insert(l->var); bound_vars.insert(l->var);
VisitExpr(l->value); VisitExpr(l->value);
VisitExpr(l->body); VisitExpr(l->body);
......
...@@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor { ...@@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor {
std::unordered_set<Var, NodeHash, NodeEqual> s; std::unordered_set<Var, NodeHash, NodeEqual> s;
void Check(const Var & v) { void Check(const Var& v) {
if (s.count(v) != 0) { if (s.count(v) != 0) {
well_formed = false; well_formed = false;
} }
s.insert(v); s.insert(v);
} }
void VisitExpr_(const LetNode * l) final { void VisitExpr_(const LetNode* l) final {
// we do letrec only for FunctionNode, // we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it. // but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var); Check(l->var);
...@@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor { ...@@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor {
CheckWellFormed(l->body); CheckWellFormed(l->body);
} }
void VisitExpr_(const FunctionNode * f) final { void VisitExpr_(const FunctionNode* f) final {
for (const Var & param : f->params) { for (const Var& param : f->params) {
Check(param); Check(param);
} }
CheckWellFormed(f->body); CheckWellFormed(f->body);
} }
public: public:
bool CheckWellFormed(const Expr & e) { bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e); this->VisitExpr(e);
return well_formed; return well_formed;
} }
}; };
bool WellFormed(const Expr & e) { bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e); return WellFormedChecker().CheckWellFormed(e);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment