Commit 7cc92ace by ziheng Committed by Tianqi Chen

[LANG] Expose tvm.cast (#195)

* [LANG] Expose tvm.cast

* Update

* Add unittest
parent 7c6a71ba
......@@ -467,6 +467,23 @@ def reduce_axis(dom, name="rv"):
"""
return _IterVar(dom, name, 2)
def cast(dtype, expr):
"""Cast an expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
expr : Expr
The expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, expr)
def select(cond, t, f):
"""Construct a select branch
Parameters
......
......@@ -89,6 +89,20 @@ class ExprOp(object):
"""
return _make.EQ(self, other)
def astype(self, dtype):
"""Cast the expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, self)
class Expr(NodeBase, ExprOp):
"""Base class of all tvm Expressions"""
......
......@@ -139,6 +139,7 @@ REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
......
......@@ -16,6 +16,8 @@ def test_tensor():
assert(T.op.output(0).__hash__() == T.__hash__())
d = {T.op.output(0) : 1}
assert(d[T] == 1)
assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16')
assert(T[0][0][0].astype('float16').dtype == 'float16')
def test_conv1d():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment