Commit b20678b0 by ziheng Committed by Tianqi Chen

[TOPI] Fix declaration for different dtypes (#546)

parent b384cd4a
...@@ -18,6 +18,7 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -18,6 +18,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
from . import make as _make from . import make as _make
from . import _api_internal
class ExprOp(object): class ExprOp(object):
def __add__(self, other): def __add__(self, other):
...@@ -60,7 +61,8 @@ class ExprOp(object): ...@@ -60,7 +61,8 @@ class ExprOp(object):
return _make.Mod(self, other) return _make.Mod(self, other)
def __neg__(self): def __neg__(self):
return self.__mul__(-1) neg_one = _api_internal._const(-1, self.dtype)
return self.__mul__(neg_one)
def __lshift__(self, other): def __lshift__(self, other):
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0) return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
......
...@@ -17,7 +17,7 @@ def relu(x): ...@@ -17,7 +17,7 @@ def relu(x):
y : tvm.Tensor y : tvm.Tensor
The result. The result.
""" """
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0)) return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), tvm.const(0, x.dtype)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
......
...@@ -38,7 +38,7 @@ def global_pool(data, pool_type): ...@@ -38,7 +38,7 @@ def global_pool(data, pool_type):
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \ tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
tag="global_pool_sum") tag="global_pool_sum")
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \ return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width), \ tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \
tag=tag.ELEMWISE) tag=tag.ELEMWISE)
else: else:
raise ValueError("Pool type should be 'avg' or 'max'.") raise ValueError("Pool type should be 'avg' or 'max'.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment