# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name """ Utilities for building Relay loops. """ from .scope_builder import ScopeBuilder from . import expr as _expr def while_loop(cond, loop_vars, loop_bodies): """ Construct a while loop. Parameters ---------- cond: Callable[Tuple[relay.Expr], relay.Expr] The condition of the loop. loop_vars: Tuple[relay.Expr] The variables being looped over. The initial values of the loop, will be used to construct the loop variables. loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]] The body of the loop, should be a function which given loop variables produces the output result also as a tuple Returns ------- loop: relay.Expr The loop expression. """ sb = ScopeBuilder() loop = _expr.Var("while_loop") fresh_vars = [] for i, loop_var in enumerate(loop_vars): name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i) new_var = _expr.var(name, type_annotation=sb.type_of(loop_var)) fresh_vars.append(new_var) with sb.if_scope(cond(*fresh_vars)): sb.ret(loop(*loop_bodies(*fresh_vars))) with sb.else_scope(): sb.ret(_expr.Tuple(fresh_vars)) func = _expr.Function(fresh_vars, sb.get()) let = _expr.Let(loop, func, loop) return let