Commit 77345051 by tqchen

Move Var back to Expr, add format str test

parent f5b8196d
......@@ -2,3 +2,5 @@
from __future__ import absolute_import as _abs
from .op import *
from .expr import Var, const
from .expr_util import *
......@@ -2,13 +2,16 @@
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from . import op as _op
from . import var_name as _name
class Expr(object):
"""Base class of expression."""
"""Base class of expression.
Expression object should be in general immutable.
"""
def children(self):
"""All expr must define this.
"""get children of this expression.
Returns
-------
......@@ -60,6 +63,21 @@ def _symbol(value):
raise TypeError("type %s not supported" % str(type(other)))
class Var(Expr):
"""Variable, is a symbolic placeholder.
Each variable is uniquely identified by its address
Note that name alone is not able to uniquely identify the var.
Parameters
----------
name : str
optional name to the var.
"""
def __init__(self, name=None):
self.name = name if name else _name.NameManager.current.get(name)
class ConstExpr(Expr):
"""Constant expression."""
def __init__(self, value):
......@@ -77,7 +95,6 @@ class BinaryOpExpr(Expr):
def children(self):
return (self.lhs, self.rhs)
_op.binary_op_cls = BinaryOpExpr
class UnaryOpExpr(Expr):
......@@ -88,3 +105,8 @@ class UnaryOpExpr(Expr):
def children(self):
return (self.src)
def const(value):
"""Return a constant value"""
return ConstExpr(value)
"""Utilities to manipulate expression"""
from __future__ import absolute_import as _abs
from . import expr as _expr
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
A shallow copy of e will happen if children differs from current children
Parameters
----------
e : Expr
The input expression
children : list of Expr
The new children
Returns
-------
new_e : Expr
Expression with the new children
"""
if children:
if isinstance(e, _expr.BinaryOpExpr):
return (e if children[0] == e.lhs and children[1] == e.rhs
else _expr.BinaryOpExpr(e.op, children[0], children[1]))
elif isinstance(e, _expr.UnaryOpExpr):
return e if children[0] == e.src else _expr.UnaryOpExpr(e.op, children[0])
else:
raise TypeError("donnot know how to handle Expr %s" % type(e))
else:
return e
def transform(e, f):
"""Apply f recursively to e and collect the resulr
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e, ret_children)
ret_children is the result of transform from children
Returns
-------
result : return value of f
The final result of transformation.
"""
return f(e , [transform(c, f) for c in e.children()])
def format_str(expr):
"""change expression to string.
Parameters
----------
expr : Expr
Input expression
Returns
-------
s : str
The string representation of expr
"""
def make_str(e, result_children):
if isinstance(e, _expr.BinaryOpExpr):
return e.op.format_str(result_children[0], result_children[1])
elif isinstance(e, _expr.UnaryOpExpr):
return e.op.format_str(result_children[0])
elif isinstance(e, _expr.ConstExpr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
def bind(expr, update_dict):
"""Replace the variable in e by specification from kwarg
Parameters
----------
expr : Expr
Input expression
update_dict : dict of Var->Expr
The variables to be replaced.
Examples
--------
eout = bind(e, update_dict={v1: (x+1)} )
"""
def replace(e, result_children):
if isinstance(e, _expr.Var) and e in update_dict:
return update_dict[e]
else:
return expr_with_new_children(e, result_children)
return transform(expr, replace)
......@@ -8,16 +8,20 @@ class BinaryOp(object):
return _binary_op_cls(self, lhs, rhs)
class AddOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs)
class SubOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
class MulOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs)
class DivOp(BinaryOp):
pass
def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs)
add = AddOp()
......
from __future__ import absolute_import as _abs
from .expr import Expr
class Var(Expr):
"""Variables"""
def __init__(self, name, expr=None):
self.name = name
self.expr = expr
def assign(self, expr):
self.expr = expr
def children(self):
if self.expr is None:
return ()
return self.expr.children()
def same_as(self, other):
return (self.name == other.name)
"""Name manager to make sure name is unique."""
from __future__ import absolute_import as _abs
class NameManager(object):
"""NameManager to do automatic naming.
User can also inherit this object to change naming behavior.
"""
current = None
def __init__(self):
self._counter = {}
self._old_manager = None
def get(self, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager
# initialize the default name manager
NameManager.current = NameManager()
import tvm
from tvm import expr
def test_const():
x = expr.ConstExpr(1)
x + 1
print x
def test_bind():
x = tvm.Var('x')
y = x + 1
z = tvm.bind(y, {x: tvm.const(10) + 9})
assert tvm.format_str(z) == '((10 + 9) + 1)'
test_const()
def test_basic():
a= tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
if __name__ == "__main__":
test_basic()
test_bind()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment