Paired residue prediction dependencies in ESM2 –

Published

March 10, 2025

Summary

During a quick analysis of the ESM2 model for masked token prediction, we noticed that amino acid probability distributions of residues affect each other in a pattern that mirrors a protein’s 3D contact map. But less so for the larger model sizes. Our question to you is, why?

We used ChatGPT to help write code, clean up code, comment our code, suggest wording ideas and then chose which small phrases or sentence structure ideas to use, write text that we edited, expand on summary text that we provided and then edited the text it produced, help clarify and streamline text that we wrote, and for general search and sanity checks. We also used Claude to help write code, clean up code, comment our code, suggest wording ideas and then chose which small phrases or sentence structure ideas to use, write text that we edited, expand on summary text that we provided and then edited the text it produced, and help clarify and streamline text that we wrote. Additionally, we used Grammarly Business to suggest wording ideas and then chose which small phrases or sentence structure ideas to use.

Introduction

This notebook details a quick analysis we performed on the ESM2 (Evolutionary Scale Modeling) models (Lin et al., 2023) for masked token prediction. We stumbled upon a counterintuitive result related to the effect that masking one residue has on the distribution of another.

Before jumping into our results, let’s first establish what masked token prediction is, how it works, and our consequent motivation for this analysis.

Brief primer

At its core, masked token prediction is a fundamental training technique in natural language processing. The basic idea is simple: take a sequence of text, hide some of its tokens (by replacing them with a special “mask” token), and ask the model to predict what should go in those masked positions. This fill-in-the-blanks approach, formally known as masked language modeling (MLM), is remarkably effective at enabling models to learn underlying patterns and dependencies in sequential data1. Nearly every state-of-the-art model is trained with masked token prediction in the protein language model space.

1 The BERT paper is the key seminal paper that popularized MLM (Devlin et al., 2019).

Once a model is trained on this task, you can mutate a sequence of interest by masking residues and having the model predict what amino acid the mask should be replaced with. This is called unmasking, and the procedure goes like this. First, an incomplete sequence is passed through the model. It’s incomplete because, at one or more residues, a masked token is placed instead of an amino acid. The collection of masked positions is denoted as \(M\), and the partially masked sequence is denoted as \(S_{-M}\). Here’s an example:

\[ S_{-M} = \text{A R <mask> A D I <mask>} \]

Here, \(M = \lbrace 2, 6 \rbrace\), which is \(0\)-indexed. When this sequence is passed through a model with an attached language modeling head (e.g., an ESM2 model), the output is an array of unnormalized probabilities for each amino acid at the masked positions. We call these values logits, which can be readily converted into amino acid probabilities for each masked residue2. The amino acid probability distribution of the \(i^\text{th}\) residue is denoted as \(p(i|M)\), where \(i \in M\).

2 What are logits? Inside a protein language model, each layer produces neuronal activations as it processes the sequence. These intermediate activations represent increasingly abstract features of the amino acid patterns. The final layer produces specific activations dubbed logits—one logit for each possible amino acid that could appear at a masked position. A high logit value for an amino acid corresponds to a high probability of that amino acid. This blog has more details.

Motivation

With this understanding of masked token prediction, we can explore an important trade-off in making these predictions. When performing masked token prediction, there’s a balance between efficiency and information utilization. On one extreme, you could take \(S_{-M}\), pass it through the model once, and then unmask each masked token. This would be very computationally efficient. However, this approach precludes the model from leveraging the information gained from each newly unmasked position to inform subsequent predictions. At the other end of the spectrum, you could unmask tokens one at a time: predict the first position, update the sequence with that prediction, pass the updated sequence through the model again to predict the second position, and so on. This iterative approach allows for maximum information utilization but can be computationally expensive. There’s also a concern that by unmasking individually, you only make conservative choices that prevent a deeper exploration of sequence space.

All this is to say that developing effective unmasking strategies is an important concern for maximizing the utility of masked token prediction models. And to do that, it’s worth gaining some insight into how masking at one position affects the outcome at another masked position. Our analysis approaches this question by masking two locations individually and jointly, and quantifying any changes in their predicted amino acid probabilities. Specifically, we’ll consider a pair of positions \(i\) and \(j\) and ask:

  1. How does \(p(i | \lbrace i \rbrace)\) compare with \(p(i | \lbrace i, j \rbrace)\)?
  2. How does \(p(j | \lbrace j \rbrace)\) compare with \(p(j | \lbrace i, j \rbrace)\)?

By doing this repeatedly for each pair of residues in the sequence, will we start to capture interdependencies between the sequence positions? Do these statistical dependencies relate to things we expect, like the model’s attention map or the amino acids’ contact map in 3D space? Are there any interpretable patterns at all? We’ll assess the similarity of these distributions with Jensen–Shannon divergence to find out.

Loading the model

Our analysis will focus on the ESM2 model series (Lin et al., 2023).

A quick survey of the available ESM2 model sizes is a good place to start. The ESM2 architecture was created and trained for six different model sizes, which we describe here:

from analysis.utils import ModelName
from rich.console import Console

console = Console()

for model_name in ModelName:
    console.print(model_name, style="strong")
ModelName.ESM2_8M
ModelName.ESM2_35M
ModelName.ESM2_150M
ModelName.ESM2_650M
ModelName.ESM2_3B
ModelName.ESM2_15B

All of these models are hosted on HuggingFace and can be loaded using HuggingFace’s transformers Python package.

The larger models need GPUs to run. We’ll do that later. For now, let’s load up the smallest model for prototyping the analysis.

Note

You can still follow along on your own computer without GPU access.

from transformers import AutoTokenizer, EsmForMaskedLM

small_model = EsmForMaskedLM.from_pretrained(ModelName.ESM2_8M.value)
print(small_model)
EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          )
          (intermediate): EsmIntermediate(
            (dense): Linear(in_features=320, out_features=1280, bias=True)
          )
          (output): EsmOutput(
            (dense): Linear(in_features=1280, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (contact_head): EsmContactPredictionHead(
      (regression): Linear(in_features=120, out_features=1, bias=True)
      (activation): Sigmoid()
    )
  )
  (lm_head): EsmLMHead(
    (dense): Linear(in_features=320, out_features=320, bias=True)
    (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    (decoder): Linear(in_features=320, out_features=33, bias=False)
  )
)

In addition to the base ESM architecture, this model has an attached language model decoder head, which you can see at the bottom of the printed model architecture. This part of the model converts the latent embedding produced by the model into the logits array, which we will then transform into amino acid probability distributions.

Getting the probability array \(P\)

Table 1. All 33 tokens in the model’s vocabulary and their indices in the logits array.

{#2794373e}
Token
Index
0 <cls>
1 <pad>
2 <eos>
3 <unk>
4 L
5 A
6 G
7 V
8 S
9 E
10 R
11 T
12 I
13 D
14 P
15 K
16 Q
17 N
18 F
19 Y
20 M
21 H
22 W
23 C
24 X
25 B
26 U
27 Z
28 O
29 .
30 -
31 <null_1>
32 <mask>

A forward pass of small_model yields a logits array, which we need to transform into our desired array of amino acid probability distributions (\(P\)). The transformation requires two key steps. Since we only want amino acid probabilities, we first subset the logits to include just the 20 amino acids from the model’s 33-token vocabulary (Table 1). And second, we need to apply a softmax to convert these selected logits into probabilities. Here’s the code to perform both steps:

import pandas as pd
import torch
from analysis.utils import amino_acids

tokenizer = AutoTokenizer.from_pretrained(
    ModelName.ESM2_8M.value,
    clean_up_tokenization_spaces=True,
)
aa_token_indices = tokenizer.convert_tokens_to_ids(amino_acids)


def get_probs_array(logits):
    """Given a logits array, return an amino acid probability distribution array.

    Args:
        logits:
            The output from the language modelling head of an ESM2 model. Shape is (batch_size, seq_len, vocab_size).

    Returns:
        A probability array. Shape is (batch_size, seq_len, 20).
    """
    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)

    # Only include tokens that correspond to amino acids. This excludes special, non-amino acid tokens.
    logits_subset = logits[..., aa_token_indices]

    return torch.nn.functional.softmax(logits_subset, dim=-1).detach().numpy()

As a sanity check, let’s create a sequence of repeating leucines. Then, we’ll mask one of the residues and see what the model predicts at that position. Although there’s no theoretically “correct” answer for what the model should predict, given that the entire sequence is homogenously leucine, it would be bizarre if the model predicted anything but.

import numpy as np
import pandas as pd

# Create a sequence of all leucine, then tokenize it.
sequence = "L" * 200
tokenized = tokenizer([sequence], return_tensors="pt")

# Place a mask token in the middle of the tokenized sequence.
mask_pos = 100
tokenized["input_ids"][0, mask_pos] = tokenizer.mask_token_id

# Pass the output of the model through our `get_probs_array` function.
logits_array = small_model(**tokenized).logits  # Shape (batch_size=1, seq_len=202, vocab_size=33)
probs_array = get_probs_array(logits_array)  # Shape (batch_size=1, seq_len=202, vocab_size=20)

# We are only interested in the position we masked, so let's subset to that.
# Note that the position is indexed with `mask_pos + 1` because of a preceding "CLS"
# token that prefixes the tokenized sequence.
probs_array_at_mask = probs_array[0, mask_pos + 1, :]

# Like a probability distribution should, the elements sum to one (or very nearly).
assert np.isclose(probs_array_at_mask.sum().item(), 1.0)

probs = dict(zip(amino_acids, probs_array_at_mask, strict=False))
probs = dict(
    sorted(((aa, prob.item()) for aa, prob in probs.items()), key=lambda x: x[1], reverse=True)
)
pd.Series(probs).to_frame(name="Probability")
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.

Table 2. The probability masses predicted by the model for the masked residue.

{#d14da7a1-d143-4eb6-b85d-7df9307739ef}
Probability
L 0.990615
P 0.001851
I 0.001022
S 0.000988
T 0.000935
R 0.000911
A 0.000728
V 0.000715
F 0.000417
Q 0.000303
H 0.000272
Y 0.000250
D 0.000236
G 0.000217
M 0.000147
K 0.000143
E 0.000115
N 0.000060
W 0.000057
C 0.000018

Encouragingly, the elements sum to one and the model confidently (> 99%) predicts leucine (Table 2).

Sequence of interest

Now, let’s introduce a real protein to work with. Any protein will do, really, but we’ve been exploring the use of adenosine deaminase (human) (ADA) as a template for protein design, so we’ve decided to use that.

Note

While we haven’t performed this analysis on any other protein, we have no reason to believe our conclusions are contingent on our protein of choice.

from pathlib import Path

from biotite.sequence.io.fasta import FastaFile

sequence_path = Path("input/P00813.fasta")
sequence = FastaFile().read(sequence_path)["P00813"]
console.print(f"{sequence=}")
console.print(f"{len(sequence)=}")
sequence='MAQTPAFDKPKVELHVHLDGSIKPETILYYGRRRGIALPANTAEGLLNVIGMDKPLTLPDFLAKFDYYMPAIAGCREAIKRIAYEFVEMKAKEGVVYVEVRYSPH
LLANSKVEPIPWNQAEGDLTPDEVVALVGQGLQEGERDFGVKARSILCCMRHQPNWSPKVVELCKKYQQQTVVAIDLAGDETIPGSSLLPGHVQAYQEAVKSGIHRTVHAGEVGS
AEVVKEAVDILKTERLGHGYHTLEDQALYNRLRQENMHFEICPWSSYLTGAWKPDTEHAVIRLKNDQANYSLNTDDPLIFKSTLDTDYQMTKRDMGFTEEEFKRLNINAAKSSFL
PEDEKRELLDLLYKAYGMPPSASAGQNL'
len(sequence)=363

Creating the mask libraries

We now know how to turn the model output, a logits array, into an array of amino acid probability distributions. This marks our launch point for comparing amino acid probability distributions for single- versus double-residue masking. We’ll do this exhaustively for our sequence of interest by establishing two different variant libraries: the singly masked variants, of which there are \(N\), and the doubly masked variants, of which there are \(N \choose 2\).

from itertools import combinations

N = len(sequence)
single_masks = list(combinations(range(N), 1))
double_masks = list(combinations(range(N), 2))

console.print(f"{single_masks[:5]=}")
console.print(f"{double_masks[:5]=}")

assert len(single_masks) == N
assert len(double_masks) == N * (N - 1) / 2
single_masks[:5]=[(0,), (1,), (2,), (3,), (4,)]
double_masks[:5]=[(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)]

Calculating all the logits

Now comes the compute-heavy portion of the analysis. We’re using Modal’s cloud infrastructure to distribute our computations across GPUs of varying power (from T4s to H100s). For each model in our analysis, we predict logit arrays twice: once for the single-mask library and again for the double mask library. The computation becomes increasingly demanding with larger models—ESM2-8M costs just pennies to run, while ESM2-15B costs around $30 due to its need for more powerful GPUs and smaller batch sizes. We’ve written all this functionality in the calculate_or_load_logits function imported below.

Note

If you’re following along on your own computer, don’t worry—the logits are tracked in the GitHub repository, so the function call below will simply load pre-computed logits from disk without requiring Modal or GPU access.

from analysis.logits import LogitsConfig, calculate_or_load_logits

config = LogitsConfig(
    sequence=sequence,
    single_masks=single_masks,
    double_masks=double_masks,
    single_logits_path=Path("output/logits_single.npz"),
    double_logits_path=Path("output/logits_double.npz"),
)

single_logits, double_logits = calculate_or_load_logits(config)

Conversion to amino acid probabilities

We’ve already written a method to convert logits to amino acid probabilities, so let’s run that conversion for all the data.

single_probs: dict[ModelName, np.ndarray] = {}
double_probs: dict[ModelName, np.ndarray] = {}

for model in ModelName:
    single_probs[model] = get_probs_array(single_logits[model])
    double_probs[model] = get_probs_array(double_logits[model])

Calculating probability metrics for each residue pair

With that, we finally have amino acid probability distributions for (a) each individually masked residue and (b) every pairwise combination of doubly masked residues.

As a reminder, we’re interested in comparing \(p(i | \lbrace i \rbrace)\) to \(p( i | \lbrace i, j \rbrace)\) and \(p(j | \lbrace j \rbrace)\) to \(p( j | \lbrace i, j \rbrace)\). To compare the degree of similarity between probability distributions, we’re going to use Jenson-Shannon (JS) divergence3. We’ll use base-2 so that \(0\) corresponds to identical distributions and \(1\) corresponds to maximally diverged distributions.

3 We originally planned to use Kullback–Leibler divergence, but because it’s not symmetric (i.e., \(KL(x,y) \ne KL(y, x)\)), it’s not formally speaking a distance measure. JS divergence was explicitly designed to resolve this.

Since we’ll have to loop through each residue pair to calculate the JS divergence, it’ll prove useful to calculate some other metrics along the way:

import pandas as pd
from analysis.residue_pair_data import calculate_pairwise_data

pairwise_data: dict[ModelName, pd.DataFrame] = {}

for model_name in ModelName:
    pairwise_dataframe = calculate_pairwise_data(
        sequence=sequence,
        single_probs=single_probs[model_name],
        double_probs=double_probs[model_name],
        double_masks=double_masks,
    )
    pairwise_data[model_name] = pairwise_dataframe

pairwise_data is a dictionary of DataFrames, one for each model. Each DataFrame holds the following metrics for each residue pair.

Calculated metrics:

Field Description
position_i The \(i^{\text{th}}\) residue position.
position_j The \(j^{\text{th}}\) residue position.
amino_acid_i The original amino acid at the \(i^{\text{th}}\) residue.
most_probable_i_i The amino acid with the highest probability in \(p(i \vert \lbrace i \rbrace)\).
most_probable_i_ij The amino acid with the highest probability in \(p( i \vert \lbrace i, j \rbrace)\).
amino_acid_j The original amino acid at the \(j^{\text{th}}\) residue.
most_probable_j_j The amino acid with the highest probability in \(p(j \vert \lbrace j \rbrace)\).
most_probable_j_ij The amino acid with the highest probability in \(p( j \vert \lbrace i, j \rbrace)\).
perplex_i_i The perplexity4 of \(p(i \vert \lbrace i \rbrace)\).
perplex_i_ij The perplexity of \(p( i \vert \lbrace i, j \rbrace)\).
perplex_j_i The perplexity of \(p(j \vert \lbrace j \rbrace)\).
perplex_j_ij The perplexity of \(p( j \vert \lbrace i, j \rbrace)\).
js_div_i The Jenson-Shannon divergence between \(p(i \vert \lbrace i \rbrace)\) and \(p( i \vert \lbrace i, j \rbrace)\).
js_div_j The Jenson-Shannon divergence between \(p(j \vert \lbrace j \rbrace)\) and \(p( j \vert \lbrace i, j \rbrace)\).
js_div_avg (js_div_i + js_div_j) / 2

4 What is perplexity? The perplexity of a probability distribution measures how uncertain or “surprised” the model is about its predictions. For amino acid predictions, a perplexity of 20 indicates maximum uncertainty (equal probability for all 20 amino acids), while a perplexity of 1 indicates complete certainty (100% probability for a single amino acid).

Also, note that the perplexities in this table refer to single residues, but in the field, (pseudo-)perplexity is often used to score a model’s confidence in an entire sequence [e.g., the ESM2 paper (Lin et al., 2023)] via methods developed in Salazar et al. (2020).

Let’s inspect some of the data for the largest model.

Code
pairwise_15b = pairwise_data[ModelName.ESM2_15B]
pairwise_15b
position_i position_j amino_acid_i most_probable_i_i most_probable_i_ij amino_acid_j most_probable_j_j most_probable_j_ij perplex_i_i perplex_i_ij perplex_j_j perplex_j_ij js_div_i js_div_j js_div_avg
0 0 1 M M M A A A 1.874111 1.742441 3.191700 2.850710 0.084210 0.089429 0.086820
1 0 2 M M M Q Q Q 1.874111 1.664312 2.018497 1.940485 0.045490 0.015094 0.030292
2 0 3 M M M T T T 1.874111 1.743261 3.114773 2.992301 0.025954 0.030157 0.028055
3 0 4 M M M P P P 1.874111 2.434045 1.156383 1.146090 0.092536 0.005108 0.048822
4 0 5 M M M A A A 1.874111 1.830846 2.304582 2.197421 0.015346 0.016212 0.015779
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
131395 362 357 L L L S S S 2.279995 1.893270 2.858127 3.379927 0.057617 0.092329 0.074973
131399 362 358 L L L A A A 2.279995 1.852332 1.807500 1.649580 0.063063 0.054722 0.058892
131402 362 359 L L L G G G 2.279995 3.412535 4.348093 6.531841 0.137119 0.191148 0.164134
131404 362 360 L L L Q S S 2.279995 1.731056 5.642571 6.295696 0.089178 0.145349 0.117264
131405 362 361 L L L N H Q 2.279995 2.111470 7.408498 10.722980 0.088659 0.240463 0.164561

131406 rows × 15 columns

Before comparing single- to double-masking results, let’s see the reconstruction accuracy5 for the protein. Relatedly, we can look at the average site perplexity to get an indication of the model’s confidence.

5 Reconstruction accuracy measures how well a model predicts masked amino acids. It’s calculated as the percentage of masked positions where the model’s most likely predicted amino acid matches the original amino acid that was masked.

Code
import arcadia_pycolor as apc
import matplotlib.pyplot as plt
import pandas as pd
from analysis.plotting import blue_palette_6

apc.mpl.setup()


def get_recon_acc(df: pd.DataFrame) -> float:
    matches = df["most_probable_i_i"] == df["amino_acid_i"]
    return 100 * matches.sum() / len(matches)


def avg_perplexity(df: pd.DataFrame) -> float:
    return df.perplex_i_i.mean()


# Prepare model names and data
model_names = [model.name for model in ModelName]
recon_acc_values = [get_recon_acc(pairwise_data[model]) for model in ModelName]
perplexity_values = [avg_perplexity(pairwise_data[model]) for model in ModelName]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5.25))

# Bar plot for Reconstruction Accuracy
ax1.bar(model_names, recon_acc_values, color=blue_palette_6.colors)
ax1.set_xlabel("Model")
ax1.set_ylabel("Reconstruction Accuracy (%)")
ax1.set_ylim(0)

# Bar plot for Perplexity
ax2.bar(model_names, perplexity_values, color=blue_palette_6.colors)
ax2.set_xlabel("Model")
ax2.set_ylabel("<Perplexity>")
ax2.set_ylim(0)

plt.setp(ax1.get_xticklabels(), rotation=20, ha="center")
plt.setp(ax2.get_xticklabels(), rotation=20, ha="center")

apc.mpl.style_plot(ax1, monospaced_axes="y")
apc.mpl.style_plot(ax2, monospaced_axes="y")

plt.tight_layout()
plt.show()
Figure 1: Single-masked token reconstruction accuracy and confidence for all ESM2 model sizes. (left) Percentage of positions that when masked, the model predicts the masked amino acid as most likely. (right) The perplexity of \(p(i \vert \lbrace i \rbrace)\) averaged over sites.

The effects of single- versus double-masking

In the single-masking case, the model has access to all context except the masked token \(i\). However, in double-masking, the model loses access to tokens \(i\) and \(j\), reducing the available context for making predictions. This reduction in context should lead to greater uncertainty in the model’s predictions. We can quantify this uncertainty using perplexity. We hypothesize that \(p(i | \lbrace i,j \rbrace)\) will show higher perplexity than \(p(i | \lbrace i \rbrace)\) because (1) the model has less contextual information to work with and (2) the presence of multiple masks creates more ambiguity about the relationships between the masked tokens.

Code
import matplotlib.pyplot as plt
import numpy as np

x_min = float("inf")
x_max = float("-inf")
for model in ModelName:
    df = pairwise_data[model]
    diff = df["perplex_i_ij"] - df["perplex_i_i"]
    x_min = min(x_min, diff.min())
    x_max = max(x_max, diff.max())

for model, color in zip(ModelName, blue_palette_6.colors, strict=True):
    fig, ax = plt.subplots(figsize=(6, 5.25))

    df = pairwise_data[model]
    diff = df["perplex_i_ij"] - df["perplex_i_i"]

    ax.hist(diff, bins=100, edgecolor="black", color=color)
    ax.set_title(f"Average shift: +{diff.mean():.3f}")
    ax.set_xlabel(r"$ PP_{i| \lbrace i, j \rbrace} - PP_{i| \lbrace i \rbrace} $")
    ax.set_yscale("log")
    ax.set_ylabel("Count")

    ax.set_xlim(x_min, x_max)

    plt.tight_layout()
    plt.show()
(a) The ESM2 8M-parameter model.
(b) The ESM2 35M-parameter model.
(c) The ESM2 150M-parameter model.
(d) The ESM2 650M-parameter model.
(e) The ESM2 3B-parameter model.
(f) The ESM2 15B-parameter model.
Figure 2: Histograms of the difference in perplexity between \(p(i | \lbrace i, j \rbrace)\) and \(p(i | \lbrace i \rbrace)\) for all residue pairs.

While the results show that there is, on average, a very slight shift towards higher perplexity when double-masking, the effect is very slight. These data illustrate that amino acid prediction probabilities are, on average, independent from each other between the single- and double-masked libraries, especially for the larger models.

However, these averages might mask important position-specific dependencies. Even if masking residue \(j\) typically has little effect on predicting residue \(i\), there could be specific pairs of positions where masking \(j\) substantially impacts the prediction of residue \(i\). To identify the extent of these position pairs, we can examine a more targeted metric: for each position \(i\), what’s the maximum increase in perplexity caused by masking any other position \(j\)? This analysis will help reveal whether certain amino acid positions have strong dependencies that are hidden when looking at average effects.

Code
import matplotlib.pyplot as plt
import numpy as np

for model, color in zip(ModelName, blue_palette_6.colors, strict=True):
    fig, ax = plt.subplots(figsize=(6, 5.25))

    df = pairwise_data[model]
    df["diff"] = df["perplex_i_ij"] - df["perplex_i_i"]
    diff = df.groupby("position_i")["diff"].max()

    ax.hist(diff, bins=np.linspace(0, 16, 50), edgecolor="black", color=color)
    ax.set_title(f"Average max shift: +{diff.mean():.3f}")
    ax.set_xlabel(r"$ \max_j(PP_{i| \lbrace i, j \rbrace} - PP_{i| \lbrace i \rbrace}) $")
    ax.set_yscale("log")
    ax.set_ylabel("Count")
    ax.set_xlim(0)

    apc.mpl.style_plot(ax, monospaced_axes="both")
    plt.tight_layout()
    plt.show()
(a) The ESM2 8M-parameter model.
(b) The ESM2 35M-parameter model.
(c) The ESM2 150M-parameter model.
(d) The ESM2 650M-parameter model.
(e) The ESM2 3B-parameter model.
(f) The ESM2 15B-parameter model.
Figure 3: Histograms of the max difference in perplexity between \(p(i | \lbrace i, j \rbrace)\) and \(p(i | \lbrace i \rbrace)\). Compared to Figure 2, only the position \(j\) that causes the largest \(p(i | \lbrace i, j \rbrace)\) for a given \(i\) is included.

These position-specific maximum shifts show that while random position pairs, on average, share little dependence for any given position \(i\), there tends to exist at least one position \(j\) that noticeably impacts the model’s confidence in predicting \(i\). Double-masking can significantly perturb perplexity even in the 15B parameter model, which has an average perplexity of less than two. To understand the magnitude of this effect in practical terms, let’s see whether \(p(i | \lbrace i \rbrace)\) and \(p(i | \lbrace i,j \rbrace)\) ever predict different amino acids as most probable, as that would indicate positions where context truly alters the model’s understanding of the protein sequence, rather than just lowering its confidence.

Code
import matplotlib.pyplot as plt
import numpy as np

aa_changes = []
for model in ModelName:
    df = pairwise_data[model]
    df["matches"] = df["most_probable_i_i"] != df["most_probable_i_ij"]
    aa_change = (df.groupby("position_i")["matches"].any().sum() / len(sequence)) * 100
    aa_changes.append(aa_change)

plt.figure(figsize=(6, 5.25))
plt.bar(
    [model.name for model in ModelName],
    aa_changes,
    color=blue_palette_6.colors,
)
plt.xlabel("Model")
plt.ylabel("% of positions affected")

plt.xticks(rotation=20)

apc.mpl.style_plot(monospaced_axes="y")
plt.show()
Figure 4: Frequency that double-masking changes the predicted amino acid. Each bar shows the percentage of positions where the most probable amino acid differs between the distributions \(p(i | \lbrace i \rbrace)\) and \(p(i | \lbrace i,j \rbrace)\) for a given model.

In the largest model, more than 10% of positions have at least one partner position whose masking changes the amino acid predicted as most likely. While this reveals clear cases of strong positional dependence, it only captures the most extreme effects—cases where context actually shifts the model’s top prediction. To capture more subtle interactions, where masking position \(j\) meaningfully shifts the probability distribution at position \(i\) without necessarily changing the most likely amino acid, we finally arrive at Jensen–Shannon (JS) divergence.

Visualizing Shannon–Jensen divergence

We can effectively visualize the relationships between all pairs of positions using heatmaps, where each cell \((i,j)\) shows the JS divergence between \(p(i|\lbrace i \rbrace)\) and \(p(i| \lbrace i,j \rbrace)\).

With that in mind, below is a series of heatmaps for the average Jenson-Shannon divergence with increasing model size.

Code
import analysis.plotting as plotting


def get_js_div_matrix(df: pd.DataFrame) -> np.ndarray:
    indices = sorted(set(df["position_i"]).union(df["position_j"]))
    matrix = pd.DataFrame(index=indices, columns=indices)  # type: ignore

    for _, row in df.iterrows():
        matrix.at[row.position_i, row.position_j] = row.js_div_avg
        matrix.at[row.position_j, row.position_i] = row.js_div_avg

    np.fill_diagonal(matrix.values, np.nan)

    matrix_values = matrix.to_numpy()
    return matrix_values.astype(np.float32)


for model in ModelName:
    size = 700
    fig = plotting.visualize_js_div_matrix(get_js_div_matrix(pairwise_data[model]))
    fig.update_layout(
        width=size,
        height=size * 80 / 100,
    )
    fig.show()
(a) The ESM2 8M-parameter model.
(b) The ESM2 35M-parameter model.
(c) The ESM2 150M-parameter model.
(d) The ESM2 650M-parameter model.
(e) The ESM2 3B-parameter model.
(f) The ESM2 15B-parameter model.
Figure 5: Heatmap of average JS divergence for each model size. Each cell is colored based on the average of \(JS(p(i \vert \lbrace i \rbrace), p(i \vert \lbrace i,j \rbrace))\) and \(JS(p(j \vert \lbrace j \rbrace), p(j \vert \lbrace i,j \rbrace))\).

Comparison to 3D contact map

A striking pattern emerges, and interestingly, it’s most prominent at the intermediately sized 150M model parameter. It looks like a structural contact map, so let’s load up the structure, calculate the pairwise residue distances, and compare.

We’ll plot the JS divergence for the 150M model in the upper left triangular region of the heatmap and the contact map in the lower right triangular region.

Code
import biotite.structure as struc
from biotite.structure.io import load_structure

# Use alpha carbons
array = load_structure("input/P00813.pdb")
array = array[array.atom_name == "CA"]

distance_map = np.zeros((len(array), len(array)))
for i in range(len(array)):
    distance_map[i, :] = struc.distance(array[i], array)

# Define the contact map as the inverse of the distances
contact_map = 1 / distance_map

m150_divergence = get_js_div_matrix(pairwise_data[ModelName.ESM2_150M])
plotting.compare_to_contact_map(m150_divergence, contact_map)
Figure 6: Comparison between JS divergence and contact map. (upper-left) The heatmap values in Figure 5 for the 150M parameter model. (bottom-right) The contact map. Calculated by taking the inverse distance between alpha-carbon atoms for each pair of residues, and is expressed in units of \(Å^{-1}\).

This illustrates that some pairwise epistasis is being captured, whereby masking one residue affects the probability distribution of a second masked residue and that the strength of this effect inversely correlates with the physical distance between the residues. Given that the attention maps learned in the ESM models recapitulate protein contact maps (Lin et al., 2023), seeing this propagate to the logits isn’t too surprising. Although, unlike the attention maps, which recapitulate contact maps better and better with increasing model size, the correlation with contact map is strongest in intermediate model sizes.

To be sure, here’s the same plot for the largest, most capable ESM2 model.

Code
b15_divergence = get_js_div_matrix(pairwise_data[ModelName.ESM2_15B])
plotting.compare_to_contact_map(b15_divergence, contact_map, js_div_zmax=0.05)
Figure 7: Comparison between JS divergence and contact map. (upper-left) The heatmap values in Figure 5 for the 15B-parameter model. The contact map. Calculated by taking the inverse distance between alpha-carbon atoms for each pair of residues, and is expressed in units of \(Å^{-1}\).

Conclusion

The contrast between these two models is quite interesting. The 15B-parameter model has a 90% single-mask reconstruction accuracy, whereas the 150M parameter model’s reconstruction accuracy is only 50%. Given this, we expected that the 150M parameter model would struggle to create residue pair dependencies that are biologically interpretable and that the 15B-parameter model might provide more interpretable answers. However, the above results show the opposite.

Your feedback requested!

We’re having trouble interpreting this result and seeking your opinion. So our question to you is:

Why do you think the contact map pattern diminishes in the JS divergence for the larger models?

Please leave your comments below, or consider helping us improve this pub with your ideas.

References

Devlin J, Chang M-W, Lee K, Toutanova K. (2019). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://doi.org/10.48550/arXiv.1810.04805
Lin Z, Akin H, Rao R, Hie B, Zhu Z, Lu W, Smetanin N, Verkuil R, Kabeli O, Shmueli Y, Santos Costa A dos, Fazel-Zarandi M, Sercu T, Candido S, Rives A. (2023). Evolutionary-scale prediction of atomic-level protein structure with a language model. https://doi.org/10.1126/science.ade2574
Salazar J, Liang D, Nguyen TQ, Kirchhoff K. (2020). Masked Language Model Scoring. https://doi.org/10.18653/v1/2020.acl-main.240