Unverified Commit e6dd8e1e by Andrew Liu Committed by GitHub

[Relay] GradientCell Relay Pass (#5039)

* save

* gradient.rly

* fix

* NOT WORKING: gradient cell pass

* test gradient pass

* fixed basic call ops

* more tests

* fix bug

* transform calls to one ones_like zero zero_like

* maintenance stuff

* fix linting

* linting

* linting

* throw default

* remove unrelated changes

* import gradent.rly in pass

* comment

* linting

* remove changes to test files

* move gradient_cell.cc to transforms

* revert change

* update files with new commits

* type

* wrapper function to main outermost function type

* fix linting

* fix unsigned and signed int comparison

* review

* GetConstructor definition in module and change op comparison

* update node instantiations

* increase code readability

Co-authored-by: Marisa Kirisame <lolisa@marisa.moe>
parent a6de507b
...@@ -163,6 +163,14 @@ class IRModuleNode : public Object { ...@@ -163,6 +163,14 @@ class IRModuleNode : public Object {
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const; TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*! /*!
* \brief Find constructor of ADT using name
* \param adt name of the ADT the constructor belongs to
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
/*!
* \brief Look up a global function by its variable. * \brief Look up a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
......
...@@ -78,6 +78,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< ...@@ -78,6 +78,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
TVM_DLL Pass DeadCodeElimination(bool inline_once = false); TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*! /*!
* \brief Convert all expressions of TensorType into GradCell,
* an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
* rather only instantiate if needed. It also defines + and * operation
* between GradCell types which can increase performance when using
* zero-filled or one-filled tensors, which is the case in reverse mode ad.
*
* \return the pass
*/
TVM_DLL Pass LazyGradientInit();
/*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
* *
* \return The pass. * \return The pass.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4
/*
* Store the Gradient Value of a Tensor of type T.
* Note that Gradient of T is stored inside a Ref(GradCell[T]) instead of GradCell[T].
*/
type GradCell[T] {
Raw(T),
One(fn() -> T),
Zero(fn() -> T)
}
def @FromGradCell[T](%g: GradCell[T]) -> T {
match (%g) {
Raw(%x) => %x,
One(%x) => %x(),
Zero(%x) => %x()
}
}
def @MultiplyGradCell[T](%multiply: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match((%l, %r)) {
(Zero(_), _) => %l,
(_, Zero(_)) => %r,
(One(_), _) => %r,
(_, One(_)) => %l,
_ => Raw(%multiply(@FromGradCell(%l), @FromGradCell(%r)))
}
}
def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match ((%l, %r)) {
(Zero(_), _) => %r,
(_, Zero(_)) => %l,
_ => Raw(%add(@FromGradCell(%l), @FromGradCell(%r)))
}
}
...@@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False): ...@@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False):
""" """
return _ffi_api.DeadCodeElimination(inline_once) return _ffi_api.DeadCodeElimination(inline_once)
def LazyGradientInit():
"""Reduces memory usage of gradient tensors
Parameters
----------
Returns
-------
ret: tvm.relay.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
return _ffi_api.LazyGradientInit()
def FoldConstant(): def FoldConstant():
"""Fold the constant expressions in a Relay program. """Fold the constant expressions in a Relay program.
......
...@@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { ...@@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second; return (*it).second;
} }
Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}
LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}
tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const { tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars; std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) { for (const auto& pair : global_type_var_map_) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file lazy_gradient_init.cc
*
* \brief Lazily instantiate 0-filled or 1-filled tensors.
* This pass should be used after reverse-mode ad so that gradient tensors
* are not instantiated until after the forward pass.
*
* This pass delays or removes memory allocation by converting tensors into
* GradCell, an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will call the One or Zero constructor
* of GradCell, which will not instantiate in memory until needed. All other cases result
* in using the Raw constructor which means the tensor is instantiated in memory.
*
* It also overloads + and * operation which can increase performance when doing
* operations involving tensors with values of only 0 or 1.
*
* Note: this pass can only be used with functions where the input/output types are
* a combination of TupleTypes and TensorTypes
*
* This pass optimizes 6 ops:
* - add
* - multiply
* - ones
* - ones_like
* - zeros
* - zeros_like
*
* This pass makes use of three visitor. The most important one visits the entire function,
* one is used for wrap inputs and one to unwrap outputs.
*
* For example:
* fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32]
*
* After this pass
* fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]]
*
* Thus, it is necessary to wrap this outer function so that the input/output types remain the same
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
#include "let_list.h"
namespace tvm {
namespace relay {
/*!
* \brief Visitor appropriately wraps tensors with Raw constructor
*
* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now)
* and either call the GradCell constructor if TensorType
* or unfold and recursively visit if TupleType
*/
class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
explicit InputVisitor(IRModule module): module_(module) {}
Expr VisitExpr_(const VarNode* op, const Type& t) final {
std::cout << op->type_annotation << std::endl;
return WrapExpr(GetRef<Var>(op), op->type_annotation);
}
Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return WrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;
Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
const Type& t = type_anno->fields[i];
fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
}
Expr tuple = Tuple(fields);
return tuple;
}
return expr;
}
};
/*!
* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors
*
* Recursively looks at the type of the expression
* and either use the FromGradCell function if TypeCall to GradCell
* or unfold and recursively visit if TupleType
*/
class OutputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
explicit OutputVisitor(IRModule module): module_(module) {}
Expr VisitExpr_(const CallNode* op, const Type& t) final {
return UnwrapExpr(GetRef<Call>(op), t);
}
Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return UnwrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;
Expr UnwrapExpr(const Expr expr, const Type& type) {
if (auto* type_call = type.as<TypeCallNode>()) {
if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) {
return Call(module_->GetGlobalVar("FromGradCell"), {expr});
}
return expr;
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
const Type& t = type_anno->fields[i];
fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
}
Expr tuple = Tuple(fields);
return tuple;
}
return expr;
}
};
class LazyGradientInitializer: public ExprMutator, public TypeMutator {
public:
explicit LazyGradientInitializer(IRModule module):
module_(module) {
module_->ImportFromStd("gradient.rly");
}
/*!
* \brief apply LazyGradientInit transformation and wrap function
* so that function type stays the same
*
* input/output types should only be a combination of TupleTypes and TensorTypes
*/
Expr Transform(const Expr& e) {
auto* f = (e).as<FunctionNode>();
auto* transformed = this->Mutate(e).as<FunctionNode>();
if (e.same_as(GetRef<Function>(transformed))) {
return GetRef<Function>(transformed);
}
// wrap inputs of Tensor type using InputVisitor class
tvm::Array<Expr> args;
for (Var var : f->params) {
Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type());
args.push_back(wrappedInput);
}
Expr transformedExpr = Call(GetRef<Function>(transformed), args);
// unwrap outputs of GradCell type into Tensor type using OutputVisitor class
Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type);
return Function(f->params, tensorOutput, f->ret_type, Array<TypeVar>());
}
Expr VisitExpr_(const ConstantNode* op) final {
return Call(module_->GetConstructor("GradCell", "Raw"),
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
}
Expr VisitExpr_(const CallNode* call_node) final {
if (auto* op = (call_node->op).as<OpNode>()) {
Expr op_expr = GetRef<Op>(op);
if (op_expr == Op::Get("add")) {
return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell"));
}
if (op_expr == Op::Get("multiply")) {
return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell"));
}
if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
// fn() -> T, function returns result of the operation
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", constructor_name),
{func}, Attrs(), {call_node->checked_type()});
}
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
// ones_like and zeros_like need TensorType input
Expr result = CallPrimitiveOp(call_node);
// fn() -> T, function returns result of operation
Expr func = Function({}, result,
{call_node->checked_type()}, Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()});
}
// handle all other ops
Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
}
// not an op
return ExprMutator::VisitExpr_(call_node);
}
Type VisitType(const Type& t) final {
return TypeMutator::VisitType(t);
}
Type VisitType_(const TensorTypeNode* op) {
GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell");
tvm::Array<Type> args;
args.push_back(GetRef<TensorType>(op));
return TypeCall(gradCell, args);
}
private:
// Module
IRModule module_;
/*!
* \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type
*/
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
// can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 ||
!AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node);
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
}
tvm::Array<Expr> args;
// create "fallback" function for overloaded function
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
// use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]});
Expr func = Function(params, callOp, paramType,
Array<TypeVar>());
// pass "fallback" function and tensors as arguments
args.push_back(func);
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
// return new call to overloaded function
return Call(overloaded_op, args, Attrs(), {paramType});
}
/*!
* \brief Convert calls to other ops by converting args into TensorType
* \return call expr returning result of op
*/
Expr CallPrimitiveOp(const CallNode* call_node) {
const auto fromFunc = module_->GetGlobalVar("FromGradCell");
tvm::Array<Expr> args;
// use FromGradCell to convert args to Tensor
for (Expr expr : call_node->args) {
args.push_back(Call(fromFunc,
{VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
}
// result of operation
return Call(call_node->op, args);
}
};
Expr LazyGradientInit(const Expr& e, IRModule mod) {
return LazyGradientInitializer(mod).Transform(e);
}
namespace transform {
Pass LazyGradientInit() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(LazyGradientInit(f, m));
};
return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {});
}
TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit")
.set_body_typed(LazyGradientInit);
} // namespace transform
} // namespace relay
} // namespace tvm
...@@ -867,7 +867,9 @@ def test_extern_adt_defn(): ...@@ -867,7 +867,9 @@ def test_extern_adt_defn():
""", """,
mod mod
) )
def test_import_grad():
mod = tvm.IRModule()
mod.import_from_std("gradient.rly")
if __name__ == "__main__": if __name__ == "__main__":
test_comments() test_comments()
...@@ -903,3 +905,4 @@ if __name__ == "__main__": ...@@ -903,3 +905,4 @@ if __name__ == "__main__":
test_duplicate_adt_cons_defn() test_duplicate_adt_cons_defn()
test_duplicate_global_var() test_duplicate_global_var()
test_extern_adt_defn() test_extern_adt_defn()
test_import_grad()
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import rand, run_infer_type
from tvm.testing import assert_allclose
import pytest
def test_tc():
"""Simple testcase, check that transformation typechecks."""
mod = tvm.IRModule()
shape = (20, 20)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x1 = relay.var("x1", t)
x2 = relay.var("x2", t)
# f(x1,x2) = (x1-x2)*x2
y = relay.Function([x1, x2], (x1 - x2) * x2)
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
# function input/output types should remain the same
assert mod["main"].checked_type == relay.FuncType([t, t], t)
def test_add():
"""Simple add testcase. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
# f(x) = x+x
y = relay.Function([x], x+x)
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy())
def test_add_tuple():
"""Add elements of tuple. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
tensor_type = relay.TensorType(shape, dtype)
t = relay.TupleType([tensor_type, tensor_type])
x = relay.var("x", t)
# f((x1,x2)) = x1 + x2
y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1))
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
mod = transform.PrintIR(show_meta_data=True)(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
ex = create_executor(mod=mod)
x = (rand(dtype, *shape), rand(dtype, *shape))
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy())
def test_mult():
"""Simple multiplication testcase. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape = (15, 15)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
# f(x) = x*x
y = relay.Function([x], x * x)
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy())
def test_ret_tuple():
"""Test tuple return type. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
# f(x) = (x,x)
func = relay.Function([x], relay.Tuple([x,x * relay.const(2.0)]))
func = run_infer_type(func)
mod["main"] = func
mod = transform.LazyGradientInit()(mod)
func = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t]))
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(func)(x)
assert_allclose(y[0].asnumpy(), x.asnumpy())
assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0)
def test_add_broadcast():
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
t1 = relay.TensorType(shape1, dtype)
t2 = relay.TensorType(shape2, dtype)
x1 = relay.var("x1", t1)
x2 = relay.var("x2", t2)
func = relay.Function([x1,x2], x1 + x2)
func = run_infer_type(func)
mod["main"] = func
mod = transform.LazyGradientInit()(mod)
func = mod["main"]
x1_np = rand(dtype, *shape1).asnumpy()
x2_np = rand(dtype, *shape2).asnumpy()
expected_forward = x1_np + x2_np
expected_forward_type = relay.TensorType(expected_forward.shape, dtype)
assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type)
ex = create_executor(mod=mod)
forward = ex.evaluate(func)(x1_np, x2_np)
assert_allclose(forward.asnumpy(), expected_forward)
def test_reverse_ad_identity():
"""Simple test with reverse mode ad."""
# of f(x) = x
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
func = run_infer_type(func)
back_func = transform.gradient(func)
back_func = run_infer_type(back_func)
mod["main"] = back_func
mod = transform.LazyGradientInit()(mod)
back_func = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t],
relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
(forward), (grad,) = ex.evaluate(back_func)(x)
assert_allclose(forward.asnumpy(), x.asnumpy())
assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
def test_multivar_reverse_ad():
"""Simple test with multivariate reverse mode ad."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype)))
func = run_infer_type(func)
back_func = transform.gradient(func)
back_func = run_infer_type(back_func)
mod["main"] = back_func
mod = transform.LazyGradientInit()(mod)
back_func = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t, t],
relay.TupleType([t, relay.TupleType([t, t])]))
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = rand(dtype, *shape)
(forward), (grad_x, grad_y, ) = ex.evaluate(back_func)(x, y)
assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
assert_allclose(grad_x.asnumpy(), y.asnumpy())
assert_allclose(grad_y.asnumpy(), x.asnumpy())
def test_after_partial_eval():
"""Test transformation following reverse mode ad and PartialEval"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype)))
func = run_infer_type(func)
back_func = transform.gradient(func)
back_func = run_infer_type(back_func)
mod["main"] = back_func
back_func = mod["main"]
seq = transform.Sequential([
transform.PartialEvaluate(),
transform.LazyGradientInit(),
transform.DeadCodeElimination()
])
mod = seq(mod)
assert mod["main"].checked_type == relay.FuncType([t, t],
relay.TupleType([t, relay.TupleType([t, t])]))
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = rand(dtype, *shape)
(forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y)
assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
assert_allclose(grad_x.asnumpy(), y.asnumpy())
assert_allclose(grad_y.asnumpy(), x.asnumpy())
def test_before_partial_eval():
"""Test transformation before PartialEval"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
func = relay.Function([x, y], x * y)
func = run_infer_type(func)
back_func = transform.gradient(func)
back_func = run_infer_type(back_func)
mod["main"] = back_func
seq = transform.Sequential([
transform.LazyGradientInit(),
transform.PartialEvaluate(),
transform.DeadCodeElimination()
])
mod = seq(mod)
back_func = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t, t],
relay.TupleType([t, relay.TupleType([t, t])]))
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = rand(dtype, *shape)
(forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y)
assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy())
assert_allclose(grad_x.asnumpy(), y.asnumpy())
assert_allclose(grad_y.asnumpy(), x.asnumpy())
def test_zeros():
"""Simple test using "zeros" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.Function([x], x + relay.zeros(shape, dtype))
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy())
def test_ones():
"""Simple test using "ones" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.Function([x], x + relay.ones(shape, dtype))
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
def test_zeros_like():
"""Simple test using "zeros_like" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.Function([x], x + relay.zeros_like(x))
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy())
def test_ones_like():
"""Simple test using "ones_like" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.Function([x], x + relay.ones_like(x))
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], t)
ex = create_executor(mod=mod)
x = rand(dtype, *shape)
y = ex.evaluate(y)(x)
assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy()))
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment