# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """The scope builder interface.""" from __future__ import absolute_import from . import expr as _expr from .._ffi import base as _base class WithScope(object): """A wrapper for builder methods which introduce scoping. Parameters ---------- enter_value: object The value returned by enter. """ def __init__(self, enter_value, exit_cb): self._enter_value = enter_value self._exit_cb = exit_cb def __enter__(self): return self._enter_value def __exit__(self, ptype, value, trace): if value: raise value else: self._exit_cb() def _make_lets(bindings, ret_value): """Make a nested let expressions. Parameters ---------- bindings: List[Tuple[tvm.relay.Var,tvm.relay.Expr]] The sequence of let bindings ret_value: tvm.relay.Expr The final value of the expression. Returns ------- lets: tvm.relay.Expr A nested let expression. """ if ret_value is None: raise RuntimeError("ret is not called in this scope") if isinstance(ret_value, _expr.If) and ret_value.false_branch is None: raise RuntimeError("Creating an If expression without else.") let_expr = ret_value for var, value in reversed(bindings): let_expr = _expr.Let(var, value, let_expr) return let_expr class ScopeBuilder(object): """Scope builder class. Enables users to build up a nested scope(let, if) expression easily. Examples -------- .. code-block: python sb = relay.ScopeBuilder() cond = relay.var("cond", 'bool') x = relay.var("x") y = relay.var("y") with sb.if_scope(cond): one = relay.const(1, "float32") t1 = sb.let(t1, relay.add(x, one)) sb.ret(t1) with sb.else_scope(): sb.ret(y) print(sb.get().astext()) """ def __init__(self): self._bindings = [[]] self._ret_values = [None] def _enter_scope(self): self._bindings.append([]) self._ret_values.append(None) def _exit_scope(self): bindings = self._bindings.pop() ret_value = self._ret_values.pop() return bindings, ret_value def let(self, var, value): """Create a new let binding. Parameters ---------- var: Union[Tuple[str, relay.Type], tvm.relay.Var] The variable or name of variable. value: tvm.relay.Expr The value to be binded """ if isinstance(var, (tuple, list)): if len(var) > 2: raise ValueError("Expect var to be Tuple[str, relay.Type]") var = _expr.var(*var) elif isinstance(var, _base.string_types): var = _expr.var(var) self._bindings[-1].append((var, value)) return var def if_scope(self, cond): """Create a new if scope. Parameters ---------- cond: tvm.relay.expr.Expr The condition Returns ------- scope: WithScope The if scope. Note ---- The user must follows with an else scope. """ self._enter_scope() def _on_exit(): bindings, ret_value = self._exit_scope() if self._ret_values[-1] is not None: raise RuntimeError("result already returned before if scope") true_branch = _make_lets(bindings, ret_value) self._ret_values[-1] = _expr.If(cond, true_branch, None) return WithScope(None, _on_exit) def else_scope(self): """Create a new else scope. Returns ------- scope: WithScope The if scope. """ self._enter_scope() def _on_exit(): bindings, ret_value = self._exit_scope() partial_if = self._ret_values[-1] no_else = (not isinstance(partial_if, _expr.If) or partial_if.false_branch is not None) if no_else: raise RuntimeError("else scope must follows") false_branch = _make_lets(bindings, ret_value) self._ret_values[-1] = _expr.If( partial_if.cond, partial_if.true_branch, false_branch) return WithScope(None, _on_exit) def type_of(self, expr): """ Compute the type of an expression. Parameters ---------- expr: relay.Expr The expression to compute the type of. """ if isinstance(expr, _expr.Var): return expr.type_annotation ity = _ty.IncompleteType() var = _expr.var("unify", ity) self.let(var, expr) return ity def ret(self, value): """Set the return value of this scope. Parameters ---------- value: tvm.relay.expr.Expr The return value. """ if self._ret_values[-1] is not None: raise RuntimeError("ret value is already set in this scope.") self._ret_values[-1] = value def get(self): """Get the generated result. Returns ------- value: tvm.relay.expr.Expr The final result of the expression. """ if len(self._bindings) != 1: raise RuntimeError("can only call get at the outmost scope") return _make_lets(self._bindings[-1], self._ret_values[-1])