Commit c3f02c4b by Altan Haan Committed by Jared Roesch

add missing gradient check to gradient pass (#4169)

parent 5a177070
...@@ -351,8 +351,6 @@ struct ReverseAD : ExprMutator { ...@@ -351,8 +351,6 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const CallNode* op) final { Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) { if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node); Op op_ref = GetRef<Op>(op_node);
CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) { return LetList::With([&](LetList* ll) {
std::vector<Var> args; std::vector<Var> args;
for (const auto& arg : op->args) { for (const auto& arg : op->args) {
...@@ -408,6 +406,34 @@ Expr BPEmpty() { ...@@ -408,6 +406,34 @@ Expr BPEmpty() {
return RefCreateNode::make(unitF); return RefCreateNode::make(unitF);
} }
bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::unordered_set<std::string> op_names;
void VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
if (!rev_map.count(op_ref)) {
op_names.insert(op_ref->name);
}
ExprVisitor::VisitExpr_(op);
}
};
MGVisitor mg;
mg.VisitExpr(e);
if (mg.op_names.size() > 0) {
LOG(WARNING) << "found operators with missing gradients:";
for (const auto& op : mg.op_names) {
LOG(WARNING) << " " << op;
}
return true;
}
return false;
}
Expr Gradient(const Expr& re, const Module& mod) { Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re); auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>(); auto f = e.as<FunctionNode>();
...@@ -416,6 +442,7 @@ Expr Gradient(const Expr& re, const Module& mod) { ...@@ -416,6 +442,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
for (const auto& p : f->params) { for (const auto& p : f->params) {
CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor"; CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
} }
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) { Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty()); Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e); Expr rev = ReverseAD(bp)(e);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment