We analyse the inconsistency between the description of the BART model in the literature and the actual BART implementation.
In section 2.1 of the BART paper, the authors stated that BART uses the standard Transformer architecture except for the activation function and initialization. We examine other notable differences between BART and the standard Transformer.
Layer Norm is calculated by this formula:
In which
In section 7 of the Transformer paper, it is said that the Transformer architecture is implemented in the tensorflow/tensor2tensor library. In the library, LayerNormalization
does not contain learnable parameters.
Similar to Layer Norm, BART also has has extra bias parameters for
In Section 3.5 of the Transformer paper, it is said that the positional encoding is fixed and is calculated by sine and cosine functions:
In BART, however, positional embedding is a learned parameter. The authors of BART seems to be aware of this, since they wrote in Section 3.4 of BART that they were updating the BART positional embeddings in the first training stage of machine translation.
In BART, the positional encoding has an offset of 2, which means that the 0-th token uses the second positional encoding, the first token uses the third positional encoding, an so on. The first two positions of the positional encoding is never used.
BART uses tied word embeddings on top of the output of the final layer of the decoder. In regular Transformer architecture, the layer is a linear layer used for classification. In BART, however, it is the transpose of the word embedding.
BART has extra dropout after activation, while Transformer do not have this.
TODO: Confirm that Transformer does not use this.
Related: https://stackoverflow.com/q/64904840
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
inputs = tokenizer(['go go go'], return_tensors='np')
inputs.input_ids.tolist()
TODO: Does Chinese BART have this issue?
Related: https://discuss.huggingface.co/t/bpe-tokenizers-and-spaces-before-words/475
TODO: Add more information of my customised class, BartTokenizerWithoutOverflowEOS
.
This section records the problems we encountered during my implementation of the BART model and the final solutions.
This issue is reported in huggingface/transformers#15559. As a consequence, we only focus on implementing bart-base in this project, and not bart-large.
import torch
x = torch.tensor([[-1., 1.]])
print(x.std(-1).numpy()) # [1.4142135]
print(x.numpy().std(-1)) # [1.]
It is because in np.std
the denominator is n, while in torch.std
it is n-1. See pytorch/pytorch#1854 for details.
However, for the standard deviation in Layer Norm, the denominator is always n in either PyTorch or NumPy.
JAX uses bfloat16 for matrix multiplication on TPU by default, even if the data type is float32. See google/jax#9973 for details.
import jax.numpy as np
print(4176 * 5996) # 25039296
a = np.array(0.4176, dtype=np.float32)
b = np.array(0.5996, dtype=np.float32)
print((a * b).item()) # 0.25039297342300415
For neural network training, however, reducing the accuracy is worthwhile because it can significantly reduce the training time, according to Tom's comments in the above issue.
Weight matrix of linear layer is transposed in PyTorch, but not in Flax. Therefore, to convert model parameters between PyTorch and Flax, it is always needed to transpose the weight matrices.
In Flax:
import flax.linen as nn
import jax.numpy as np
import jax.random as rand
linear = nn.Dense(5)
key = rand.PRNGKey(42)
params = linear.init(key, np.zeros((3,)))
print(params['params']['kernel'].shape) # (3, 5)
In PyTorch:
import torch.nn as nn
linear = nn.Linear(3, 5)
print(linear.weight.shape) # (5, 3), not (3, 5)
This can cause sneaky bugs for bart-base, in which the Q, K, V matrices are square matrices. If the matrices are not transposed, there will be no shape error, but the result will be totally incorrect.
BartTokenizerWithoutOverflowEOS
:
Changed two behaviours :
add_prefix_space=True
- without overflow EOS: huggingface/transformers#19742
rand.PRNGKey > random.wrapper.seed2key
rand.split > random.wrapper.split_key
rand.KeyArray > lib.random.wrapper.KeyArray
...; del key
Git
Create a new branch such as tuning/learning_rate
, then commit
- Change learning rate to xxx
- Change learning rate to xxx
- Change learning rate to xxx
- Change learning rate to xxx
After finding the best learning rate, squash and merge to the main branch.
Also keep the old branch.
If commit on main branch, and the behaviour is not changed (including the random seed), add [chore]
to the title.