Training for Single Correct Prediction Makes Your Model More Fragile

The abilities of language models become even more striking when we consider how simplified is the training signal we provide them with. In this blog, we'll look beyond the default and show you the potential of training objectives that reflect deeper into the semantics of your task.

Michal Štefánik
7 min readJun 18, 2024
In reality, a correct answer can take different forms, yet our training objectives don't leave space for ambiguity. Picture from DALL-E.

Warmup: How do we train generative models?

Since the rise of the first generative language models in machine translation, we pre-train and fine-tune language models to predict a single, next token from some reference text, given the input text, and previous tokens:

next_token = Model(prev_tokens, prompt)

Casting this into the training objective is relatively straightforward: with any parametrized model M having #output classes = #tokens, we maximise model's probability assigned to the output class = correct next token. Usually, we do that by minimising a cross-entropy between the distribution y = M(prev_tokens, prompt) that the model M produces and a true distribution Y = (0, 0, …, 0, 1, 0, …, 0), which for all tokens from M's vocab assigns p = 1 only for the true following token, and p = 0 for all the others. We assume that this is the output distribution that we want our model to produce.

A nice feature of this formulation is that it can be applied in a myriad of settings: even if we don’t have reference texts (i.e. labels), we can simply use the input as a reference and train the model to predict each following token given just the previous texts from the input.

Note how efficiently we can do this: to construct targets, we segment the reference text into tokens and just assign "1" to the correct positions in the all-zeros matrix. If reference text is also the input text (the case of causal language modeling), we can first encode the model input and then, to create labels, we just shift the inputs one position to the right.

Fig: We commonly train language models to assign all probability to a single correct token from the reference.

So, are there any downsides?

You might have noticed that the described objective ignores the fact that not all bad predictions are equally bad. On the token level, for the most generative tasks, predicting synonyms is as good as the correct prediction and certainly better than a vast majority of all other tokens that get the same p = 0 as the synonyms. On the sequence level, an arbitrary paraphrase of the reference would also fare comparably well to the reference for most of the tasks (perhaps except for paraphrasing:)

Training models for a single correct prediction neglects the semantic dimension of the task we hope the model will learn. But is this a problem?

With enough training data, not reflecting on the full complexity of the training problem induces some irreducible portion of training loss — if the model fits the problem perfectly, it would correctly distribute output probability over ambiguous predictions but would not get the zero loss. However, thanks to large enough training batches, this will not perplex the training because, with enough averaging, the updates are clean enough. It will just cost us more compute and a lot of data to fit the model.

But our oversimplification shows its toll when training for tasks with limited data. Here, models tend to overfit along two dimensions:

  1. Models overfit the target distribution, assigning almost 100% probability to each next token of reference texts but close-to-zero probabilities to all other tokens. The resulting models are very fragile to unseen prompts, where the model draws the next tokens from close-to-zero probabilities, which are not well calibrated in training.
  2. Models overfit the distribution of previous tokens from the reference. Using reference for previous tokens (also called teacher-forcing) allows us to keep the prediction aligned (i.e. we are sure that the predicted token is always a reasonable continuation). From a machine learning perspective, this is problematic because, in the actual generation, previous tokens come from a different distribution: they are no longer drawn from the reference but from the model's predictions. This difference between training and actual use is also called exposure bias.

A real question is: Can we do any better?

Unfortunately, most training datasets today provide a single reference to a given prompt and don't annotate a "true" ambiguity for the model to learn. This would be super expensive to annotate. However, a good thing is that, specifically for language, we already have pretty good tools to represent and compare the semantics of the words: they are called language models!

In our ACL 2023 paper Soft Alignment Objectives for Robust Adaptation of Language Generation, we propose to train models to represent ambiguity by constructing training targets which respect the mutual similarity of semantic token embeddings. These can be obtained from a pre-trained language model; following the example of BERTScore, we use XLM-RoBERTa, which also works well for non-English texts.

Fig: In *Align objectives, we use semantic similarity of tokens to construct training target distribution

We infer the new target distribution as a cosine similarity between the embedding of the reference token and all other tokens from the model vocabulary. As a result, the original reference token still gets p = 1, but the synonyms will get p = 0.95 or 0.88, while the irrelevant tokens will be assigned minimal probabilities such as 0.002, 0.0012, etc. We call this objective TokenAlign, as the targets are constructed based on how well they "align" with the reference.

A relatively small adjustment of TokenAlign compared to standard, one-hot targets makes this new objective similarly compute-efficient, but TokenAlign still can not mitigate the exposure bias (point 2) we mentioned above. To mitigate the possible over-reliance on the training prefixes, in another objective called SeqAlign, we additionally propose to sample the prefixes from the model's own generations. Knowing the semantic similarity of all tokens from the vocabulary, we can again construct target distribution from the similarity of the vocabulary tokens to the reference.

Experiments: Can modeling ambiguity actually help?

We experiment with our Alignment objectives in low-resource domain adaptation in machine translation (MT), where we expect that a richer training signal could help to substitute data scarcity. Looking into machine translation also allows us to critically assess whether the quality of semantic embeddings does not degrade when used in other languages.

We start with a general-purpose MT model for each language pair and adapt it on data of each language+domain. OPUS collection provides a wide catalogue of available domains, and we try to pick a challengingly diverse, almost obscure set of domains available in all our target languages.

After the adaptation, we evaluate each adapted model for (1) improvements on the adapted domain (called in-domain; ID), and (2) decays on other domain (out-of-domain; OOD) — to uncover possible overfitting and robustness to a broad variance of possible user queries.

Table 1: Comparison of in-domain improvements, and out-of-domain decays for the Alignment objectives modeling ambiguity in their targets. Numbers denote a percental change of BLEU compared to the original, fine-tuned model.

Table 1 shows that modeling ambiguity creates more robust models without compromising the performance gains of training. While the surface-level BLEU shows some compromises of adaptation gains compared to standard training (MLE), a BERTScore metric also reflecting on semantic similarity shows that the quality of adaptation with Alignment objectives is comparable, or even better than, that of traditional training.

However, the most striking results come in the evaluations of OOD robustness: In terms of BLEU, Alignment objectives avoid 88–95% of performance loss of traditional training, and even more in terms of BERTScore. On the other side, this comparison shows that training language models without respecting the natural ambiguity of our task creates models that are much more fragile.

In our follow-up analyses, we also show that much of the robustness gains can be obtained simply by feeding the model with its own outputs (i.e. by avoiding the teacher forcing), but constructing semantically grounded targets still largely complements these improvements.

What's next?

A surprising finding was how much does simply feeding the model with its own outputs helps the robustness. This makes me wonder to what extent the gains of new Preference Optimisation methods (DPO, KTO, IPO) — avoiding the teacher forcing by design — are caused by that Preference Optimisation models are just less reliant on exposure bias. Our recent work comparing Preference Optimisation to traditional training provides related evidence by showing that Preference Optimisation creates models that are less reliant on specific prompt styling.

Despite nice performance results, Soft Alignment objectives leave plenty of space for further technical improvements. For instance, while getting rid of teacher forcing, SeqAlign computes targets against the best-matching token from the reference, which likely induces noise as the sequences grow longer. This could be improved by more sophisticated alignment schemes, involving Optimal transport, or other matching algorithms. It also remains an open question whether modeling semantic ambiguity could complement traditional training also in much larger data settings, for instance, in pre-training.

Link to the original paper:

Soft Alignment Objectives for Robust Adaptation of Language Generation

--

--