Abstract: In the 1990s, the constant error carousel and gating were introduced as the central ideas of the Long Short-Term Memory (LSTM). Since then, LSTMs have stood the test of time and contributed to numerous deep learning success stories, in particular they constituted the first Large Language Models (LLMs). However, the advent of the Transformer technology with parallelizable self-attention at its core marked the dawn of a new era, outpacing LSTMs at scale. We now raise a simple question: How far do we get in language modeling when scaling LSTMs to billions of parameters, leveraging the latest techniques from modern LLMs, but mitigating known limitations of LSTMs? Firstly, we introduce exponential gating with appropriate normalization and stabilization techniques. Secondly, we modify the LSTM memory structure, obtaining: (i) sLSTM with a scalar memory, a scalar update, and new memory mixing, (ii) mLSTM that is fully parallelizable with a matrix memory and a covariance update rule. Integrating these LSTM extensions into residual block backbones yields xLSTM blocks that are then residually stacked into xLSTM architectures. Exponential gating and modified memory structures boost xLSTM capabilities to perform favorably when compared to state-of-the-art Transformers and State Space Models, both in performance and scaling.https://arxiv.org/abs/2405.04517
(72) By applying the elementwise exponential input gate activation function naively, we obtain the unstabilized gate activation matrix D RT T as D = F exp(I) . (73) In order to avoid overflow due to the exponential function we apply the same stabilization as in the recurrent sLSTM, see Equation 15. In the parallel formulation of the mLSTM we get a numerically stable gate activation matrix D RT T by taking the logarithm of D element-wise and subtracting the row-wise maximum value of D from each element: (cid:16) (cid:101)D = log D = log F max (cid:101)D) D = exp( (cid:101)D exp(I) (cid:17) = log F + I (74) (75) RT d, for a full sequence we can compute all hidden Given the queries, keys and values Q, K, V pre-activation states (cid:101)H RT d in parallel for the un-stabilized version by (cid:101)H = C V , with C =
id: 3ce21b73257b1def5b86b862f1d52f58 - page: 26
(76) Note that we extract the 1 d this yields factor for K explicitly here and further on. For the stabilized version (cid:101)H = C V , with C = max (cid:16) | (cid:101)C (cid:80)T j=1 (cid:101)C ij| , exp( max (cid:101)D) (cid:17) , and (cid:101)C = QK d D , (77) where for both versions the hidden pre-activation states (cid:101)H are identical. RT d we can compute the hidden states H With the output gate pre-activations (cid:101)O timesteps by applying the output gate in parallel for each timestep element-wise: RT d for all H = ( (cid:101)O) (cid:101)H . (78) This gives the parallel forward pass of the mLSTM for a full input sequence X
id: e59322d39cd9ca22d04aa33dfe73bf97 - page: 26
RT d. Parallel mLSTM Backward Pass. We present the backward pass of the mLSTM for the stabilized version only. For completeness we summarize the forward pass in the stabilized version before we present the backward pass. Given the forget gate matrix F and the input gate matrix I RT T , RT T as introduced above, together with the queries, keys and values RT T , the logarithm of the forget gate matrix F = log F 26 Q, K, V RT d, we can write the forward pass of the mLSTM in the stabilized version as: (cid:101)D = F + I m = max j (cid:101)Dij , row-wise maximum (79) (80) D = exp( (cid:101)D QK m 1) (cid:101)C = D d T (cid:88) (cid:101)C ij = (cid:101)C 1 , b =
id: 1dbc30a5fe5f4dc64523465667d3a562 - page: 26
We denote the gradient with respect to variable a as a. RT d we can compute the backward pass for the intermediate Given the output gradient (cid:101)H gradients as: C = V (cid:101)H (cid:16) (cid:101)C (cid:16)(cid:16) (cid:17) (cid:0)n2 1(cid:1) (cid:17) (cid:17) 1 n = C (cid:101)C n2 1 (cid:26)1 0 C = if | otherwise b | > exp( n b = sign (n) (cid:101)C,C = (cid:0)n1 1(cid:1) (cid:101)C,b = 1 b , (cid:101)C = (cid:101)C,C + (cid:101)C,B QK C , m) column-wise broadcast column-wise broadcast (87) (88) (89) (90) (91) (92) (93) D = d (cid:101)C (94) D = D (cid:101)D = exp( (cid:101)D m)
id: d4131e7a47418bf95b8c50ea742124b9 - page: 27