DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函數(shù)的解讀
?
?
目錄
tf.contrib.rnn.BasicLSTMCell(rnn_unit)函數(shù)的解讀
函數(shù)功能解讀
函數(shù)代碼實現(xiàn)
?
?
tf.contrib.rnn.BasicLSTMCell(rnn_unit)函數(shù)的解讀
函數(shù)功能解讀
| ? """Basic LSTM recurrent network cell. ? The implementation is based on: http://arxiv.org/abs/1409.2329. ? We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. ? It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.? For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} ? """ ? def __init__(self, | 基本LSTM遞歸網(wǎng)絡(luò)單元。 實現(xiàn)基于:http://arxiv.org/abs/1409.2329。 我們在遺忘門的偏見中加入了遺忘偏見(默認值:1),以減少訓(xùn)練開始時的遺忘程度。 它不允許細胞剪切(一個投影層),也不使用窺孔連接:它是基本的基線。對于高級模型,請使用完整的@{tf.n .rnn_cell. lstmcell}遵循。 ? |
| ? ? Args: ? ? ? When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. | 參數(shù): 從經(jīng)過cudnnlstm訓(xùn)練的檢查點恢復(fù)時,必須使用“CudnnCompatibleLSTMCell”。 |
?
函數(shù)代碼實現(xiàn)
@tf_export("nn.rnn_cell.BasicLSTMCell") class BasicLSTMCell(LayerRNNCell):"""Basic LSTM recurrent network cell.The implementation is based on: http://arxiv.org/abs/1409.2329.We add forget_bias (default: 1) to the biases of the forget gate in order toreduce the scale of forgetting in the beginning of the training.It does not allow cell clipping, a projection layer, and does notuse peep-hole connections: it is the basic baseline.For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}that follows."""def __init__(self,num_units,forget_bias=1.0,state_is_tuple=True,activation=None,reuse=None,name=None,dtype=None):"""Initialize the basic LSTM cell.Args:num_units: int, The number of units in the LSTM cell.forget_bias: float, The bias added to forget gates (see above).Must set to `0.0` manually when restoring from CudnnLSTM-trainedcheckpoints.state_is_tuple: If True, accepted and returned states are 2-tuples ofthe `c_state` and `m_state`. If False, they are concatenatedalong the column axis. The latter behavior will soon be deprecated.activation: Activation function of the inner states. Default: `tanh`.reuse: (optional) Python boolean describing whether to reuse variablesin an existing scope. If not `True`, and the existing scope already hasthe given variables, an error is raised.name: String, the name of the layer. Layers with the same name willshare weights, but to avoid mistakes we require reuse=True in suchcases.dtype: Default dtype of the layer (default of `None` means use the typeof the first input). Required when `build` is called before `call`.When restoring from CudnnLSTM-trained checkpoints, must use`CudnnCompatibleLSTMCell` instead."""super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)if not state_is_tuple:logging.warn("%s: Using a concatenated state is slower and will soon be ""deprecated. Use state_is_tuple=True.", self)# Inputs must be 2-dimensional.self.input_spec = base_layer.InputSpec(ndim=2)self._num_units = num_unitsself._forget_bias = forget_biasself._state_is_tuple = state_is_tupleself._activation = activation or math_ops.tanh@propertydef state_size(self):return (LSTMStateTuple(self._num_units, self._num_units)if self._state_is_tuple else 2 * self._num_units)@propertydef output_size(self):return self._num_unitsdef build(self, inputs_shape):if inputs_shape[1].value is None:raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"% inputs_shape)input_depth = inputs_shape[1].valueh_depth = self._num_unitsself._kernel = self.add_variable(_WEIGHTS_VARIABLE_NAME,shape=[input_depth + h_depth, 4 * self._num_units])self._bias = self.add_variable(_BIAS_VARIABLE_NAME,shape=[4 * self._num_units],initializer=init_ops.zeros_initializer(dtype=self.dtype))self.built = Truedef call(self, inputs, state):"""Long short-term memory cell (LSTM).Args:inputs: `2-D` tensor with shape `[batch_size, input_size]`.state: An `LSTMStateTuple` of state tensors, each shaped`[batch_size, num_units]`, if `state_is_tuple` has been set to`True`. Otherwise, a `Tensor` shaped`[batch_size, 2 * num_units]`.Returns:A pair containing the new hidden state, and the new state (either a`LSTMStateTuple` or a concatenated state, depending on`state_is_tuple`)."""sigmoid = math_ops.sigmoidone = constant_op.constant(1, dtype=dtypes.int32)# Parameters of gates are concatenated into one multiply for efficiency.if self._state_is_tuple:c, h = stateelse:c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)gate_inputs = math_ops.matmul(array_ops.concat([inputs, h], 1), self._kernel)gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)# i = input_gate, j = new_input, f = forget_gate, o = output_gatei, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=one)forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)# Note that using `add` and `multiply` instead of `+` and `*` gives a# performance improvement. So using those at the cost of readability.add = math_ops.addmultiply = math_ops.multiplynew_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),multiply(sigmoid(i), self._activation(j)))new_h = multiply(self._activation(new_c), sigmoid(o))if self._state_is_tuple:new_state = LSTMStateTuple(new_c, new_h)else:new_state = array_ops.concat([new_c, new_h], 1)return new_h, new_state?
總結(jié)
以上是生活随笔為你收集整理的DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Paper:Xavier参数初始化之《U
- 下一篇: DL之DNN:基于Tensorflow框