Commit befd8c1e by ziheng Committed by Tianqi Chen

[LANG] Comparison operators support for Imm expressions (#3283)

parent 072f8cc7
......@@ -349,6 +349,16 @@ class StringImm(ConstExpr):
self.__init_handle_by_constructor__(
_make.StringImm, value)
def __eq__(self, other):
if isinstance(other, ConstExpr):
return self.value == other.value
return self.value == other
def __ne__(self, other):
if isinstance(other, ConstExpr):
return self.value != other.value
return self.value != other
@register_node
class Cast(Expr):
......
......@@ -22,7 +22,6 @@ import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
......@@ -301,8 +300,6 @@ def optimize(func, params=None):
"FoldConstant",
"CanonicalizeOps"]
cfg = _transform.build_config(required_pass=opt_passes)
if params:
name_dict = {}
for arg in func.params:
......@@ -321,25 +318,25 @@ def optimize(func, params=None):
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in cfg.required_pass:
if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in cfg.required_pass:
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in cfg.required_pass:
if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in cfg.required_pass:
if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in cfg.required_pass:
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
return func
......
......@@ -108,8 +108,8 @@ bool BroadcastRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl;
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
......@@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
<< ",Out:" << types[2] << std::endl;
// DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
// << ",Out:" << types[2] << std::endl;
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
......
......@@ -163,6 +163,14 @@ def test_equality():
d = (c != c)
assert not d
def test_equality_string_imm():
x = 'a'
y = tvm.make.StringImm(x)
x == y.value
x == y
if __name__ == "__main__":
test_cast()
test_attr()
......@@ -178,3 +186,4 @@ if __name__ == "__main__":
test_all()
test_bitwise()
test_equality()
test_equality_string_imm()
......@@ -65,9 +65,16 @@ def test_map_save_load_json():
assert(dd == {"a": 2, "b": 3})
def test_in_container():
arr = tvm.convert(['a', 'b', 'c'])
assert 'a' in arr
assert tvm.make.StringImm('a') in arr
assert 'd' not in arr
if __name__ == "__main__":
test_str_map()
test_array()
test_map()
test_array_save_load_json()
test_map_save_load_json()
test_in_container()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment