Unverified Commit 7902f762 by Adrian Muresan Committed by GitHub

Fixed typo and type mismatch (#5259)

Co-authored-by: Adrian Muresan <muresan.adrian.bn@gmail.com>
parent 8df97ff6
...@@ -216,13 +216,13 @@ class CustomPipeline: ...@@ -216,13 +216,13 @@ class CustomPipeline:
obj = self obj = self
class ReplaceConstant(tvm.relay.ExprMutator): class ReplaceConstant(tvm.relay.ExprMutator):
def visit_const(self, c): def visit_constant(self, c):
return relay.multiply(obj.multiplier, c) return relay.multiply(obj.multiplier, c)
return ReplaceConstant().visit(func) return ReplaceConstant().visit(func)
f = example() f = example()
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
custom_pass = CustomPipeline(multiplier=relay.const(3, "float")) custom_pass = CustomPipeline(multiplier=relay.const(3, "float32"))
assert custom_pass.info.name == "CustomPipeline" assert custom_pass.info.name == "CustomPipeline"
mod3 = custom_pass(mod) mod3 = custom_pass(mod)
print(mod3) print(mod3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment