A quantitative-genetic decomposition of a neural network –

Published

October 30, 2025

Summary

We tested equivalent linear mapping (ELM) on a neural network trained to predict phenotypes from genotypes in simulated data. We show that ELM successfully recapitulates additive and epistatic effects learned by the model, even in data with substantial environmental noise.

We used Claude (Opus 4.1, Sonnet 4.5) to help write code and to review our code, selectively incorporating its feedback. We also used ChatGPT (GPT-5) to review our code and selectively incorporated its feedback.

Purpose

Understanding the behaviour of deep learning (DL) models when applied to genotype-phenotype mapping has been a recent focus at Arcadia (Sandler and York, 2025; York et al., 2025). To date, our efforts have largely focused on exploring the predictive performance of various types of models (but see York and Mets, 2025). In this notebook, however, we extend our work to trying to understand how we might extract functional information from trained DL models. We apply a recently described method for obtaining equivalent linear representations of DL models to assess feature importance per input sample. We show that under certain conditions, the distribution of feature importance values across samples can be used to infer familiar quantitative genetic parameters such as additive effect sizes, epistatic interactions, and genetic variance components, allowing for the use of DL models in tasks such as genomic target identification and interaction mapping.

Introduction

DL models have been increasingly applied in predicting phenotypes from genotypes (Rijal et al., 2025; Sigurdsson et al., 2023; Zeng et al., 2021), with the hope that their ability to model nonlinear interactions such as epistasis might improve our ability to capture G-P maps. However, the “black box” nature of DL models has made it challenging to extract meaningful biological insights from them, despite their potential for capturing useful interactions, hampering our ability to describe novel biology using deep learning. Recent advances in mechanistic interpretability may be changing this, as a variety of new techniques have been developed that allow for the evaluation of input feature importance, neuron feature specialization, and model steering (Adams et al., 2025; Bricken et al., 2023; Elhage et al., 2021) in otherwise black box models. Through these tools, inferring scientific insights by looking under the hood of DL models has become increasingly feasible.

An especially interesting approach allows DL models to be studied as a linear system by computing the Jacobian of a trained model conditioned on a given input sample (Golden, 2025). This sample-specific Jacobian represents a local ‘linearized’ version of the model that produces output values almost identical to the predictions of the original model. We refer to this approach as equivalent linear mapping (ELM) for short. ELM leverages the full expressive power of DL models, while the equivalent linear representation with the Jacobian for each input allows for interpretation of the model as a collection of linear systems.

Since quantitative genetics models are also fundamentally systems of linear equations, we guessed that ELM could be useful in extracting quantitative genetics parameters from DL models trained on genomic prediction tasks. ELM therefore potentially gives us the opportunity to combine the ability of DL models to implicitly learn higher order interactions with the interpretability of linear quantitative genetics models.

In this notebook, we apply ELM to a simple two-layer neural network trained on simulated genotype-phenotype (G-P) data. We use these analyses as a proof of concept of ELM in a context where a neural network has the best chance to learn ‘ground-truth’ G-P mappings and consider how we might apply this method to real biological data.

The dataset

This notebook uses data we previously generated in Sandler and York (2025), investigating the scaling behaviour of neural networks in genotype-phenotype mapping tasks. We focus our analysis on a single simulation replicate from this study, which includes 10,000 haploid individuals, each with a genome of 64 loci expressing 5 phenotypes (labelled as ‘traits’ in the simulations), ranging from additive (\(\frac{V_A}{V_G} = 1\)) to highly epistatic (\(\frac{V_A}{V_G} = 0.12\)). This dataset represents a clean ‘best-case’ scenario for the application of neural networks to G-P mapping. This replicate has 1) no environmental noise in phenotypes, 2) no linkage disequilibrium among genotypes, 3) strictly 50/50 allele frequencies 4) enough data for the model to learn how to predict almost all phenotypic variation from withheld data. While this may seem like an excessively easy case, we suspect many of these conditions can likely be violated or managed in strategic ways to apply this method to real-world data (see Discussion and Appendix). We include the raw simulation data in the GitHub repo associated with this notebook pub, but alternative simulation scenarios from our previous study can also be found here.

Imports and constants
import os
from collections import defaultdict
from pathlib import Path
from typing import cast

import arcadia_pycolor as apc
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib.ticker import FormatStrFormatter, MaxNLocator, ScalarFormatter
from numpy.polynomial.polynomial import polyfit
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from torch.utils.data import Dataset

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

apc.mpl.setup()


plot_style = {
    "figure.facecolor": apc.parchment,
    "axes.facecolor": apc.parchment,
    "axes.titleweight": "bold",  # Make titles bold
    "axes.labelweight": "bold",  # Make axis labels bold
    "axes.spines.top": False,  # Remove top spine
    "axes.spines.right": False,  # Remove right spine
}

os.environ["MKL_VERBOSE"] = "0"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model params
hidden_size = 1024
learning_rate = 0.001
EPS = 1e-15
regularization = 1

# data params
sample_size = 10000
qtl_n = 64
rep = 1
n_alleles = 2
heritability = 1

input_data_dir = Path("input_data")
sim_name = f"qhaplo_{qtl_n}qtl_{sample_size}n_rep{rep}"
base_file_name = str(input_data_dir / f"{sim_name}_")
true_eff = pd.read_csv(
    input_data_dir / f"qhaplo_{qtl_n}qtl_{sample_size}n_rep1_eff.txt", delimiter=" "
)
Define datasets and dataloaders
# base class for metadata
class BaseDataset(Dataset):
    """
    Base dataset class for loading data from HDF5 files.

    Args:
        hdf5_path: Path to the HDF5 file containing the data
    """

    def __init__(self, hdf5_path: Path) -> None:
        self.hdf5_path = hdf5_path
        self.h5 = None
        self._strain_group = None
        self.strains = None
        # Open temporarily to get keys and length for initialization
        with h5py.File(self.hdf5_path, "r") as temp_h5:
            temp_strain_group = cast(h5py.Group, temp_h5["strains"])
            self._strain_keys: list[str] = list(temp_strain_group.keys())
            self._len = len(temp_strain_group)

    def _init_h5(self):
        if self.h5 is None:
            self.h5 = h5py.File(self.hdf5_path, "r")
            self._strain_group = cast(h5py.Group, self.h5["strains"])
            self.strains = self._strain_keys

    def __len__(self) -> int:
        return self._len


# adds geno and pheno data
class GenoPhenoDataset(BaseDataset):
    """
    Dataset for loading both genotype and phenotype data.

    Returns tuples of (phenotype, genotype) tensors.

    Args:
        hdf5_path: Path to the HDF5 file containing the data
        encode_minus_plus_one: If True, recode genotypes from 0/1 to -1/1
    """

    def __init__(self, hdf5_path: Path, encode_minus_plus_one: bool = True) -> None:
        super().__init__(hdf5_path)
        self.encode_minus_plus_one = encode_minus_plus_one

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        self._init_h5()
        strain = self.strains[idx]
        strain_data = cast(Dataset, self._strain_group[strain])

        phens = torch.tensor(strain_data["phenotype"][:], dtype=torch.float32)
        gens = torch.tensor(strain_data["genotype"][:], dtype=torch.float32).flatten()

        # Recode genotypes from 0/1 to -1/1 to match true sims GP model
        if self.encode_minus_plus_one:
            gens = 2 * gens - 1

        return phens, gens


# adds ability to adjust heritability of phenotypes by adding noise
class HeritabilityAdjustedDataset(torch.utils.data.Dataset):
    """
    Dataset wrapper that adjusts phenotypes to achieve desired heritability
    while maintaining a total variance of 1.

    Note: Assumes input phenotypes are normalized to variance = 1.


    Args:
        base_dataset: Original dataset providing (phenotype, genotype) tuples
        target_heritability: Desired broad-sense heritability (0-1)
        seed: Random seed for reproducibility
    """

    def __init__(self, base_dataset, target_heritability=1.0, seed=42):
        self.base_dataset = base_dataset
        self.target_heritability = target_heritability

        if target_heritability < 1.0:
            # Calculate genetic scaling factor to achieve target heritability
            # while maintaining total variance = 1
            self.genetic_scale = np.sqrt(target_heritability)

            # Calculate noise variance to make total variance = 1
            self.noise_variance = 1 - target_heritability

            # Pre-generate all noise patterns
            rng = np.random.RandomState(seed)
            self.noise_patterns = []

            # Get shape of phenotypes by looking at first item
            sample_phens, _ = base_dataset[0]
            phen_shape = sample_phens.shape

            # Generate noise for each sample
            for _ in range(len(base_dataset)):
                noise = torch.tensor(
                    rng.normal(0, np.sqrt(self.noise_variance), size=phen_shape),
                    dtype=torch.float32,
                )
                self.noise_patterns.append(noise)
        else:
            # No adjustment needed if heritability = 1
            self.genetic_scale = 1.0
            self.noise_patterns = None

    def __getitem__(self, idx):
        phens, gens = self.base_dataset[idx]

        if self.target_heritability < 1.0:
            # Scale down genetic component
            scaled_phens = phens * self.genetic_scale

            # Add noise component
            adjusted_phens = scaled_phens + self.noise_patterns[idx]
        else:
            adjusted_phens = phens

        return adjusted_phens, gens

    def __len__(self):
        return len(self.base_dataset)


# dataloaders for pytorch neural network
def create_data_loaders_with_heritability(
    base_file_name,
    heritability=1.0,
    batch_size=128,
    num_workers=0,
    shuffle=True,
    seed=42,
    encode_minus_plus_one=True,
):
    """
    Create DataLoaders with adjusted heritability.

    Args:
        base_file_name: Base path for HDF5 files
        heritability: Target broad-sense heritability (0-1)
        batch_size: Batch size for DataLoaders
        num_workers: Number of worker processes for DataLoaders
        shuffle: Whether to shuffle the data
        seed: Random seed for reproducibility
        encode_minus_plus_one: If True, encode genotypes as -1/1 instead of 0/1

    Returns:
        Dictionary containing DataLoaders with adjusted heritability
    """
    # Create base datasets
    train_data_gp = GenoPhenoDataset(
        Path(f"{base_file_name}train.hdf5"), encode_minus_plus_one=encode_minus_plus_one
    )
    test_data_gp = GenoPhenoDataset(
        Path(f"{base_file_name}test.hdf5"), encode_minus_plus_one=encode_minus_plus_one
    )

    # Wrap with heritability adjustment
    train_data_gp_adjusted = HeritabilityAdjustedDataset(train_data_gp, heritability, seed)
    test_data_gp_adjusted = HeritabilityAdjustedDataset(test_data_gp, heritability, seed)

    # Create DataLoaders
    train_loader_gp = torch.utils.data.DataLoader(
        dataset=train_data_gp_adjusted,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )
    test_loader_gp = torch.utils.data.DataLoader(
        dataset=test_data_gp_adjusted,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )

    return {
        "train_loader_gp": train_loader_gp,
        "test_loader_gp": test_loader_gp,
        "test_data_gp": test_data_gp_adjusted,
    }
Initialize loaders for model training
loaders = create_data_loaders_with_heritability(
    base_file_name, heritability=heritability
)  # choose heritability
train_loader_gp = loaders["train_loader_gp"]
test_loader_gp = loaders["test_loader_gp"]
Extract metadata from loader
# Automatically extract n_phen and n_loci from the input file
with h5py.File(f"{base_file_name}train.hdf5", "r") as f:
    # Get number of phenotypes
    n_phen = len(f["metadata"]["phenotype_names"])
    print(f"Number of phenotypes: {n_phen}")

    # Get number of loci
    n_loci = len(f["metadata"]["loci"])
    print(f"Number of loci: {n_loci}")

    # Optional: Print phenotype names
    phenotype_names = [name.decode("utf-8") for name in f["metadata"]["phenotype_names"]]
    print(f"Phenotype names: {phenotype_names}")
Number of phenotypes: 5
Number of loci: 64
Phenotype names: ['Trait1', 'Trait2', 'Trait3', 'Trait4', 'Trait5']

Equivalent linear mapping

Building on previous work (Mohan et al., 2020; Wang et al., 2016), Golden (2025) demonstrated that using gradient computation with respect to an input sequence allows for the reconstruction of a large language model’s output using the Jacobian of the model, such that: \[ \hat{y}(x) = J(x)\cdot x \]

While this Jacobian reconstruction usually holds only for a narrow neighborhood around any given input (the specific linear piece on which the input resides), it allows us to assess the relative importance of input features for the prediction of any given input sample. For example, by comparing the magnitude of partial derivatives of input loci with respect to a phenotype (hereafter referred to as locus sensitivity values), we can estimate the relative importance of loci to phenotype prediction. How does this work in practice? To understand how ELM will behave when applied to our simulated data, it’s first necessary to consider how our simulated phenotypes are generated. We used AlphaSimR (v1.6.1) (Gaynor et al., 2021) to generate phenotypes with both additive and epistatic effects using the following quantitative genetic G-P model:

\[ P = \sum_i^{n} \beta_i x_i + \sum_i^{n} \sum_{j \neq i}^{n} \beta_{ij} x_i x_j \tag{1}\]

Where \(P\) is the value of a quantitative phenotype, \(x_i\) is the allelic state of a locus (taking values -1/1), \(n\) is the number of loci in the genome, and \(\beta_i, \beta_{ij}\) are additive/epistatic locus effects drawn from separate Gaussian distributions. Note that any locus can participate in at most one epistatic interaction (giving us 32 epistatic pairs for our 64 loci), which simplifies the epistatic sum in Equation 1 to just one non-zero \(\beta_{ij}\) term per locus. The five phenotypes we simulate are completely independent with respect to each other. If we compute the Jacobian of a model trained on this underlying data, we get a vector of sensitivity values for each individual (and for each of their phenotypes), capturing the partial derivative of each \(P\) with respect to all input loci \(x_i\) as follows: \[ P = \frac{\partial P}{\partial x} \cdot x \]

Assuming a model learns to explain most genetic variance in a phenotype by using ground truth locus allelic state, the sensitivity of any given locus should capture a mixture of additive and epistatic effects as follows: \[ \frac{\partial P}{\partial x_i} = \beta_i + \beta_{ij} x_j \tag{2}\]

Simplifying the summation of the epistatic term to only capture the single locus \(x_j\) we expect the focal locus \(x_i\) to interact with. Further down, we show how computing \(\frac{\partial P}{\partial x_i}\) across a population of test samples can be used to estimate \(\beta_i\) and \(\beta_{ij}\).

Model fitting

Before applying ELM, we first have to train a genomic prediction model. As in our previous pub, we fit a simple 2-layer multilayer perceptron (MLP) to our simulated data using an 85-15% test-train split. However, we have to make one minor modification to allow us to apply ELM to this model. We remove all bias terms as they wouldn’t be captured during the computation of gradients and thus prevent ELM from accurately replicating model outputs1. We’ve found that this negligibly impacts model performance.

1 Activation functions that aren’t homogeneous order one (e.g., SwiGLU, softmax) would also require detachment of their nonlinear terms from the computational graph during Jacobian computation for ELM to hold (as in Golden (2025)). Because our models use only Leaky ReLU activations (which are homogeneous order one) and no bias terms, the standard Jacobian computed via automatic differentiation already provides exact reconstruction.

With this modified MLP setup, we can train our neural network and verify its performance on withheld (test-set) data. Our training loop uses learning rate scheduling and early stopping to ensure a good model fit.

Define model architecture
class GP_net(nn.Module):
    """
    Modified fully connected G -> P network with Leaky ReLU for Jacobian analysis

    Args:
        n_loci: #QTLs
        hidden_layer_size: geno hidden layer size
        hidden_layer1_size: optional different size for 1st hidden layer
        n_pheno: number of phenotypes to output/predict
        leak: slope of the negative part of Leaky ReLU (default: 0.01)
    """

    def __init__(self, n_loci, hidden_layer_size, n_pheno, hidden_layer1_size=None, leak=0.01):
        super().__init__()

        if hidden_layer1_size is None:
            hidden_layer1_size = hidden_layer_size

        # Create individual layers without bias
        self.layer1 = nn.Linear(in_features=n_loci, out_features=hidden_layer1_size, bias=False)
        self.layer2 = nn.Linear(
            in_features=hidden_layer1_size, out_features=hidden_layer_size, bias=False
        )
        self.layer3 = nn.Linear(in_features=hidden_layer_size, out_features=n_pheno, bias=False)

        # Store leak parameter
        self.leak = leak

    def forward(self, x):
        # First layer with Leaky ReLU activation
        x = self.layer1(x)
        x = F.leaky_relu(x, negative_slope=self.leak)

        # Second layer with Leaky ReLU activation
        x = self.layer2(x)
        x = F.leaky_relu(x, negative_slope=self.leak)

        # Output layer (no activation)
        x = self.layer3(x)

        return x
Define training loop
# Model initialization and training loop
def train_gpnet(
    model,
    train_loader,
    test_loader=None,
    n_loci=n_loci,
    n_alleles=2,
    max_epochs=100,  # Set a generous upper limit
    patience=10,  # Number of epochs to wait for improvement
    min_delta=0.003,  # Minimum change to count as improvement
    verbose=True,
    learning_rate=None,
    weight_decay=regularization,
    device=device,
):
    """
    Train model with early stopping to prevent overtraining
    """
    # Move model to device
    model = model.to(device)

    # Initialize optimizer with proper weight decay
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)

    history = {"train_loss": [], "test_loss": [], "epochs_trained": 0}

    # Early stopping variables
    best_loss = float("inf")
    best_epoch = 0
    best_model_state = None
    patience_counter = 0

    # Training loop
    for epoch in range(max_epochs):
        # Training
        model.train()
        train_loss = 0

        for _, (phens, gens) in enumerate(train_loader):
            phens = phens.to(device)
            gens = gens[
                :, : n_loci * n_alleles
            ]  # convert genotypes to -1/1 from 0/1 to match simulator ground truth
            gens = gens.to(device)
            gens_simplified = gens[:, 1::2]  # simplify genotypes from 1-hot encoded as redundant

            # Forward pass
            optimizer.zero_grad()
            output = model(gens_simplified)

            # focal loss
            g_p_recon_loss = F.l1_loss(output + EPS, phens + EPS)

            # Backward and optimize
            g_p_recon_loss.backward()
            optimizer.step()

            train_loss += g_p_recon_loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        if test_loader is not None:
            model.eval()
            test_loss = 0

            with torch.no_grad():
                for phens, gens in test_loader:
                    phens = phens.to(device)
                    gens = gens[:, : n_loci * n_alleles]
                    gens = gens.to(device)
                    gens_simplified = gens[:, 1::2]

                    output = model(gens_simplified)
                    test_loss += F.l1_loss(output + EPS, phens + EPS)

            avg_test_loss = test_loss / len(test_loader)

            if (epoch + 1) % 10 == 0 and verbose:
                print(
                    f"Epoch: {epoch + 1}/{max_epochs}, Train Loss: {avg_train_loss:.6f}, "
                    f"Test Loss: {avg_test_loss:.6f}"
                )

            # Update learning rate
            scheduler.step(avg_test_loss)

            # Check for improvement
            if avg_test_loss < (best_loss - min_delta):
                best_loss = avg_test_loss
                best_epoch = epoch
                patience_counter = 0
                # Save best model state
                best_model_state = {
                    k: v.cpu().detach().clone() for k, v in model.state_dict().items()
                }
            else:
                patience_counter += 1

            # Early stopping check
            if patience_counter >= patience:
                if verbose:
                    print(f"Early stopping triggered after {epoch + 1} epochs")
                break

    # Record how many epochs were actually used
    history["epochs_trained"] = epoch + 1

    # Restore best model
    if best_model_state is not None:
        if verbose:
            print(f"Restoring best model from epoch {best_epoch + 1}")
        model.load_state_dict(best_model_state)

    return model, best_loss
Run neural network training
# expect model training to take a couple minutes to run on CPU
model = GP_net(n_loci=n_loci, hidden_layer_size=hidden_size, n_pheno=n_phen)

# Use early stopping with appropriate patience
model, best_loss_gp = train_gpnet(
    model=model,
    train_loader=train_loader_gp,
    test_loader=test_loader_gp,
    n_loci=n_loci,
    learning_rate=learning_rate,
    device=device,
)
model.eval()
Epoch: 10/100, Train Loss: 0.142251, Test Loss: 0.172582
Epoch: 20/100, Train Loss: 0.129432, Test Loss: 0.167517
Epoch: 30/100, Train Loss: 0.081292, Test Loss: 0.124460
Epoch: 40/100, Train Loss: 0.048542, Test Loss: 0.110764
Epoch: 50/100, Train Loss: 0.029345, Test Loss: 0.099643
Epoch: 60/100, Train Loss: 0.026389, Test Loss: 0.099313
Early stopping triggered after 67 epochs
Restoring best model from epoch 57
GP_net(
  (layer1): Linear(in_features=64, out_features=1024, bias=False)
  (layer2): Linear(in_features=1024, out_features=1024, bias=False)
  (layer3): Linear(in_features=1024, out_features=5, bias=False)
)
Test set prediction statistics
true_phenotypes = []
predicted_phenotypes = []

with torch.no_grad():
    for phens, gens in test_loader_gp:
        phens = phens.to(device)
        gens = gens.to(device)
        gens_simplified = gens[:, 1::2]

        # Get predictions
        predictions = model(gens_simplified)

        # Store results
        true_phenotypes.append(phens.cpu().numpy())
        predicted_phenotypes.append(predictions.cpu().numpy())

    # Concatenate batches
    true_phenotypes = np.concatenate(true_phenotypes)
    predicted_phenotypes = np.concatenate(predicted_phenotypes)

    # Calculate correlations for each phenotype
    correlations = []
    r2s = []

    for i in range(n_phen):
        corr, _ = pearsonr(true_phenotypes[:, i], predicted_phenotypes[:, i])
        correlations.append(corr)
        r2 = r2_score(true_phenotypes[:, i], predicted_phenotypes[:, i])
        r2s.append(r2)

    true_h2 = [1, 0.787, 0.505, 0.403, 0.122]

    # Create a detailed DataFrame with all results
    results_df_fit = pd.DataFrame(
        {
            "Phenotype": range(1, n_phen + 1),
            "Pearson correlation": correlations,
            "r²": r2s,
            "h²": true_h2,
        }
    )

results_df_fit = results_df_fit.round(3)

results_df_fit
Table 1: Test dataset prediction performance of a neural network trained on simulated genotype-phenotype data. Phenotype narrow sense heritability (\(h^2\)) also included.
Phenotype Pearson correlation
0 1 0.999 0.997 1.000
1 2 0.994 0.988 0.787
2 3 0.991 0.982 0.505
3 4 0.988 0.975 0.403
4 5 0.987 0.970 0.122

As we can see from the test-set metrics (Table 1), our modified MLP learns to almost perfectly predict all five phenotypes. Phenotypes are numbered from most additive to most epistatic, which is reflected in a tiny bit of performance decay as the G-P map becomes more nonlinear from phenotype 1 to phenotype 5.

Applying ELM

Now that we’ve trained a model, we can apply ELM by computing the Jacobian of each test-set sample with respect to each phenotype with PyTorch (v2.2.2) autograd. We can check that the Jacobian accurately captures the model’s behaviour by comparing the output of the trained model to the dot product of the Jacobian and input genotypes.

Function comparing ELM and model outputs
def plot_jacobian_vs_predictions(model, test_loader, n_phenotypes=5, max_batches=20, device=device):
    dot_products = [[] for _ in range(n_phenotypes)]
    predictions = [[] for _ in range(n_phenotypes)]

    def get_jacobian(x, idx):
        def fn(inp):
            return model(inp)[:, idx]

        return torch.autograd.functional.jacobian(fn, x, vectorize=True).squeeze()

    with torch.no_grad():
        for batch_idx, (_, gens) in enumerate(test_loader):
            if batch_idx >= max_batches:
                break

            gens = gens.to(device)
            gens_simplified = gens[:, 1::2] if gens.shape[1] > n_loci else gens

            for i in range(gens.shape[0]):
                single_gens = gens_simplified[i : i + 1]
                model_preds = model(single_gens)[0]

                for p_idx in range(n_phenotypes):
                    predictions[p_idx].append(model_preds[p_idx].item())
                    jacobian = get_jacobian(single_gens, p_idx)
                    dot_products[p_idx].append(torch.sum(jacobian * single_gens).item())

    for p_idx in range(n_phenotypes):
        with mpl.rc_context(plot_style):
            plt.figure(figsize=(6, 6))

            # print("Jacobian model prediction accuracy")

            dots = np.array(dot_products[p_idx])
            preds = np.array(predictions[p_idx])
            # r2 = r2_score(dots, preds)

            plt.scatter(preds, dots, alpha=0.5, color=apc.steel)
            min_val, max_val = min(min(dots), min(preds)), max(max(dots), max(preds))
            plt.plot([min_val, max_val], [min_val, max_val], "r--")
            plt.xlabel("Model output", fontsize=16)
            plt.ylabel("Jacobian*input", fontsize=16)

            ax = plt.gca()
            plt.gca().set_aspect("equal")

            ax.xaxis.set_major_locator(MaxNLocator(4))
            ax.yaxis.set_major_locator(MaxNLocator(4))
            ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))

            # plt.sca(plt)
            apc.mpl.style_plot(monospaced_axes="both")

            # print(f"Phenotype {p_idx + 1}: r² = {r2:.2f}")
            plt.show()


plot_jacobian_vs_predictions(model, test_loader_gp)
Scatter plots of neural network output vs. jacobian reconstruction which shows the near perfect reconstruction of ELM
(a) Phenotype 1
Scatter plots of neural network output vs. jacobian reconstruction which shows the near perfect reconstruction of ELM
(b) Phenotype 2
Scatter plots of neural network output vs. jacobian reconstruction which shows the near perfect reconstruction of ELM
(c) Phenotype 3
Scatter plots of neural network output vs. jacobian reconstruction which shows the near perfect reconstruction of ELM
(d) Phenotype 4
Scatter plots of neural network output vs. jacobian reconstruction which shows the near perfect reconstruction of ELM
(e) Phenotype 5
Figure 1: Trained MLP model output plotted against summed product of model Jacobian and input feature values for test-set samples. Dashed line denotes 1:1 relationship. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

As expected, the Jacobian predicts model outputs effectively perfectly (Figure 1), suggesting that our implementation of ELM is working for our simple G-P neural network.

Estimating additive effects

Equation 2 tells us that locus-specific sensitivity values obtained from the Jacobian will reflect both additive and epistatic effects, depending on the allelic states in any given individual input sample. What about the expected locus sensitivity if we average these values across a large number of test-set individuals?

\[ \begin{align*} E\left[\frac{\partial P}{\partial x_i}\right] &= \beta_i + \beta_{ij} \cdot E[x_j] \\ &= \beta_i + \beta_{ij} (p_j - q_j) \\ &\approx \beta_i \quad \text{if } p_j \approx q_j \approx 0.5 \end{align*} \tag{3}\]

The expected locus sensitivity value reflects the true underlying additive effect (\(\beta_i\)), along with some epistatic effect (\(\beta_{ij}\)) depending on the frequencies of the two alleles (\(p_j\) and \(q_j\)) at the interacting locus \(x_j\). Unsurprisingly, this mirrors the quantitative genetic definition of the expected effect of allelic substitution at locus \(x_i\) (i.e., additive effect), which can statistically capture both additive and epistatic effects (Mackay, 2014). In our simulations, where all allele frequencies are roughly 0.5, this equation conveniently simplifies to the true additive effect. Can we actually back out additive effects from our neural network by looking at the average locus sensitivity values in our test set? We can easily check this since we have ground truth effect sizes for our simulation.

Estimate mean locus sensitivity
def analyze_feature_importance_across_validation(
    model, validation_loader, true_effects_df, phenotype_idx=None, device=device, max_samples=None
):
    """
    Calculate feature importance for each sample in the validation set and visualize distributions.

    Args:
        model: Trained neural network model
        validation_loader: DataLoader for validation data
        true_effects_df: DataFrame containing ground truth effects
        phenotype_idx: List of indexes of phenotypes to analyze (1-based, default: [1] for first phenotype)
        device: Device to run calculations on
        max_samples: Maximum number of samples to process (None for all)

    Returns:
        List of DataFrames with importance values for each locus across all samples, one per phenotype
    """
    # default analyze first phenotype
    if phenotype_idx is None:
        phenotype_idx = [1]

    # Function to get importance for a specific phenotype
    def get_phenotype_importance(x, pheno_idx):
        def phenotype_fn(inp):
            return model(inp)[:, pheno_idx - 1]  # Convert 1-based to 0-based

        # Calculate Jacobian
        jacobian = torch.autograd.functional.jacobian(phenotype_fn, x, vectorize=True).squeeze()

        # Return unmodified Jacobian value
        return jacobian

    # Store importance values for each sample and each phenotype
    all_importance_values = {idx: [] for idx in phenotype_idx}
    sample_count = 0

    # Process validation samples
    with torch.no_grad():  # No need for gradients in the forward pass
        for _, gens in validation_loader:
            # Process each sample in the batch
            for i in range(gens.shape[0]):
                # Check if we've reached the maximum samples
                if max_samples is not None and sample_count >= max_samples:
                    break

                # Get single sample
                single_gens = gens[i : i + 1]

                # Flatten and simplify (remove one-hot encoding)
                flattened_gens = single_gens.reshape(single_gens.shape[0], -1)
                simplified_gens = flattened_gens[:, 1::2].to(device)

                # Calculate importance for each phenotype
                for idx in phenotype_idx:
                    importance = get_phenotype_importance(simplified_gens, idx)
                    # Store the importance values
                    all_importance_values[idx].append(importance.cpu().numpy())

                sample_count += 1

                # Print progress
                # if sample_count % 1000 == 0:
                #    print(f"Processed {sample_count} samples")

            # Check again after batch
            if max_samples is not None and sample_count >= max_samples:
                break

    # print(f"Completed importance analysis for {sample_count} samples")

    # Convert to DataFrames for easier analysis
    importance_dfs = {}
    for idx in phenotype_idx:
        importance_df = pd.DataFrame(all_importance_values[idx])
        importance_df.columns = [f"Locus_{i}" for i in range(importance_df.shape[1])]
        importance_dfs[idx] = importance_df

    # Create individual plots for each phenotype
    for idx in phenotype_idx:
        with mpl.rc_context(plot_style):
            plt.figure(figsize=(6, 6))
            # Get the dataframe for this phenotype
            importance_df = importance_dfs[idx]

            # Extract ground truth for the specific phenotype (idx is already 1-based)
            true_effects = true_effects_df[true_effects_df["trait"] == idx]

            # Generate correlation plot between median importance and true effects
            if "add_eff" in true_effects.columns:
                # Get data for scatter plot
                true_effects_values = true_effects["add_eff"] * np.sqrt(heritability)
                mean_importance = importance_df.mean().values

                # Create scatter plot
                plt.scatter(true_effects_values, mean_importance, alpha=0.7, color=apc.steel)

                # Add best fit line
                z = np.polyfit(true_effects_values, mean_importance, 1)
                p = np.poly1d(z)
                plt.plot(true_effects_values, p(true_effects_values), "r-", alpha=0.7)

                # Add correlation coefficient
                r2 = r2_score(true_effects_values, mean_importance)
                plt.xlabel(r"Additive effect size ($\beta_i$)", fontsize=16)
                plt.ylabel("Mean Locus Sensitivity", fontsize=16)

                # plt.gca().set_aspect('equal')

                # Add slope information
                slope = z[0]  # Extract the slope from the polyfit result
                plt.text(
                    0.05,
                    0.95,
                    f"Slope = {slope:.3f}\nr² = {r2:.2f}",
                    transform=plt.gca().transAxes,
                    fontsize=15,
                    verticalalignment="top",
                    bbox=dict(
                        boxstyle="round", edgecolor="none", facecolor=apc.parchment, alpha=0.7
                    ),
                )

                # Get current axes for formatting
                ax = plt.gca()
                ax.xaxis.set_major_locator(MaxNLocator(4))
                ax.yaxis.set_major_locator(MaxNLocator(4))
                ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
                ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))

                apc.mpl.style_plot(monospaced_axes="both")

            plt.show()

    # Return the DataFrames for further analysis
    return list(importance_dfs.values())


importance_results = analyze_feature_importance_across_validation(
    model=model,
    validation_loader=test_loader_gp,
    true_effects_df=true_eff,
    phenotype_idx=[1, 2, 3, 4, 5],
    device=device,
    max_samples=2000,  # (optionally) limit number of samples for faster processing
)
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects
(a) Phenotype 1
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects
(b) Phenotype 2
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects
(c) Phenotype 3
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects
(d) Phenotype 4
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects
(e) Phenotype 5
Figure 2: Ground truth locus additive effect size plotted against mean locus sensitivity value for test-set samples. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

The answer is a resounding yes for all phenotypes (Figure 2). While there’s a tiny bit of noise in the most epistatic phenotypes, mean locus sensitivity tracks true additive effect size extremely closely in all cases, suggesting that the ELM decomposition of our MLP is backing out the underlying (additive) quantitative genetic framework used to simulate this data.

Estimating total epistatic effects

So we can back out additive effect sizes. This is a good start, but could just as easily be achieved by fitting a linear model to this data. What’s more exciting is the possibility of extracting nonlinear locus interactions from our model. As a first pass at this, we can start thinking beyond the mean and calculate the variance in locus sensitivity across test-set individuals. Our baseline expectation for the variance of locus sensitivity is as follows

\[ \begin{align*} Var\left[\frac{\partial P}{\partial x_i}\right] &= Var[\beta_i] + Var[\beta_{ij} \cdot x_j] \\ &= \beta_{ij}^2 \cdot Var[x_j] \\ &= \beta_{ij}^2 \cdot 4 p_j q_j \\ &\approx \beta_{ij}^2 \quad \text{if } p_j \approx q_j \approx 0.5 \end{align*} \tag{4}\]

The variance in locus sensitivity should depend on the square of the epistatic effect (a consequence of a constant in a variance term) along with allele frequencies at any interacting locus. Again, in our simple case of 0.5 allele frequencies (i.e., \(p_j = q_j\)), this equation simplifies to just the square of \(\beta_{ij}\). Intuitively, this expression makes sense: a locus whose effect isn’t dependent on genetic background should be invariant in estimated effect across individuals, while a locus that interacts with other alleles should vary in effect more strongly the stronger those interactions are. Does this match our data?

Estimate variance in locus sensitivity
def plot_sensitivity_variance_vs_epistasis(importance_dfs, true_effects_df, phenotype_idx=None):
    """
    Plot the variance in locus sensitivity across samples versus epistatic effects,
    with a quadratic fit curve and equation.

    Args:
        importance_dfs: List of DataFrames with importance values, one per phenotype
        true_effects_df: DataFrame containing ground truth effects
        phenotype_idx: List of indexes of phenotypes to analyze (1-based, default: [1] for first phenotype)

    Returns:
        None
    """
    # default analyze first phenotype
    if phenotype_idx is None:
        phenotype_idx = [1]

    # Create individual plots for each phenotype
    for i, idx in enumerate(phenotype_idx):
        with mpl.rc_context(plot_style):
            plt.figure(figsize=(6, 6))

            # Get the dataframe for this phenotype
            if isinstance(importance_dfs, list):
                importance_df = importance_dfs[i]
            else:
                importance_df = importance_dfs[idx]

            # Calculate variance for each locus
            locus_variance = importance_df.var().values

            # Extract ground truth for the specific phenotype (idx is already 1-based)
            true_effects = true_effects_df[true_effects_df["trait"] == idx]

            # Extract epistatic effects
            epistatic_effects = true_effects["epi_eff"].values
            epistatic_effects = epistatic_effects * np.sqrt(heritability)

            # Create scatter plot
            plt.scatter(epistatic_effects, locus_variance, alpha=0.7, color=apc.steel)

            # Fit quadratic curve (y = ax² + bx + c) to epistatic phenotypes
            if len(epistatic_effects) > 2 and np.std(epistatic_effects) > 1e-6:
                # Polynomial fit (degree 2 for quadratic)
                c, b, a = polyfit(epistatic_effects, locus_variance, 2)

                # Generate points for the fit curve
                x_fit = np.linspace(min(epistatic_effects), max(epistatic_effects), 100)
                y_fit = a * x_fit**2 + b * x_fit + c

                # Plot the fit curve
                plt.plot(x_fit, y_fit, "r-", linewidth=2)

                # Add the equation to the plot
                equation = f"y = {a:.2f}x² + ..."

                # Calculate R-squared
                y_pred = a * np.array(epistatic_effects) ** 2 + b * np.array(epistatic_effects) + c
                ss_total = np.sum((locus_variance - np.mean(locus_variance)) ** 2)
                ss_residual = np.sum((locus_variance - y_pred) ** 2)
                r_squared = 1 - (ss_residual / ss_total)

                # Add equation and R² to the plot
                plt.text(
                    0.1,
                    0.95,
                    f"{equation}\nr² = {r_squared:.2f}",
                    transform=plt.gca().transAxes,
                    fontsize=15,
                    verticalalignment="top",
                    bbox=dict(
                        boxstyle="round", facecolor=apc.parchment, alpha=0.7, edgecolor="none"
                    ),
                )

            # Add labels and title
            plt.xlabel(r"Epistatic Effect Size ($\beta_{ij}$)", fontsize=16)
            plt.ylabel("Locus Sensitivity Variance", fontsize=16)

            # Get current axes for formatting
            ax = plt.gca()
            ax.xaxis.set_major_locator(MaxNLocator(4))
            ax.yaxis.set_major_locator(MaxNLocator(4))
            # ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            # ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
            ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))

            ax.set_aspect("auto")

            plt.tight_layout(pad=0.5)

            apc.mpl.style_plot(monospaced_axes="both")

            plt.show()


plot_sensitivity_variance_vs_epistasis(importance_results, true_eff, phenotype_idx=[1, 2, 3, 4, 5])
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM accurately backs out total epistatic effects for each locus
(a) Phenotype 1
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM accurately backs out total epistatic effects for each locus
(b) Phenotype 2
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM accurately backs out total epistatic effects for each locus
(c) Phenotype 3
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM accurately backs out total epistatic effects for each locus
(d) Phenotype 4
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM accurately backs out total epistatic effects for each locus
(e) Phenotype 5
Figure 3: Ground truth locus epistatic effect size plotted against variance in locus sensitivity value for test-set samples. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

Again, we recover a very clean relationship between ground truth epistatic effect size for each locus and the variance in locus sensitivity, except for phenotype 1, where there is no epistasis (Figure 3). This is a particularly useful result as it allows us to identify candidate epistatic loci without testing anything about the identity of their interactors, which normally requires performing laborious pairwise (or higher order) perturbation analyses. By selecting a subset of loci based on their individual sensitivity variances, we can perform post-hoc feature selection and limit more expensive analyses to a more promising subset of loci.

If you look closely at the fit of the parabola, you’ll notice that the quadratic term is scaled by a factor of 0.8, an underestimate of what should be a 1:1 relationship. We’ll see this moderate underestimate of the interaction effect size again later on, so keep this in mind.

Inferring pairwise epistatic interactions

Looking at the variance in locus sensitivity tells us which loci we expect to be participating in epistatic interactions generally. But what about inferring which combinations of loci actually specifically interact with each other? As a first pass at this, we can think about pairwise epistasis by considering pairs of locus sensitivity values conditioned on specific allelic combinations. In our two-locus haploid system, each locus can take on allelic values of -1/1, resulting in four possible genotypic combinations (1/1, 1/-1, -1/1, -1/-1). The equations below shows the expected sensitivity values of each of two interacting loci (\(x_i\) and \(x_j\)) for these four genotypes based on substituting the relevant genotypic values into Equation 2.

\[ \begin{array}{|c|c|c|c|} \hline x_i & x_j & \frac{\partial P}{\partial x_i} &\frac{\partial P}{\partial x_j} \\ \hline 1 & 1 & \beta_i + \beta_{ij} & \beta_j + \beta_{ij} \\ 1 & -1 & \beta_i - \beta_{ij} & \beta_j + \beta_{ij} \\ -1 & 1 & \beta_i + \beta_{ij} & \beta_j - \beta_{ij} \\ -1 & -1 & \beta_i - \beta_{ij} & \beta_j - \beta_{ij} \\ \hline \end{array} \tag{5}\]

Notice that the locus sensitivity value for locus \(x_i\) depends on the allelic state at locus \(x_j\) and vice versa, assuming that \(\beta_{ij}\) is a non-zero term (i.e., the two loci genuinely interact). Below is a graphical representation of how we expect the sensitivity values of a pair of loci to change as a function of the strength of epistasis between them.

(a) Locus sensitivities for a pair of loci with no epistatic interactions.
(b) Locus sensitivities for a pair of loci that epistatically interact with each other.
Figure 4: Graphical illustration of expected locus sensitivity patterns for a pair of loci conditioned on genotypic state.

If neither locus interacts epistatically at all, we should observe a single locus sensitivity for \(x_i\) and \(x_j\) for all genotypes centered on the additive effect of both loci (Figure 4 (a)). If we introduce an epistatic interaction between the pair, the four genotypes will be pushed to four opposing corners of locus sensitivity space (Figure 4 (b)). This leaves us with a simple prediction. If two loci epistatically interact, we should be able to capture and quantify this interaction term by plotting their locus sensitivity values and observing four masses of points that correspond to the four genotypic combinations. As a proof of concept, we can check this by visualizing genotype-specific locus sensitivities for a few epistatic pairs of loci for phenotype 2.

Plot sensitivity values for epistatically interacting pairs
def extract_top_epistatic_pairs(
    model,
    validation_loader,
    true_effects_df,
    top_n=10,
    phenotype_idx=None,
    device="cuda",
    max_samples=1000,
):
    """
    Extract the top N strongest epistatic pairs for each phenotype and plot them.
    Assumes genotype encoding is 1/-1.

    Args:
        model: Trained neural network model
        validation_loader: DataLoader for validation data
        true_effects_df: DataFrame containing ground truth effects (must have epi_eff column)
        top_n: Number of top epistatic pairs to extract per phenotype
        phenotype_idx: List of phenotype indices to analyze (1-based, None for all phenotypes)
        device: Device to run calculations on
        max_samples: Maximum number of samples to process
    """
    # Ensure required columns exist
    required_cols = ["trait", "locus", "epi_loc", "epi_eff"]
    if not all(col in true_effects_df.columns for col in required_cols):
        raise ValueError(f"true_effects_df must have these columns: {required_cols}")

    # Determine phenotype indices to analyze
    if phenotype_idx is None:
        all_traits = true_effects_df["trait"].unique()
        phenotype_idx = list(all_traits)  # Keep 1-based indexing
    elif isinstance(phenotype_idx, int):
        phenotype_idx = [phenotype_idx]

    print(f"Analyzing phenotype: {phenotype_idx}")

    def get_phenotype_importance(x, pheno_idx):
        def phenotype_fn(inp):
            return model(inp)[:, pheno_idx - 1]  # Convert 1-based to 0-based

        jacobian = torch.autograd.functional.jacobian(phenotype_fn, x, vectorize=True).squeeze()
        return jacobian

    # Find top N epistatic pairs for each phenotype
    top_pairs = {}
    for idx in phenotype_idx:
        pheno_effects = true_effects_df[true_effects_df["trait"] == idx]  # idx is already 1-based
        pheno_effects = pheno_effects[pheno_effects["epi_eff"] != 0]
        pheno_effects["abs_epi_eff"] = np.abs(pheno_effects["epi_eff"])

        # Sort by largest absolute effects first
        sorted_effects = pheno_effects.sort_values("abs_epi_eff", ascending=False)

        unique_pairs = set()
        pairs = []

        for _, row in sorted_effects.iterrows():
            locus1 = int(row["locus"]) - 1 if row["locus"] >= 1 else int(row["locus"])
            locus2 = int(row["epi_loc"]) - 1 if row["epi_loc"] >= 1 else int(row["epi_loc"])
            pair_key = tuple(sorted([locus1, locus2]))

            if pair_key in unique_pairs:
                continue

            unique_pairs.add(pair_key)
            pairs.append(
                {
                    "locus1": locus1,
                    "locus2": locus2,
                    "epi_eff": row["epi_eff"],
                    "abs_epi_eff": row["abs_epi_eff"],
                }
            )

            if len(pairs) >= top_n:
                break

        top_pairs[idx] = pairs
        print(f"Selected {len(pairs)} top epistatic pairs for phenotype {idx}")

    # Process validation samples
    results = {}
    sample_count = 0

    with torch.no_grad():
        for _batch_idx, (_, gens) in enumerate(validation_loader):
            for i in range(gens.shape[0]):
                if max_samples is not None and sample_count >= max_samples:
                    break

                single_gens = gens[i : i + 1]
                flattened_gens = single_gens.reshape(single_gens.shape[0], -1)
                simplified_gens = flattened_gens[:, 1::2].to(device)
                genotype = simplified_gens.cpu().numpy().squeeze()

                for idx in phenotype_idx:
                    importance = get_phenotype_importance(simplified_gens, idx)
                    importance_np = importance.cpu().numpy()

                    for pair in top_pairs[idx]:
                        locus1, locus2 = pair["locus1"], pair["locus2"]
                        pair_key = (idx, locus1, locus2)

                        if pair_key not in results:
                            results[pair_key] = {
                                "locus1_sensitivities": [],
                                "locus2_sensitivities": [],
                                "genotype1_values": [],
                                "genotype2_values": [],
                                "epi_eff": pair["epi_eff"],
                            }

                        results[pair_key]["locus1_sensitivities"].append(importance_np[locus1])
                        results[pair_key]["locus2_sensitivities"].append(importance_np[locus2])
                        results[pair_key]["genotype1_values"].append(genotype[locus1])
                        results[pair_key]["genotype2_values"].append(genotype[locus2])

                sample_count += 1
                if sample_count % 1000 == 0:
                    print(f"Processed {sample_count} samples")

            if max_samples is not None and sample_count >= max_samples:
                break

    print(f"Completed analysis for {sample_count} samples")

    # Convert lists to arrays
    for key in results:
        for array_key in [
            "locus1_sensitivities",
            "locus2_sensitivities",
            "genotype1_values",
            "genotype2_values",
        ]:
            results[key][array_key] = np.array(results[key][array_key])

    # Create plots
    max_pairs = max(len(pairs) for pairs in top_pairs.values()) if top_pairs else 0
    if max_pairs == 0:
        print("No pairs to plot.")
        return

    # Calculate grid dimensions
    n_cols = min(3, max_pairs)
    n_rows_per_pheno = (max_pairs + n_cols - 1) // n_cols
    total_rows = len(phenotype_idx) * n_rows_per_pheno

    with mpl.rc_context(plot_style):
        fig, axes = plt.subplots(total_rows, n_cols, figsize=(12, 4))
        if total_rows == 1 and n_cols == 1:
            axes = np.array([[axes]])
        elif total_rows == 1:
            axes = axes.reshape(1, -1)
        elif n_cols == 1:
            axes = axes.reshape(-1, 1)

        plot_idx = 0
        for _, idx in enumerate(phenotype_idx):
            for _, pair in enumerate(top_pairs[idx]):
                row = plot_idx // n_cols
                col = plot_idx % n_cols
                ax = axes[row, col]

                locus1, locus2 = pair["locus1"], pair["locus2"]
                pair_key = (idx, locus1, locus2)

                if pair_key not in results:
                    continue

                pair_data = results[pair_key]
                locus1_sens = pair_data["locus1_sensitivities"]
                locus2_sens = pair_data["locus2_sensitivities"]
                geno1_vals = pair_data["genotype1_values"]
                geno2_vals = pair_data["genotype2_values"]

                # Create scatter plot with different colors for genotype combinations
                colors = [apc.wish, apc.canary, apc.amber, apc.vital]
                labels = ["-1/-1", "-1/1", "1/-1", "1/1"]

                for k, (g1, g2) in enumerate([(-1, -1), (-1, 1), (1, -1), (1, 1)]):
                    mask = (geno1_vals == g1) & (geno2_vals == g2)
                    if np.any(mask):
                        ax.scatter(
                            locus1_sens[mask],
                            locus2_sens[mask],
                            c=colors[k],
                            label=labels[k],
                            alpha=0.7,
                        )

                ax.set_xlabel(f"Locus {locus1} Sensitivity", fontsize=13)
                ax.set_ylabel(f"Locus {locus2} Sensitivity", fontsize=13)
                ax.set_title(
                    rf"Loci {locus1}/{locus2}, $\beta_{{ij}}$= {pair['epi_eff']:.2f}", fontsize=15
                )
                if plot_idx == 0:
                    handles, labels_list = ax.get_legend_handles_labels()

                ax.xaxis.set_major_locator(MaxNLocator(4))
                ax.yaxis.set_major_locator(MaxNLocator(4))
                ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
                ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
                plt.sca(ax)
                apc.mpl.style_plot(monospaced_axes="both")

                plot_idx += 1

        # Hide empty subplots
        total_plots_needed = sum(len(pairs) for pairs in top_pairs.values())
        for plot_idx in range(total_plots_needed, total_rows * n_cols):
            row = plot_idx // n_cols
            col = plot_idx % n_cols
            if row < axes.shape[0] and col < axes.shape[1]:
                axes[row, col].set_visible(False)

        # Add single legend in the margin (outside all subplots)
        fig.legend(
            handles,
            labels_list,
            title="Genotypes",
            loc="center left",
            bbox_to_anchor=(1.0, 0.5),
            fontsize=13,
            title_fontsize=14,
            frameon=False,
        )

        # Make title bold
        legend = fig.legends[-1]
        legend.get_title().set_fontweight("bold")

        plt.tight_layout()
        plt.show()


extract_top_epistatic_pairs(
    model, test_loader_gp, true_eff, top_n=3, phenotype_idx=[2], device=device, max_samples=2000
)
Analyzing phenotype: [2]
Selected 3 top epistatic pairs for phenotype 2
Processed 1000 samples
Completed analysis for 1500 samples
Scatter plots of three pairs of locus sensitivity values showing how epistasis causes sensitivity values to segregate by genotype
Figure 5: Locus sensitivity values plotted for three pairs of epistatically interacting loci, labelled by genotypic state.

As expected, we get four clusters of points which cleanly correspond to the four genotypes (Figure 5). The exact configuration of the square made by the genotype locus sensitivities depends on the sign of the epistatic and additive terms. We noticed that at times, the expected square of point masses manifests as a parallelogram. We haven’t quite been able to figure out why the shape seems to wobble in this way, but it may have to do with some kind of regularization.

Our results imply that we should be able to back out \(\beta_{ij}\) directly by examining the genotype-specific locus sensitivity values. Specifically, by iterating over all combinations of \(x_i\) and \(x_j\) and summing them in a way to knock out the \(\beta_i\) terms, we can estimate \(\beta_{ij}\) as follows:

\[ \begin{align*} &\left.\frac{\partial P}{\partial x_i}\right|_{x_i = -1, x_j = -1} - \left.\frac{\partial P}{\partial x_i}\right|_{x_i = -1, x_j = 1} + \left.\frac{\partial P}{\partial x_i}\right|_{x_i = 1, x_j = -1} - \left.\frac{\partial P}{\partial x_i}\right|_{x_i = 1, x_j = 1} \\ &\quad = (\beta_{i} - \beta_{ij} ) - (\beta_{i} + \beta_{ij} ) + (\beta_{i} - \beta_{ij} ) - (\beta_{i} + \beta_{ij} )\\ &\quad = -4\beta_{ij} \end{align*} \tag{6}\]

In fact, we should be able to do this for both the genotype-specific locus sensitivity values of locus \(x_i\) and locus \(x_j\) and get roughly the same estimate of \(\beta_{ij}\). We can test this out for the full set of simulated loci in our dataset, since we have only 64 positions (and 2016 possible unique pairs). First, for every possible pairing of loci, we calculate a mean locus sensitivity value for each of the four genotypes for both \(x_i\) and \(x_j\). Next, we combine these means as in the sum above to obtain \(-4\beta_{ij}\). Finally, we take the average of \(\beta_{ij}\) as calculated from \(\frac{\partial P}{\partial x_i}\) and \(\frac{\partial P}{\partial x_j}\). Does this match epistatic effects from our simulations?

Estimate genotype-specific mean sensitivities for all pairs of loci
# expect a minute or so runtime for this function
def extract_all_epistatic_pairs(
    model,
    validation_loader,
    true_effects_df,
    phenotype_idx=None,
    device=device,
    verbose=True,
    max_samples=1000,
):
    """
    Extract sensitivities for all possible unique pairs of loci for each phenotype.
    For pairs with known epistatic effects in true_effects_df, include those effects.
    For pairs without known epistatic effects, set the effect to 0.

    Uses 1-based locus indexing throughout.

    Args:
        model: Trained neural network model
        validation_loader: DataLoader for validation data
        true_effects_df: DataFrame containing ground truth effects with columns:
                         trait (1-based phenotype), locus (1-based), add_eff,
                         epi_loc (1-based interacting locus), epi_eff
        phenotype_idx: List of phenotype indices to analyze (None for all phenotypes)
        device: Device to run calculations on
        max_samples: Maximum number of samples to process

    Returns:
        DataFrame with one row per phenotype/genotype/locus pair
    """
    # Ensure true_effects_df has required columns
    required_cols = ["trait", "locus", "epi_loc", "epi_eff"]
    if not all(col in true_effects_df.columns for col in required_cols):
        raise ValueError(f"true_effects_df must have these columns: {required_cols}")

    # Check for additive effects columns
    if "add_eff" in true_effects_df.columns:
        add_eff_col = "add_eff"

    # Determine phenotype indices to analyze
    if phenotype_idx is None:
        # Get all unique phenotypes from true_effects_df
        all_traits = true_effects_df["trait"].unique()
        # Use 1-based indexing for phenotypes
        phenotype_idx = list(all_traits)
    elif isinstance(phenotype_idx, int):
        phenotype_idx = [phenotype_idx]

    if verbose:
        print(f"Analyzing phenotypes: {phenotype_idx}")

    # Function to get importance for a specific phenotype
    def get_phenotype_importance(x, pheno_idx):
        # Adjust for 1-based phenotype indexing
        model_pheno_idx = pheno_idx - 1  # Convert 1-based to 0-based for the model

        def phenotype_fn(inp):
            return model(inp)[:, model_pheno_idx]

        # Calculate Jacobian
        jacobian = torch.autograd.functional.jacobian(phenotype_fn, x, vectorize=True).squeeze()

        # Return unmodified Jacobian value
        return jacobian

    # Get number of loci from the model
    n_loci_model = n_loci  # This should be defined in your environment or passed to the function

    # Create mapping of known epistatic effects
    epistatic_effects = {}
    additive_effects = {}

    # First, collect all additive effects
    for idx in phenotype_idx:
        additive_effects[idx] = {}
        pheno_effects = true_effects_df[true_effects_df["trait"] == idx]

        if add_eff_col is not None:
            # Get unique loci and their additive effects
            for _, row in pheno_effects.drop_duplicates("locus").iterrows():
                locus = int(row["locus"])
                add_eff = row[add_eff_col]
                additive_effects[idx][locus] = add_eff

    # Now collect all epistatic effects
    for idx in phenotype_idx:
        epistatic_effects[idx] = {}
        pheno_effects = true_effects_df[true_effects_df["trait"] == idx]

        # Process each row to extract epistatic effects
        for _, row in pheno_effects.iterrows():
            locus1 = int(row["locus"])
            locus2 = int(row["epi_loc"])

            # Create a canonical representation of the pair (smaller locus first)
            pair_key = tuple(sorted([locus1, locus2]))

            # Store the epistatic effect
            epistatic_effects[idx][pair_key] = row["epi_eff"]

    # Generate all possible locus pairs
    all_pairs = {}
    for idx in phenotype_idx:
        pairs = []

        # Generate all possible pairs of loci (using 1-based indexing)
        for locus1 in range(1, n_loci_model + 1):  # Start from 1, not 0
            for locus2 in range(locus1 + 1, n_loci_model + 1):  # Only use each pair once
                # Create canonical pair key
                pair_key = (locus1, locus2)

                # Get epistatic effect if it exists, otherwise use 0
                epi_eff = epistatic_effects[idx].get(pair_key, 0.0)
                abs_epi_eff = abs(epi_eff)

                # Get additive effects
                add_eff1 = additive_effects[idx].get(locus1, 0.0)
                add_eff2 = additive_effects[idx].get(locus2, 0.0)

                # Store pair info
                pairs.append(
                    {
                        "locus1": locus1,
                        "locus2": locus2,
                        "epi_eff": epi_eff,
                        "abs_epi_eff": abs_epi_eff,
                        "add_eff1": add_eff1,
                        "add_eff2": add_eff2,
                        "has_epistasis": epi_eff != 0.0,
                    }
                )

        all_pairs[idx] = pairs
        epistatic_count = sum(1 for p in pairs if p["has_epistasis"])
        if verbose:
            print(
                f"Generated {len(pairs)} total locus pairs for phenotype {idx}, {epistatic_count} with known epistatic effects"
            )

    # Process validation samples to get sensitivities and genotypes
    sample_indices = []
    genotypes_dict = {}
    sensitivities_dict = defaultdict(lambda: defaultdict(list))

    # Process validation samples
    sample_count = 0
    with torch.no_grad():
        for _batch_idx, (_, gens) in enumerate(validation_loader):
            # Process each sample in the batch
            for i in range(gens.shape[0]):
                # Check if we've reached the maximum samples
                if max_samples is not None and sample_count >= max_samples:
                    break

                # Store sample index
                sample_indices.append(sample_count)

                # Get single sample
                single_gens = gens[i : i + 1]

                # Flatten and simplify (remove one-hot encoding)
                flattened_gens = single_gens.reshape(single_gens.shape[0], -1)
                simplified_gens = flattened_gens[:, 1::2].to(device)

                # Store the raw genotype data
                genotype = simplified_gens.cpu().numpy().squeeze()
                genotypes_dict[sample_count] = genotype

                # Calculate importance for each phenotype
                for idx in phenotype_idx:
                    importance = get_phenotype_importance(simplified_gens, idx)
                    importance_np = importance.cpu().numpy()

                    # Store sensitivity values for each locus in all pairs
                    for pair in all_pairs[idx]:
                        locus1 = pair["locus1"]
                        locus2 = pair["locus2"]
                        pair_key = (idx, locus1, locus2)

                        # Convert from 1-based to 0-based for accessing the importance array
                        model_locus1 = locus1 - 1
                        model_locus2 = locus2 - 1

                        # Store sensitivity values
                        sensitivities_dict[pair_key]["locus1"].append(importance_np[model_locus1])
                        sensitivities_dict[pair_key]["locus2"].append(importance_np[model_locus2])

                sample_count += 1

                # Print progress
                if sample_count % 1000 == 0 and verbose:
                    print(f"Processed {sample_count} samples")

            # Check again after batch
            if max_samples is not None and sample_count >= max_samples:
                break

    # Prepare final results
    summary_rows = []  # For the summarized DataFrame

    # Get phenotype names (trait names) if available
    trait_names = {}
    if "trait_name" in true_effects_df.columns:
        for _, row in true_effects_df.drop_duplicates("trait").iterrows():
            trait_names[row["trait"]] = row["trait_name"]  # Keep 1-based indexing

    for idx in phenotype_idx:
        # Get phenotype name if available, otherwise use index
        phenotype_name = trait_names.get(idx, f"Phenotype_{idx}")

        for pair in all_pairs[idx]:
            locus1 = pair["locus1"]
            locus2 = pair["locus2"]
            pair_key = (idx, locus1, locus2)

            # Get sensitivity values
            locus1_sensitivities = np.array(sensitivities_dict[pair_key]["locus1"])
            locus2_sensitivities = np.array(sensitivities_dict[pair_key]["locus2"])

            # Get genotypes for this pair
            genotype1_values = []
            genotype2_values = []
            combined_genotypes = []

            for s_idx in sample_indices:
                genotype = genotypes_dict[s_idx]

                # Convert from 1-based to 0-based for accessing genotype array
                model_locus1 = locus1 - 1
                model_locus2 = locus2 - 1

                # Extract genotype values for the two loci
                geno1 = genotype[model_locus1]
                geno2 = genotype[model_locus2]

                # Store individual genotypes
                genotype1_values.append(geno1)
                genotype2_values.append(geno2)

                # Encode combined genotype for -1/1 encoding
                # Map (-1,-1) → 0, (-1,1) → 1, (1,-1) → 2, (1,1) → 3
                combined = ((geno1 + 1) // 2) * 2 + ((geno2 + 1) // 2)
                combined_genotypes.append(combined)

            # Convert to numpy arrays
            genotype1_values = np.array(genotype1_values)
            genotype2_values = np.array(genotype2_values)
            combined_genotypes = np.array(combined_genotypes)

            # Calculate mean sensitivity values for each genotype combination
            for geno1 in [-1, 1]:
                for geno2 in [-1, 1]:
                    # Create mask for this genotype combination
                    mask = (genotype1_values == geno1) & (genotype2_values == geno2)

                    if np.any(mask):
                        # Get mean sensitivities for this genotype combination
                        mean_sens1 = np.mean(locus1_sensitivities[mask])
                        mean_sens2 = np.mean(locus2_sensitivities[mask])
                        count = np.sum(mask)

                        # Create a row for the summary DataFrame
                        summary_rows.append(
                            {
                                "phenotype": phenotype_name,
                                "phenotype_idx": idx,
                                "locus1": locus1,
                                "locus2": locus2,
                                "genotype1": geno1,
                                "genotype2": geno2,
                                "add_eff1": pair["add_eff1"],
                                "add_eff2": pair["add_eff2"],
                                "epi_eff": pair["epi_eff"],
                                "abs_epi_eff": pair["abs_epi_eff"],
                                "mean_sens1": mean_sens1,
                                "mean_sens2": mean_sens2,
                                "count": count,
                                "has_epistasis": pair["has_epistasis"],
                            }
                        )

    # Create the summary DataFrame
    summary_df = pd.DataFrame(summary_rows)

    # Add some useful calculated columns
    if len(summary_df) > 0:
        # Calculate covariance between locus1 and locus2 sensitivities for each pair
        for idx in phenotype_idx:
            for pair in all_pairs[idx]:
                locus1 = pair["locus1"]
                locus2 = pair["locus2"]
                pair_key = (idx, locus1, locus2)

                if pair_key in sensitivities_dict:
                    # Update all rows for this pair
                    mask = (
                        (summary_df["phenotype_idx"] == idx)
                        & (summary_df["locus1"] == locus1)
                        & (summary_df["locus2"] == locus2)
                    )

        # Add combined genotype column for convenience
        summary_df["genotype_combined"] = summary_df["genotype1"].astype(str) + summary_df[
            "genotype2"
        ].astype(str)

        # Calculate weighted means across genotypes
        # Create unique pair identifiers
        summary_df["pair_id"] = (
            summary_df["phenotype_idx"].astype(str)
            + "_"
            + summary_df["locus1"].astype(str)
            + "_"
            + summary_df["locus2"].astype(str)
        )

        # Add weighted mean sensitivities
        weighted_means = summary_df.groupby("pair_id").apply(
            lambda x: pd.Series(
                {
                    "weighted_mean_sens1": (x["mean_sens1"] * x["count"]).sum() / x["count"].sum(),
                    "weighted_mean_sens2": (x["mean_sens2"] * x["count"]).sum() / x["count"].sum(),
                    "total_samples": x["count"].sum(),
                }
            ),
            include_groups=False,
        )

        # Merge weighted means back to summary_df
        summary_df = summary_df.merge(weighted_means, on="pair_id")

    return summary_df


paired_all = extract_all_epistatic_pairs(
    model, test_loader_gp, true_eff, phenotype_idx=[1, 2, 3, 4, 5], device=device, max_samples=2000
)
Analyzing phenotypes: [1, 2, 3, 4, 5]
Generated 2016 total locus pairs for phenotype 1, 0 with known epistatic effects
Generated 2016 total locus pairs for phenotype 2, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 3, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 4, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 5, 32 with known epistatic effects
Processed 1000 samples
Estimate epistatic effect sizes for all pairs of loci
# expect a minute or so runtime for this function
def estimate_effects_from_sensitivities(summary_df, heritability=1.0, verbose=True):
    """
    Estimate additive and epistatic effects from locus sensitivity means across genotype combinations.
    Works with either 0/1 or -1/1 genotype encoding.
    Applies appropriate sign corrections and heritability scaling.

    Args:
        summary_df: DataFrame output from extract_top_epistatic_pairs or extract_all_epistatic_pairs,
                   containing mean sensitivity values for each locus/genotype combination
        heritability: The broad-sense heritability value used for the phenotype (default=1.0)

    Returns:
        DataFrame with estimated additive and epistatic effects for each locus pair,
        alongside true values for comparison
    """
    # Create unique identifier for each locus pair
    summary_df["pair_id"] = (
        summary_df["phenotype_idx"].astype(str)
        + "_"
        + summary_df["locus1"].astype(str)
        + "_"
        + summary_df["locus2"].astype(str)
    )

    # Calculate heritability scaling factor
    h_sqrt = np.sqrt(heritability)
    if verbose:
        print(f"Using heritability scaling factor: {h_sqrt:.4f} (h² = {heritability:.2f})")

    encoding_type = "-1/1"  # simulation ground truth
    genotype_mapping = {(-1, -1): "00", (-1, 1): "01", (1, -1): "10", (1, 1): "11"}

    # Initialize results list
    results = []

    # Process each unique locus pair
    for pair_id in summary_df["pair_id"].unique():
        # Get data for this pair
        pair_data = summary_df[summary_df["pair_id"] == pair_id].copy()

        # Extract true effects (same for all rows of this pair)
        true_add_eff1 = pair_data["add_eff1"].iloc[0]
        true_add_eff2 = pair_data["add_eff2"].iloc[0]
        true_epi_eff = pair_data["epi_eff"].iloc[0]
        phenotype = pair_data["phenotype"].iloc[0]
        phenotype_idx = pair_data["phenotype_idx"].iloc[0]
        locus1 = pair_data["locus1"].iloc[0]
        locus2 = pair_data["locus2"].iloc[0]

        # Apply heritability scaling to true effects
        scaled_true_add_eff1 = true_add_eff1 * h_sqrt if true_add_eff1 is not None else None
        scaled_true_add_eff2 = true_add_eff2 * h_sqrt if true_add_eff2 is not None else None
        scaled_true_epi_eff = true_epi_eff * h_sqrt if true_epi_eff is not None else None

        # Create dictionary to store sensitivity means by genotype
        sens1_by_genotype = {}
        sens2_by_genotype = {}

        # Fill dictionary with sensitivity means for each genotype combination
        for _, row in pair_data.iterrows():
            geno1 = row["genotype1"]
            geno2 = row["genotype2"]
            genotype_key = genotype_mapping.get((geno1, geno2), f"{geno1}{geno2}")
            sens1_by_genotype[genotype_key] = row["mean_sens1"]
            sens2_by_genotype[genotype_key] = row["mean_sens2"]

        # Check if we have all four genotype combinations
        required_genotypes = ["00", "01", "10", "11"]
        if not all(g in sens1_by_genotype and g in sens2_by_genotype for g in required_genotypes):
            print(f"Skipping pair {pair_id} - missing genotype combinations")
            continue

        # 1. Estimate additive effects from mean sensitivity across all genotypes
        # For locus 1: a₁ = mean(s₁)/2  (removed negative sign)
        est_add_eff1 = np.mean([sens1_by_genotype[g] for g in required_genotypes]) / 2

        # For locus 2: a₂ = mean(s₂)/2  (removed negative sign)
        est_add_eff2 = np.mean([sens2_by_genotype[g] for g in required_genotypes]) / 2

        # 2. Estimate epistatic effect with sign correction
        # Method 1: From locus 1 sensitivities with sign correction
        epi_est1 = (
            -1
            * (
                sens1_by_genotype["00"]
                - sens1_by_genotype["01"]
                + sens1_by_genotype["10"]
                - sens1_by_genotype["11"]
            )
            / 4
        )

        # Method 2: From locus 2 sensitivities with sign correction
        epi_est2 = (
            -1
            * (
                sens2_by_genotype["00"]
                - sens2_by_genotype["10"]
                + sens2_by_genotype["01"]
                - sens2_by_genotype["11"]
            )
            / 4
        )

        # Average both estimates
        est_epi_eff = (epi_est1 + epi_est2) / 2

        # 6. Store results
        results.append(
            {
                "phenotype": phenotype,
                "phenotype_idx": phenotype_idx,
                "locus1": locus1,
                "locus2": locus2,
                "pair_id": pair_id,
                # Original true effects
                "true_add_eff1": true_add_eff1,
                "true_add_eff2": true_add_eff2,
                "true_epi_eff": true_epi_eff,
                # Scaled true effects (adjusted for heritability)
                "scaled_true_add_eff1": scaled_true_add_eff1,
                "scaled_true_add_eff2": scaled_true_add_eff2,
                "scaled_true_epi_eff": scaled_true_epi_eff,
                # Estimated effects (already sign-corrected)
                "est_add_eff1": est_add_eff1,
                "est_add_eff2": est_add_eff2,
                "est_epi_eff": est_epi_eff,
                # Record heritability used
                "heritability": heritability,
                # Record encoding type
                "encoding_type": encoding_type,
            }
        )

    # Create results DataFrame
    results_df = pd.DataFrame(results)

    return results_df


qg_estimates = estimate_effects_from_sensitivities(paired_all, heritability=heritability)
Using heritability scaling factor: 1.0000 (h² = 1.00)
Plot epistatic effect size estimates vs. ground truth
def plot_epistatic_estimates(qg_estimates):
    """
    Plot estimated vs true epistatic effects for each phenotype.

    Args:
        qg_estimates: DataFrame with columns 'phenotype', 'true_epi_eff', 'est_epi_eff'
    """
    unique_phenotypes = qg_estimates["phenotype"].unique()

    for _, phenotype in enumerate(unique_phenotypes):
        with mpl.rc_context(plot_style):
            plt.figure(figsize=(6, 6))

            # Filter data for this phenotype
            phenotype_data = qg_estimates[qg_estimates["phenotype"] == phenotype]

            # Create scatter plot
            plt.scatter(
                phenotype_data["true_epi_eff"],
                phenotype_data["est_epi_eff"],
                alpha=0.7,
                color=apc.steel,
            )

            # Add best fit line and slope (skip for Phenotype_1 as in original)
            if phenotype != "Phenotype_1":
                z = np.polyfit(phenotype_data["true_epi_eff"], phenotype_data["est_epi_eff"], 1)
                p = np.poly1d(z)
                xlims = plt.xlim()
                x_smooth = np.linspace(xlims[0], xlims[1], 100)
                plt.plot(x_smooth, p(x_smooth), "r-", alpha=0.7, linewidth=2)

                slope = z[0]
                r2 = r2_score(phenotype_data["true_epi_eff"], phenotype_data["est_epi_eff"])

                plt.text(
                    0.05,
                    0.95,
                    f"Slope = {slope:.3f}\nr² = {r2:.2f}",
                    transform=plt.gca().transAxes,
                    verticalalignment="top",
                    fontsize=15,
                    bbox=dict(
                        boxstyle="round", facecolor=apc.parchment, alpha=0.7, edgecolor="none"
                    ),
                )

            # Add title and labels
            plt.xlabel("True Epistatic Effect", fontsize=16)
            plt.ylabel("Estimated Epistatic Effect", fontsize=16)

            # Get current axes for formatting
            ax = plt.gca()
            ax.xaxis.set_major_locator(MaxNLocator(4))
            ax.yaxis.set_major_locator(MaxNLocator(4))
            ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))

            apc.mpl.style_plot(monospaced_axes="both")

            plt.show()
            plt.close()


plot_epistatic_estimates(qg_estimates)
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing very high concordance between both values
(a) Phenotype 1
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing very high concordance between both values
(b) Phenotype 2
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing very high concordance between both values
(c) Phenotype 3
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing very high concordance between both values
(d) Phenotype 4
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing very high concordance between both values
(e) Phenotype 5
Figure 6: Ground truth epistatic interaction terms plotted against epistatic terms estimated from paired genotype specific locus sensitivity shifts. Points shown for all possible pairs of loci. Locus pairs with no interaction have a ground truth epistatic term set to 0. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

Again, the answer is a strong yes (Figure 6). There’s very strong concordance in the estimated and ground truth epistatic effect sizes, even including the majority of false positive epistatic pairs (i.e., \(\beta_{ij} = 0\)). This implies that we have the ability to reconstruct the quantitative genetic parameters underlying our simulated data with very high accuracy, at least in this well-behaved simulated dataset.

Once more, it’s interesting to note that the relationship between estimated and ground-truth \(\beta_{ij}\) isn’t 1:1, but rather close to 0.86:1, a similar scaling we observed when we compared variance in locus sensitivity to epistatic effect size. This concordance is actually surprising, since the variance-based \(\beta_{ij}\) should scale with the square of the direct \(\beta_{ij}\) estimate. In other words, we should see a scaling of 0.86 for the direct \(\beta_{ij}\) estimate and ~0.74 for the variance-based estimate. One explanation for this discrepancy is that the direct \(\beta_{ij}\) estimate we derive here pertains to a specific pair of loci, while the variance-based \(\beta_{ij}\) estimate captures all epistatic variance associated with a focal locus. As a result, it probably captures the low levels of false-positive epistatic variance we observe here and might inflate the variance-based metric.

Either way, our analyses strongly suggest that second-order effects (pairwise epistasis) estimated from our model are biased downward to a stronger degree than first-order (additive) effects. We strongly suspect this is a product of regularization through the high weight decay parameter we used when training our MLP. Through initial experimentation, we found that adding weight decay gave us a final small boost in predictive performance and dropped the magnitude of false positive pairwise epistatic terms. But this likely comes at the relatively small cost of shrinking true parameter estimates (a trade-off also made with genomic prediction models like GBLUP). Perhaps the reason higher-order effects are regularized more strongly relates to the larger number of neurons that will have to be employed to capture such effects, resulting in stronger regularization under weight decay.

Discussion

We applied ELM to a simple genotype-phenotype neural network to test whether we could back out interpretable quantitative genetics parameters from a “black box” model. Our results turned out to be surprisingly interpretable. ELM allowed us to infer both additive and epistatic effects with high accuracy, demonstrating that our neural network had learned the underlying quantitative genetic model used to generate the data and that we could extract usable effect size information about it. We’re now moving to testing this approach in real-world biological datasets. Our hope is that ELM can allow us to harness some of the interpretability of simpler linear models, while taking advantage of the power of deep learning. For example, by fitting a neural network and subsequently filtering loci by their sensitivity variance we might be able to find evidence of epistatically interacting loci by only testing a small fraction of them, providing a substantial advantage compared to traditional GWAS approaches where all combinatorial tests for epistasis have to be performed and the model is fit to only a handful of loci at a time.

While our results look promising, there are several limitations to this study we’d like to highlight. The main limitation is that our simulated data have unrealistically low levels of noise, and we have many more genomes than interacting loci, conditions we previously identified as being suitable for exceptional neural network predictive performance. Few real-world biological datasets will ever provide such convenient conditions for mapping genotypes to phenotypes. However, we believe real-world limitations can be overcome. First, we supply an Appendix to this notebook where we test adding substantial amounts of environmental noise to the observed phenotypes and report that ELM still works surprisingly effectively even in regimes where overall predictive performance seems to suggest no epistatic variance has been learned. Second, we advocate for applying DL model training and ELM not to all possible loci in a given dataset, but to first perform feature selection to narrow down the input to a set of candidate loci from which a model is trained. Given this, we expect ELM might allow us to uncover more biological nonlinearity than would normally be revealed when fitting explicit linear models to G-P datasets.

Finally, we want to briefly comment on how else we might use the information provided by ELM applied to a nonlinear G-P model. In our analysis, we chose to frame our exploration around inferring classical quantitative genetics parameters. We chose to do so for two reasons. First, the simulated data is generated by a true quantitative genetics model, making this an obvious framework to use. Second, this is the primary framework in the realm of understanding genotype to phenotype mapping, and it has, as far as we know, not been previously possible to link to deep learning models as we have done here. However, there may be other creative ways to use the rich feature importance information that the Jacobian and ELM provide. Ultimately, the sensitivity values within a sample reflect the model’s approximation of an individual’s G-P map. Most of our analyses here rely on averaging out much of this variation, and we suspect there may be valuable opportunities to develop methods that instead leverage it. For example, applying approaches such as PCA or UMAP to the full space of sensitivity values might yield interesting biological insights about the structure of locus sensitivity variation, particularly in real-world complex biological data. Further, mapping the boundaries where the ELM changes (essentially finding discontinuities in the Jacobian across genotype space) could reveal epistatic interactions without requiring an explicit quantitative genetics model. These approaches could be particularly powerful for understanding complex phenotypes where the classical additive model fails to capture the full genetic architecture.

Appendix: Adding environmental noise

Our main analysis demonstrated that ELM is effective in a regime where phenotypes have no environmental noise, and model predictions from genotypes almost perfectly recover phenotypes. In reality, no phenotype of interest will conform to these generous conditions, so in this appendix, we test the performance of ELM when we add a substantial amount of environmental noise to our simulated phenotypes.

As a test case, we adjust the broad sense heritability (\(H^2\)) of our phenotypes downward to 0.25, mimicking the modest level of heritability/repeatability observed in quantitative phenotypes like crop yield (Tucker et al., 2020). We implement this by adding Gaussian noise to all our phenotypes and then re-scaling the total phenotypic variance back down. This approach keeps the total phenotypic variance as in the main analysis (\(V_P = 1\)), and preserves the relative contribution of additive and epistatic effects to genetic variance, but introduces the desired level of environmental noise. We consequently also shrink all ground truth effect sizes (by \(\sqrt{H^2}\)) to account for their smaller contributions to phenotypic variance.

Create new dataloaders with environmental noise added
# | cache: true

new_heritability = 0.25

loaders_hnew = create_data_loaders_with_heritability(base_file_name, heritability=new_heritability)
train_loader_gp_hnew = loaders_hnew["train_loader_gp"]
test_loader_gp_hnew = loaders_hnew["test_loader_gp"]

Model fitting

As before, we first fit our neural network to this new noisy data and observe its performance on the test set. Keep in mind the upper limit to \(r^2\) for any phenotype is now just 0.25 (the model can’t learn to predict random environmental noise).

Train model on noisy data
# expect a minute or so to train model on noisy data on CPU
model_H05 = GP_net(n_loci=n_loci, hidden_layer_size=hidden_size, n_pheno=n_phen)

# Use early stopping with appropriate patience
model_H05, best_loss_gp = train_gpnet(
    model=model_H05,
    train_loader=train_loader_gp_hnew,
    test_loader=test_loader_gp_hnew,
    n_loci=n_loci,
    learning_rate=learning_rate,
    device=device,
)
model_H05.eval()
Epoch: 10/100, Train Loss: 0.541953, Test Loss: 0.758366
Early stopping triggered after 13 epochs
Restoring best model from epoch 3
GP_net(
  (layer1): Linear(in_features=64, out_features=1024, bias=False)
  (layer2): Linear(in_features=1024, out_features=1024, bias=False)
  (layer3): Linear(in_features=1024, out_features=5, bias=False)
)
Evaluate model performance on test set
true_phenotypes = []
predicted_phenotypes = []

with torch.no_grad():
    for phens, gens in test_loader_gp_hnew:
        phens = phens.to(device)
        gens = gens.to(device)
        gens_simplified = gens[:, 1::2]

        # Get predictions
        predictions = model_H05(gens_simplified)

        # Store results
        true_phenotypes.append(phens.cpu().numpy())
        predicted_phenotypes.append(predictions.cpu().numpy())

    # Concatenate batches
    true_phenotypes = np.concatenate(true_phenotypes)
    predicted_phenotypes = np.concatenate(predicted_phenotypes)

    # Calculate correlations for each phenotype
    correlations = []
    r2s = []

    for i in range(n_phen):
        corr, _ = pearsonr(true_phenotypes[:, i], predicted_phenotypes[:, i])
        correlations.append(corr)

        r2 = r2_score(true_phenotypes[:, i], predicted_phenotypes[:, i])
        r2s.append(r2)

    relAA = [1, 0.787, 0.505, 0.403, 0.122]

    # Create a detailed DataFrame with all results
    results_df_fit = pd.DataFrame(
        {
            "Phenotype": range(1, n_phen + 1),
            "Pearson correlation": correlations,
            "r²": r2s,
            "h²": relAA,
        }
    )

results_df_fit["h²"] = results_df_fit["h²"] * new_heritability  # correct for added noise

results_df_fit = results_df_fit.round(3)

results_df_fit.head()
Table 2: Test dataset prediction performance of a neural network trained on simulated genotype-phenotype data with added environmental noise. Phenotype narrow sense heritability (\(h^2\)) also included.
Phenotype Pearson correlation
0 1 0.462 0.209 0.250
1 2 0.429 0.154 0.197
2 3 0.356 0.119 0.126
3 4 0.327 0.103 0.101
4 5 0.232 0.051 0.030

As expected, model performance is considerably worse for these noisier data, and there’s a pronounced difference between more additive vs. more epistatic phenotypes (Table 2). Let’s take a closer look at phenotype 2. This phenotype has a narrow sense heritability (\(h^2\)) of ~0.2, which is higher than the observed \(r^2\) (~0.15), suggesting the model has learned to predict less genetic variance than additive effects alone should account for. It’s a similar story for the other phenotypes with \(r^2\) hovering generally slightly below \(h^2\) and never substantially above it. At first glance, it might look like the model hasn’t learned anything about epistasis for these phenotypes. But we’ll see further down that this isn’t necessarily the case, so keep this result in mind.

Effects estimation

Next, let’s again apply ELM to a version of the trained model with detached activation functions to get locus sensitivity values.

Taking the average of locus sensitivities across all test set individuals we can check if we can still back out ground truth additive effects for our loci.

Estimate mean locus sensitivity
importance_results_hnew = analyze_feature_importance_across_validation(
    model=model_H05,
    validation_loader=test_loader_gp_hnew,
    true_effects_df=true_eff,
    phenotype_idx=[1, 2, 3, 4, 5],  #  phenotype
    device=device,
    max_samples=2000,  # Limit number of samples for faster processing
)
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects even when environmental noise is high
(a) Phenotype 1
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects even when environmental noise is high
(b) Phenotype 2
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects even when environmental noise is high
(c) Phenotype 3
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects even when environmental noise is high
(d) Phenotype 4
Scatter plots of ground truth additive effect vs. average locus sensitivity which show how ELM accurately backs out additive effects even when environmental noise is high
(e) Phenotype 5
Figure 7: Ground truth locus additive effect size plotted against mean locus sensitivity value for test-set samples with added environmental noise. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

While the fit isn’t quite as clean as before, the results actually still look quite good (Figure 7), so for the most part, ELM looks to be quite capable of extracting ground truth additive effect sizes from this noisy data. It’s interesting to note that the slope on the fit of ground truth to average locus sensitivity, though, which tends to be ~0.45, is considerably lower than the 1:1 expectation, which was previously matched in the main analysis. It seems that we’re again finding the effects of regularisation in this noisier phenotypic regime.

What about epistatic effects? First, let’s look at the variance in locus sensitivity to see if we get the same parabola shapes as before.

Estimate variance in locus sensitivity
plot_sensitivity_variance_vs_epistasis(
    importance_results_hnew, true_eff, phenotype_idx=[1, 2, 3, 4, 5]
)
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM only somewhat backs out total epistatic effects for each locus when environmental noise is high
(a) Phenotype 1
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM only somewhat backs out total epistatic effects for each locus when environmental noise is high
(b) Phenotype 2
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM only somewhat backs out total epistatic effects for each locus when environmental noise is high
(c) Phenotype 3
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM only somewhat backs out total epistatic effects for each locus when environmental noise is high
(d) Phenotype 4
Scatter plots of ground truth epistatic effect vs. variance in locus sensitivity which show how ELM only somewhat backs out total epistatic effects for each locus when environmental noise is high
(e) Phenotype 5
Figure 8: Ground truth locus epistatic effect size plotted against variance in locus sensitivity value for test-set samples with added environmental noise. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

We observe a fairly noisy fit between locus sensitivity variance and ground truth epistasis (for phenotypes with epistasis), particularly for phenotypes with low epistatic variance (Figure 8). Again, we observe fairly intense regularization as the polynomial term is scaled well below the 1:1 expectation. This suggests that the variance in sensitivity values might suffer in noisy phenotype regimes.

But, what about pairwise epistatic effects? Let’s compare all possible pairs of loci and manipulate their genotype-specific sensitivities to estimate \(\beta_{ij}\).

Estimate epistatic effect size between all pairs of loci
# these functions looping over all pairs of loci will take ~2-3min to run on CPU

# calculate genotype specific sensitivity means for all possible pairs of loci
paired_all_hnew = extract_all_epistatic_pairs(
    model_H05,
    test_loader_gp_hnew,
    true_eff,
    phenotype_idx=[1, 2, 3, 4, 5],
    device=device,
    max_samples=2000,
)

# estimate quant-gen parameters
qg_estimates_hnew = estimate_effects_from_sensitivities(
    paired_all_hnew, heritability=new_heritability
)
Analyzing phenotypes: [1, 2, 3, 4, 5]
Generated 2016 total locus pairs for phenotype 1, 0 with known epistatic effects
Generated 2016 total locus pairs for phenotype 2, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 3, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 4, 32 with known epistatic effects
Generated 2016 total locus pairs for phenotype 5, 32 with known epistatic effects
Processed 1000 samples
Using heritability scaling factor: 0.5000 (h² = 0.25)
Plot epistatic effect size between all pairs of loci
# plot estimates vs. ground truth
plot_epistatic_estimates(qg_estimates_hnew)
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing moderate concordance between both values even with added environmental noise
(a) Phenotype 1
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing moderate concordance between both values even with added environmental noise
(b) Phenotype 2
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing moderate concordance between both values even with added environmental noise
(c) Phenotype 3
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing moderate concordance between both values even with added environmental noise
(d) Phenotype 4
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing moderate concordance between both values even with added environmental noise
(e) Phenotype 5
Figure 9: Ground truth epistatic interaction terms plotted against epistatic terms estimated from paired genotype-specific locus sensitivity shifts. Points shown for all possible pairs of loci. Locus pairs with no interaction have a ground truth epistatic term set to 0. Environmental noise added to phenotypes. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

Once more, we recover a very good fit between ground truth pairwise epistatic terms and estimates from ELM (Figure 9), albeit with a stronger underestimate of absolute effect sizes (again, regularization). In fact, the main effect of adding noise seems not to be to lower the model’s ability to learn the relative ranks of epistatic terms, but rather to inflate the false positive epistatic variance terms. Focusing on phenotype 2, where we previously noted that predictive performance suggested no epistasis had been learned, there’s a very clean relationship between ground truth and learned epistatic couplings that seems to be masked by noise from false positive interactions. This might explain why this phenotype had an \(r^2\) lower than \(h^2\). This is a particularly promising result, as it implies that we might be able to extract information on epistatic interactions even in cases where predictive performance suggests that epistasis hasn’t been learned. One way we can do that is by fitting multiple independent replicates of our neural network, estimating epistatic coefficients using ELM, and then averaging the results. This approach should work well if the noise from false positive epistatic interactions is more stochastic than the signal learned from genuine interactions. Below, we show an attempt at this with 25 replicate runs.

Function to run replicates of ELM analysis
def run_epistasis_replicates(
    n_replicates,
    base_file_name,
    true_eff,
    phenotype_idx=None,
    heritability=1.0,
    hidden_size=1024,
    learning_rate=0.001,
    regularization=1,
    max_samples=2000,
    device=device,
    verbose=False,
):
    """
    Run epistasis interaction estimation across multiple model replicates.

    Args:
        n_replicates: Number of training replicates to run
        base_file_name: Base path for HDF5 files
        true_eff: DataFrame with true effects
        phenotype_idx: List of phenotype indices (1-based)
        heritability: Heritability for data generation
        hidden_size: Hidden layer size for neural network
        learning_rate: Learning rate for training
        regularization: Weight decay for training
        max_samples: Max samples for epistasis analysis
        device: Device to run on
        verbose: Whether to print progress

    Returns:
        dict: Contains replicate results and aggregated estimates
    """
    if phenotype_idx is None:
        phenotype_idx = [1, 2, 3, 4, 5]

    print(f"Starting {n_replicates} replicates of epistasis analysis...")
    print(f"Phenotypes: {phenotype_idx}, Heritability: {heritability}")

    # Storage for all replicate results
    all_qg_estimates = []

    for rep in range(n_replicates):
        if verbose:
            print(f"\n=== Replicate {rep + 1}/{n_replicates} ===")

        # Create data loaders with different seed for each replicate
        loaders = create_data_loaders_with_heritability(
            base_file_name, heritability=heritability, seed=42 + rep
        )
        train_loader_gp = loaders["train_loader_gp"]
        test_loader_gp = loaders["test_loader_gp"]

        # Initialize and train model
        model = GP_net(n_loci=n_loci, hidden_layer_size=hidden_size, n_pheno=n_phen)

        print(f"Training model for replicate {rep + 1}...")

        model, best_loss = train_gpnet(
            model=model,
            train_loader=train_loader_gp,
            test_loader=test_loader_gp,
            n_loci=n_loci,
            verbose=False,
            learning_rate=learning_rate,
            weight_decay=regularization,
            device=device,
        )

        if verbose:
            print(f"Model trained. Best loss: {best_loss:.6f}")

        # Extract all epistatic pairs
        if verbose:
            print("Extracting epistatic pairs...")

        paired_all = extract_all_epistatic_pairs(
            model,
            test_loader_gp,
            true_eff,
            verbose=False,
            phenotype_idx=phenotype_idx,
            device=device,
            max_samples=max_samples,
        )

        qg_estimates = estimate_effects_from_sensitivities(
            paired_all,
            heritability=heritability,
            verbose=False,
        )

        # Add replicate identifier
        qg_estimates["replicate"] = rep + 1

        all_qg_estimates.append(qg_estimates)

        if verbose:
            # Quick summary stats for this replicate
            epi_pairs = qg_estimates[qg_estimates["true_epi_eff"] != 0]
            if len(epi_pairs) > 0:
                corr = np.corrcoef(epi_pairs["scaled_true_epi_eff"], epi_pairs["est_epi_eff"])[0, 1]
                print(f"  Epistatic correlation: {corr:.3f}")

    # Combine all replicates
    combined_results = pd.concat(all_qg_estimates, ignore_index=True)

    if verbose:
        print(f"\n=== Aggregating across {n_replicates} replicates ===")

    # Calculate mean estimates across replicates for each locus pair
    group_cols = ["phenotype", "phenotype_idx", "locus1", "locus2"]

    # Get mean estimates
    mean_estimates = (
        combined_results.groupby(group_cols)
        .agg(
            {
                "true_add_eff1": "first",
                "true_add_eff2": "first",
                "true_epi_eff": "first",
                "scaled_true_add_eff1": "first",
                "scaled_true_add_eff2": "first",
                "scaled_true_epi_eff": "first",
                "est_add_eff1": "mean",
                "est_add_eff2": "mean",
                "est_epi_eff": "mean",
                "heritability": "first",
                "encoding_type": "first",
            }
        )
        .reset_index()
    )

    # Get standard deviations
    std_estimates = (
        combined_results.groupby(group_cols)
        .agg({"est_add_eff1": "std", "est_add_eff2": "std", "est_epi_eff": "std"})
        .reset_index()
    )

    # Rename std columns
    std_estimates.columns = [
        col if col in group_cols else f"std_{col}" for col in std_estimates.columns
    ]

    # Merge means and stds
    final_estimates = mean_estimates.merge(std_estimates, on=group_cols)
    final_estimates["n_replicates"] = n_replicates

    if verbose:
        print(f"Final results: {len(final_estimates)} unique locus pairs")

        # Overall correlation for epistatic effects
        epi_pairs = final_estimates[final_estimates["true_epi_eff"] != 0]
        if len(epi_pairs) > 0:
            overall_corr = np.corrcoef(epi_pairs["scaled_true_epi_eff"], epi_pairs["est_epi_eff"])[
                0, 1
            ]
            print(f"Overall epistatic correlation: {overall_corr:.3f}")

            # Show correlation by phenotype
            for pheno in epi_pairs["phenotype"].unique():
                pheno_data = epi_pairs[epi_pairs["phenotype"] == pheno]
                if len(pheno_data) > 1:
                    pheno_corr = np.corrcoef(
                        pheno_data["scaled_true_epi_eff"], pheno_data["est_epi_eff"]
                    )[0, 1]
                    print(f"  {pheno}: r = {pheno_corr:.3f}")

    return {
        "all_replicates": all_qg_estimates,
        "combined_results": combined_results,
        "mean_estimates": final_estimates,
        "n_replicates": n_replicates,
        "phenotype_idx": phenotype_idx,
        "heritability": heritability,
    }
Run epistasis analysis across 25 replicates
# note this will take ~45min to run on CPU (25 reps in our study), more if you increase the replicate #
epistasis_rep_results = run_epistasis_replicates(
    n_replicates=25,
    base_file_name=base_file_name,
    true_eff=true_eff,
    phenotype_idx=[1, 2, 3, 4, 5],
    heritability=new_heritability,
    hidden_size=hidden_size,
    learning_rate=learning_rate,
    regularization=regularization,
    max_samples=2000,
    device=device,
    verbose=False,
)
Starting 25 replicates of epistasis analysis...
Phenotypes: [1, 2, 3, 4, 5], Heritability: 0.25
Training model for replicate 1...
Training model for replicate 2...
Training model for replicate 3...
Training model for replicate 4...
Training model for replicate 5...
Training model for replicate 6...
Training model for replicate 7...
Training model for replicate 8...
Training model for replicate 9...
Training model for replicate 10...
Training model for replicate 11...
Training model for replicate 12...
Training model for replicate 13...
Training model for replicate 14...
Training model for replicate 15...
Training model for replicate 16...
Training model for replicate 17...
Training model for replicate 18...
Training model for replicate 19...
Training model for replicate 20...
Training model for replicate 21...
Training model for replicate 22...
Training model for replicate 23...
Training model for replicate 24...
Training model for replicate 25...
Plot epistasis estimates pooled over analysis replicates
def plot_replicate_epistasis_estimates(results, show_error_bars=False):
    """
    Plot mean epistatic estimates across replicates vs true effects.

    Args:
        results: Output from run_epistasis_replicates
        show_error_bars: Whether to show error bars from replicate variation
    """
    mean_estimates = results["mean_estimates"]
    # n_replicates = results["n_replicates"]

    unique_phenotypes = mean_estimates["phenotype"].unique()

    for _, phenotype in enumerate(unique_phenotypes):
        with mpl.rc_context(plot_style):
            plt.figure(figsize=(6, 6))

            phenotype_data = mean_estimates[mean_estimates["phenotype"] == phenotype]

            if show_error_bars and "std_est_epi_eff" in phenotype_data.columns:
                plt.errorbar(
                    phenotype_data["scaled_true_epi_eff"],
                    phenotype_data["est_epi_eff"],
                    yerr=phenotype_data["std_est_epi_eff"],
                    fmt="o",
                    alpha=0.7,
                    capsize=3,
                    color=apc.steel,
                )
            else:
                plt.scatter(
                    phenotype_data["scaled_true_epi_eff"],
                    phenotype_data["est_epi_eff"],
                    alpha=0.7,
                    color=apc.steel,
                )

            # Add best fit line and slope (skip for Phenotype_1 as in original)
            if phenotype != "Phenotype_1" and len(phenotype_data) > 1:
                z = np.polyfit(
                    phenotype_data["scaled_true_epi_eff"], phenotype_data["est_epi_eff"], 1
                )
                p = np.poly1d(z)
                xlims = plt.xlim()
                x_smooth = np.linspace(xlims[0], xlims[1], 100)
                plt.plot(x_smooth, p(x_smooth), "r-", alpha=0.7, linewidth=2)

                slope = z[0]
                r2 = r2_score(phenotype_data["scaled_true_epi_eff"], phenotype_data["est_epi_eff"])

                plt.text(
                    0.05,
                    0.95,
                    f"Slope = {slope:.3f}\nr² = {r2:.2f}",
                    transform=plt.gca().transAxes,
                    verticalalignment="top",
                    fontsize=15,
                    bbox=dict(
                        boxstyle="round", facecolor=apc.parchment, alpha=0.7, edgecolor="none"
                    ),
                )

            plt.xlabel("True Epistatic Effect (scaled)", fontsize=16)
            plt.ylabel("Estimated Epistatic Effect", fontsize=16)

            # Get current axes for formatting
            ax = plt.gca()
            ax.xaxis.set_major_locator(MaxNLocator(4))
            ax.yaxis.set_major_locator(MaxNLocator(4))
            ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))

            plt.tight_layout(pad=0.5)
            apc.mpl.style_plot(monospaced_axes="both")

            plt.show()


plot_replicate_epistasis_estimates(epistasis_rep_results)
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing strong concordance between both values even with added environmental noise when results are averaged over replicates
(a) Phenotype 1
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing strong concordance between both values even with added environmental noise when results are averaged over replicates
(b) Phenotype 2
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing strong concordance between both values even with added environmental noise when results are averaged over replicates
(c) Phenotype 3
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing strong concordance between both values even with added environmental noise when results are averaged over replicates
(d) Phenotype 4
Scatter plots of ground truth epistatic interactions vs. interactions estimated through genotype specific ELM locus sensitivities showing strong concordance between both values even with added environmental noise when results are averaged over replicates
(e) Phenotype 5
Figure 10: Ground truth epistatic interaction terms plotted against epistatic terms estimated from paired genotype specific locus sensitivity shifts. Points shown for all possible pairs of loci. Epistatic terms estimated by averaging results from 25 replicates of model fitting. Environmental noise added to phenotypes. Results presented for 5 phenotypes (a-e) ranging from purely additive (a) to almost purely epistatic (e).

Sure enough, adding model replicates neatly cleans up the epistatic results, both by shrinking the false positives to 0 and by better fitting the true interactions (Figure 10). We can probably push this result to be even cleaner with more replicates. This validates our intuition about extracting epistatic signal from noise by averaging, and implies we might be able to learn about epistasis under a variety of model performance conditions.

References

Adams E, Bai L, Lee M, Yu Y, AlQuraishi M. (2025). From Mechanistic Interpretability to Mechanistic Biology: Training, Evaluating, and Interpreting Sparse Autoencoders on Protein Language Models. https://doi.org/10.1101/2025.02.06.636901
Bricken T, Templeton A, Batson J, Chen B, Jermyn A, Conerly T, Turner N, Anil C, Denison C, Askell A, Lasenby R, Wu Y, Kravec S, Schiefer N, Maxwell T, Joseph N, Hatfield-Dodds Z, Tamkin A, Nguyen K, McLean B, Burke JE, Hume T, Carter S, Henighan T, Olah C. (2023). Towards monosemanticity: Decomposing language models with dictionary learning
Elhage N, Nanda N, Olsson C, Henighan T, Joseph N, Mann B, Askell A, Bai Y, Chen A, Conerly T, DasSarma N, Drain D, Ganguli D, Hatfield-Dodds Z, Hernandez D, Jones A, Kernion J, Lovitt L, Ndousse K, Amodei D, Brown T, Clark J, Kaplan J, McCandlish S, Olah C. (2021). A mathematical framework for transformer circuits
Gaynor RC, Gorjanc G, Hickey JM. (2021). AlphaSimR: An R package for breeding program simulations. https://doi.org/10.1093/g3journal/jkaa017
Golden JR. (2025). Equivalent linear mappings of large language models
Mackay TFC. (2014). Epistasis and quantitative traits: Using model organisms to study gene–gene interactions. https://doi.org/10.1038/nrg3627
Mohan S, Kadkhodaie Z, Simoncelli EP, Fernandez-Granda C. (2020). Robust and interpretable blind image denoising via bias-free convolutional neural networks
Rijal K, Holmes CM, Petti S, Reddy G, Desai MM, Mehta P. (2025). Inferring genotype-phenotype maps using attention models
Sandler G, York R. (2025). Epistasis and deep learning in quantitative genetics. https://doi.org/10.57844/arcadia-25nt-guw3
Sigurdsson AI, Louloudis I, Banasik K, Westergaard D, Winther O, Lund O, Ostrowski SR, Erikstrup C, Pedersen OBV, Nyegaard M, DBDS Genomic Consortium, Brunak S, Vilhjálmsson BJ, Rasmussen S. (2023). Deep integrative models for large-scale human genomics. https://doi.org/10.1093/nar/gkad373
Tucker SL, Dohleman FG, Grapov D, Flagel L, Yang S, Wegener KM, Kosola K, Swarup S, Rapp RA, Bedair M, Halls SC, Glenn KC, Hall MA, Allen E, Rice EA. (2020). Evaluating maize phenotypic variance, heritability, and yield relationships at multiple biological scales across agronomically relevant environments. https://doi.org/10.1111/pce.13681
Wang S, Mohamed A-R, Caruana R, Bilmes J, Plilipose M, Richardson M, Geras K, Urban G, Aslan O. (2016). Analysis of deep neural networks with the extended data jacobian matrix
York R, Kiefl E, Bigge BM, McGeever E. (2025). Cross-trait learning with a canonical transformer tops custom attention in genotype–phenotype mapping. https://doi.org/10.57844/arcadia-bmb9-fzxd
York R, Mets DG. (2025). G–P Atlas: A neural network framework for mapping genotypes to many phenotypes. https://doi.org/10.57844/arcadia-d316-721f
Zeng S, Mao Z, Ren Y, Wang D, Xu D, Joshi T. (2021). G2PDeep: A web-based deep-learning framework for quantitative phenotype prediction and discovery of genomic markers. https://doi.org/10.1093/nar/gkab407