# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm.relay import var, const, create_executor
from tvm.relay.op import debug


_test_debug_hit = False

def test_debug():
    global _test_debug_hit
    ex = create_executor()
    x = var('x', shape=(), dtype='int32')
    _test_debug_hit = False
    def did_exec(x):
        global _test_debug_hit
        _test_debug_hit = True
    prog = debug(x, debug_func=did_exec)
    result = ex.evaluate(prog, { x: const(1, 'int32') })
    assert _test_debug_hit
    assert result.asnumpy() == 1


def test_debug_with_expr():
    global _test_debug_hit
    _test_debug_hit = False
    ex = create_executor()
    x = var('x', shape=(), dtype='int32')
    _test_debug_hit = False
    def did_exec(x):
        global _test_debug_hit
        _test_debug_hit = True
    prog = debug(x + x * x, debug_func=did_exec)
    result = ex.evaluate(prog, { x: const(2, 'int32') })
    assert _test_debug_hit
    assert result.asnumpy() == 6