Commit 7cc92ace by ziheng Committed by Tianqi Chen

[LANG] Expose tvm.cast (#195)

* [LANG] Expose tvm.cast

* Update

* Add unittest
parent 7c6a71ba
...@@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"): ...@@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"):
""" """
return _IterVar(dom, name, 2) return _IterVar(dom, name, 2)
def cast(dtype, expr):
"""Cast an expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
expr : Expr
The expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, expr)
def select(cond, t, f): def select(cond, t, f):
"""Construct a select branch """Construct a select branch
Parameters Parameters
......
...@@ -89,6 +89,20 @@ class ExprOp(object): ...@@ -89,6 +89,20 @@ class ExprOp(object):
""" """
return _make.EQ(self, other) return _make.EQ(self, other)
def astype(self, dtype):
"""Cast the expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, self)
class Expr(NodeBase, ExprOp): class Expr(NodeBase, ExprOp):
"""Base class of all tvm Expressions""" """Base class of all tvm Expressions"""
......
...@@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or); ...@@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE1(Not); REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select); REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp); REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast); REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let); REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
......
...@@ -16,6 +16,8 @@ def test_tensor(): ...@@ -16,6 +16,8 @@ def test_tensor():
assert(T.op.output(0).__hash__() == T.__hash__()) assert(T.op.output(0).__hash__() == T.__hash__())
d = {T.op.output(0) : 1} d = {T.op.output(0) : 1}
assert(d[T] == 1) assert(d[T] == 1)
assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16')
assert(T[0][0][0].astype('float16').dtype == 'float16')
def test_conv1d(): def test_conv1d():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment