Commit cc7a8fcf by Ding Committed by Tianqi Chen

[TOPI] Overload operators of Tensor when TOPI is imported (#1029)

parent 63a3477a
......@@ -15,6 +15,7 @@ from . import module
from . import node
from . import ir_builder
from . import target
from . import generic
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
......
......@@ -18,32 +18,34 @@ For example, you can use addexp.a to get the left operand of an Add node.
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import generic as _generic
from . import _api_internal
class ExprOp(object):
def __add__(self, other):
return _make.Add(self, other)
return _generic.add(self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return _make.Sub(self, other)
return _generic.subtract(self, other)
def __rsub__(self, other):
return _make.Sub(other, self)
return _generic.subtract(other, self)
def __mul__(self, other):
return _make.Mul(self, other)
return _generic.multiply(self, other)
def __rmul__(self, other):
return _make.Mul(other, self)
return _generic.multiply(other, self)
def __div__(self, other):
return _make.Div(self, other)
return _generic.divide(self, other)
def __rdiv__(self, other):
return _make.Div(other, self)
return _generic.divide(other, self)
def __truediv__(self, other):
return self.__div__(other)
......
"""Generic opertors in TVM.
We follow the numpy naming convention for this interface
(e.g., tvm.generic.multitply ~ numpy.multiply).
The default implementation is used by tvm.ExprOp.
"""
# pylint: disable=unused-argument
from . import make as _make
#Operator precedence used when overloading.
__op_priority__ = 0
def add(lhs, rhs):
"""Generic add operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of add operaton.
"""
return _make.Add(lhs, rhs)
def subtract(lhs, rhs):
"""Generic subtract operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of subtract operaton.
"""
return _make.Sub(lhs, rhs)
def multiply(lhs, rhs):
"""Generic multiply operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of multiply operaton.
"""
return _make.Mul(lhs, rhs)
def divide(lhs, rhs):
"""Generic divide operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
Returns
-------
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.Div(lhs, rhs)
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
def test_operator_type_and_tags():
k = 1
n = tvm.var('n')
A = tvm.placeholder((), name='A')
B = tvm.placeholder((10, 5), name='B')
B1 = B[0]
B2 = B[0,0]
assert isinstance(k + n, tvm.expr.Expr)
assert isinstance(n + n, tvm.expr.Expr)
assert isinstance(k + A, tvm.expr.Expr)
assert isinstance(A + k, tvm.expr.Expr)
assert isinstance(n + A, tvm.expr.Expr)
assert isinstance(A + n, tvm.expr.Expr)
assert isinstance(A + A, tvm.expr.Expr)
assert isinstance(k + B, tvm.tensor.Tensor)
assert isinstance(B + k, tvm.tensor.Tensor)
assert isinstance(n + B, tvm.tensor.Tensor)
assert isinstance(B + n, tvm.tensor.Tensor)
assert isinstance(A + B, tvm.tensor.Tensor)
assert isinstance(B + A, tvm.tensor.Tensor)
assert isinstance(B + B, tvm.tensor.Tensor)
assert (k + B).op.tag == topi.tag.ELEMWISE
assert (B + k).op.tag == topi.tag.ELEMWISE
assert (n + B).op.tag == topi.tag.ELEMWISE
assert (B + n).op.tag == topi.tag.ELEMWISE
assert (A + B).op.tag == topi.tag.ELEMWISE
assert (B + A).op.tag == topi.tag.ELEMWISE
assert (B + B).op.tag == topi.tag.BROADCAST
assert isinstance(k + B2, tvm.expr.Expr)
assert isinstance(B2 + k, tvm.expr.Expr)
assert isinstance(n + B2, tvm.expr.Expr)
assert isinstance(B2 + n, tvm.expr.Expr)
assert isinstance(B2 + B2, tvm.expr.Expr)
assert isinstance(B2 + A, tvm.expr.Expr)
assert isinstance(A + B2, tvm.expr.Expr)
assert isinstance(B2 + B, tvm.tensor.Tensor)
assert isinstance(B + B2, tvm.tensor.Tensor)
def test_combination():
k = 3
n = 5
m = 10
x = tvm.var('x')
A = tvm.placeholder((n, m), name='A')
B = tvm.placeholder((n, m), name='B')
C = tvm.placeholder((n, m), name='C')
D = k + A - B * C / x
s = tvm.create_schedule(D.op)
foo = tvm.build(s, [x, A, B, C, D], "llvm")
ctx = tvm.cpu(0)
x = 2
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n, m)).astype(C.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
foo(x, a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), k + a.asnumpy() - b.asnumpy() * c.asnumpy() / x)
def verify_tensor_scalar_bop(shape, typ="add"):
"""Verify non-constant Tensor and scalar binary operations."""
sh = [tvm.var('n%d' % i) for i in range(0, len(shape))]
k = tvm.var('k')
A = tvm.placeholder(sh, name='A')
if typ == "add":
B = A + k
elif typ == "sub":
B = A - k
elif typ == "mul":
B = A * k
elif typ == "div":
B = A / k
else:
raise NotImplementedError()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
k_ = 2
foo = tvm.build(s, [A, B, k] + sh, device, name="tensor_scalar_" + typ)
a_npy = np.random.uniform(size=shape).astype(A.dtype)
if typ == "add":
b_npy = a_npy + k_
elif typ == "sub":
b_npy = a_npy - k_
elif typ == "mul":
b_npy = a_npy * k_
elif typ == "div":
b_npy = a_npy / k_
else:
raise NotImplementedError()
a_nd = tvm.nd.array(a_npy, ctx)
b_nd = tvm.nd.array(np.empty(b_npy.shape).astype(B.dtype), ctx)
foo(a_nd, b_nd, k_, *shape)
np.testing.assert_allclose(b_nd.asnumpy(), b_npy, rtol=1e-5)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def verify_broadcast_bop(lhs_shape, rhs_shape, typ="add"):
A = tvm.placeholder(shape=lhs_shape, name="A")
B = tvm.placeholder(shape=rhs_shape, name="B")
if typ == "add":
C = A + B
elif typ == "sub":
C = A - B
elif typ == "mul":
C = A * B
elif typ == "div":
C = A / B
else:
raise NotImplementedError()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
if typ == "add":
out_npy = lhs_npy + rhs_npy
elif typ == "sub":
out_npy = lhs_npy - rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
else:
raise NotImplementedError()
lhs_nd = tvm.nd.array(lhs_npy, ctx)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def verify_conv2d_scalar_bop(batch, in_size, in_channel, num_filter, kernel, stride, padding, typ="add"):
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
k = 10.0
with tvm.target.create(device):
A = tvm.placeholder((batch, in_channel, in_size, in_size), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, stride, padding)
if typ == "add":
C = B + k
elif typ == "sub":
C = B - k
elif typ == "mul":
C = B * k
elif typ == "div":
C = B / k
else:
raise NotImplementedError()
s = topi.generic.schedule_conv2d_nchw([C])
foo = tvm.build(s, [A, W, B, C], device, name="conv2d_scalar_" + typ)
a_npy = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_npy = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_npy = topi.testing.conv2d_nchw_python(a_npy, w_npy, stride, padding)
c_npy = np.random.uniform(size=get_const_tuple(B.shape)).astype(B.dtype)
if typ == "add":
c_npy = b_npy + k
elif typ == "sub":
c_npy = b_npy - k
elif typ == "mul":
c_npy = b_npy * k
elif typ == "div":
c_npy = b_npy / k
else:
raise NotImplementedError()
a_nd = tvm.nd.array(a_npy, ctx)
w_nd = tvm.nd.array(w_npy, ctx)
b_nd = tvm.nd.array(np.empty(b_npy.shape).astype(B.dtype), ctx)
c_nd = tvm.nd.array(np.empty(c_npy.shape).astype(C.dtype), ctx)
foo(a_nd, w_nd, b_nd, c_nd)
np.testing.assert_allclose(c_nd.asnumpy(), c_npy, rtol=1E-4, atol=1E-4)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_tensor_scalar_bop():
verify_tensor_scalar_bop((1,), typ="add")
verify_tensor_scalar_bop((3, 5), typ="sub")
verify_tensor_scalar_bop((1, 3, 5), typ="mul")
verify_tensor_scalar_bop((2, 3, 1, 32), typ="div")
def test_broadcast_bop():
verify_broadcast_bop((2, 3), (), typ="add")
verify_broadcast_bop((5, 2, 3), (1,), typ="add")
verify_broadcast_bop((1, 32), (64, 32), typ="sub")
verify_broadcast_bop((5, 64, 128), (2, 5, 64, 1), typ="mul")
verify_broadcast_bop((2, 3, 1, 32), (64, 32), typ="div")
def test_conv2d_scalar_bop():
verify_conv2d_scalar_bop(1, 16, 4, 4, 3, 1, 1, typ="add")
verify_conv2d_scalar_bop(1, 32, 2, 1, 3, 1, 1, typ="sub")
verify_conv2d_scalar_bop(1, 32, 1, 1, 3, 1, 1, typ="mul")
verify_conv2d_scalar_bop(1, 16, 2, 1, 3, 1, 1, typ="div")
if __name__ == "__main__":
test_operator_type_and_tags()
test_combination()
test_tensor_scalar_bop()
test_broadcast_bop()
test_conv2d_scalar_bop()
\ No newline at end of file
......@@ -17,6 +17,7 @@ from . import cpp
from .math import *
from .tensor import *
from .generic_op_impl import *
from .reduction import *
from .transform import *
from .broadcast import *
......
"""Implementation of generic operators in the presence of Tensor"""
# pylint: disable=invalid-name, too-many-arguments
from __future__ import absolute_import as _abs
import tvm
from . import broadcast as _broadcast
from . import tag
def _make_bop(elementwise_bop, broadcast_bop, orig_bop):
"""Make a specific overloaded binary operator of Tensor when applicable;
apply the original operator if it is not supposed to be overloaded.
Consider the following scenario:
OP : + | - | * | /
R0 : int | float | Expr | TensorSlice | Tensor (rank zero)
R1 : Tensor (positive rank)
In terms of (LHS OP RHS), we apply the following overloading rules:
(1) We use broadcast_OP(LHS, RHS), when both LHS and RHS are R1.
(2) We perform element-wise operation of Tensor and scalar,
when one of LHS and RHS is R1 and another is R0.
(3) We do not overload OP (i.e. stick to orig_bop) otherwise.
Parameters
----------
elementwise_bop : operator function
Operator for element-wise tensor-scalar operation, for rule (2).
broadcast_bop : operator function
Operator for broadcast tensor-tensor operation, for rule (1).
orig_bop: operator function
Operator before overloading, for rule (3).
Returns
-------
ret : operator function
The overloaded operator function if applicable or orig_bop otherwise.
"""
name = orig_bop.__name__
def _tensor_bop_impl(lhs, rhs):
"""Overloaded {op} operator.
If both operands are non-zero-rank Tensors, it performs
tensor-tensor {op} operation, and broadcasts inputs when necessary.
If one operand is non-zero-rank Tensor, while the other operand is
scalar like type (e.g., numeric types, Expr, or TensorSlice),
it performs tensor-scalar {op} operation on an element-wise basis.
Otherwise, it performs default generic.{op} operation, as defined
in tvm.generic module.
Parameters
----------
lhs : object
Left operand.
rhs : object
Right operand.
Returns
-------
ret : tvm.Tensor (if at least one operand is non-zero-rank Tensor)
tvm.Expr (otherwise)
The result of {op} operation.
"""
def _get_rank(x):
"""Get the rank of a value.
If x is Tensor, then return its rank;
if x is scalar_like (i.e., numeric types, Expr, or TensorSlice), return 0;
otherwise, return -1.
"""
if isinstance(x, tvm.tensor.Tensor):
return len(x.shape)
elif isinstance(x, (int, float, tvm.expr.Expr, tvm.tensor.TensorSlice)):
return 0
return -1
rl = _get_rank(lhs)
rr = _get_rank(rhs)
if rl == -1 or rr == -1 or (rl == 0 and rr == 0):
return orig_bop(lhs, rhs)
elif rl > 0 and rr > 0:
return broadcast_bop(lhs, rhs)
elif rl == 0:
f = lambda *i: elementwise_bop(lhs, rhs(*i))
with tvm.tag_scope(tag=tag.ELEMWISE):
return tvm.compute(rhs.shape, f, "tensor_" + name)
elif rr == 0:
f = lambda *i: elementwise_bop(lhs(*i), rhs)
with tvm.tag_scope(tag=tag.ELEMWISE):
return tvm.compute(lhs.shape, f, "tensor_" + name)
else:
raise AssertionError("Cannot reach this line.")
_tensor_bop_impl.__doc__ = _tensor_bop_impl.__doc__.format(op=name)
return _tensor_bop_impl
def _bind_generic_ops():
"""Bind generic operators for Tensor."""
# Check __op_priority__ to make sure the binding happens only once.
__op_priority__ = 1
if __op_priority__ > tvm.generic.__op_priority__:
tvm.generic.__op_priority__ = __op_priority__
tvm.generic.add = _make_bop(lambda x, y: x + y,
_broadcast.broadcast_add,
tvm.generic.add)
tvm.generic.subtract = _make_bop(lambda x, y: x - y,
_broadcast.broadcast_sub,
tvm.generic.subtract)
tvm.generic.multiply = _make_bop(lambda x, y: x * y,
_broadcast.broadcast_mul,
tvm.generic.multiply)
tvm.generic.divide = _make_bop(lambda x, y: x / y,
_broadcast.broadcast_div,
tvm.generic.divide)
_bind_generic_ops()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment