Unverified Commit e6dd8e1e by Andrew Liu Committed by GitHub

[Relay] GradientCell Relay Pass (#5039)

* save

* gradient.rly

* fix

* NOT WORKING: gradient cell pass

* test gradient pass

* fixed basic call ops

* more tests

* fix bug

* transform calls to one ones_like zero zero_like

* maintenance stuff

* fix linting

* linting

* linting

* throw default

* remove unrelated changes

* import gradent.rly in pass

* comment

* linting

* remove changes to test files

* move gradient_cell.cc to transforms

* revert change

* update files with new commits

* type

* wrapper function to main outermost function type

* fix linting

* fix unsigned and signed int comparison

* review

* GetConstructor definition in module and change op comparison

* update node instantiations

* increase code readability

Co-authored-by: Marisa Kirisame <lolisa@marisa.moe>
parent a6de507b
...@@ -163,6 +163,14 @@ class IRModuleNode : public Object { ...@@ -163,6 +163,14 @@ class IRModuleNode : public Object {
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const; TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*! /*!
* \brief Find constructor of ADT using name
* \param adt name of the ADT the constructor belongs to
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
/*!
* \brief Look up a global function by its variable. * \brief Look up a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
......
...@@ -78,6 +78,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< ...@@ -78,6 +78,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
TVM_DLL Pass DeadCodeElimination(bool inline_once = false); TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*! /*!
* \brief Convert all expressions of TensorType into GradCell,
* an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
* rather only instantiate if needed. It also defines + and * operation
* between GradCell types which can increase performance when using
* zero-filled or one-filled tensors, which is the case in reverse mode ad.
*
* \return the pass
*/
TVM_DLL Pass LazyGradientInit();
/*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
* *
* \return The pass. * \return The pass.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4
/*
* Store the Gradient Value of a Tensor of type T.
* Note that Gradient of T is stored inside a Ref(GradCell[T]) instead of GradCell[T].
*/
type GradCell[T] {
Raw(T),
One(fn() -> T),
Zero(fn() -> T)
}
def @FromGradCell[T](%g: GradCell[T]) -> T {
match (%g) {
Raw(%x) => %x,
One(%x) => %x(),
Zero(%x) => %x()
}
}
def @MultiplyGradCell[T](%multiply: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match((%l, %r)) {
(Zero(_), _) => %l,
(_, Zero(_)) => %r,
(One(_), _) => %r,
(_, One(_)) => %l,
_ => Raw(%multiply(@FromGradCell(%l), @FromGradCell(%r)))
}
}
def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match ((%l, %r)) {
(Zero(_), _) => %r,
(_, Zero(_)) => %l,
_ => Raw(%add(@FromGradCell(%l), @FromGradCell(%r)))
}
}
...@@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False): ...@@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False):
""" """
return _ffi_api.DeadCodeElimination(inline_once) return _ffi_api.DeadCodeElimination(inline_once)
def LazyGradientInit():
"""Reduces memory usage of gradient tensors
Parameters
----------
Returns
-------
ret: tvm.relay.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
return _ffi_api.LazyGradientInit()
def FoldConstant(): def FoldConstant():
"""Fold the constant expressions in a Relay program. """Fold the constant expressions in a Relay program.
......
...@@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { ...@@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second; return (*it).second;
} }
Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}
LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}
tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const { tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars; std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) { for (const auto& pair : global_type_var_map_) {
......
...@@ -867,7 +867,9 @@ def test_extern_adt_defn(): ...@@ -867,7 +867,9 @@ def test_extern_adt_defn():
""", """,
mod mod
) )
def test_import_grad():
mod = tvm.IRModule()
mod.import_from_std("gradient.rly")
if __name__ == "__main__": if __name__ == "__main__":
test_comments() test_comments()
...@@ -903,3 +905,4 @@ if __name__ == "__main__": ...@@ -903,3 +905,4 @@ if __name__ == "__main__":
test_duplicate_adt_cons_defn() test_duplicate_adt_cons_defn()
test_duplicate_global_var() test_duplicate_global_var()
test_extern_adt_defn() test_extern_adt_defn()
test_import_grad()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment