Commit befd8c1e by ziheng Committed by Tianqi Chen

[LANG] Comparison operators support for Imm expressions (#3283)

parent 072f8cc7
...@@ -349,6 +349,16 @@ class StringImm(ConstExpr): ...@@ -349,6 +349,16 @@ class StringImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.StringImm, value) _make.StringImm, value)
def __eq__(self, other):
if isinstance(other, ConstExpr):
return self.value == other.value
return self.value == other
def __ne__(self, other):
if isinstance(other, ConstExpr):
return self.value != other.value
return self.value != other
@register_node @register_node
class Cast(Expr): class Cast(Expr):
......
...@@ -22,7 +22,6 @@ import numpy as np ...@@ -22,7 +22,6 @@ import numpy as np
from . import _quantize from . import _quantize
from .. import expr as _expr from .. import expr as _expr
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op from .. import op as _op
from ... import make as _make from ... import make as _make
from ..base import NodeBase, register_relay_node from ..base import NodeBase, register_relay_node
...@@ -301,8 +300,6 @@ def optimize(func, params=None): ...@@ -301,8 +300,6 @@ def optimize(func, params=None):
"FoldConstant", "FoldConstant",
"CanonicalizeOps"] "CanonicalizeOps"]
cfg = _transform.build_config(required_pass=opt_passes)
if params: if params:
name_dict = {} name_dict = {}
for arg in func.params: for arg in func.params:
...@@ -321,25 +318,25 @@ def optimize(func, params=None): ...@@ -321,25 +318,25 @@ def optimize(func, params=None):
bind_dict[arg] = _expr.const(v) bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict) func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.required_pass: if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func) func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func) func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.required_pass: if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func) func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.required_pass: if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func) func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func) func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func) func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func) func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func) func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.required_pass: if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func) func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func) func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.required_pass: if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func) func = _ir_pass.fold_constant(func)
return func return func
......
...@@ -108,8 +108,8 @@ bool BroadcastRel(const Array<Type>& types, ...@@ -108,8 +108,8 @@ bool BroadcastRel(const Array<Type>& types,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl; // << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) { if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) { if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype); CHECK_EQ(t0->dtype, t1->dtype);
...@@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array<Type>& types, ...@@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl; // << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) { if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) { if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype); CHECK_EQ(t0->dtype, t1->dtype);
......
...@@ -163,6 +163,14 @@ def test_equality(): ...@@ -163,6 +163,14 @@ def test_equality():
d = (c != c) d = (c != c)
assert not d assert not d
def test_equality_string_imm():
x = 'a'
y = tvm.make.StringImm(x)
x == y.value
x == y
if __name__ == "__main__": if __name__ == "__main__":
test_cast() test_cast()
test_attr() test_attr()
...@@ -178,3 +186,4 @@ if __name__ == "__main__": ...@@ -178,3 +186,4 @@ if __name__ == "__main__":
test_all() test_all()
test_bitwise() test_bitwise()
test_equality() test_equality()
test_equality_string_imm()
...@@ -65,9 +65,16 @@ def test_map_save_load_json(): ...@@ -65,9 +65,16 @@ def test_map_save_load_json():
assert(dd == {"a": 2, "b": 3}) assert(dd == {"a": 2, "b": 3})
def test_in_container():
arr = tvm.convert(['a', 'b', 'c'])
assert 'a' in arr
assert tvm.make.StringImm('a') in arr
assert 'd' not in arr
if __name__ == "__main__": if __name__ == "__main__":
test_str_map() test_str_map()
test_array() test_array()
test_map() test_map()
test_array_save_load_json() test_array_save_load_json()
test_map_save_load_json() test_map_save_load_json()
test_in_container()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment