# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Implementation of a Long Short-Term Memory (LSTM) cell. Adapted from: https://gist.github.com/merrymercy/5eb24e3b019f84200645bd001e9caae9 """ from tvm import relay from . import layers from .init import create_workload def lstm_cell(num_hidden, batch_size=1, dtype="float32", name=""): """Long-Short Term Memory (LSTM) network cell. Parameters ---------- num_hidden : int Number of units in output symbol. batch_size : int Batch size (length of states). Returns ------- result : tvm.relay.Function A Relay function that evaluates an LSTM cell. The function takes in a tensor of input data, a tuple of two states, and weights and biases for dense operations on the inputs and on the state. It returns a tuple with two members, an output tensor and a tuple of two new states. """ builder = relay.ScopeBuilder() input_type = relay.TensorType((batch_size, num_hidden), dtype) weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype) bias_type = relay.TensorType((4*num_hidden,), dtype) dense_type = relay.TensorType((batch_size, 4*num_hidden), dtype) slice_type = relay.TupleType([input_type, input_type, input_type, input_type]) ret_type = relay.TupleType([input_type, relay.TupleType([input_type, input_type])]) inputs = relay.Var("inputs", input_type) states = relay.Var("states", relay.TupleType([input_type, input_type])) i2h_weight = relay.Var("i2h_weight", weight_type) i2h_bias = relay.Var("i2h_bias", bias_type) h2h_weight = relay.Var("h2h_weight", weight_type) h2h_bias = relay.Var("h2h_bias", bias_type) i2h = builder.let(("i2h", dense_type), layers.dense_add_bias( data=inputs, units=num_hidden * 4, weight=i2h_weight, bias=i2h_bias, name="%si2h" % name)) h2h = builder.let(("h2h", dense_type), layers.dense_add_bias( data=relay.TupleGetItem(states, 0), units=num_hidden * 4, weight=h2h_weight, bias=h2h_bias, name="%sh2h" % name)) gates = builder.let(("gates", dense_type), relay.add(i2h, h2h)) slice_gates = builder.let(("slice_gates", slice_type), relay.split(gates, indices_or_sections=4, axis=1).astuple()) in_gate = builder.let(("in_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 0))) forget_gate = builder.let(("forget_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 1))) in_transform = builder.let(("in_transform", input_type), relay.tanh(relay.TupleGetItem(slice_gates, 2))) out_gate = builder.let(("out_gate", input_type), relay.sigmoid(relay.TupleGetItem(slice_gates, 3))) next_c = builder.let(("next_c", input_type), relay.add(relay.multiply(forget_gate, relay.TupleGetItem(states, 1)), relay.multiply(in_gate, in_transform))) next_h = builder.let(("next_h", input_type), relay.multiply(out_gate, relay.tanh(next_c))) ret = builder.let(("ret", ret_type), relay.Tuple([next_h, relay.Tuple([next_h, next_c])])) builder.ret(ret) body = builder.get() return relay.Function([inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias], body, ret_type) def get_net(iterations, num_hidden, batch_size=1, dtype="float32"): '''Constructs an unrolled RNN with LSTM cells''' input_type = relay.TensorType((batch_size, num_hidden), dtype) weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype) bias_type = relay.TensorType((4*num_hidden,), dtype) state_type = relay.TupleType([input_type, input_type]) cell_type = relay.TupleType([input_type, state_type]) builder = relay.ScopeBuilder() zeros = builder.let(("zeros", input_type), relay.zeros((batch_size, num_hidden), dtype)) init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros])) states = init_states out = None for i in range(iterations): inputs = relay.Var("data", input_type) i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type) i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type) h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type) h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type) cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i) call = builder.let(("call_%s" % i, cell_type), relay.Call(cell_fn, [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias])) new_out = builder.let(("out_%s" % i, input_type), relay.TupleGetItem(call, 0)) new_states = builder.let(("states_%s" % i, state_type), relay.TupleGetItem(call, 1)) states = new_states out = new_out builder.ret(out) body = builder.get() args = relay.analysis.free_vars(body) return relay.Function(args, body, input_type) def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"): """Get benchmark workload for an LSTM RNN. Parameters ---------- iterations : int The number of iterations in the desired LSTM RNN. num_hidden : int The size of the hiddxen state batch_size : int, optional (default 1) The batch size used in the model dtype : str, optional (default "float32") The data type Returns ------- mod : tvm.relay.Module The relay module that contains a LSTM network. params : dict of str to NDArray The parameters. """ net = get_net(iterations, num_hidden, batch_size, dtype) return create_workload(net)