Commit eca4f88a by reminisce Committed by Tianqi Chen

Fix broadcast add and subtract grad (#2465)

parent 5194da65
...@@ -83,7 +83,7 @@ else(MSVC) ...@@ -83,7 +83,7 @@ else(MSVC)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
add_compile_options(-Wall -fPIC -std=c++11) add_compile_options(-O0 -Wall -fPIC -std=c++11)
else() else()
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}")
......
...@@ -5,17 +5,21 @@ import topi ...@@ -5,17 +5,21 @@ import topi
from .op import register_compute, register_schedule, register_pattern from .op import register_compute, register_schedule, register_pattern
from .op import register_gradient from .op import register_gradient
from .op import schedule_injective, OpPattern from .op import schedule_injective, OpPattern
from .transform import collapse_sum_like
from .tensor import negative
def add_grad(orig, grad): def add_grad(orig, grad):
from tvm.relay import op return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])]
return [op.broadcast_to_like(grad, orig.args[0]), op.broadcast_to_like(grad, orig.args[1])]
register_gradient("add", add_grad) register_gradient("add", add_grad)
def subtract_grad(orig, grad): def subtract_grad(orig, grad):
from tvm.relay import op return [collapse_sum_like(grad, orig.args[0]),
return [op.broadcast_to_like(grad, orig.args[0]), collapse_sum_like(negative(grad), orig.args[1])]
op.broadcast_to_like(op.negative(grad), orig.args[1])]
register_gradient("subtract", subtract_grad) register_gradient("subtract", subtract_grad)
......
...@@ -69,8 +69,64 @@ def test_sub(): ...@@ -69,8 +69,64 @@ def test_sub():
np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy())) np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
def test_broadcast_add():
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
x_nd = rand(dtype, *shape1)
y_nd = rand(dtype, *shape2)
x_np = x_nd.asnumpy()
y_np = y_nd.asnumpy()
expected_forward = x_np + y_np
t1 = relay.TensorType(shape1, dtype)
t2 = relay.TensorType(shape2, dtype)
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x + y)
full_func = relay.ir_pass.infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
relay.TupleType([t1, t2])]))
ex = create_executor()
forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
np.testing.assert_allclose(grad_y.asnumpy(),
np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
def test_broadcast_subtract():
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
x_nd = rand(dtype, *shape1)
y_nd = rand(dtype, *shape2)
x_np = x_nd.asnumpy()
y_np = y_nd.asnumpy()
expected_forward = x_np - y_np
t1 = relay.TensorType(shape1, dtype)
t2 = relay.TensorType(shape2, dtype)
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x - y)
full_func = relay.ir_pass.infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
relay.TupleType([t1, t2])]))
ex = create_executor()
forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
np.testing.assert_allclose(grad_y.asnumpy(),
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_add() test_add()
test_temp_add() test_temp_add()
test_sub() test_sub()
test_broadcast_add()
test_broadcast_subtract()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment