Commit 02ddb5a9 by 雾雨魔理沙 Committed by Jared Roesch

save (#3901)

parent 19f8c123
......@@ -290,6 +290,12 @@ def dense_grad(orig, grad):
collapse_sum_like(data * transpose(grad), weight)]
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]
@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
from tvm import relay
......@@ -58,6 +59,4 @@ def test_negative_grad():
if __name__ == "__main__":
test_clip()
test_transpose_grad()
test_negative_grad()
pytest.main()
......@@ -21,7 +21,7 @@ from nose.tools import raises
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, check_grad
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
......@@ -247,6 +247,7 @@ def test_reshape():
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
func = relay.Function([x], z)
check_grad(func)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment