Unverified Commit e4a5441d by Yizhi Liu Committed by GitHub

[TE] reverse-mode autodiff without any optimization (#5121)

* [TE] reverse-mode autodiff without any optimization

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>

* address review comments

* add comments and retrigger CI

* move unittest to debug ci

* move test back and add seed

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
parent ff7bab80
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/te/autodiff.h
* \brief Automatic differentiation of tensor expressions.
*/
#ifndef TVM_TE_AUTODIFF_H_
#define TVM_TE_AUTODIFF_H_
#include <tvm/runtime/object.h>
#include <tvm/tir/expr.h>
#include "tensor.h"
namespace tvm {
/*! \brief Tensor expression language DSL. */
namespace te {
/*!
* \brief Take the derivative of the expression with respect to the given variable.
* \param expr The expression to differentiate.
* \param var The variable to differentiate with respect to.
* \return The expression for the derivative.
*/
PrimExpr Derivative(const PrimExpr& expr, const Var& var);
/*!
* \brief Get the tensor representing the Jacobian of the output with respect to the input.
*
* Note that if \p output depends on \p input indirectly (by using some other tensor
* depending on \p input), this dependency won't contribute to the resulting Jacobian.
* For such cases use the function ::Gradient.
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \return The tensor representing the Jacobian of shape `output.shape + input.shape`.
*/
Tensor Jacobian(const Tensor& output, const Tensor& input);
/*!
* \brief The building block for reverse-mode AD.
*
* Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor
* dot product. \p input must be an immediate dependency of \p output (must be called from within
* the body of \p output). That is, the function will compute one summand of the adjoint for \p input
* given the adjoint for \p output (which is called \p head here).
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \param head The adjoint of \p output. Must be of shape `prefix + output.shape`
* \return The tensor of shape `prefix + input.shape`
* representing the partial adjoint of \p input wrt one of its consumers (output)
*/
Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head);
/*!
* \brief Perform reverse mode automatic differentiation.
*
* Each item of the `result` field of the result is an adjoint for the corresponding item of
* \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the
* corresponding item of \p inputs.
*
* \param output The tensor to differentiate.
* \param inputs The array of input tensors. When the array is empty, will perform differentiation
* wrt all tensors the output depends on.
* \param head The adjoint of the output, in other words, some tensor, by which the Jacobians
* will be multiplied (using tensordot axes=`output.shape`).
* Its shape must be of the form `prefix + output.shape`. If the null pointer is provided,
* the identity tensor of shape `output.shape + output.shape` will be used.
* \return An array of adjoints corresponding to \p inputs.
*/
TVM_DLL Array<Tensor> Gradient(
const Tensor& output,
const Array<Tensor>& inputs,
const Tensor& head = Tensor());
} // namespace te
} // namespace tvm
#endif // TVM_TE_AUTODIFF_H_
......@@ -33,3 +33,4 @@ from .operation import placeholder, compute, scan, extern, var, size_var
from .operation import thread_axis, reduce_axis
from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Automatic differentiation of tensor expressions."""
from . import _ffi_api
def gradient(output, inputs, head=None):
"""Perform reverse-mode automatic differentiation.
Parameters
----------
output : Tensor
The tensor to differentiate.
inputs : List[Tensor]
The list of input tensors to be differentiated wrt.
head : Tensor
The adjoint of the output, in other words, some tensor, by which the Jacobians
will be multiplied. Its shape must be of the form `prefix + output.shape`.
If `None` is passed, the identity tensor of shape `output.shape + output.shape`
will be used.
Returns
-------
tensors: List[Tensor]
The result gradient, in the same order as the inputs
Example
-------
.. code-block:: python
x = tvm.placeholder((32, 3, 28, 28), name='x')
w1 = tvm.placeholder((10, 3, 3, 3), name='w1')
w2 = tvm.placeholder((10, 10, 3, 3), name='w2')
z1 = topi.nn.conv2d(x, w1, 1, 1, 1)
z2 = topi.nn.conv2d(z1, w2, 1, 1, 1)
y = topi.sum(z2)
# produce gradients
[dw1, dw2] = tvm.gradient(y, [w1, w2])
# produce Jacobians
[jw1, jw2] = tvm.gradient(z2, [w1, w2])
# produce gradients, the head adjoint for z2 is provided manually
[dw1, dw2] = tvm.gradient(z2, [w1, w2], topi.full_like(z2, 1.0))
"""
if not isinstance(inputs, list):
inputs = [inputs]
return _ffi_api.Gradient(output, inputs, head)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ad_util.cc
* \brief Utility for tensor-level auto-differentiation.
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <string>
#include "ad_util.h"
namespace tvm {
namespace te {
std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars) {
Array<IterVar> new_vars;
Map<Var, PrimExpr> vmap;
for (const IterVar& iv : vars) {
IterVar new_v =
IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""),
iv->iter_type, iv->thread_tag);
new_vars.push_back(new_v);
vmap.Set(iv->var, new_v->var);
}
return std::make_pair(std::move(new_vars), std::move(vmap));
}
PrimExpr CloneReduction(const PrimExpr& expr) {
if (const ReduceNode* red = expr.as<ReduceNode>()) {
Array<IterVar> new_axis;
Map<Var, PrimExpr> vmap;
std::tie(new_axis, vmap) = CloneIterVars(red->axis);
Array<PrimExpr> src_with_newaxis;
for (const auto& src : red->source) {
src_with_newaxis.push_back(tir::Substitute(src, vmap));
}
return ReduceNode::make(red->combiner, src_with_newaxis,
new_axis, tir::Substitute(red->condition, vmap), red->value_index);
} else {
return expr;
}
}
} // namespace te
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ad_util.h
* \brief Helper utilities to implement auto-differentiation.
*/
#ifndef TVM_TE_AUTODIFF_AD_UTIL_H_
#define TVM_TE_AUTODIFF_AD_UTIL_H_
#include <tvm/tir/expr.h>
#include <tvm/te/operation.h>
#include <vector>
#include <unordered_map>
#include <utility>
namespace tvm {
namespace te {
/*!
* \brief Clone iter vars and return both the new vars and the substitution from old to new.
*
* \param vars The original iter vars.
* \return A pair containing the array of new iter vars and the map from old vars to new ones.
*/
std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars);
/*!
* \brief Clone reduction by cloning the axis variables.
* \param expr A reduction expr to clone. Non-reduction expressions are left intact.
*/
PrimExpr CloneReduction(const PrimExpr& expr);
} // namespace te
} // namespace tvm
#endif // TVM_TE_AUTODIFF_AD_UTIL_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file adjoint.cc
* \brief Perform reverse-mode autodiff.
* Suppose we have f(x) = g(h1(x), h2(x), ..., hn(x)),
* df/dx = \sum_i df/dhi * dhi/dx
* We call df/dx as adjoint(x), df/dhi as adjoint(hi), dhi/dx is the Jacobian
* The idea is to first construct the reverse-dependency {input->outputs} between tensors,
* start from one input,
* (1) collect adjoints from all its dependencies (outputs),
* (2) multiply the Jacobian (PartialAdjoint),
* (3) and sum them together to get the adjoint of the input itself.
* The three steps are computed recursively.
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/autodiff.h>
#include <tvm/tir/stmt_functor.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <memory>
#include <vector>
namespace tvm {
namespace te {
Tensor Identity(const Tensor& output) {
Array<PrimExpr> shape = output->shape;
for (auto e : output->shape) {
// add extra dimension for Jacobian
shape.push_back(e);
}
auto func =
[&output](const Array<Var>& input_indices) {
PrimExpr res = const_true();
for (size_t i = 0; i < output->shape.size(); ++i) {
res = res && (PrimExpr(input_indices[i]) ==
PrimExpr(input_indices[output->shape.size() + i]));
}
return CastNode::make(output->dtype, res);
};
return te::compute(shape, func, "identity");
}
Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) {
Tensor jac = Jacobian(output, input);
Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(),
output->op->name + "." + input->op->name + ".grad");
return result;
}
Array<Tensor> Gradient(const Tensor& output,
const Array<Tensor>& inputs,
const Tensor& head_or_null) {
// Diagonal identity tensor
Tensor head = head_or_null.get() ? head_or_null : Identity(output);
// This Map{input -> outputs} maps a tensor to the list of tensors
// immediately depending on it (using it in their bodies)
std::unordered_map<Tensor, std::vector<Tensor>> reverse_dependencies;
std::vector<Tensor> stack({output});
while (!stack.empty()) {
Tensor tensor = stack.back();
stack.pop_back();
for (const Tensor& input : tensor->op->InputTensors()) {
if (!reverse_dependencies.count(input)) {
stack.push_back(input);
}
reverse_dependencies[input].push_back(tensor);
}
}
// This map maps tensors to the corresponding adjoints (dLoss/dTensor)
std::unordered_map<Tensor, Tensor> adjoints;
// head is the adjoint of output by definition
adjoints[output] = head;
// This is a recursive function that does all the work. It computes the adjoint for a given
// tensor, adds it to the map, and returns it
std::function<Tensor(const Tensor&)> compute_adjoint;
compute_adjoint =
[&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output]
(const Tensor& tensor) {
if (!adjoints.count(tensor)) {
// Here the adjoint hasn't been computed yet
Tensor res_adjoint;
std::vector<Tensor> direct_consumers = reverse_dependencies[tensor];
if (direct_consumers.empty()) {
// No reverse dependencies means that the output does not depend on this tensor,
// return a zero tensor of the appropriate shape
// (i.e., output shape + tensor shape, aka shape of Jacobian)
Array<PrimExpr> result_shape(head->shape.begin(),
head->shape.end() + (-output->shape.size()));
for (auto e : tensor->shape) {
result_shape.push_back(e);
}
res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype));
} else {
// The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied
// by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian
// and the multiplication is done in the function VectorJacobianProduct
for (const Tensor& direct_consumer : direct_consumers) {
// part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor)
Tensor part = VectorJacobianProduct(
direct_consumer, tensor, compute_adjoint(direct_consumer));
res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part;
}
}
adjoints[tensor] = res_adjoint;
return res_adjoint;
} else {
return adjoints[tensor];
}
};
// Adjoints corresponding to inputs
Array<Tensor> result;
// Compute an adjoint for each input
for (const Tensor& input : inputs) {
result.push_back(compute_adjoint(input));
}
return result;
}
TVM_REGISTER_GLOBAL("te.Gradient")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LOG(WARNING) << "te.Gradient is an experimental feature.";
if (args.size() == 2) {
*ret = Gradient(args[0], args[1]);
} else if (args.size() == 3) {
*ret = Gradient(args[0], args[1], args[2]);
}
});
} // namespace te
} // namespace tvm
......@@ -25,7 +25,6 @@
#define TVM_TE_OPERATION_COMPUTE_OP_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/te/operation.h>
#include <vector>
#include <unordered_map>
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm.testing import check_numerical_grads, assert_allclose
import topi
from topi.util import get_const_tuple
import numpy as np
def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None):
inputs = inputs if isinstance(inputs, list) else [inputs]
def check_device(device, host="llvm"):
ctx = tvm.context(device, 0)
if not tvm.runtime.enabled(host):
return
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
sout = te.create_schedule(out.op)
mout = tvm.build(sout, [out] + inputs)
out_shape = get_const_tuple(out.shape)
l, h = data_range
input_data = [tvm.nd.array(
np.random.uniform(l, h, size=get_const_tuple(input.shape)).astype(input.dtype))
for input in inputs]
ones = topi.full_like(out, 1.0)
# we provide head to sum and reduce the output dimension,
# which equals to grad(out.sum(), inputs)
grads = te.gradient(out, inputs, head=ones)
grad_sched = te.create_schedule([grad.op for grad in grads])
mgrad = tvm.build(grad_sched, list(grads) + inputs)
# print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True))
grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype)
for i, g in zip(inputs, grads)]
mgrad(*grad_data, *input_data)
g_res = [g.asnumpy() for g in grad_data]
if desired_grads:
assert isinstance(desired_grads, list)
for actual, desired in zip(g_res, desired_grads):
assert_allclose(actual, desired, rtol=0.1, atol=1e-2)
else:
def forward(*in_data):
out_data = tvm.nd.empty(out_shape, out.dtype)
mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
return out_data.asnumpy().sum()
check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res)
check_device("cpu")
def test_basic_operation():
np.random.seed(0)
shape = (10, 10)
x = te.var("x", dtype='float32')
k = te.reduce_axis((0, 10), name="k")
l = te.reduce_axis((0, 10), name="l")
A0 = te.placeholder(shape, name='A0')
A1 = te.placeholder(shape, name='A1')
zeros = np.zeros(shape)
B = te.compute(shape, lambda i, j: A0[i, j], name='B')
check_grad(B, [A0])
B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name='B')
check_grad(B, [A0, A1])
B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B')
check_grad(B, A0, desired_grads=[zeros])
B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B')
check_grad(B, A0, desired_grads=[zeros])
B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B')
check_grad(B, A0, desired_grads=[zeros])
B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B')
check_grad(B, A0, desired_grads=[zeros])
B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.sigmoid(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.tanh(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.sqrt(A0[i, j]*A0[i, j]*A0[j, i]), name='B')
check_grad(B, A0, data_range=(0.1, 10))
B = te.compute(shape, lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]), name='B')
check_grad(B, A0, data_range=(-4, 4))
B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name='B')
check_grad(B, A0)
B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.sum(A0[i, k]*A0[k, i] + 5, axis=k), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: te.max(A0[i, k]*A0[k, j] + 5, axis=k), name='B')
check_grad(B, A0)
B = te.compute(shape, lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B')
check_grad(B, [A0, A1])
B = te.compute(shape, lambda i, j: te.sum(A0[k, k] -
A0[te.min(j + k, 9), j]*A0[i, k],
axis=k), name='B')
check_grad(B, A0)
def fcombine(x, y):
return x*y
def fidentity(t0):
return tvm.tir.const(1, t0)
prod = te.comm_reducer(fcombine, fidentity, name='prod')
B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B')
check_grad(B, A0)
X = te.placeholder((10,), name='X')
A = te.compute((10,), lambda i: X[i] + X[9 - i])
B = te.compute((10,), lambda i: X[i] * X[9 - i])
Y = topi.tensordot(A, B, 1)
check_grad(Y, X)
def test_conv2d():
np.random.seed(0)
X = te.placeholder((1, 2, 4, 4), name='X')
W = te.placeholder((5, 2, 3, 3), name='W')
R = topi.nn.conv2d(X, W, 1, 1, 1)
check_grad(R, [X, W])
if __name__ == "__main__":
test_basic_operation()
test_conv2d()
......@@ -34,6 +34,8 @@ def test_check_numerical_grads():
lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))),
]
np.random.seed(0)
# Avoid values too close to 0 since singularities of our functions are there
min_x = 0.5
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment