Commit 09efdc9d by 雾雨魔理沙 Committed by Yizhi Liu

[Relay] Fix format (#1957)

* save

* fix format
parent 390acc52
......@@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
}
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
......@@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) {
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
......
......@@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
return (*it).second;
}
Function EnvironmentNode::Lookup(const std::string &name) {
Function EnvironmentNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
void EnvironmentNode::Update(const Environment &env) {
void EnvironmentNode::Update(const Environment& env) {
for (auto pair : env->functions) {
this->Update(pair.first, pair.second);
}
......
......@@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstantNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) {
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)";
});
......@@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
p->stream << "Tuple(" << node->fields << ")";
});
......@@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) {
p->stream << ", ty=";
......@@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GlobalVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) {
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
p->stream << "GlobalVar(" << node->name_hint << ")";
});
......@@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")";
});
......@@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode *node, tvm::IRPrinter *p) {
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
......@@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
});
......@@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
});
......
......@@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
......@@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
......@@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
......@@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
p->stream << "TypeRelationNode("
<< node->func->name
<< ", " << node->args << ")";
......@@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleTypeNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TupleTypeNode(" << node->fields << ")";
});
......
......@@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
};
bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined())
if (t1.defined() != t2.defined()) {
return false;
}
if (!t1.defined())
if (!t1.defined()) {
return true;
}
TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
......@@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) return;
if (!equal) {
return;
}
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
......
......@@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {
// checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) {
if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k;
}
if (const TypeVarNode *tp = t.as<TypeVarNode>()) {
if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
return tp->kind == k;
}
......@@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
}
}
bool Check(const Type &t) {
bool Check(const Type& t) {
this->VisitType(t);
return valid;
}
......
......@@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return new_e;
}
Type VisitType(const Type &t) final {
Type VisitType(const Type& t) final {
return solver_->Resolve(t);
}
......
......@@ -14,10 +14,10 @@ namespace relay {
class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) :
std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }
void VisitType_(const TypeVarNode* tp) final {
......@@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
};
class FreeVar : public ExprVisitor {
void VisitExpr_(const VarNode *v) final {
void VisitExpr_(const VarNode* v) final {
auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
......@@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
}
}
void VisitExpr_(const FunctionNode *f) final {
void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
......@@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
VisitType(f->ret_type);
}
void VisitExpr_(const LetNode *l) final {
void VisitExpr_(const LetNode* l) final {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
......
......@@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor {
std::unordered_set<Var, NodeHash, NodeEqual> s;
void Check(const Var & v) {
void Check(const Var& v) {
if (s.count(v) != 0) {
well_formed = false;
}
s.insert(v);
}
void VisitExpr_(const LetNode * l) final {
void VisitExpr_(const LetNode* l) final {
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
......@@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor {
CheckWellFormed(l->body);
}
void VisitExpr_(const FunctionNode * f) final {
for (const Var & param : f->params) {
void VisitExpr_(const FunctionNode* f) final {
for (const Var& param : f->params) {
Check(param);
}
CheckWellFormed(f->body);
}
public:
bool CheckWellFormed(const Expr & e) {
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
return well_formed;
}
};
bool WellFormed(const Expr & e) {
bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment