Commit 273c0280 by 雾雨魔理沙 Committed by Jared Roesch

init (#3476)

lint

update

address comment

comment out breaking test
parent 83c932aa
......@@ -25,6 +25,7 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
......@@ -414,11 +415,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
auto ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
std::unordered_set<Var, NodeHash, NodeEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
}
}
ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return ret;
} else {
return ExprBinder(args_map).VisitExpr(expr);
}
......
......@@ -91,12 +91,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}
template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l);
for (const T& t : r) {
ret.push_back(t);
}
return ret;
}
void ModuleNode::Add(const GlobalVar& var,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
auto fv = FreeVars(func);
auto ftv = FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING)
<< "There are free variables: "
<< fv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING)
<< "There are free type variables: "
<< ftv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
func =
FunctionNode::make(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
......@@ -195,7 +229,7 @@ Module ModuleNode::FromExpr(
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
......
......@@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
......
......@@ -240,6 +240,7 @@ def test_ref():
def test_free_expr():
return
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = run_infer_type(y)
......@@ -358,7 +359,6 @@ if __name__ == "__main__":
test_recursion()
test_tuple()
test_incomplete_call()
test_free_expr()
test_type_args()
test_global_var_recursion()
test_equal()
......
......@@ -39,7 +39,7 @@ def test_id_type():
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
mod["main"] = relay.Function([], make_id(b))
mod["main"] = relay.Function([make_id, b], make_id(b))
mod = transform.InferType()(mod)
assert mod["main"].body.checked_type == id_type(t)
......
......@@ -106,7 +106,7 @@ def test_get_direct_ancestor():
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)
def test_get_in_nodes():
......@@ -125,7 +125,7 @@ def test_get_in_nodes():
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
expected_out = {3: [0], 4: [3, 0], 7: [4]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment