Commit 02ddb5a9 by 雾雨魔理沙 Committed by Jared Roesch

save (#3901)

parent 19f8c123
...@@ -290,6 +290,12 @@ def dense_grad(orig, grad): ...@@ -290,6 +290,12 @@ def dense_grad(orig, grad):
collapse_sum_like(data * transpose(grad), weight)] collapse_sum_like(data * transpose(grad), weight)]
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]
@register_gradient("nn.batch_flatten") @register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad): def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims""" """Returns grad reshaped to data dims"""
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import numpy as np import numpy as np
import pytest
import tvm import tvm
from tvm import relay from tvm import relay
...@@ -58,6 +59,4 @@ def test_negative_grad(): ...@@ -58,6 +59,4 @@ def test_negative_grad():
if __name__ == "__main__": if __name__ == "__main__":
test_clip() pytest.main()
test_transpose_grad()
test_negative_grad()
...@@ -21,7 +21,7 @@ from nose.tools import raises ...@@ -21,7 +21,7 @@ from nose.tools import raises
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import create_executor, transform from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, check_grad
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
...@@ -247,6 +247,7 @@ def test_reshape(): ...@@ -247,6 +247,7 @@ def test_reshape():
assert zz.checked_type == relay.ty.TensorType(oshape, "float32") assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
func = relay.Function([x], z) func = relay.Function([x], z)
check_grad(func)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape) ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment