Commit d45b6d4b by ziheng Committed by Tianqi Chen

[DOC] Add intro to 'comm_reducer' in tutorial; fix doc (#108)

* [DOC] Add intro to 'comm_reducer' in tutorial; fix doc

* Fix

* Fix
parent 6ab6bb3f
......@@ -84,6 +84,13 @@ Stmt CanonicalSimplify(Stmt stmt);
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
......
......@@ -70,7 +70,7 @@ setuptools.setup(
],
zip_safe=False,
packages=[
'tvm', 'tvm.addon',
'tvm', 'tvm.contrib',
'tvm._ffi', 'tvm._ffi._ctypes',
'tvm._ffi._cy2', 'tvm._ffi._cy3'
],
......
......@@ -23,10 +23,12 @@ handle = "handle"
def min_value(dtype):
"""minimum value of dtype"""
return _api_internal._min_value(dtype)
def max_value(dtype):
"""maximum value of dtype"""
return _api_internal._max_value(dtype)
......@@ -438,7 +440,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
-------
reducer : function
A function which creates a reduce expression over axis.
There are two to use it:
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
......
......@@ -125,9 +125,27 @@ np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
######################################################################
# Define General Commutative Reduction Operation
# ----------------------------------------------
# Besides the built-in reduction operations like :any:`tvm.sum`,
# :any:`tvm.min` and :any:`tvm.max`, you can also define your
# commutative reduction operation by :any:`tvm.comm_reducer`.
#
n = tvm.var('n')
m = tvm.var('m')
product = tvm.comm_reducer(lambda x, y: x*y,
lambda t: tvm.const(1, dtype=t), name="product")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
######################################################################
# Summary
# -------
# This tutorial provides a walk through of reduction schedule.
#
# - Describe reduction with reduce_axis.
# - Use rfactor to factor out axis if we need parallelism.
# - Define new reduction operation by :any:`tvm.comm_reducer`
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment