import subprocess
from pathlib import Path
= Path("datasets/")
dataset_dir if not dataset_dir.exists():
= "https://zenodo.org/records/15313069/files/datasets.zip"
zenodo_url "wget", zenodo_url])
subprocess.run(["unzip", "datasets.zip"])
subprocess.run(["datasets.zip").unlink() Path(
We used Claude to help write code and clean up code. We used ChatGPT to help write code, clean up code, write text that we edited, suggest wording ideas and then chose which small phrases or sentence structure ideas to use, expand on summary text that we provided and then edited the text it produced, help clarify and streamline text that we wrote, interpret model training results data, and suggest papers on relevant science, we did further reading, and we cited some of this literature. We also provided ChatGPT with starting text and had it rearrange that text to fit the structure of one of our pub templates. We used Gemini in similar ways to help write code, clean up code, write text that we edited, suggest wording ideas and then chose which small phrases or sentence structure ideas to use, help clarify and streamline text that we wrote, interpret model training results data, and suggest papers on relevant science, we did further reading, and we cited some of this literature. We also provided Gemini with starting text and had it rearrange that text to fit the structure of one of our pub templates, and expanded on summary text that we provided and then edited the text it produced.
Introduction
On April 11th, Rijal et al. (2025) released a preprint that introduces an application of attention mechanisms for inferring genotype-phenotype maps, particularly focusing on capturing complex epistatic interactions. This work sparked considerable interest given our ongoing exploration of nonlinear genotype-phenotype models.
By training on 100,000 yeast segregants over 18 growth phenotypes, Rijal et al. (2025) demonstrated that an attention-based architecture extracts more epistatic signals than conventional linear methods, and does so with far fewer parameters than the full second-order regression that has become a norm in quantitative genetics. Their study therefore provides an important proof-of-principle for deep-learning architectures being able to cope with linkage disequilibrium, noise, and sparse high-order interactions in large-scale genotype-phenotype data.
That success naturally raised our curiosity. The network they used omits several standard transformer components that brought attention into the limelight: skip connections, layer normalisation, and feed-forward sub-layers (Vaswani et al., 2017). Given the cross-domain success that transformers have seen, we wondered: Could the performance be further improved by replacing their model with a standard “vanilla” transformer architecture? We found out. Also, along the way, we uncovered a missed opportunity by Rijal et al. (2025) to reinforce the learning signal by properly leveraging cross-phenotype genetic correlations, which led to significant performance gains.
The dataset
The experimental data used in Rijal et al. (2025) comes from the work of Nguyen Ba et al. (2022), who performed a large-scale quantitative trait locus (QTL) study in yeast. In short, they measured the growth rates of ~100,000 yeast segregants across 18 conditions and for ~40,000 loci, creating a massive dataset suitable for mapping genotype to phenotype.
Due to extensive linkage disequilibrium (LD), the loci in the dataset are highly correlated with each other. To create a set of independent loci, Rijal et al. (2025) defined a set of loci such that the correlation between the SNPs present at any pair of loci is less than 94%, resulting in a set of 1164 “independent” loci.
We were unable to find the set of loci, nor the genotypic and phenotypic data used for training, so we located the raw data that Nguyen Ba et al. (2022) originally uploaded alongside their study, then used this notebook uploaded by Rijal et al. (2025) to recapitulate the 1164 loci. To save everyone else the trouble, we uploaded the train, test, and validation datasets we’re pretty sure Rijal et al. (2025) used in their study.
You can find the data here: https://zenodo.org/records/15313069
We’ll use this data in what follows, so let’s go ahead and download it into the current working directory:
Phenotype | Source R² | |
---|---|---|
0 | 23C | 0.612 |
1 | 25C | 0.621 |
2 | 27C | 0.632 |
3 | 30C | 0.609 |
4 | 33C | 0.615 |
5 | 35C | 0.582 |
6 | 37C | 0.566 |
7 | cu | 0.610 |
8 | suloc | 0.638 |
9 | ynb | 0.487 |
10 | eth | 0.604 |
11 | gu | 0.566 |
12 | li | 0.630 |
13 | mann | 0.600 |
14 | mol | 0.618 |
15 | raff | 0.653 |
16 | sds | 0.630 |
17 | 4NQO | 0.609 |
Reproducing single phenotype results
Let’s first try and reproduce the single-phenotype attention model performances observed in Figure 3 (red dots). This will give us the confidence to know we both (a) correctly reverse-engineered the specifics of their training/validation/test datasets and (b) accurately implemented their model.
Note that a separate model is trained on each phenotype, so reproducing this figure involves training 18 models. To do this, we need to create config objects specifying each model architecture and how the training should proceed.
Code
import attrs
from analysis.base import ModelConfig, TrainConfig
from analysis.dataset import phenotype_names
= ModelConfig(
model_config ="rijal_et_al",
model_type=1164,
seq_length=13,
embedding_dim=3,
num_layers
)
# This template config sets all the shared parameters.
= TrainConfig(
train_template_config =dataset_dir,
data_dir=Path("models/fig3"),
save_dir="fig3_23C",
name_prefix=["23C"],
phenotypes="adam",
optimizer=64,
batch_size=0.001,
learning_rate=False,
lr_schedule=0.0,
weight_decay=200,
max_epochs=0,
gradient_clip_val=True,
use_cache=False,
use_modal
)
= []
jobs for phenotype in phenotype_names:
# Each train config needs to set its corresponding phenotype(s).
= attrs.evolve(
phenotype_config =[phenotype], name_prefix=f"fig3_{phenotype}"
train_template_config, phenotypes
) jobs.append((model_config, phenotype_config))
With each job properly configured, we can train a model for each phenotype:
Code
from analysis.train import run_trainings
= run_trainings(jobs) model_dirs
Pre-trained model 'fig3_23C' found. Returning path.
Pre-trained model 'fig3_25C' found. Returning path.
Pre-trained model 'fig3_27C' found. Returning path.
Pre-trained model 'fig3_30C' found. Returning path.
Pre-trained model 'fig3_33C' found. Returning path.
Pre-trained model 'fig3_35C' found. Returning path.
Pre-trained model 'fig3_37C' found. Returning path.
Pre-trained model 'fig3_cu' found. Returning path.
Pre-trained model 'fig3_suloc' found. Returning path.
Pre-trained model 'fig3_ynb' found. Returning path.
Pre-trained model 'fig3_eth' found. Returning path.
Pre-trained model 'fig3_gu' found. Returning path.
Pre-trained model 'fig3_li' found. Returning path.
Pre-trained model 'fig3_mann' found. Returning path.
Pre-trained model 'fig3_mol' found. Returning path.
Pre-trained model 'fig3_raff' found. Returning path.
Pre-trained model 'fig3_sds' found. Returning path.
Pre-trained model 'fig3_4NQO' found. Returning path.
The above code will initiate training for all configured jobs, with important considerations:
- Caching: Since
train_config.use_cache = True
and these models are stored in the GitHub repository, executing this locally will use cached models instead of performing expensive retraining - Training Duration: Each training job takes approximately 2.5 hours to complete on an A10G GPU, so running all jobs without caching would require significant time
- Compute Configuration: In our production environment, we set
train_config.use_modal = True
to distribute compute jobs via Modal. For compatibility with different compute architectures, this notebook usestrain_config.use_modal = False
by default.
Inside each run directory is a metrics.csv
that we can get the test dataset \(R^2\) from and compare directly against Figure 3 from Rijal et al. (2025).
Code
from pathlib import Path
import arcadia_pycolor as apc
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
apc.mpl.setup()
def get_test_r2(model_dir: Path) -> float:
= pd.read_csv(model_dir / "metrics.csv")
metrics return float(metrics.loc[metrics["metric"] == "test_r2", "value"].iloc[0])
= [get_test_r2(d) for d in model_dirs]
fig3_reproduction_r2 "reproduction R2"] = fig3_reproduction_r2
fig3_results[= pd.DataFrame(fig3_results)
df
= np.arange(len(df))
x = 0.35
width
with mpl.rc_context({"figure.facecolor": apc.parchment}):
=300)
plt.figure(dpi
- width / 2, df["source R2"], width, label="Source R²", color=apc.cloud)
plt.bar(x + width / 2, df["reproduction R2"], width, label="Reproduction R²", color=apc.steel)
plt.bar(x
"Phenotype")
plt.xlabel("R² Score")
plt.ylabel(0.46, 0.67)
plt.ylim("phenotype"], rotation=90)
plt.xticks(x, df[=(0.05, 0.90), ncol=2)
plt.legend(loc
plt.tight_layout() plt.show()

With an \(n\) of 1, we can’t assess reproducibility rigorously. Even so, our re-implementation matches the published numbers very closely. Here are some numbers on the phenotype with the largest deviation:
Code
= (df["source R2"] - df["reproduction R2"]).abs()
deltas = deltas.argmax()
max_idx = deltas.max()
max_delta = df["phenotype"].iloc[max_idx]
phenotype_with_largest_delta = 100 * max_delta / df["source R2"].iloc[max_idx]
percent_diff
print("Biggest discrepenacy:")
print(f"- Phenotype: {phenotype_with_largest_delta}")
print(f"- R2 difference: {max_delta:.3f}")
print(f"- Percent difference: {percent_diff:.1f}")
Biggest discrepenacy:
- Phenotype: 33C
- R2 difference: 0.008
- Percent difference: 1.3
Overall, these results help assure us that:
- The train/validation/test partitions are identical, or at least functionally equivalent
- Our code reproduces the authors’ architecture and training procedure.
With this validated baseline in place, we’re now confident in using it as the starting point for modifying the architecture.
Canonical ML components outperform bespoke customization
The Rijal et al. (2025) model architecture uses non-standard components. To test whether these idiosyncratic choices are actually helpful, we replaced each one with a standard, “textbook” alternative and measured the collective impact on predictive accuracy.
Non-standard element in Rijal et al. | What it does | Canonical replacement we tried |
---|---|---|
Random projection of a diagonal genotype matrix X(g)·R |
Encodes each locus as a row of a fixed random matrix R , multiplied by the allele sign (±1). |
A learned embedding table nn.Embedding(L, D) whose rows are multiplied by the allele sign. |
Concatenated column of ones | Appends a constant 1 to every token to mimic an explicit bias term (Appendix F in the original paper). | Simply enable the bias in the linear Q, K, V projections (nn.Linear(..., bias=True) ). |
Phenotype represented as an input token | Adds one-hot phenotype tokens to the sequence, forcing attention layers to discover gene × phenotype interactions. | Condition the model after the attention block: predict all phenotypes jointly from a pooled sequence representation (see below). |
Flattened fitness matrix | Treats each (genotype, phenotype) pair as an independent sample; the network outputs one scalar at a time. | Mean-pool the token embeddings → (B, D) and use a single linear layer D → 18 so all phenotypes are predicted simultaneously. |
Why predict all phenotypes at once?
The measured phenotypes are likely to be genetically correlated due to biological processes such as pleiotropy. Consequently, knowing that a mutation hurts growth in one condition has the potential to inform its effect in another. A shared output head lets the network exploit this mutual information, whereas the original set-up can only share knowledge through the shared attention weights. Our phenotype-phenotype autoencoder study (Avasthi et al., 2023) showed significant prediction benefits when accounting for phenotypic covariation, so we’re expecting the same thing to be true here, too.
Experimental protocol
The codebase contains a new architecture with the above modifications. Let’s test its performance with the following changes to the training:
- Increase the number of phenotypes from \(1\) to \(18\) (all).
- Correspondingly increase the hidden dimension from \(d=12\) to \(d=128\).
- Decrease the learning rate 10-fold, from \(1 \times 10^{-3}\) to \(1 \times 10^{-4}\).
Finally, because Figure 1 hints at potentially significant run-to-run variance, let’s run five replicates, where each replicate differs only in its initializing seed (the train/validation/test splits are held constant).
Code
= 5
NUM_REPLICATES
= []
jobs for i in range(NUM_REPLICATES):
= f"{i:02d}"
replicate_id = f"std_d128_rep_{replicate_id}"
job_name_prefix
= ModelConfig(
model_config =128,
embedding_dim="modified",
model_type=1164,
seq_length=3,
num_layers
)
= attrs.evolve(
train_config
train_template_config,=Path("models/canonical"),
save_dir=phenotype_names,
phenotypes=job_name_prefix,
name_prefix="adamw",
optimizer=100,
max_epochs
)
jobs.append((model_config, train_config))
print(f"\nGenerated {len(jobs)} job configurations.")
Generated 5 job configurations.
If you’re running this yourself and aren’t interested in replicates, you can reduce the amount of required compute by setting NUM_REPLICATES
in the above cell to 1.
Now, let’s run the experiment:
Code
= run_trainings(jobs) model_dirs
Pre-trained model 'std_d128_rep_00' found. Returning path.
Pre-trained model 'std_d128_rep_01' found. Returning path.
Pre-trained model 'std_d128_rep_02' found. Returning path.
Pre-trained model 'std_d128_rep_03' found. Returning path.
Pre-trained model 'std_d128_rep_04' found. Returning path.
Consistent gains across phenotypes
Figure 2 shows that swapping the custom choices in Rijal et al. (2025) for a canonical embedding \(\rightarrow\) mean-pool \(\rightarrow\) linear-head model and predicting all 18 phenotypes jointly yields a boost in predictive power across the board:
Code
= {
label_lookup "std": "Canonical",
}
def get_phenotype_r2_data(model_dir: Path) -> float:
# Determine the architecture and replicate number from model directory name
= model_dir.parent.parent.name # e.g., "std_d128_rep_09"
dir_name = dir_name.rpartition("_rep_")
variant_key, _, rep_str = variant_key.rpartition("_d")
base_arch_key, _, _ = label_lookup[base_arch_key]
architecture_label
# Load and wrangle the metrics.csv. Keep only the per-phenotype test R2 values.
= pd.read_csv(model_dir / "metrics.csv")
df = df[df["metric"].str.startswith("test_r2_")]
df "phenotype"] = df["metric"].str[8:]
df["metric", axis=1, inplace=True)
df.drop(={"value": "r2"}, inplace=True)
df.rename(columns"architecture"] = architecture_label
df["replicate"] = int(rep_str)
df["r2"] = df["r2"].astype(float)
df[
return df
# Concat the phenotype data across all models
= pd.concat([get_phenotype_r2_data(model_dir) for model_dir in model_dirs])
canonical_plot_data
# Calculate the mean and standard error R2 over replicates
= (
canonical_plot_data "architecture", "phenotype"])["r2"]
canonical_plot_data.groupby([="mean", std_r2="std")
.agg(mean_r2
.reset_index()
)
# Concatenate the manually scraped Fig3 R2s
= fig3_results.copy(deep=True)
rijal_r2s "reproduction R2", axis=1, inplace=True)
rijal_r2s.drop(={"source R2": "mean_r2"}, inplace=True)
rijal_r2s.rename(columns"std_r2"] = 0.0
rijal_r2s["architecture"] = "Rijal"
rijal_r2s[= pd.concat([canonical_plot_data, rijal_r2s])
canonical_plot_data
# Plotting code.
"phenotype"] = pd.Categorical(
canonical_plot_data["phenotype"], categories=phenotype_names, ordered=True
canonical_plot_data[
)= ["Rijal", "Canonical"]
arch_order
= np.arange(len(phenotype_names))
x = 0.25
width
= {
color_map "Rijal": apc.cloud,
"Canonical": apc.mars,
}
with mpl.rc_context({"figure.facecolor": apc.parchment}):
=300)
plt.figure(dpi
for i, arch in enumerate(arch_order):
= (
sub "architecture"] == arch]
canonical_plot_data[canonical_plot_data["phenotype")
.set_index(
.reindex(phenotype_names)
)
plt.bar(+ (i - 0.5) * width, # centre bars around tick
x "mean_r2"],
sub[
width,=sub["std_r2"],
yerr="#00000099",
ecolor=dict(elinewidth=1.5),
error_kw=arch,
label=color_map[arch],
color
)
=90)
plt.xticks(x, phenotype_names, rotation"Phenotype")
plt.xlabel("Mean $R^2$ (± SD)")
plt.ylabel(0.43, 0.69)
plt.ylim(=(0.05, 0.92), ncol=3)
plt.legend(loc
plt.tight_layout() plt.show()

The numbers labeled “Rijal” in Figure 2 come from the single-phenotype models reported in their Figure 3. Rijal et al. (2025) do develop a multi-phenotype architecture, but careful inspection of their code shows that it still emits a scalar output per forward pass and simply loops over phenotypes at train time. In other words, each (genotype, phenotype) pair is treated as an independent sample. The predictive power of this model is shown in their Figure 4, and is actually worse than the single-phenotype models at predicting any given phenotype.
Because our goal is to test whether a genuine multi-output head + canonical components confer an advantage, we chose the best-performing baseline available (their single-phenotype models). The comparison is therefore conservative: any gains we report would be larger, not smaller, if the authors’ scalar multi-phenotype model were used as the reference.
Quantifying which changes contributed to these gains would require an ablation study, however our hunch is that the primary gains come from cross-phenotype genetic correlations. For instance we note that some of the most consistent gains we see in our implementation of the model come from fitness measurements at different temperatures, a set of phenotypes that is almost certainly impacted by pleiotropic sets of genes. Joint training allows the network to transfer information among genetically correlated traits, an advantage that the single-output baseline can’t exploit beyond implicit relationships learned through shared attention weights.
Taken together, these results validate the canonical multi-output model as a stronger starting point than the original design.
A vanilla transformer architecture
The current architecture still departs from the reference transformer in several respects: it omits residual connections, layer normalization, and position-wise feed-forward blocks. The logical next experiment is therefore to level up to a bona-fide vanilla transformer, preserving the current tokenization while adding:
- Residual (skip) connections and pre-LayerNorm,
- Feed-forward sub-layers with RELU activations,
- Scaled dot-product attention with the canonical \(1/\sqrt{d}\) factor,
- Dropout and weight-decay for regularization to prevent overfitting.
It’s worth noting that transformers excel at sequence modeling—and sequences are ordered collections. While the loci in this dataset could be attributed chromosomal coordinates, and therefore in a sense an ordering, like Rijal et al. (2025), we’re treating the loci as an unordered collection. Thus, we don’t add any positional encodings, either in the form of absolute or relative positional encodings.
Code
= 5
NUM_REPLICATES
= []
jobs for i in range(NUM_REPLICATES):
= f"{i:02d}"
replicate_id = f"xformer_rep_{replicate_id}"
job_name_prefix
= ModelConfig(
model_config =256,
embedding_dim="transformer",
model_type=1164,
seq_length=3,
num_layers=4,
nhead=1024,
dim_feedforward
)
= attrs.evolve(
train_config
train_template_config,=Path("models/transformer"),
save_dir=phenotype_names,
phenotypes=job_name_prefix,
name_prefix="adamw",
optimizer=80,
max_epochs
)
jobs.append((model_config, train_config))
print(f"\nGenerated {len(jobs)} job configurations.")
Generated 5 job configurations.
Code
= run_trainings(jobs) model_dirs
Pre-trained model 'xformer_rep_00' found. Returning path.
Pre-trained model 'xformer_rep_01' found. Returning path.
Pre-trained model 'xformer_rep_02' found. Returning path.
Pre-trained model 'xformer_rep_03' found. Returning path.
Pre-trained model 'xformer_rep_04' found. Returning path.
Figure 3 shows the change in performance:
Code
= {
label_lookup "std": "Canonical",
"xformer": "Transformer",
}
def get_phenotype_r2_data_xformer(model_dir: Path) -> float:
# Determine the architecture and replicate number from model directory name
= model_dir.parent.parent.name # e.g., "std_d128_rep_09"
dir_name = dir_name.rpartition("_rep_")
variant_key, _, rep_str = label_lookup[variant_key]
architecture_label
# Load and wrangle the metrics.csv. Keep only the per-phenotype test R2 values.
= pd.read_csv(model_dir / "metrics.csv")
df = df[df["metric"].str.startswith("test_r2_")]
df "phenotype"] = df["metric"].str[8:]
df["metric", axis=1, inplace=True)
df.drop(={"value": "r2"}, inplace=True)
df.rename(columns"architecture"] = architecture_label
df["replicate"] = int(rep_str)
df["r2"] = df["r2"].astype(float)
df[
return df
# Concat the phenotype data across all models
= pd.concat([get_phenotype_r2_data_xformer(model_dir) for model_dir in model_dirs])
df = df.groupby(["architecture", "phenotype"])["r2"].agg(mean_r2="mean", std_r2="std").reset_index()
df
# Concat with existing canonical plot data
= pd.concat([df, canonical_plot_data])
xformer_plot_data
# Plotting code.
"phenotype"] = pd.Categorical(
xformer_plot_data["phenotype"], categories=phenotype_names, ordered=True
xformer_plot_data[
)= ["Rijal", "Canonical", "Transformer"]
arch_order
= np.arange(len(phenotype_names))
x = 0.25
width
= {
color_map "Rijal": apc.cloud,
"Canonical": apc.mars,
"Transformer": apc.dragon,
}
with mpl.rc_context({"figure.facecolor": apc.parchment}):
=300)
plt.figure(dpi
for i, arch in enumerate(arch_order):
= (
sub "architecture"] == arch]
xformer_plot_data[xformer_plot_data["phenotype")
.set_index(
.reindex(phenotype_names)
)
plt.bar(+ (i - 1) * width, # centre bars around tick
x "mean_r2"],
sub[
width,=sub["std_r2"],
yerr="#00000099",
ecolor=dict(elinewidth=1.5),
error_kw=arch,
label=color_map[arch],
color
)
=90)
plt.xticks(x, phenotype_names, rotation"Phenotype")
plt.xlabel("Mean $R^2$ (± SD)")
plt.ylabel(0.43, 0.69)
plt.ylim(=(0.05, 0.92), ncol=3)
plt.legend(loc
plt.tight_layout() plt.show()

The full transformer nudges performance above the already-improved canonical model across most phenotypes (Figure 3). The lift is apparent, but no stepwise performance gains are observed. There’s undoubtedly a lot of room for improvement: we made no attempt at hyper-parameter tuning beyond some cursory investigations of dropout and weight-decay (neither of which positively affected performance), and all other parameters were guessed zero shot.
Conclusion and outlook
Our central question was simple: do off-the-shelf transformer components—and a genuine multi-output head—move the needle on genotype to phenotype prediction?
The answer is an emphatic yes.
Why multi-output matters
Leveraging mutual information between genetically correlated phenotypes represents a natural way of boosting model performance when training data is limited. In the context of plant/animal breeding, multi-trait genomic prediction has been well established to improve the predictive power of linear models (particularly for low heritability traits) (Jia and Jannink, 2012). Our own work with phenotype-phenotype autoencoders has demonstrated that encoding multiple phenotypes jointly lets an auto-encoder predict individual phenotypes with very high accuracy, particularly as the number of phenotypes considered simultaneously increases (Avasthi et al., 2023). These observations are examples of the benefits of multi-task learning, which have long been appreciated in ML literature (Caruana, 1997). Our presented results are no different: when the network can see all 18 phenotypes at once, shared genetic effects (either through pleiotropy or in some instances through LD (Chebib and Guillaume, 2021)) reinforce, rather than fragment, the learning signal.
A foundation, not a finish line
A natural next step given these results is to test cross-environment transfer learning in our multi-task model. Rijal et al. (2025) demonstrated that transfer learning was possible across temperature growth conditions even with a multi-environment model that did not outperform a single-environment model. Given the overall superior performance of the models we present here, they may be particularly well suited for fitness prediction for novel phenotypes with even fewer fine-tuning observations than used by Rijal et al. (2025).
Finally, it would be interesting to evaluate where the performance boost of multi-task learning comes from, from a quantitative genetics perspective. Phenotypic variance can generally be broken down into additive (\(G\)), epistatic (\(G \times G\)) and gene by environment (\(G \times E\)) components (ignoring higher order terms for simplicity). Intuitively, we might expect multi-trait learning to excel at explaining \(G\) and \(G \times G\) variance components that are constant across environments, however it’s possible that \(G \times E\) variance is also better explained, as multi-task learning benefits are known to extend to even unrelated tasks (Paredes et al., 2012). We suspect this could be investigated through ablation studies that track environment specific prediction outcomes.
There is also still plenty of headroom for improving the model:
- No hyper-parameter search was attempted for the transformer; depth, head count, LR schedules, dropout rates, and layer widths all remain untouched.
- How far can the canonical model be pushed? Maybe the transformer is overkill.
- The model treats the loci as an unordered collection. It’s possible that adding chromosomal coordinates or some kind of other positional encoding could let the model sing.
- We haven’t experimented with more sophisticated feature selection methods for reducing the number of loci prior to model training, a task that may be particularly fruitful for improving the ability of models to capture pairwise and higher order epistasis.
An invitation to build further
All code, configs, and cached model checkpoints are available in this notebook’s repository. The Appendix documents the engineering decisions that should help orient you to the codebase.
Appendix: a codebase primer
This section elaborates on engineering details that will be of interest to those planning to reproduce, modify, or build off of this research.
Rather than building off the Rijal et al. (2025) notebook files, we re-implemented their code into our own codebase to improve code quality and make room for our modifications and experimentation. Here is a very high-level summary of the changes we made:
- Added configuration dataclasses to co-localize tunables
- Saved the training/validation/test datasets to file to avoid splitting on-the-fly
- Created PyTorch dataloaders to manage accession, batching, and data shuffling
- Automated the training loop with PyTorch Lightning, creating separation between the training loop and the model logic
- Added canonical learning parameters like early stopping, learning rate scheduling, gradient norm clipping, weight decay, and more
- Generalized the Rijal et al. (2025) model with toggleable skip connections, layer normalization, scaled attention, dropout rate, and more
The upshot is that we’re proud of this code and think it establishes a much-needed foundation that can be used to build off the research seeded by Rijal et al. (2025).
Training models
In the analysis above, we illustrated how multiple training jobs can be run using the high level entrypoint, run_trainings
, that:
- Trains a model for a given set of phenotypes
- Determines the best model, defined as the model with the lowest loss (MSE) calculated over the validation dataset
- Reports the \(R^2\) for the test dataset using the best model
- Saves the model to file for downstream use
The codebase also exposes equivalent behavior through a command-line interface (CLI). It can be accessed via:
$ python -m analysis.train --help
Distributed computing
We performed the analysis using Modal’s cloud infrastructure to distribute computations across GPUs, allowing us to rapidly measure performance across many different model architectures and training specifications. Whether you want to train with or without Modal can be toggled by the attribute train.use_modal
. By default, Modal execution is disabled. The downside is that your training jobs will run in serial, rather than being distributed across different machines.
Caching behavior
We implemented a simple cache mechanism that avoids training if a model directory for a given training config already exists. We did this so that GPUs aren’t a requirement for engaging with this research.
train_config.use_cache = True
(default): Skips retraining if a model with the same configuration already existstrain_config.use_cache = False
: Forces retraining regardless of existing models
All training runs in this analysis use the default caching mode (train_config.use_cache = True
), and the results are git tracked. As a result, if you execute this notebook locally, these models will be loaded from cache rather than retrained.