top of page

ELECTRA: Replaced Token Detection

ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately) is a novel pre-training method for language models. Its key innovation is the use of replaced token detection instead of the masked language modeling (MLM) used in BERT.


How it works


  1. Generator: A small generator model (usually a smaller version of BERT) is trained to predict masked tokens, similar to BERT's MLM task.

  2. Corruption: The generator is used to replace some tokens in the input sequence with plausible alternatives.

  3. Discriminator: The main ELECTRA model is trained as a discriminator to detect which tokens have been replaced and which are original.

  4. Training Objective: The model learns to classify each token as "original" or "replaced," which is a binary classification task for every token in the sequence.


ELECTRA: Detailed Training Process


The training process of ELECTRA involves jointly training two models: a generator (G) and a discriminator (D). This process is inspired by Generative Adversarial Networks (GANs) but with some key differences. Here's a detailed breakdown of the training process:


  1. Initialization

    1. Initialize both the generator (G) and discriminator (D) with random weights.

    2. G is typically smaller than D (e.g., 1/4 to 1/2 the size).

  2. Input Preparation

    1. Take a batch of input sequences.

    2. Randomly mask 15% of the tokens in each sequence (similar to BERT).

  3. Generator Training

    1. G is trained to predict the original tokens at the masked positions.

    2. Loss function for G: masked language modeling (MLM) loss, similar to BERT.

    3. G's weights are updated to minimize this loss.

  4. Token Replacement

    1. Using the trained G, replace the masked tokens with its predictions.

    2. This creates a "corrupted" version of the input sequence.

  5. Discriminator Training

    1. D takes both the original and corrupted sequences as input.

    2. For each token position, D outputs a probability indicating whether the token is "original" (0) or "replaced" (1).

    3. Loss function for D: binary cross-entropy loss summed over all token positions.

    4. D's weights are updated to minimize this loss.

  6. Joint Training: The key to ELECTRA's training is how G and D are trained together:

    1. Generator Objective:

      1. Primary: Minimize MLM loss (predict masked tokens accurately).

      2. Secondary: Generate tokens that fool the discriminator.

    2. Discriminator Objective:

      1. Accurately classify tokens as original or replaced.

  7. Adversarial Component:

    1. As G improves, it produces more realistic replacements.

    2. This challenges D to become better at detecting subtle replacements.

    3. As D improves, it forces G to generate even more plausible replacements.

  8. Weight Updates:

    1. Both models are updated in each training step.

    2. The gradients from D's loss are not propagated back to G.

  9. Training Loop

    1. For each batch:

    2. Mask input tokens

    3. G predicts masked tokens

    4. Create corrupted sequence using G's predictions

    5. D classifies each token in the corrupted sequence

    6. Calculate losses for both G and D

    7. Update weights for both models

    8. Repeat this process for many iterations over the entire dataset.

  10. Learning Rate and Optimization

    1. Usually uses different learning rates for G and D.

    2. Typically uses the Adam optimizer with weight decay.

    3. Learning rate warmup and decay schedules are often employed.

  11. Model Selection

    1. The final discriminator (D) is used for downstream tasks.

    2. The generator (G) is typically discarded after pre-training.



ELECTRA: Replaced Token Detection Example


Let's walk through the process of replaced token detection in ELECTRA using a simple sentence:


Original sentence: "The cat sat on the mat."


  1. Step 1: Tokenization

    1. First, we tokenize the sentence: ["The", "cat", "sat", "on", "the", "mat", "."]

  2. Step 2: Masking

    1. We randomly mask 15% of the tokens (in this case, let's mask one token): ["The", "cat", "[MASK]", "on", "the", "mat", "."]

  3. Step 3: Generator Prediction

    1. The generator model predicts a replacement for the masked token. Let's say it predicts "stood": ["The", "cat", "stood", "on", "the", "mat", "."]

  4. Step 4: Creating Corrupted Sequence

    1. We now have a corrupted sequence where "sat" has been replaced with "stood": ["The", "cat", "stood", "on", "the", "mat", "."]

  5. Step 5: Discriminator Task

    1. The discriminator is given both the original and corrupted sequences. Its job is to identify which tokens have been replaced.

    2. The discriminator outputs a probability for each token, indicating whether it thinks the token is original (0) or replaced (1):

  6. Step 6: Loss Calculation

    1. The discriminator's loss is calculated using binary cross-entropy between its predictions and the ground truth (whether each token was actually replaced or not).

  7. Step 7: Training: Based on this loss, both the generator and discriminator are updated:

    1. The generator is trained to produce more plausible replacements that can fool the discriminator.

    2. The discriminator is trained to become better at detecting replaced tokens.


Through many iterations of this process over a large corpus, ELECTRA learns to understand the nuances of language, enabling it to perform well on various downstream tasks.


Token

Original?

Discriminator Output

Ideal Output

"The"

Yes

0.02

0

"cat"

Yes

0.01

0

"stood"

No

0.98

1

"on"

Yes

0.03

0

"the"

Yes

0.01

0

"mat"

Yes

0.02

0

"."

Yes

0.01

0


Comparison Table: BERT vs ELECTRA

Feature

BERT

ELECTRA

Pre-training Task

Masked Language Modeling

Replaced Token Detection

Learning Signal

15% of tokens (masked)

All tokens

Architecture

Single Transformer

Generator + Discriminator

Training Efficiency

Lower

Higher

Sample Efficiency

Lower

Higher

Computational Cost

Higher

Lower for similar performance

Pre-train/Fine-tune Gap

Present (due to [MASK] token)

Minimal

Bidirectional Context

Yes

Yes

Performance on Small Models

Good

Often better than BERT

Adoption and Resources

Widely adopted

Growing adoption

Interpretability

Relatively straightforward

More complex due to adversarial nature

In the context of Masked Language Modeling (MLM), ELECTRA presents several advantages over BERT, primarily in terms of efficiency and the ability to learn from all input tokens. However, BERT remains widely used due to its simplicity and the abundance of pre-trained models and resources available. The choice between BERT and ELECTRA often depends on specific use cases, available computational resources, and the need for sample efficiency.


Pros & Cons of Token Replacements


Pros


  1. More sample efficient than BERT

  2. Learns from all input tokens, not just masked ones

  3. Faster training and often better performance, especially with smaller models

  4. Addresses the pre-training and fine-tuning discrepancy


Cons


  1. More complex architecture (generator + discriminator)

  2. Potentially harder to interpret due to the adversarial training process

  3. Less widely adopted compared to BERT (as of 2024)


Key Features


  1. Uses Replaced Token Detection instead of MLM

  2. Generator-Discriminator architecture

  3. Trains on detecting real vs. fake tokens

Comments

Rated 0 out of 5 stars.
No ratings yet

Add a rating
bottom of page