Spike Analysis

This notebook contains the entire analysis seen in our manuscript Jointly modeling deep mutational scans identifies shifted mutational effects among SARS-CoV-2 spike homologs. Some cell’s input, output, or both are hidden for brevity - you may toggle the contents using the drop down bars.

See the repository README for more instructions on how to run this notebook.

Computational platform environment

This section shows the attributes of the machine which ran this notebook, as well as imports the necessary dependencies.

Operating system

Hide content
! grep -E '^(VERSION|NAME)=' /etc/os-release
NAME="Ubuntu"
VERSION="18.04.6 LTS (Bionic Beaver)"

Hardware (Processors and RAM)

Hide content
! lshw -class memory -class processor
WARNING: you should run this program as super-user.
  *-memory                  
       description: System memory
       physical id: 0
       size: 996GiB
  *-cpu
       product: AMD EPYC 75F3 32-Core Processor
       vendor: Advanced Micro Devices [AMD]
       physical id: 1
       bus info: cpu@0
       size: 1499MHz
       capacity: 2950MHz
       width: 64 bits
       capabilities: fpu fpu_exception wp vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp x86-64 constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca cpufreq
WARNING: output may be incomplete or inaccurate, you should run this program as super-user.

GPU’s

Hide content
%env CUDA_VISIBLE_DEVICES=0
! nvidia-smi -L
env: CUDA_VISIBLE_DEVICES=0
GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-414cb1bd-372a-4926-b140-b734687c927f)
GPU 1: NVIDIA A100 80GB PCIe (UUID: GPU-e54c2054-5be3-ebd0-e22e-b98441ec664f)
Hide content
# built-in libraries
import os
import sys
from itertools import combinations
from collections import defaultdict
import time
import pprint
import copy
import pickle
from functools import reduce

# external dependencies
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr
from matplotlib.lines import Line2D
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
from matplotlib.transforms import (
    Bbox, TransformedBbox, blended_transform_factory)
from mpl_toolkits.axes_grid1.inset_locator import (
    BboxPatch, BboxConnector, BboxConnectorPatch)
import matplotlib.patches as patches
import matplotlib.colors as colors
import numpy as np
import scipy
from tqdm.notebook import tqdm
import jax
import jax.numpy as jnp
import shutil
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.PDBList import PDBList
from Bio.PDB.DSSP import DSSP
from Bio import SeqIO
import multidms
%matplotlib inline

This was notebook was run with the following multidms version.

multidms.__version__
'0.2.0'

Set some global configurations for plotting.

Hide content
OUTDIR = "results/multidms02"
if not os.path.exists(OUTDIR): os.mkdir(OUTDIR)

rc_kwargs = {
    'legend.frameon': False,
    "font.size" : 11,
    "font.weight" : "normal"
}

plt.rcParams.update(**rc_kwargs)

Input Data

Load and organize the funcational score dms data.

We begin with 16 individual sets of barcoded variants and their associated functional scores pre-computed. Each set derives from a single DMS experiment using one of Delta, Omicron BA.1, or Omicron BA.2 as the experimental wildtype. First, we parse the filenames to get experimental attributes tied in with the individual datasets as nested pd.DataFrames

Hide source
func_score_data = pd.DataFrame()

for homolog in ["Delta", "Omicron_BA1", "Omicron_BA2"]:
    
    # functional scores
    func_sel = (
        pd.read_csv(f"data/{homolog}/functional_selections.csv")
        .assign(
            filename = lambda x: f"data/{homolog}/" + 
            x.library + "_" + 
            x.preselection_sample + 
            "_vs_" + x.postselection_sample + 
            "_func_scores.csv"
        )
        .assign(
            func_sel_scores_df = lambda x: x.filename.apply(
                lambda f: pd.read_csv(f)
            )   
        )
        .assign(
            len_func_sel_scores_df = lambda x: x.func_sel_scores_df.apply(
                lambda x: len(x)
            )
        )
        .assign(homolog = homolog)
    )
    func_score_data = pd.concat([func_score_data, func_sel]).reset_index(drop=True)

# Add a column that gives a unique ID to each homolog/DMS experiment
func_score_data['condition'] = func_score_data.apply(
    lambda row: f"{row['homolog']}-{row['library']}".replace('-Lib',''),
    axis=1
)
func_score_data[['library', 'replicate', 'filename', 'condition']]
library replicate filename condition
0 Lib-1 1 data/Delta/Lib-1_2021-10-28_thaw-1_VSVG_contro... Delta-1
1 Lib-1 2 data/Delta/Lib-1_2021-10-28_thaw-1_VSVG_contro... Delta-1
2 Lib-3 1 data/Delta/Lib-3_2021-10-28_thaw-1_VSVG_contro... Delta-3
3 Lib-3 2 data/Delta/Lib-3_2021-10-28_thaw-1_VSVG_contro... Delta-3
4 Lib-4 1 data/Delta/Lib-4_2021-10-28_thaw-1_VSVG_contro... Delta-4
5 Lib-4 2 data/Delta/Lib-4_2021-10-28_thaw-1_VSVG_contro... Delta-4
6 Lib-2 1 data/Delta/Lib-2_2021-10-28_thaw-1_VSVG_contro... Delta-2
7 Lib-2 2 data/Delta/Lib-2_2021-10-28_thaw-1_VSVG_contro... Delta-2
8 Lib-1 1 data/Omicron_BA1/Lib-1_2022-03-25_thaw-1_VSVG_... Omicron_BA1-1
9 Lib-1 2 data/Omicron_BA1/Lib-1_2022-03-25_thaw-1_VSVG_... Omicron_BA1-1
10 Lib-2 1 data/Omicron_BA1/Lib-2_2022-06-22_thaw-1_VSVG_... Omicron_BA1-2
11 Lib-3 1 data/Omicron_BA1/Lib-3_2022-06-22_thaw-1_VSVG_... Omicron_BA1-3
12 Lib-1 1 data/Omicron_BA2/Lib-1_2022-10-22_thaw-1_VSVG_... Omicron_BA2-1
13 Lib-2 1 data/Omicron_BA2/Lib-2_2022-10-22_thaw-1_VSVG_... Omicron_BA2-2
14 Lib-1 2 data/Omicron_BA2/Lib-1_2022-10-22_thaw-2_VSVG_... Omicron_BA2-1
15 Lib-2 2 data/Omicron_BA2/Lib-2_2022-10-22_thaw-2_VSVG_... Omicron_BA2-2
Hide source
avail_cond_str = '\n- '.join(list(func_score_data.condition.unique()))
print(f"Available conditions for fitting are:\n- {avail_cond_str}")
Available conditions for fitting are:
- Delta-1
- Delta-3
- Delta-4
- Delta-2
- Omicron_BA1-1
- Omicron_BA1-2
- Omicron_BA1-3
- Omicron_BA2-1
- Omicron_BA2-2

Concatentate each of the individual experiments, keeping track of the library and homolog of each. Output noteable features, for a random sample of 10

Hide source
func_score_df = pd.DataFrame()
for idx, row in tqdm(func_score_data.iterrows(), total=len(func_score_data)):
    df = row.func_sel_scores_df.assign(
        homolog=row.homolog,
        library = row.library,
        replicate = row.replicate,
        condition=row.condition
    )
    func_score_df = pd.concat([func_score_df, df])

# rename, sort index, and fill na (wildtype values) with empty strings
func_score_df = (func_score_df
    .rename(
        {"aa_substitutions_reference":"aa_substitutions"}, 
        axis=1
    )
    .reset_index(drop=True)
    .fillna("")
    .sort_values(by="condition")
)
func_score_df[["library", "barcode", "aa_substitutions", "func_score", "condition"]].sample(10, random_state=0)
library barcode aa_substitutions func_score condition
721925 Lib-3 CGTAAAGTTCCAACAA G769R D950F R1107M N1192S -2.1765 Omicron_BA1-3
239549 Lib-4 AATAATTTTCCTACAC -2.0202 Delta-4
816259 Lib-1 GATGATACCAAACTAT K814T L1024I E1207K -2.1526 Omicron_BA2-1
612980 Lib-2 AAATATCCTACAAGAA C738Y A890T H1058Y -9.0995 Omicron_BA1-2
368871 Lib-1 TAATACCGAATCCCCC A893V S939D A1078T -4.1550 Omicron_BA1-1
1115330 Lib-2 GTATACATGTATGATG S71L D1163E S1242N 0.2762 Omicron_BA2-2
410949 Lib-1 GCATTACTACAAATAA N960K 0.6777 Omicron_BA1-1
971589 Lib-1 CAATATAGCATAGAGG R78L 0.1378 Omicron_BA2-1
592643 Lib-2 ACAAGCTTTGCAACAA Y200H 1.3313 Omicron_BA1-2
381265 Lib-1 CTAGTCTCCGACAAAA F347S D627G I850L -6.8325 Omicron_BA1-1

Discard all variants with a pre-selection count of 100.

Hide source
n_pre_threshold = len(func_score_df)
func_score_df.query("pre_count >= 100", inplace=True)
print(f"Of {n_pre_threshold} variants, {n_pre_threshold - len(func_score_df)} had fewer than the threshold of counts before selection, and were filtered out")
Of 1135096 variants, 120164 had fewer than the threshold of counts before selection, and were filtered out

We only require a functional score, aa substitutions, and condition column for instatiating the multidms.Data object. drop the rest.

Hide source
required_cols = ['func_score', 'aa_substitutions', 'condition']
func_score_df.drop([c for c in func_score_df if c not in required_cols], axis=1, inplace=True)

Remove all variants with string-suffixed sites (indels) and stop codon wildtypes.

Hide source
stop_wt_vars = []
non_numeric_sites = []
for idx, row in tqdm(func_score_df.iterrows(), total=len(func_score_df)):
    for sub in row["aa_substitutions"].split():
        if sub[0] == "*":
            stop_wt_vars.append(idx)
        if not sub[-2].isnumeric():
            non_numeric_sites.append(idx)

to_drop = set.union(set(stop_wt_vars), set(non_numeric_sites))
func_score_df.drop(to_drop, inplace=True)

We clip all functional scores at a lower bound of -3.5, and an upper bound of 2.5.

n_below_clip = len(func_score_df.query(f"func_score < -3.5"))
n_above_clip = len(func_score_df.query(f"func_score > 2.5"))
print(f"There are {n_below_clip} variants below the clip theshold, and {n_above_clip} above.")
func_score_df['func_score'].clip(-3.5, 2.5, inplace=True)
There are 143177 variants below the clip theshold, and 345 above.

Choose two representative biological replicates for each of the three homologs.

experiment_conditions = ["Delta", "Omicron_BA1", "Omicron_BA2"]
replicate_1_experiments = ["Delta-2", "Omicron_BA1-2", "Omicron_BA2-1"]
replicate_2_experiments = ["Delta-4", "Omicron_BA1-3", "Omicron_BA2-2"]

Organize the two replicates and annotate replicates “1” and “2”. These each represent a distinct training set such that we may train replicate models and compare their results. Output a random sample of 10 variants.

Hide source
func_score_df = pd.concat(
    [
        (
            func_score_df
            .query("condition in @replicate_1_experiments")
            .replace(dict(zip(replicate_1_experiments, experiment_conditions)))
            .assign(replicate=1)
        ),
        (
            func_score_df
            .query("condition in @replicate_2_experiments")
            .replace(dict(zip(replicate_2_experiments, experiment_conditions)))
            .assign(replicate=2)
        )
    ]
)
func_score_df = func_score_df.assign(
    n_subs = [
        len(aa_subs.split()) 
        for aa_subs in func_score_df.aa_substitutions
    ]
)
func_score_df.sample(10)
func_score aa_substitutions condition replicate n_subs
1071040 0.0505 I105T N1023T D1127E Omicron_BA2 2 3
880651 -1.0075 D1084E L1200I Omicron_BA2 2 2
901472 -0.1839 S1037N Omicron_BA2 2 1
801311 -1.1914 T33I D936H Omicron_BA2 1 2
908179 -2.5949 D339P Y655A D1168V Omicron_BA2 2 3
172050 -2.7820 R19K T95* V159I R237K R452M A845V Delta 2 6
882804 0.6012 H954D Omicron_BA2 2 1
178845 -3.5000 L8M G446S S605N L1063V I1081F Delta 2 5
631205 -1.7408 S254G S939P Omicron_BA1 1 2
724691 0.1839 I95T Omicron_BA1 2 1
func_score_df.to_csv("results/training_functional_scores.csv", index=False)

Variant barcode and mutation background stats

In this section we briedly query and visualize charictaristics of the replicate training sets.

Get the expected number substitutions per variant for each condition replicate.

Hide source
for group, group_df in func_score_df.groupby(["condition", "replicate"]):
    print(f"{group[0]} - rep {group[1]} has {round(group_df.n_subs.mean(), 5)} subs per variant, on average")
Delta - rep 1 has 2.18671 subs per variant, on average
Delta - rep 2 has 2.29472 subs per variant, on average
Omicron_BA1 - rep 1 has 1.802 subs per variant, on average
Omicron_BA1 - rep 2 has 1.75802 subs per variant, on average
Omicron_BA2 - rep 1 has 2.31117 subs per variant, on average
Omicron_BA2 - rep 2 has 2.32827 subs per variant, on average

Get the number of unique mutations seen in each condition replicate.

Hide source
for group, group_df in func_score_df.groupby(["condition", "replicate"]):
    print(f"{group[0]} - rep {group[1]} has {len(group_df.aa_substitutions.unique())}")  
Delta - rep 1 has 28515
Delta - rep 2 has 29158
Omicron_BA1 - rep 1 has 70597
Omicron_BA1 - rep 2 has 62129
Omicron_BA2 - rep 1 has 60397
Omicron_BA2 - rep 2 has 57719

Visualize the distribution of barcodes per variant, as well as the distribution of unique backgrounds per mutation.

Hide source
saveas = f"raw_data_summary_barcodes_backgrounds_hist"
logscale=False
fig, ax = plt.subplots(2,3, sharex="row", sharey="row", figsize=[6.4, 5.5])

condition_title = {
    "Delta":"Delta",
    "Omicron_BA1" : "BA.1",
    "Omicron_BA2" : "BA.2"
}

row = 0
for col, (condition, condition_df) in enumerate(func_score_df.groupby("condition")):
    iter_ax = ax[row, col]
    
    df = condition_df.query("aa_substitutions != ''")
    df = df.assign(
        num_muts = [
            len(aa_subs.split())
            for aa_subs in df.aa_substitutions
        ]
    )
    
    sns.histplot(df.query("num_muts <= 10"), x="num_muts", ax=iter_ax, hue="replicate", discrete=True)
    for rep, rep_df in df.groupby("replicate"):
        mean = rep_df['num_muts'].mean()
        iter_ax.axvline(mean, linestyle=("-" if rep == 1 else "--"))
    
    if logscale: iter_ax.set_yscale('log')
    if col != 2: 
        iter_ax.get_legend().remove()
    n_rep1 = len(df.query("replicate == 1"))//1000
    n_rep2 = len(df.query("replicate == 2"))//1000
    iter_ax.text(
        0.1, 1.1, 
        f"$N={n_rep1}K, {n_rep2}K$", 
        ha="left", va="top", 
        transform=iter_ax.transAxes
    )
    xscale = "number of amino-acid substitutions per variant" if col == 1 else ""
    iter_ax.set_xlabel(xscale)
    
    ylabel = f"variant counts" if col == 0 else ""
    iter_ax.set_ylabel(ylabel)
    iter_ax.set_xticks(
        [i+1 for i in range(10)],
        labels=[i+1 for i in range(10)], 
        ha="center",
        size=7,
        rotation=0
    )
    sns.despine(ax=iter_ax)
    iter_ax.set_title(condition_title[condition], y=1.15)

row = 1
collapsed_bc_df = func_score_df.groupby(
    ["replicate", "condition", "aa_substitutions"]
).aggregate("mean").reset_index()
for col, (condition, condition_df) in enumerate(collapsed_bc_df.groupby("condition")):
    iter_ax = ax[row, col]
    df = pd.DataFrame()
    for rep, rep_df in condition_df.groupby("replicate"):
        
        times_seen = (
            rep_df["aa_substitutions"].str.split().explode().value_counts()
        )
        if (times_seen == times_seen.astype(int)).all():
            times_seen = times_seen.astype(int)
        times_seen = pd.DataFrame(times_seen)
        times_seen.index.name = f"mutation"
        df = pd.concat([df, times_seen.assign(replicate=rep).reset_index()])

    sns.histplot(
        df.query("count <= 50"), 
        x="count", 
        ax=iter_ax, 
        element='step', 
        hue="replicate", 
        discrete=True
    )
    
    for rep, rep_df in df.groupby("replicate"):
        mean = rep_df['count'].mean()
        iter_ax.axvline(mean, linestyle=("-" if rep == 1 else "--"))
        
    iter_ax.get_legend().remove()
    n_rep1 = len(df.query("replicate == 1"))
    n_rep2 = len(df.query("replicate == 2"))
    iter_ax.text(
        0.1, 1.1, 
        f"$N={n_rep1}, {n_rep2}$", 
        ha="left", va="top", 
        transform=iter_ax.transAxes
    )
    
    xscale = "number of variant backgrounds \nfor a given amino-acid substitution" if col == 1 else ""
    iter_ax.set_xlabel(xscale)
    
    ylabel = f"mutation counts" if col == 0 else ""
    iter_ax.set_ylabel(ylabel)
    
    xticks = [i for i in range(0, 51) if i % 5 == 0]
    iter_ax.set_xticks(
        xticks,
        labels=xticks, 
        ha="center",
        size=7,
        rotation=0
    )
    
    sns.despine(ax=iter_ax)

plt.tight_layout()

ax[0,0].text(
    -0.1, 1.06, 
    f"A", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.06, 
    f"B", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)

fig.subplots_adjust(hspace=.6)
fig.savefig(f"{OUTDIR}/{saveas}.pdf")
fig.savefig(f"{OUTDIR}/{saveas}.png")
plt.show()
_images/3c60d7ae9db6ac7aa4012a22c369b3d91d57853a353ea7e31c08af73a32af3ca.png

Plot the correlation of variant functional scores (averaged across barcodes) between replicates in each condition, as well as the full distribution of functional scores.

Hide source
saveas = "replicate_functional_score_correlation_scatter"
pal = sns.color_palette('tab20')

fig, ax = plt.subplots(2,3, sharex="row", sharey="row", figsize=[6.4, 5.3])
collapsed_bc_df = func_score_df.groupby(
    ["replicate", "condition", "aa_substitutions"]
).aggregate("mean").reset_index()
collapsed_bc_df = collapsed_bc_df.assign(
    is_stop=[True if "*" in aasubs else False for aasubs in collapsed_bc_df.aa_substitutions]
)

is_stop_alpha_dict = {
    True : 0.5,
    False : 0.2
}

lim = [-3.8, 2.8]
ticks = np.linspace(-3, 2, 6)
for col, (condition, condition_df) in enumerate(collapsed_bc_df.groupby("condition")):
    
    row = 0
    iter_ax = ax[row, col]
    
    df = reduce(
        lambda left, right: pd.merge(
            left, right, left_index=True, right_index=True, how="inner"
        ),
        [
            rep_df.rename({"func_score":f"rep_{rep}_func_score"}, axis=1).set_index("aa_substitutions")
            for rep, rep_df in condition_df.groupby("replicate") 
        ],
    )
    
    df = df.assign(
        is_stop=[True if "*" in aasubs else False for aasubs in df.index.values]
    )
    df = df.assign(
        n_subs=[len(aasubs.split()) for aasubs in df.index.values]
    )
    
    alpha = [is_stop_alpha_dict[istp] for istp in df.is_stop]
    sns.scatterplot(
        df, 
        x="rep_1_func_score", 
        y="rep_2_func_score", 
        ax =iter_ax,
        alpha=alpha,
        hue="is_stop",
        hue_order=[False, True],
        palette=["darkgrey", "red"]
    )
    
    iter_ax.plot([-3.5, 2.5], [-3.5, 2.5], "--", lw=2, c="royalblue")
    
    iter_ax.set_ylim(lim)
    iter_ax.set_xlim(lim)
    if col == 0:
        iter_ax.set_yticks(ticks, labels=ticks)
    iter_ax.set_xticks(ticks, labels=ticks, rotation=45)
    
    corr = pearsonr(df["rep_1_func_score"], df["rep_2_func_score"])[0]
    iter_ax.annotate(
        f"$r = {corr:.2f}$", 
        (0.1, 0.9), 
        xycoords="axes fraction", 
        fontsize=12
    )
    iter_ax.set_title(condition)
    iter_ax.get_legend().remove()
    sns.despine(ax=iter_ax)
    
    row = 1
    iter_ax = ax[row, col]
    sns.violinplot(
        condition_df,
        x="is_stop",
        y="func_score",
        hue="replicate",
        split=True,
        palette=["0.5", "0.75"],
        ax=iter_ax
    )
    
    sns.despine(ax=iter_ax)
    if col != 2:
        iter_ax.get_legend().remove()
    else:
        iter_ax.legend(bbox_to_anchor = (1.25, 1.05), title="replicate")
    if col == 0:
        iter_ax.set_yticks(ticks, labels=ticks)

ax[0,0].set_xlabel("")
ax[0,0].set_ylabel("replicate 2 \n functional score")

ax[0,1].set_xlabel("replicate 1 functional score")
ax[0,1].set_title("BA.1")
ax[0,2].set_xlabel("")
ax[0,2].set_title("BA.2")

ax[1,0].set_xlabel("")
ax[1,0].set_ylabel("functional score")

ax[1,1].set_xlabel("variants contain stop codon mutations")
ax[1,2].set_xlabel("")
ax[1,2].set_ylabel("")
ax[1,1].set_ylabel("")

ax[0,0].text(
    -0.1, 1.06, 
    f"A", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.06, 
    f"B", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)

# fig.suptitle("Variant Functional Score \nReplicate Correlation")
plt.tight_layout()
fig.subplots_adjust(wspace=0.08, hspace = 0.5)
fig.savefig(f"{OUTDIR}/{saveas}.pdf")
fig.savefig(f"{OUTDIR}/{saveas}.png")
plt.show()
_images/5b601ebb6fac1d1e46fb1a88ee7f05bd42c45cb3bbfd8804c3d37ef988ac9183.png

Encode data for fitting

Next, we use the multidms.Data class to prep our data for fitting. To see the class docstring describing the required input and keyword arguments, toggle the output for the line below.

help(multidms.Data)
Hide output
Help on class Data in module multidms.data:

class Data(builtins.object)
 |  Data(variants_df: pandas.core.frame.DataFrame, reference: str, alphabet=('A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'), collapse_identical_variants=False, condition_colors=('#0072B2', '#CC79A7', '#009E73', '#17BECF', '#BCDB22'), letter_suffixed_sites=False, assert_site_integrity=False, verbose=False, nb_workers=None)
 |  
 |  Prep and store one-hot encoding of
 |  variant substitutions data.
 |  Individual objects of this type can be shared
 |  by multiple :py:class:`multidms.Model` Objects
 |  for effeciently fitting various models to the same data.
 |  
 |  Note
 |  ----
 |  You can initialize a :class:`Data` object with a :class:`pandas.DataFrame`
 |  with a row for each variant sampled and annotations
 |  provided in the required columns:
 |  
 |  1. `condition` - Experimental condition from
 |      which a sample measurement was obtained.
 |  2. `aa_substitutions` - Defines each variant
 |      :math:`v` as a string of substitutions (e.g., ``'M3A K5G'``).
 |      Note that while conditions may have differing wild types
 |      at a given site, the sites between conditions should reference
 |      the same site when alignment is performed between
 |      condition wild types.
 |  3. `func_score` - The functional score computed from experimental
 |      measurements.
 |  
 |  Parameters
 |  ----------
 |  variants_df : :class:`pandas.DataFrame` or None
 |      The variant level information from all experiments you
 |      wish to analyze. Should have columns named ``'condition'``,
 |      ``'aa_substitutions'``, and ``'func_score'``.
 |      See the class note for descriptions of each of the features.
 |  reference : str
 |      Name of the condition which annotates the reference.
 |      variants. Note that for model fitting this class will convert all
 |      amino acid substitutions for non-reference condition groups
 |      to relative to the reference condition.
 |      For example, if the wild type amino acid at site 30 is an
 |      A in the reference condition, and a G in a non-reference condition,
 |      then a Y30G mutation in the non-reference condition is recorded as an A30G
 |      mutation relative to the reference. This way, each condition informs
 |      the exact same parameters, even at sites that differ in wild type amino acid.
 |      These are encoded in a :class:`binarymap.binarymap.BinaryMap` object for each
 |      condtion,
 |      where all sites that are non-identical to the reference are 1's.
 |      For motivation, see the `Model overview` section in :class:`multidms.Model`
 |      class notes.
 |  collapse_identical_variants : {'mean', 'median', False}
 |      If identical variants in ``variants_df`` (same 'aa_substitutions'),
 |      exist within individual condition groups,
 |      collapse them by taking mean or median of 'func_score', or
 |      (if `False`) do not collapse at all. Collapsing will make fitting faster,
 |      but *not* a good idea if you are doing bootstrapping.
 |  assert_site_integrity : bool
 |      If True, will assert that all sites in the data frame
 |      have the same wild type amino acid, grouped by condition.
 |  alphabet : array-like
 |      Allowed characters in mutation strings.
 |  condition_colors : array-like or dict
 |      Maps each condition to the color used for plotting. Either a dict keyed
 |      by each condition, or an array of colors that are sequentially assigned
 |      to the conditions.
 |  letter_suffixed_sites: bool
 |      True if sites are sequential and integer, False otherwise.
 |  
 |  Example
 |  -------
 |  Simple example with two conditions (``'a'`` and ``'b'``)
 |  
 |  >>> import pandas as pd
 |  >>> import multidms
 |  >>> func_score_data = {
 |  ...     'condition' : ["a","a","a","a", "b","b","b","b","b","b"],
 |  ...     'aa_substitutions' : [
 |  ...         'M1E', 'G3R', 'G3P', 'M1W', 'M1E',
 |  ...         'P3R', 'P3G', 'M1E P3G', 'M1E P3R', 'P2T'
 |  ...     ],
 |  ...     'func_score' : [2, -7, -0.5, 2.3, 1, -5, 0.4, 2.7, -2.7, 0.3],
 |  ... }
 |  >>> func_score_df = pd.DataFrame(func_score_data)
 |  >>> func_score_df  # doctest: +NORMALIZE_WHITESPACE
 |  condition aa_substitutions  func_score
 |  0         a              M1E         2.0
 |  1         a              G3R        -7.0
 |  2         a              G3P        -0.5
 |  3         a              M1W         2.3
 |  4         b              M1E         1.0
 |  5         b              P3R        -5.0
 |  6         b              P3G         0.4
 |  7         b          M1E P3G         2.7
 |  8         b          M1E P3R        -2.7
 |  9         b              P2T         0.3
 |  
 |  Instantiate a ``Data`` Object allowing for stop codon variants
 |  and declaring condition `"a"` as the reference condition.
 |  
 |  >>> data = multidms.Data(
 |  ...     func_score_df,
 |  ...     alphabet = multidms.AAS_WITHSTOP,
 |  ...     reference = "a",
 |  ... )  # doctest: +ELLIPSIS
 |  INFO: Pandarallel will run on ... workers.
 |  ...
 |  
 |  Note this may take some time due to the string
 |  operations that must be performed when converting
 |  amino acid substitutions to be with respect to a
 |  reference wild type sequence.
 |  
 |  After the object has finished being instantiated,
 |  we now have access to a few 'static' properties
 |  of our data. See individual property docstrings
 |  for more information.
 |  
 |  >>> data.reference
 |  'a'
 |  
 |  >>> data.conditions
 |  ('a', 'b')
 |  
 |  >>> data.mutations
 |  ('M1E', 'M1W', 'G3P', 'G3R')
 |  
 |  >>> data.site_map  # doctest: +NORMALIZE_WHITESPACE
 |  a  b
 |  1  M  M
 |  3  G  P
 |  
 |  >>> data.mutations_df  # doctest: +NORMALIZE_WHITESPACE
 |    mutation wts  sites muts  times_seen_a  times_seen_b
 |  0      M1E   M      1    E             1           3.0
 |  1      M1W   M      1    W             1           0.0
 |  2      G3P   G      3    P             1           1.0
 |  3      G3R   G      3    R             1           2.0
 |  
 |  >>> data.variants_df  # doctest: +NORMALIZE_WHITESPACE
 |    condition aa_substitutions  func_score var_wrt_ref
 |  0         a              M1E         2.0         M1E
 |  1         a              G3R        -7.0         G3R
 |  2         a              G3P        -0.5         G3P
 |  3         a              M1W         2.3         M1W
 |  4         b              M1E         1.0     G3P M1E
 |  5         b              P3R        -5.0         G3R
 |  6         b              P3G         0.4
 |  7         b          M1E P3G         2.7         M1E
 |  8         b          M1E P3R        -2.7     G3R M1E
 |  
 |  Methods defined here:
 |  
 |  __init__(self, variants_df: pandas.core.frame.DataFrame, reference: str, alphabet=('A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'), collapse_identical_variants=False, condition_colors=('#0072B2', '#CC79A7', '#009E73', '#17BECF', '#BCDB22'), letter_suffixed_sites=False, assert_site_integrity=False, verbose=False, nb_workers=None)
 |      See main class docstring.
 |  
 |  convert_subs_wrt_ref_seq(self, condition, aa_subs)
 |      Covert amino acid substitutions to be with respect to the reference sequence.
 |      
 |      Parameters
 |      ----------
 |      condition : str
 |          The condition from which aa substitutions are relative to.
 |      aa_subs : str
 |          A string of amino acid substitutions, relative to the condition sequence,
 |          to converted
 |      
 |      Returns
 |      -------
 |      str
 |          A string of amino acid substitutions relative to the reference sequence.
 |  
 |  plot_func_score_boxplot(self, saveas=None, show=True, **kwargs)
 |      Plot a boxplot of the functional scores for each condition.
 |  
 |  plot_times_seen_hist(self, saveas=None, show=True, **kwargs)
 |      Plot a histogram of the number of times each mutation was seen.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties defined here:
 |  
 |  binarymaps
 |      A dictionary keyed by condition names with values
 |      being a ``BinaryMap`` object for each condition.
 |  
 |  conditions
 |      A tuple of all conditions.
 |  
 |  mut_parser
 |      The mutation parser used to parse mutations.
 |  
 |  mutations
 |      A tuple of all mutations in the order reletive to their index into
 |      the binarymap.
 |  
 |  mutations_df
 |      A dataframe summarizing all single mutations
 |  
 |  non_identical_mutations
 |      A dictionary keyed by condition names with values
 |      being a string of all mutations that differ from the
 |      reference sequence.
 |  
 |  non_identical_sites
 |      A dictionary keyed by condition names with values
 |      being a :class:`pandas.DataFrame` indexed by site,
 |      with columns for the reference
 |      and non-reference amino acid at each site that differs.
 |  
 |  reference
 |      The name of the reference condition.
 |  
 |  reference_sequence_conditions
 |      A list of conditions that have the same wild type
 |      sequence as the reference condition.
 |  
 |  site_map
 |      A dataframe indexed by site, with columns
 |      for all conditions giving the wild type amino acid
 |      at each site.
 |  
 |  split_subs
 |      A function that splits amino acid substitutions into wt, site, and mut
 |      using the mutation parser.
 |  
 |  targets
 |      The functional scores for each variant in the training data.
 |  
 |  training_data
 |      A dictionary with keys 'X' and 'y' for the training data.
 |  
 |  variants_df
 |      A dataframe summarizing all variants in the training data.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __slotnames__ = []

Instantiate an object for each of our two replicate training sets, and append them to a list

datasets = []
for res, fsdf in func_score_df.groupby("replicate"):       

    start = time.time()

    # instantiate data object
    data = multidms.Data(
        func_score_df = fsdf,
        collapse_identical_variants="mean",       # take the average variant func score across barcode replicates
        alphabet=multidms.AAS_WITHSTOP_WITHGAP,   # 
        reference="Omicron_BA1",
        assert_site_integrity=False,
        verbose=True,
        nb_workers=8 
    )

    end = time.time()
    prep_time = round(end-start)
    print(f"Finished, time: {prep_time}")

    datasets.append(data)
Hide output
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [144, 143, 69, 145, 70, 211, 25, 26, 24, 157, 158, 898],
dropping: 10983 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 2041 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
Finished, time: 49
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [145, 70, 144, 143, 69, 211, 422, 26, 24, 25, 157, 158],
dropping: 10129 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 1873 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
Finished, time: 44

Change the datasets colors - the only mutible thing about Data Objects.

Hide content
# condition colors must be hex-encoded,
palette = [
    '#F97306',
    '#BFBFBF',
    '#9400D3'
]

conditions = [
    "Delta",
    "Omicron_BA1",
    "Omicron_BA2"
]

cc = {con:col for con, col in zip(conditions, palette)}
for data in datasets:
    data.condition_colors = cc

Fit Models (Shrinkage Analysis)

For each replicate training set, initialize and fit a set models with variable lasso strength coefficients. Here, we use the multidms.utils.fit_wrapper function to handle the instatiation, and parameter fitting. Expand the cell output below to view the function docstring.

help(multidms.utils.fit_wrapper)
Hide output
Help on function fit_wrapper in module multidms.utils:

fit_wrapper(dataset, huber_scale_huber=1, scale_coeff_lasso_shift=2e-05, scale_coeff_ridge_beta=0, scale_coeff_ridge_shift=0, scale_coeff_ridge_gamma=0, scale_coeff_ridge_ch=0, data_idx=0, epistatic_model='Identity', output_activation='Identity', lock_beta=False, lock_beta_naught=None, gamma_corrected=True, alpha_d=False, init_beta_naught=0.0, warmup_beta=False, tol=0.001, num_training_steps=10, iterations_per_step=2000, save_model_at=[2000, 10000, 20000], PRNGKey=0)
    Fit a multidms model to a dataset. This is a wrapper around the multidms
    fit method that allows for easy specification of the fit parameters.
    This method is helpful for comparing and organizing multiple fits.
    
    Parameters
    ----------
    dataset : :class:`multidms.Data`
        The dataset to fit to.
    huber_scale_huber : float, optional
        The scale of the huber loss function. The default is 1.
    scale_coeff_lasso_shift : float, optional
        The scale of the lasso penalty on the shift parameter. The default is 2e-5.
    scale_coeff_ridge_beta : float, optional
        The scale of the ridge penalty on the beta parameter. The default is 0.
    scale_coeff_ridge_shift : float, optional
        The scale of the ridge penalty on the shift parameter. The default is 0.
    scale_coeff_ridge_gamma : float, optional
        The scale of the ridge penalty on the gamma parameter. The default is 0.
    scale_coeff_ridge_ch : float, optional
        The scale of the ridge penalty on the ch parameter. The default is 0.
    data_idx : int, optional
        The index of the data to fit to. The default is 0.
    epistatic_model : str, optional
        The epistatic model to use. The default is "Identity".
    output_activation : str, optional
        The output activation function to use. The default is "Identity".
    lock_beta : bool, optional
        Whether to lock the beta parameter. The default is False.
    lock_beta_naught : float or None optional
        The value to lock the beta_naught parameter to. If None,
        the beta_naught parameter is free to vary. The default is None.
    gamma_corrected : bool, optional
        Whether to use the gamma corrected model. The default is True.
    alpha_d : bool, optional
        Whether to use the conditional c model. The default is False.
    init_beta_naught : float, optional
        The initial value of the beta_naught parameter. The default is 0.0.
    warmup_beta : bool, optional
        Whether to warmup the model by fitting beta parameters to the
        reference dataset before fitting the full model. The default is False.
    tol : float, optional
        The tolerance for the fit. The default is 1e-3.
    num_training_steps : int, optional
        The number of training steps to perform. The default is 10.
    iterations_per_step : int, optional
        The number of iterations to perform per training step. The default is 2000.
    save_model_at : list, optional
        The iterations at which to save the model. The default is [2000, 10000, 20000].
    PRNGKey : int, optional
        The PRNGKey to use to initialize model parameters. The default is 0.
    
    Returns
    -------
    fit_series : :class:`pandas.Series`
        A series containing the fit attributes and pickled model objects
        at the specified save_model_at steps.
# all models created will be referenced through this dataframe, tying each model to it's respective hyperparams
models = pd.DataFrame()

# the lasso strength coefficients we will test
shrinkage_sweep = [0.0, 1e-06, 1e-05, 2e-05, 5e-05, 0.0001, 0.0005, 0.001]

# choose fitting hyper-parameters
fit_params = {
    'init_beta_naught' : 0.0,
    'epistatic_model' : "Sigmoid",
    'output_activation' : "Identity",
    'warmup_beta':False,
    'gamma_corrected':False,
    'alpha_d': True,
    'scale_coeff_lasso_shift':lasso,
    'scale_coeff_ridge_beta':0,
    'scale_coeff_ridge_shift':0,
    'scale_coeff_ridge_gamma':1e-3,
    'scale_coeff_ridge_ch':1e-3,
    'tol':1e-4,
    'save_model_at':[30000],
    'num_training_steps': 30,
    'iterations_per_step':1000
}
fit_iter = 0
for lasso in shrinkage_sweep:
    for replicate, dataset in enumerate(datasets):

        # update lasso param
        fit_params['scale_coeff_lasso_shift'] = lasso
        
        start_t = time.time()

        # Create and fit model
        model = multidms.utils.fit_wrapper(dataset, **fit_params)
        model['replicate'] = replicate

        # append model and attributes to the final DataFrame
        models = pd.concat([models, model], ignore_index=True)
        end_t = time.time()
        fit_iter += 1
        print(f"Done with {fit_iter}/{len(shrinkage_sweep)*2}! Total time: {round(end_t - start_t)}")
        print("------------------------------------------")

    # each time we fit both replicates to a lasso, re-write the binary to include the latest models
    pickle.dump(models, open(f"{OUTDIR}/models.pkl", "wb"))
Hide output
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9395107515646628, Time: 23 Seconds
training_step 1/30,Loss: 0.7935431678910877, Time: 17 Seconds
training_step 2/30,Loss: 0.7402912563956504, Time: 17 Seconds
training_step 3/30,Loss: 0.7100134660575739, Time: 17 Seconds
training_step 4/30,Loss: 0.6924396753590121, Time: 17 Seconds
training_step 5/30,Loss: 0.6812149717155611, Time: 17 Seconds
training_step 6/30,Loss: 0.6730202561437256, Time: 17 Seconds
training_step 7/30,Loss: 0.6667970782257362, Time: 17 Seconds
training_step 8/30,Loss: 0.6616904748468551, Time: 17 Seconds
training_step 9/30,Loss: 0.6573043138020965, Time: 17 Seconds
training_step 10/30,Loss: 0.6535816971412826, Time: 17 Seconds
training_step 11/30,Loss: 0.6504526680356228, Time: 17 Seconds
training_step 12/30,Loss: 0.6477518124723911, Time: 17 Seconds
training_step 13/30,Loss: 0.6453274955655679, Time: 17 Seconds
training_step 14/30,Loss: 0.6431133403579621, Time: 17 Seconds
training_step 15/30,Loss: 0.6411322630837335, Time: 17 Seconds
training_step 16/30,Loss: 0.6393056306514439, Time: 17 Seconds
training_step 17/30,Loss: 0.6376465911930261, Time: 17 Seconds
training_step 18/30,Loss: 0.6360881408818575, Time: 17 Seconds
training_step 19/30,Loss: 0.6346343974063552, Time: 17 Seconds
training_step 20/30,Loss: 0.6332460323775633, Time: 17 Seconds
training_step 21/30,Loss: 0.6319324470504862, Time: 17 Seconds
training_step 22/30,Loss: 0.6306299843866991, Time: 17 Seconds
training_step 23/30,Loss: 0.6293943866691546, Time: 17 Seconds
training_step 24/30,Loss: 0.6282842235755766, Time: 17 Seconds
training_step 25/30,Loss: 0.6272695410741851, Time: 17 Seconds
training_step 26/30,Loss: 0.626328139571639, Time: 17 Seconds
training_step 27/30,Loss: 0.6254504537584734, Time: 17 Seconds
training_step 28/30,Loss: 0.6246252894843752, Time: 17 Seconds
training_step 29/30,Loss: 0.623845908403332, Time: 17 Seconds
Done with 1/10! Total time: 517
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9531235360436872, Time: 22 Seconds
training_step 1/30,Loss: 0.7824949636183749, Time: 17 Seconds
training_step 2/30,Loss: 0.7267341420942524, Time: 17 Seconds
training_step 3/30,Loss: 0.6895850138520322, Time: 17 Seconds
training_step 4/30,Loss: 0.6735172462694564, Time: 17 Seconds
training_step 5/30,Loss: 0.6604496693741294, Time: 17 Seconds
training_step 6/30,Loss: 0.6518953769263008, Time: 17 Seconds
training_step 7/30,Loss: 0.6451382625449014, Time: 17 Seconds
training_step 8/30,Loss: 0.639534094515506, Time: 17 Seconds
training_step 9/30,Loss: 0.6347831383738758, Time: 17 Seconds
training_step 10/30,Loss: 0.6308211397433248, Time: 17 Seconds
training_step 11/30,Loss: 0.6274681383273852, Time: 17 Seconds
training_step 12/30,Loss: 0.6246522114974457, Time: 17 Seconds
training_step 13/30,Loss: 0.6221565619438494, Time: 17 Seconds
training_step 14/30,Loss: 0.6199115457783899, Time: 17 Seconds
training_step 15/30,Loss: 0.6179847373197274, Time: 17 Seconds
training_step 16/30,Loss: 0.6162569378271759, Time: 17 Seconds
training_step 17/30,Loss: 0.6146742408544656, Time: 17 Seconds
training_step 18/30,Loss: 0.6132061975274181, Time: 17 Seconds
training_step 19/30,Loss: 0.611838884413034, Time: 17 Seconds
training_step 20/30,Loss: 0.610564407164543, Time: 17 Seconds
training_step 21/30,Loss: 0.6093740895289907, Time: 17 Seconds
training_step 22/30,Loss: 0.6082581772816207, Time: 17 Seconds
training_step 23/30,Loss: 0.607202688240156, Time: 17 Seconds
training_step 24/30,Loss: 0.6061988342400632, Time: 17 Seconds
training_step 25/30,Loss: 0.6052442437932108, Time: 17 Seconds
training_step 26/30,Loss: 0.6043407629779405, Time: 17 Seconds
training_step 27/30,Loss: 0.603481547133716, Time: 17 Seconds
training_step 28/30,Loss: 0.6026304184437654, Time: 17 Seconds
training_step 29/30,Loss: 0.6018356593434895, Time: 17 Seconds
Done with 2/10! Total time: 523
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 1e-06,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9411085084349053, Time: 20 Seconds
training_step 1/30,Loss: 0.7951353346883032, Time: 17 Seconds
training_step 2/30,Loss: 0.7387577608441988, Time: 17 Seconds
training_step 3/30,Loss: 0.7115377966372558, Time: 17 Seconds
training_step 4/30,Loss: 0.6938689227388375, Time: 17 Seconds
training_step 5/30,Loss: 0.6827670813820705, Time: 17 Seconds
training_step 6/30,Loss: 0.6749049151885564, Time: 17 Seconds
training_step 7/30,Loss: 0.6687281345947783, Time: 17 Seconds
training_step 8/30,Loss: 0.6635773017173755, Time: 17 Seconds
training_step 9/30,Loss: 0.6593272811632758, Time: 17 Seconds
training_step 10/30,Loss: 0.6556644899581849, Time: 17 Seconds
training_step 11/30,Loss: 0.6524189050161004, Time: 17 Seconds
training_step 12/30,Loss: 0.6494638440351522, Time: 17 Seconds
training_step 13/30,Loss: 0.6466752567717697, Time: 17 Seconds
training_step 14/30,Loss: 0.6439298270236455, Time: 17 Seconds
training_step 15/30,Loss: 0.6414866750117889, Time: 17 Seconds
training_step 16/30,Loss: 0.6392525235774877, Time: 17 Seconds
training_step 17/30,Loss: 0.637157669089053, Time: 17 Seconds
training_step 18/30,Loss: 0.6350338677421581, Time: 17 Seconds
training_step 19/30,Loss: 0.6329171420035633, Time: 17 Seconds
training_step 20/30,Loss: 0.6310369894395464, Time: 17 Seconds
training_step 21/30,Loss: 0.6295253825411683, Time: 17 Seconds
training_step 22/30,Loss: 0.6284107914287551, Time: 17 Seconds
training_step 23/30,Loss: 0.6272871189873673, Time: 17 Seconds
training_step 24/30,Loss: 0.6262150079555098, Time: 17 Seconds
training_step 25/30,Loss: 0.625188288805329, Time: 17 Seconds
training_step 26/30,Loss: 0.6241868947096071, Time: 17 Seconds
training_step 27/30,Loss: 0.6232436166873888, Time: 17 Seconds
training_step 28/30,Loss: 0.6223983803525996, Time: 17 Seconds
training_step 29/30,Loss: 0.6216150227756063, Time: 17 Seconds
Done with 3/10! Total time: 512
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 1e-06,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9539601906094174, Time: 20 Seconds
training_step 1/30,Loss: 0.7860932644613358, Time: 17 Seconds
training_step 2/30,Loss: 0.7235788834201794, Time: 17 Seconds
training_step 3/30,Loss: 0.6932237641487461, Time: 17 Seconds
training_step 4/30,Loss: 0.6744345010420105, Time: 17 Seconds
training_step 5/30,Loss: 0.6631357557438576, Time: 17 Seconds
training_step 6/30,Loss: 0.6546761603134306, Time: 17 Seconds
training_step 7/30,Loss: 0.6481288887887863, Time: 17 Seconds
training_step 8/30,Loss: 0.6428338018664971, Time: 17 Seconds
training_step 9/30,Loss: 0.6383385701591913, Time: 17 Seconds
training_step 10/30,Loss: 0.6343713311198351, Time: 17 Seconds
training_step 11/30,Loss: 0.6309478658647509, Time: 17 Seconds
training_step 12/30,Loss: 0.6280280620196295, Time: 17 Seconds
training_step 13/30,Loss: 0.6254801775343634, Time: 17 Seconds
training_step 14/30,Loss: 0.6232104699195797, Time: 17 Seconds
training_step 15/30,Loss: 0.6211819818987275, Time: 17 Seconds
training_step 16/30,Loss: 0.6193839378836729, Time: 17 Seconds
training_step 17/30,Loss: 0.6177733415110307, Time: 17 Seconds
training_step 18/30,Loss: 0.6163086017814902, Time: 17 Seconds
training_step 19/30,Loss: 0.6149626525647085, Time: 17 Seconds
training_step 20/30,Loss: 0.6137159450484408, Time: 17 Seconds
training_step 21/30,Loss: 0.6125548955261022, Time: 17 Seconds
training_step 22/30,Loss: 0.6114687557491234, Time: 17 Seconds
training_step 23/30,Loss: 0.6104446464301501, Time: 17 Seconds
training_step 24/30,Loss: 0.6094735861124313, Time: 17 Seconds
training_step 25/30,Loss: 0.6085505744442841, Time: 17 Seconds
training_step 26/30,Loss: 0.6076714440896187, Time: 17 Seconds
training_step 27/30,Loss: 0.6068356101564485, Time: 17 Seconds
training_step 28/30,Loss: 0.6060423769532972, Time: 17 Seconds
training_step 29/30,Loss: 0.6052864702251068, Time: 17 Seconds
Done with 4/10! Total time: 521
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 1e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9529916694713718, Time: 20 Seconds
training_step 1/30,Loss: 0.8109438266595337, Time: 17 Seconds
training_step 2/30,Loss: 0.7490176650887481, Time: 17 Seconds
training_step 3/30,Loss: 0.7258298889131694, Time: 17 Seconds
training_step 4/30,Loss: 0.7116347487025543, Time: 17 Seconds
training_step 5/30,Loss: 0.7012832183051562, Time: 17 Seconds
training_step 6/30,Loss: 0.6905670417646879, Time: 17 Seconds
training_step 7/30,Loss: 0.6811450273143584, Time: 17 Seconds
training_step 8/30,Loss: 0.6759444841155552, Time: 17 Seconds
training_step 9/30,Loss: 0.6717872314394205, Time: 17 Seconds
training_step 10/30,Loss: 0.6678662297444783, Time: 17 Seconds
training_step 11/30,Loss: 0.6649882376822165, Time: 17 Seconds
training_step 12/30,Loss: 0.662531953868567, Time: 17 Seconds
training_step 13/30,Loss: 0.6604426048340516, Time: 17 Seconds
training_step 14/30,Loss: 0.6586480977193018, Time: 17 Seconds
training_step 15/30,Loss: 0.6570881153699908, Time: 17 Seconds
training_step 16/30,Loss: 0.6557269568915551, Time: 17 Seconds
training_step 17/30,Loss: 0.6545379199131502, Time: 17 Seconds
training_step 18/30,Loss: 0.6534961967305373, Time: 17 Seconds
training_step 19/30,Loss: 0.6525535810871477, Time: 17 Seconds
training_step 20/30,Loss: 0.6516855057712272, Time: 17 Seconds
training_step 21/30,Loss: 0.6508554177833095, Time: 17 Seconds
training_step 22/30,Loss: 0.650071950567261, Time: 17 Seconds
training_step 23/30,Loss: 0.6492801199119748, Time: 17 Seconds
training_step 24/30,Loss: 0.6484547847919899, Time: 17 Seconds
training_step 25/30,Loss: 0.6476587538536213, Time: 17 Seconds
training_step 26/30,Loss: 0.6468616450949586, Time: 17 Seconds
training_step 27/30,Loss: 0.645985090458572, Time: 17 Seconds
training_step 28/30,Loss: 0.6452370687393627, Time: 17 Seconds
training_step 29/30,Loss: 0.6445437767621642, Time: 17 Seconds
Done with 5/10! Total time: 512
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 1e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9665988520714643, Time: 20 Seconds
training_step 1/30,Loss: 0.7995667254358467, Time: 17 Seconds
training_step 2/30,Loss: 0.7337965768933323, Time: 17 Seconds
training_step 3/30,Loss: 0.708499385582115, Time: 17 Seconds
training_step 4/30,Loss: 0.6936101626289363, Time: 17 Seconds
training_step 5/30,Loss: 0.6832169206386218, Time: 17 Seconds
training_step 6/30,Loss: 0.6744424425913891, Time: 17 Seconds
training_step 7/30,Loss: 0.6656927218324825, Time: 17 Seconds
training_step 8/30,Loss: 0.6578313546097615, Time: 17 Seconds
training_step 9/30,Loss: 0.6529273045743456, Time: 17 Seconds
training_step 10/30,Loss: 0.6493759448539036, Time: 17 Seconds
training_step 11/30,Loss: 0.6464656034741934, Time: 17 Seconds
training_step 12/30,Loss: 0.6436324406366413, Time: 17 Seconds
training_step 13/30,Loss: 0.6413797550852132, Time: 17 Seconds
training_step 14/30,Loss: 0.6395537410944931, Time: 17 Seconds
training_step 15/30,Loss: 0.6379542802206452, Time: 17 Seconds
training_step 16/30,Loss: 0.6365301706338777, Time: 17 Seconds
training_step 17/30,Loss: 0.6352842689217071, Time: 17 Seconds
training_step 18/30,Loss: 0.6341719639755358, Time: 17 Seconds
training_step 19/30,Loss: 0.6331896118829146, Time: 17 Seconds
training_step 20/30,Loss: 0.6323171346595687, Time: 17 Seconds
training_step 21/30,Loss: 0.631510000009023, Time: 17 Seconds
training_step 22/30,Loss: 0.6307530441488656, Time: 17 Seconds
training_step 23/30,Loss: 0.6300440249789128, Time: 17 Seconds
training_step 24/30,Loss: 0.6293919625327828, Time: 17 Seconds
training_step 25/30,Loss: 0.6287951949529891, Time: 17 Seconds
training_step 26/30,Loss: 0.6282469621734285, Time: 17 Seconds
training_step 27/30,Loss: 0.6277436493327149, Time: 17 Seconds
training_step 28/30,Loss: 0.6272740365645973, Time: 17 Seconds
training_step 29/30,Loss: 0.6268346913198533, Time: 17 Seconds
Done with 6/10! Total time: 522
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 2e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9716740948127651, Time: 20 Seconds
training_step 1/30,Loss: 0.8299453454175213, Time: 17 Seconds
training_step 2/30,Loss: 0.7703904390725789, Time: 17 Seconds
training_step 3/30,Loss: 0.7476169439051259, Time: 17 Seconds
training_step 4/30,Loss: 0.7335335125570075, Time: 17 Seconds
training_step 5/30,Loss: 0.7209290532052455, Time: 17 Seconds
training_step 6/30,Loss: 0.7091287282358232, Time: 17 Seconds
training_step 7/30,Loss: 0.7030996492065293, Time: 17 Seconds
training_step 8/30,Loss: 0.6975641602613856, Time: 17 Seconds
training_step 9/30,Loss: 0.69356242424749, Time: 17 Seconds
training_step 10/30,Loss: 0.6903134185615807, Time: 17 Seconds
training_step 11/30,Loss: 0.6875892891096955, Time: 17 Seconds
training_step 12/30,Loss: 0.6851666834840342, Time: 17 Seconds
training_step 13/30,Loss: 0.6830973743308437, Time: 17 Seconds
training_step 14/30,Loss: 0.6811825256209616, Time: 17 Seconds
training_step 15/30,Loss: 0.6792122788609162, Time: 17 Seconds
training_step 16/30,Loss: 0.677198641282623, Time: 17 Seconds
training_step 17/30,Loss: 0.6750113443991779, Time: 17 Seconds
training_step 18/30,Loss: 0.6732722637689561, Time: 17 Seconds
training_step 19/30,Loss: 0.6720459726533167, Time: 17 Seconds
training_step 20/30,Loss: 0.6710517551829874, Time: 17 Seconds
training_step 21/30,Loss: 0.670199400365799, Time: 17 Seconds
training_step 22/30,Loss: 0.6693605656642059, Time: 17 Seconds
training_step 23/30,Loss: 0.6684547409066477, Time: 17 Seconds
training_step 24/30,Loss: 0.6677614999328346, Time: 17 Seconds
training_step 25/30,Loss: 0.6671745495086318, Time: 17 Seconds
training_step 26/30,Loss: 0.6666522591020745, Time: 17 Seconds
training_step 27/30,Loss: 0.6661759311157153, Time: 17 Seconds
training_step 28/30,Loss: 0.6657440518988291, Time: 17 Seconds
training_step 29/30,Loss: 0.6653488835089487, Time: 17 Seconds
Done with 7/10! Total time: 513
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 2e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 0.9829587868024998, Time: 20 Seconds
training_step 1/30,Loss: 0.8184428455574099, Time: 17 Seconds
training_step 2/30,Loss: 0.7578696294097671, Time: 17 Seconds
training_step 3/30,Loss: 0.7325859226494296, Time: 17 Seconds
training_step 4/30,Loss: 0.7172782588637057, Time: 17 Seconds
training_step 5/30,Loss: 0.7032520570882904, Time: 17 Seconds
training_step 6/30,Loss: 0.6912878218717999, Time: 17 Seconds
training_step 7/30,Loss: 0.6851013296476907, Time: 17 Seconds
training_step 8/30,Loss: 0.6797209091126535, Time: 17 Seconds
training_step 9/30,Loss: 0.6757295731902787, Time: 17 Seconds
training_step 10/30,Loss: 0.67241719234021, Time: 17 Seconds
training_step 11/30,Loss: 0.6697199374476086, Time: 17 Seconds
training_step 12/30,Loss: 0.6674330373922486, Time: 17 Seconds
training_step 13/30,Loss: 0.6654443465930208, Time: 18 Seconds
training_step 14/30,Loss: 0.6637515019575377, Time: 17 Seconds
training_step 15/30,Loss: 0.6622480739018693, Time: 17 Seconds
training_step 16/30,Loss: 0.6608521686111264, Time: 17 Seconds
training_step 17/30,Loss: 0.6594041121629161, Time: 17 Seconds
training_step 18/30,Loss: 0.6580577685335367, Time: 17 Seconds
training_step 19/30,Loss: 0.6565305934720959, Time: 17 Seconds
training_step 20/30,Loss: 0.6551412290042901, Time: 17 Seconds
training_step 21/30,Loss: 0.6541476045375894, Time: 17 Seconds
training_step 22/30,Loss: 0.6533255126084351, Time: 17 Seconds
training_step 23/30,Loss: 0.6526277657451961, Time: 17 Seconds
training_step 24/30,Loss: 0.6520335603210731, Time: 17 Seconds
training_step 25/30,Loss: 0.6514881024581287, Time: 17 Seconds
training_step 26/30,Loss: 0.6509788040923261, Time: 17 Seconds
training_step 27/30,Loss: 0.6504565225855403, Time: 17 Seconds
training_step 28/30,Loss: 0.64993106266573, Time: 17 Seconds
training_step 29/30,Loss: 0.6494950645660391, Time: 17 Seconds
Done with 8/10! Total time: 521
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 5e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.0301489074085346, Time: 20 Seconds
training_step 1/30,Loss: 0.8814828962596548, Time: 17 Seconds
training_step 2/30,Loss: 0.8319113607238466, Time: 17 Seconds
training_step 3/30,Loss: 0.8060958716525601, Time: 17 Seconds
training_step 4/30,Loss: 0.7900223474324285, Time: 17 Seconds
training_step 5/30,Loss: 0.7792984211152172, Time: 17 Seconds
training_step 6/30,Loss: 0.7709705883948414, Time: 17 Seconds
training_step 7/30,Loss: 0.7628029260521945, Time: 17 Seconds
training_step 8/30,Loss: 0.7542031374203901, Time: 17 Seconds
training_step 9/30,Loss: 0.7491451754146574, Time: 17 Seconds
training_step 10/30,Loss: 0.7454933308354424, Time: 17 Seconds
training_step 11/30,Loss: 0.7419740010856419, Time: 17 Seconds
training_step 12/30,Loss: 0.7392133373058141, Time: 17 Seconds
training_step 13/30,Loss: 0.7368275015653778, Time: 17 Seconds
training_step 14/30,Loss: 0.734762479208252, Time: 17 Seconds
training_step 15/30,Loss: 0.7328745898324096, Time: 17 Seconds
training_step 16/30,Loss: 0.7311239958872938, Time: 17 Seconds
training_step 17/30,Loss: 0.7294470505889156, Time: 17 Seconds
training_step 18/30,Loss: 0.7276018039039213, Time: 17 Seconds
training_step 19/30,Loss: 0.7256510317861257, Time: 17 Seconds
training_step 20/30,Loss: 0.7233097131980532, Time: 17 Seconds
training_step 21/30,Loss: 0.7216899401497405, Time: 17 Seconds
training_step 22/30,Loss: 0.7204421380447336, Time: 17 Seconds
training_step 23/30,Loss: 0.7193771509115203, Time: 17 Seconds
training_step 24/30,Loss: 0.7182970589854603, Time: 17 Seconds
training_step 25/30,Loss: 0.717297234531093, Time: 17 Seconds
training_step 26/30,Loss: 0.7163944611681639, Time: 17 Seconds
training_step 27/30,Loss: 0.7155494194900364, Time: 17 Seconds
training_step 28/30,Loss: 0.7148058366514399, Time: 17 Seconds
training_step 29/30,Loss: 0.7141392263232371, Time: 17 Seconds
Done with 9/10! Total time: 512
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 5e-05,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.0488032403331338, Time: 20 Seconds
training_step 1/30,Loss: 0.8753436739858613, Time: 17 Seconds
training_step 2/30,Loss: 0.8101906569927726, Time: 17 Seconds
training_step 3/30,Loss: 0.7856918831145904, Time: 17 Seconds
training_step 4/30,Loss: 0.770878407657035, Time: 17 Seconds
training_step 5/30,Loss: 0.7603754359330324, Time: 17 Seconds
training_step 6/30,Loss: 0.7496893076277338, Time: 17 Seconds
training_step 7/30,Loss: 0.740859616712574, Time: 17 Seconds
training_step 8/30,Loss: 0.7358931868223947, Time: 17 Seconds
training_step 9/30,Loss: 0.7316010001728322, Time: 17 Seconds
training_step 10/30,Loss: 0.7281106525605864, Time: 17 Seconds
training_step 11/30,Loss: 0.72521815106694, Time: 17 Seconds
training_step 12/30,Loss: 0.7227933752107949, Time: 17 Seconds
training_step 13/30,Loss: 0.7206609051498328, Time: 17 Seconds
training_step 14/30,Loss: 0.7187065568157593, Time: 17 Seconds
training_step 15/30,Loss: 0.7169396697163374, Time: 18 Seconds
training_step 16/30,Loss: 0.7151988144199851, Time: 17 Seconds
training_step 17/30,Loss: 0.7134136791885336, Time: 17 Seconds
training_step 18/30,Loss: 0.7115252940791621, Time: 17 Seconds
training_step 19/30,Loss: 0.7095626074905415, Time: 17 Seconds
training_step 20/30,Loss: 0.7081903493347494, Time: 17 Seconds
training_step 21/30,Loss: 0.7070926091993874, Time: 17 Seconds
training_step 22/30,Loss: 0.7061576339283927, Time: 17 Seconds
training_step 23/30,Loss: 0.7052796525610245, Time: 17 Seconds
training_step 24/30,Loss: 0.704343536313657, Time: 17 Seconds
training_step 25/30,Loss: 0.7035551826961794, Time: 17 Seconds
training_step 26/30,Loss: 0.702830591255112, Time: 17 Seconds
training_step 27/30,Loss: 0.70215070857367, Time: 17 Seconds
training_step 28/30,Loss: 0.7015208302350162, Time: 17 Seconds
training_step 29/30,Loss: 0.7009295906103199, Time: 17 Seconds
Done with 10/10! Total time: 520
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0001,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.0710028536597322, Time: 20 Seconds
training_step 1/30,Loss: 0.9272533240685601, Time: 17 Seconds
training_step 2/30,Loss: 0.8871303780658435, Time: 17 Seconds
training_step 3/30,Loss: 0.8659071460006909, Time: 17 Seconds
training_step 4/30,Loss: 0.8519914253205725, Time: 17 Seconds
training_step 5/30,Loss: 0.8418000965016657, Time: 17 Seconds
training_step 6/30,Loss: 0.8339841657490454, Time: 17 Seconds
training_step 7/30,Loss: 0.8278423818432433, Time: 17 Seconds
training_step 8/30,Loss: 0.8230905321559758, Time: 17 Seconds
training_step 9/30,Loss: 0.8194231822390863, Time: 17 Seconds
training_step 10/30,Loss: 0.8165072122537879, Time: 17 Seconds
training_step 11/30,Loss: 0.814271097967542, Time: 17 Seconds
training_step 12/30,Loss: 0.8125973558109776, Time: 17 Seconds
training_step 13/30,Loss: 0.8112157881904055, Time: 17 Seconds
training_step 14/30,Loss: 0.8100531899867289, Time: 17 Seconds
training_step 15/30,Loss: 0.809065425861986, Time: 17 Seconds
training_step 16/30,Loss: 0.8082399342076081, Time: 17 Seconds
training_step 17/30,Loss: 0.8075536728205767, Time: 17 Seconds
training_step 18/30,Loss: 0.8069728869711577, Time: 17 Seconds
training_step 19/30,Loss: 0.8064725921519781, Time: 17 Seconds
training_step 20/30,Loss: 0.8060387770583584, Time: 17 Seconds
training_step 21/30,Loss: 0.8056738117039688, Time: 17 Seconds
training_step 22/30,Loss: 0.8053987019399054, Time: 17 Seconds
training_step 23/30,Loss: 0.8051975120036422, Time: 17 Seconds
training_step 24/30,Loss: 0.8050488842489194, Time: 17 Seconds
training_step 25/30,Loss: 0.8049277325915496, Time: 17 Seconds
training_step 26/30,Loss: 0.8048242241711646, Time: 17 Seconds
training_step 27/30,Loss: 0.8047561444353504, Time: 17 Seconds
training_step 28/30,Loss: 0.8047561469785109, Time: 8 Seconds
training_step 29/30,Loss: 0.8047561416586488, Time: 8 Seconds
Done with 11/10! Total time: 493
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0001,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.105118551072002, Time: 20 Seconds
training_step 1/30,Loss: 0.9214784279548643, Time: 17 Seconds
training_step 2/30,Loss: 0.8777185896716494, Time: 17 Seconds
training_step 3/30,Loss: 0.8488670490839356, Time: 17 Seconds
training_step 4/30,Loss: 0.8253149782540141, Time: 18 Seconds
training_step 5/30,Loss: 0.8144529545935955, Time: 17 Seconds
training_step 6/30,Loss: 0.8074847481290137, Time: 17 Seconds
training_step 7/30,Loss: 0.8019760209215097, Time: 17 Seconds
training_step 8/30,Loss: 0.7978052667868804, Time: 17 Seconds
training_step 9/30,Loss: 0.7944655507812542, Time: 17 Seconds
training_step 10/30,Loss: 0.7918328692552002, Time: 17 Seconds
training_step 11/30,Loss: 0.7896310491639691, Time: 17 Seconds
training_step 12/30,Loss: 0.7878058501271108, Time: 17 Seconds
training_step 13/30,Loss: 0.7862480843200061, Time: 17 Seconds
training_step 14/30,Loss: 0.7848903466979735, Time: 17 Seconds
training_step 15/30,Loss: 0.783706094030978, Time: 17 Seconds
training_step 16/30,Loss: 0.7826961402819297, Time: 17 Seconds
training_step 17/30,Loss: 0.7818464455768005, Time: 17 Seconds
training_step 18/30,Loss: 0.7810956670687091, Time: 17 Seconds
training_step 19/30,Loss: 0.7804276473864248, Time: 17 Seconds
training_step 20/30,Loss: 0.7798464431872437, Time: 17 Seconds
training_step 21/30,Loss: 0.7793456168284726, Time: 17 Seconds
training_step 22/30,Loss: 0.7789229377761374, Time: 17 Seconds
training_step 23/30,Loss: 0.7785592555437937, Time: 17 Seconds
training_step 24/30,Loss: 0.7782403243799398, Time: 17 Seconds
training_step 25/30,Loss: 0.7779732812643934, Time: 17 Seconds
training_step 26/30,Loss: 0.7779732794152913, Time: 8 Seconds
training_step 27/30,Loss: 0.7779732754393336, Time: 8 Seconds
training_step 28/30,Loss: 0.7779732545002062, Time: 8 Seconds
training_step 29/30,Loss: 0.7779732458470086, Time: 8 Seconds
Done with 12/10! Total time: 482
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0005,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.260790670167503, Time: 20 Seconds
training_step 1/30,Loss: 0.9591817432874664, Time: 17 Seconds
training_step 2/30,Loss: 0.9152207706194297, Time: 17 Seconds
training_step 3/30,Loss: 0.8958393731846424, Time: 17 Seconds
training_step 4/30,Loss: 0.8847774219598153, Time: 17 Seconds
training_step 5/30,Loss: 0.8771388675771862, Time: 17 Seconds
training_step 6/30,Loss: 0.871604680598868, Time: 17 Seconds
training_step 7/30,Loss: 0.8671787856953397, Time: 17 Seconds
training_step 8/30,Loss: 0.86295654697443, Time: 17 Seconds
training_step 9/30,Loss: 0.8595400229311561, Time: 17 Seconds
training_step 10/30,Loss: 0.8570449017057934, Time: 17 Seconds
training_step 11/30,Loss: 0.8550095309997527, Time: 17 Seconds
training_step 12/30,Loss: 0.853229635942459, Time: 17 Seconds
training_step 13/30,Loss: 0.8514523649770078, Time: 17 Seconds
training_step 14/30,Loss: 0.8498203000608489, Time: 17 Seconds
training_step 15/30,Loss: 0.8483146387749771, Time: 17 Seconds
training_step 16/30,Loss: 0.8470333702457044, Time: 17 Seconds
training_step 17/30,Loss: 0.8470333423308939, Time: 8 Seconds
training_step 18/30,Loss: 0.8456955114136565, Time: 17 Seconds
training_step 19/30,Loss: 0.8456954910713248, Time: 8 Seconds
training_step 20/30,Loss: 0.8456954762717659, Time: 8 Seconds
training_step 21/30,Loss: 0.8456954521196871, Time: 8 Seconds
training_step 22/30,Loss: 0.8456954412935006, Time: 8 Seconds
training_step 23/30,Loss: 0.8456954159690919, Time: 8 Seconds
training_step 24/30,Loss: 0.8456953983651515, Time: 8 Seconds
training_step 25/30,Loss: 0.8456953874984139, Time: 8 Seconds
training_step 26/30,Loss: 0.8456953621849683, Time: 8 Seconds
training_step 27/30,Loss: 0.845695344612227, Time: 8 Seconds
training_step 28/30,Loss: 0.8456953337040725, Time: 8 Seconds
training_step 29/30,Loss: 0.8456953084018886, Time: 8 Seconds
Done with 13/10! Total time: 401
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.0005,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.1129877654379086, Time: 20 Seconds
training_step 1/30,Loss: 0.9576904772259396, Time: 17 Seconds
training_step 2/30,Loss: 0.9154177122077596, Time: 17 Seconds
training_step 3/30,Loss: 0.8954528159100025, Time: 17 Seconds
training_step 4/30,Loss: 0.8826456836761875, Time: 17 Seconds
training_step 5/30,Loss: 0.8738895415497575, Time: 17 Seconds
training_step 6/30,Loss: 0.8678674830848935, Time: 17 Seconds
training_step 7/30,Loss: 0.8631852016297108, Time: 17 Seconds
training_step 8/30,Loss: 0.8592923704190718, Time: 17 Seconds
training_step 9/30,Loss: 0.856068205368557, Time: 17 Seconds
training_step 10/30,Loss: 0.8533649237211887, Time: 17 Seconds
training_step 11/30,Loss: 0.8510324658103617, Time: 17 Seconds
training_step 12/30,Loss: 0.8488985079484156, Time: 17 Seconds
training_step 13/30,Loss: 0.8467727251378079, Time: 17 Seconds
training_step 14/30,Loss: 0.8449412938146786, Time: 17 Seconds
training_step 15/30,Loss: 0.8435677765471077, Time: 17 Seconds
training_step 16/30,Loss: 0.8423967385138377, Time: 17 Seconds
training_step 17/30,Loss: 0.8413400932450386, Time: 17 Seconds
training_step 18/30,Loss: 0.8403120259829661, Time: 17 Seconds
training_step 19/30,Loss: 0.8392535175521084, Time: 17 Seconds
training_step 20/30,Loss: 0.8392535015503193, Time: 8 Seconds
training_step 21/30,Loss: 0.8381735913841794, Time: 17 Seconds
training_step 22/30,Loss: 0.8381735661336341, Time: 8 Seconds
training_step 23/30,Loss: 0.8381735495612981, Time: 8 Seconds
training_step 24/30,Loss: 0.8381735210222258, Time: 8 Seconds
training_step 25/30,Loss: 0.8381735018573719, Time: 8 Seconds
training_step 26/30,Loss: 0.8381734848408026, Time: 8 Seconds
training_step 27/30,Loss: 0.8381734564494945, Time: 8 Seconds
training_step 28/30,Loss: 0.838173437622868, Time: 8 Seconds
training_step 29/30,Loss: 0.8381734201205366, Time: 8 Seconds
Done with 14/10! Total time: 435
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1c7254290>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.001,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.7003381661760741, Time: 20 Seconds
training_step 1/30,Loss: 0.9936700786708972, Time: 17 Seconds
training_step 2/30,Loss: 0.9145926861906877, Time: 17 Seconds
training_step 3/30,Loss: 0.8899795237469192, Time: 17 Seconds
training_step 4/30,Loss: 0.878741086187927, Time: 17 Seconds
training_step 5/30,Loss: 0.8710912969958748, Time: 17 Seconds
training_step 6/30,Loss: 0.8658846555858415, Time: 17 Seconds
training_step 7/30,Loss: 0.8623780041551086, Time: 17 Seconds
training_step 8/30,Loss: 0.8592528742892954, Time: 17 Seconds
training_step 9/30,Loss: 0.8564333612819667, Time: 17 Seconds
training_step 10/30,Loss: 0.8539668966303096, Time: 17 Seconds
training_step 11/30,Loss: 0.8519398562733687, Time: 17 Seconds
training_step 12/30,Loss: 0.8501624132606327, Time: 17 Seconds
training_step 13/30,Loss: 0.8485726179305347, Time: 17 Seconds
training_step 14/30,Loss: 0.8470032300266048, Time: 17 Seconds
training_step 15/30,Loss: 0.8470032063730433, Time: 8 Seconds
training_step 16/30,Loss: 0.8470031885507137, Time: 8 Seconds
training_step 17/30,Loss: 0.8470031609023305, Time: 8 Seconds
training_step 18/30,Loss: 0.8458326372447792, Time: 17 Seconds
training_step 19/30,Loss: 0.8458326230465535, Time: 8 Seconds
training_step 20/30,Loss: 0.8458326127671693, Time: 8 Seconds
training_step 21/30,Loss: 0.8458325985207821, Time: 8 Seconds
training_step 22/30,Loss: 0.845832588291104, Time: 8 Seconds
training_step 23/30,Loss: 0.8458325739954639, Time: 8 Seconds
training_step 24/30,Loss: 0.8458325638167038, Time: 8 Seconds
training_step 25/30,Loss: 0.845832549470497, Time: 8 Seconds
training_step 26/30,Loss: 0.8458325393439324, Time: 8 Seconds
training_step 27/30,Loss: 0.8458325249459995, Time: 8 Seconds
training_step 28/30,Loss: 0.8458325148728886, Time: 8 Seconds
training_step 29/30,Loss: 0.8447098128067998, Time: 17 Seconds
Done with 15/10! Total time: 391
------------------------------------------
running:
{'PRNGKey': 0,
 'alpha_d': True,
 'data_idx': 0,
 'dataset': <multidms.data.Data object at 0x7fa1ca51cbd0>,
 'epistatic_model': 'Sigmoid',
 'gamma_corrected': False,
 'huber_scale_huber': 1,
 'init_beta_naught': 0.0,
 'iterations_per_step': 1000,
 'lock_beta': False,
 'lock_beta_naught': None,
 'num_training_steps': 30,
 'output_activation': 'Identity',
 'save_model_at': [30000],
 'scale_coeff_lasso_shift': 0.001,
 'scale_coeff_ridge_beta': 0,
 'scale_coeff_ridge_ch': 0.001,
 'scale_coeff_ridge_gamma': 0.001,
 'scale_coeff_ridge_shift': 0,
 'step_loss': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'tol': 0.0001,
 'warmup_beta': False}
training_step 0/30,Loss: 1.067255084146299, Time: 21 Seconds
training_step 1/30,Loss: 0.9379927931029655, Time: 17 Seconds
training_step 2/30,Loss: 0.9056464083824038, Time: 17 Seconds
training_step 3/30,Loss: 0.8884359761251179, Time: 17 Seconds
training_step 4/30,Loss: 0.8771505390487915, Time: 17 Seconds
training_step 5/30,Loss: 0.869614283858527, Time: 17 Seconds
training_step 6/30,Loss: 0.8640081708018755, Time: 17 Seconds
training_step 7/30,Loss: 0.8595309250936951, Time: 17 Seconds
training_step 8/30,Loss: 0.8559375587440375, Time: 17 Seconds
training_step 9/30,Loss: 0.8530319257173358, Time: 17 Seconds
training_step 10/30,Loss: 0.8502731162399977, Time: 17 Seconds
training_step 11/30,Loss: 0.8480639282152509, Time: 17 Seconds
training_step 12/30,Loss: 0.8463942553823061, Time: 17 Seconds
training_step 13/30,Loss: 0.8448746359178987, Time: 17 Seconds
training_step 14/30,Loss: 0.8433701591319093, Time: 17 Seconds
training_step 15/30,Loss: 0.8419845251136339, Time: 17 Seconds
training_step 16/30,Loss: 0.8406835659018096, Time: 17 Seconds
training_step 17/30,Loss: 0.8394963381633823, Time: 17 Seconds
training_step 18/30,Loss: 0.8384647577470552, Time: 17 Seconds
training_step 19/30,Loss: 0.8384647379280572, Time: 8 Seconds
training_step 20/30,Loss: 0.83846472200163, Time: 8 Seconds
training_step 21/30,Loss: 0.8384646981873964, Time: 8 Seconds
training_step 22/30,Loss: 0.837334955784605, Time: 17 Seconds
training_step 23/30,Loss: 0.8373349332535582, Time: 8 Seconds
training_step 24/30,Loss: 0.8373349174892963, Time: 8 Seconds
training_step 25/30,Loss: 0.8373349018502734, Time: 8 Seconds
training_step 26/30,Loss: 0.8373348817508238, Time: 8 Seconds
training_step 27/30,Loss: 0.8373348705306773, Time: 8 Seconds
training_step 28/30,Loss: 0.83733484878986, Time: 8 Seconds
training_step 29/30,Loss: 0.8373348425661542, Time: 8 Seconds
Done with 16/10! Total time: 425
------------------------------------------

The cell above saves the models and their relevant hyper-parameters in a DataFrame to a pickle binary file. Hense, If it’s already been run, and you want to execute the code below without re-fitting, the following cell will load that binary file.

Hide content
models = pickle.load(open(f"{OUTDIR}/models.pkl", "rb"))

We have now fit 8 models per replicate, each with a different lasso strength. These data are organized into a dataframe, where the column “model_30000” contains a reference to the respective multidms.Model Object.

models[['epistatic_model', 'output_activation', 'step_loss', 'model_30000', 'replicate']]
Hide output
epistatic_model output_activation step_loss model_30000 replicate
0 Sigmoid Identity [0.9395107515646628, 0.7935431678910877, 0.740... <multidms.model.Model object at 0x7fed3020cf10> 0
1 Sigmoid Identity [0.9531235360436874, 0.7824949636183749, 0.726... <multidms.model.Model object at 0x7fed59943c10> 1
2 Sigmoid Identity [0.9411085084349053, 0.7951353346883032, 0.738... <multidms.model.Model object at 0x7fed3020ed90> 0
3 Sigmoid Identity [0.9539601906094174, 0.7860932644613358, 0.723... <multidms.model.Model object at 0x7fed3020fed0> 1
4 Sigmoid Identity [0.9529916694713718, 0.8109438266595337, 0.749... <multidms.model.Model object at 0x7fed30264d50> 0
5 Sigmoid Identity [0.9665988520714643, 0.7995667254358467, 0.733... <multidms.model.Model object at 0x7fed30265990> 1
6 Sigmoid Identity [0.9716740948127651, 0.8299453454175213, 0.770... <multidms.model.Model object at 0x7fed30266f10> 0
7 Sigmoid Identity [0.9829587868024998, 0.8184428455574099, 0.757... <multidms.model.Model object at 0x7fed30268250> 1
8 Sigmoid Identity [1.0301489074085346, 0.8814828962596548, 0.831... <multidms.model.Model object at 0x7fed30269690> 0
9 Sigmoid Identity [1.0488032403331338, 0.8753436739858613, 0.810... <multidms.model.Model object at 0x7fed3026ab10> 1
10 Sigmoid Identity [1.0710028536597322, 0.9272533240685601, 0.887... <multidms.model.Model object at 0x7fed3026c390> 0
11 Sigmoid Identity [1.1051185510720019, 0.9214784279548643, 0.877... <multidms.model.Model object at 0x7fed3026d510> 1
12 Sigmoid Identity [1.260790670167503, 0.9591817432874664, 0.9152... <multidms.model.Model object at 0x7fed3026e410> 0
13 Sigmoid Identity [1.1129877654379086, 0.9576904772259396, 0.915... <multidms.model.Model object at 0x7fed3026f110> 1
14 Sigmoid Identity [1.7003381661760741, 0.9936700786708974, 0.914... <multidms.model.Model object at 0x7fed30274710> 0
15 Sigmoid Identity [1.067255084146299, 0.9379927931029655, 0.9056... <multidms.model.Model object at 0x7fed30275790> 1
Hide content
# set some global variables
# columns name
model_name = "model_30000"

# the number of times a mutation must be seen in each condition to be included in downstream analysis
times_seen_threshold = 1

# The chosen lasso strength for our final spike model results
chosen_lasso_strength = 5e-5

Plot the model loss over training steps, as provided by fit_wrapper.

Hide source
saveas="convergence_all_lasso_lines"
cmap=plt.get_cmap("Set2")


fig, ax = plt.subplots(1,figsize=[6.4,4.5])
color_idx = -1
for i, (model, model_row) in enumerate(models.iterrows()):
    if i%2 == 0: color_idx += 1

    ax.plot(
        [1000 * (s+1) for s in range(len(model_row.step_loss))],
        model_row.step_loss,
        c=cmap.colors[color_idx],
        lw=3,
        linestyle="-" if model_row.replicate == 0 else "--",
        label=f"rep: {model_row.replicate} scale_coeff: {model_row.scale_coeff_lasso_shift}"
    )

ticks = range(0, 30001, 5000)
labels = [f"{t//1000}K" for t in ticks]
ax.set_xticks(ticks, labels, rotation=0, ha='center')
ax.set_ylabel("Model Loss (w/o L1 penalty)")
ax.set_xlabel("Optimization iterations")

black_line = mlines.Line2D([], [], color='black', linestyle='-',
                          markersize=5, label='rep 1')
black_dashed = mlines.Line2D([], [], color='black',linestyle='--',
                          markersize=5, label='rep 2')
lasso_color_handles = [
    mlines.Line2D(
        [], [], 
        color=color, 
        linestyle='-',
        markersize=5,
        linewidth=3,
        label="$\lambda$: "+str(lasso)
    )
    for lasso, color in zip(models.scale_coeff_lasso_shift.unique(), cmap.colors)
]

elements = [black_line, black_dashed] + lasso_color_handles
ax.legend(handles=elements, bbox_to_anchor = (1, 1), loc='upper left', frameon=False, fontsize=9)
sns.despine(ax=ax)
ax.set_ylim()
plt.tight_layout()
fig.savefig(f"{OUTDIR}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{OUTDIR}/{saveas}.png",bbox_inches='tight')
plt.show()
_images/98859fe28b30001a65a77d1af840d635bc95f40c5f0b3074d1c923ad2a58c310.png

Next, we’ll wrangle the results of all models into tall-style a dataframe where each row is a single mutation - lasso strength combination

Hide source
tall_mut_df = pd.DataFrame()
for replicate, rep_models in models.groupby("replicate"):
    fit_dict = {f"l_{float(row.scale_coeff_lasso_shift)}":row[model_name] for _, row in rep_models.iterrows()}
    
    mut_df = multidms.utils.combine_replicate_muts(
        fit_dict, 
        times_seen_threshold=times_seen_threshold, 
        how="outer"
    )
    mut_df.rename(
        {
            bc : f"{bc}_Omicron_BA1"
            for bc in mut_df.columns
            if "beta" in bc
        }, 
        axis=1, 
        inplace=True
    )

    mut_df = pd.melt(
        mut_df.reset_index(), 
        id_vars=["mutation"],
        value_vars=[
            col for col in mut_df.columns 
            if ("_shift_" in col or "beta" in col) and "avg" not in col
        ],
        value_name="S"
    )
    
    mut_df = mut_df.assign(
        scale_coeff_lasso=[ 
            v.split("_")[1]
            for v in mut_df.variable
        ],
        sense=[
            "stop" if "*" in mut else "nonsynonymous"
            for mut in mut_df.mutation
        ],
        condition=[
            "_".join(v.split("_")[3:])
            for v in mut_df.variable
        ],
        replicate=replicate
    )
    
    mut_df.drop("variable", axis=1, inplace=True)
    tall_mut_df = pd.concat([tall_mut_df, mut_df])
    
tall_mut_df.sample(10, random_state=23)
mutation S scale_coeff_lasso sense condition replicate
53706 G485D 0.598320 0.0 nonsynonymous Delta 0
58872 I233L -0.972775 1e-06 nonsynonymous Delta 0
109322 K1181N -0.516587 0.0 nonsynonymous Omicron_BA2 1
145884 F1075S 0.000000 0.0005 nonsynonymous Omicron_BA2 0
56057 G842C 0.425078 0.0 nonsynonymous Delta 1
93964 G880C -0.000000 0.0005 nonsynonymous Delta 0
152946 D848Y -0.000000 0.001 nonsynonymous Omicron_BA2 1
7504 Y170- -0.249903 1e-06 nonsynonymous Omicron_BA1 1
109112 F86S -0.741166 1e-06 nonsynonymous Omicron_BA2 0
31092 F1109Y 0.225376 5e-05 nonsynonymous Omicron_BA1 0

Compute summary stats of each model at each lasso strength.

Hide content
def loss(model, condition):
    kwargs = {
        'scale_coeff_ridge_beta': 0.,
        'scale_coeff_ridge_shift': 0.,
        'scale_coeff_ridge_gamma': 0.
    }
    
    data = (
        {condition:model.data.training_data["X"][condition]}, 
        {condition:model.data.training_data["y"][condition]}
    )
    
    params = model.get_condition_params(condition)
    return jax.jit(model._model_components["objective"])(model.params, data)

lasso_sparsity_loss = defaultdict(list)
for lasso, lasso_replicates in models.groupby("scale_coeff_lasso_shift"):
    lasso_sparsity_loss["lasso"].append(str(lasso))
    fit_dict = {}
    for _, row in lasso_replicates.iterrows():
        model = row[model_name]
        replicate = row.replicate
        fit_dict[f"{replicate}"] = model

    muts_df_outer = multidms.utils.combine_replicate_muts(
        fit_dict, 
        times_seen_threshold=times_seen_threshold, 
        how="outer"
    )
    muts_df_outer = muts_df_outer.assign(
        sense = [
            "stop" if "*" in mut else "nonsynonymous"
            for mut in muts_df_outer.index.values
        ]
    )

    muts_df_inner = muts_df_outer.dropna()
    for cond in model.data.conditions:
        if cond == model.data.reference:
            

            if np.all(muts_df_inner[f"0_beta"] <= 0.05) or np.all(muts_df_inner[f"1_beta"] <= 0.05):
                lasso_sparsity_loss[f"{cond}-replicate-correlation"].append(np.nan)
            else:
                corr = pearsonr(muts_df_inner[f"0_beta"], muts_df_inner[f"1_beta"])
                r = corr[0]
                lasso_sparsity_loss[f"{cond}-replicate-correlation"].append(r)

            for rep, marker in zip([0,1], ["o", "D"]):
                rep_cond_shifts = muts_df_outer[[f"{rep}_beta", "sense"]].dropna()
                for sense, color in zip(["stop", "nonsynonymous"], ["red", "blue"]):
                    shifts = rep_cond_shifts.query("sense == @sense")[f"{rep}_beta"]
                    sparsity = (len(shifts[shifts==0]) / len(shifts))*100
                    lasso_sparsity_loss[f"{cond}-{rep}-{sense}-sparsity"].append(sparsity)

        else:
            if np.all(muts_df_inner[f"0_shift_{cond}"] <= 0.05) or np.all(muts_df_inner[f"1_shift_{cond}"] <= 0.05):
                lasso_sparsity_loss[f"{cond}-replicate-correlation"].append(np.nan)
            else:
                corr = pearsonr(muts_df_inner[f"0_shift_{cond}"], muts_df_inner[f"1_shift_{cond}"])
                r = corr[0]
                r = np.nan if np.isclose(r, 1.0) else r
                lasso_sparsity_loss[f"{cond}-replicate-correlation"].append(r)

            for rep, marker in zip([0,1], ["o", "D"]):
                rep_cond_shifts = muts_df_outer[[f"{rep}_shift_{cond}", "sense"]].dropna()
                for sense, color in zip(["stop", "nonsynonymous"], ["red", "blue"]):
                    shifts = rep_cond_shifts.query("sense == @sense")[f"{rep}_shift_{cond}"]
                    sparsity = (len(shifts[shifts==0]) / len(shifts))*100
                    lasso_sparsity_loss[f"{cond}-{rep}-{sense}-sparsity"].append(sparsity)
            
    for _, row in lasso_replicates.iterrows():
        model = row[model_name]
        rep = row.replicate
        for cond in model.data.conditions:
            lasso_sparsity_loss[f"{cond}-{rep}-loss"].append(loss(model, cond))

lasso_sparsity_loss = pd.DataFrame(lasso_sparsity_loss)
lasso_sparsity_loss
lasso Delta-replicate-correlation Delta-0-stop-sparsity Delta-0-nonsynonymous-sparsity Delta-1-stop-sparsity Delta-1-nonsynonymous-sparsity Omicron_BA1-replicate-correlation Omicron_BA1-0-stop-sparsity Omicron_BA1-0-nonsynonymous-sparsity Omicron_BA1-1-stop-sparsity ... Omicron_BA2-0-stop-sparsity Omicron_BA2-0-nonsynonymous-sparsity Omicron_BA2-1-stop-sparsity Omicron_BA2-1-nonsynonymous-sparsity Delta-0-loss Omicron_BA1-0-loss Omicron_BA2-0-loss Delta-1-loss Omicron_BA1-1-loss Omicron_BA2-1-loss
0 0.0 0.546623 0.000000 0.000000 0.000000 0.000000 0.927772 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.20909000580973827 0.22839310471511234 0.18636279787848137 0.24426963039574864 0.18240780094215583 0.17515822800558498
1 1e-06 0.559236 1.604278 1.064688 3.791469 1.153107 0.928948 0.0 0.0 0.0 ... 3.743316 2.274560 3.317536 2.274183 0.20917665394134038 0.22596427021589502 0.1864740986183709 0.24575726872194636 0.18343952479741535 0.17608967670574516
2 1e-05 0.657253 43.850267 14.566866 44.075829 17.120436 0.929408 0.0 0.0 0.0 ... 57.754011 22.471366 59.715640 23.398463 0.21720514216870665 0.2338510968877193 0.1934875377057383 0.25515255109651935 0.18879111257186482 0.18289102765146906
3 2e-05 0.687851 80.748663 25.504114 77.725118 28.219090 0.924362 0.0 0.0 0.0 ... 87.700535 39.425714 91.943128 39.509930 0.2246265603265015 0.24097466572774723 0.19974765745469986 0.2642087677227236 0.19616986638389988 0.18911643045941567
4 5e-05 0.697926 97.860963 48.281981 96.682464 50.416400 0.914043 0.0 0.0 0.0 ... 99.465241 64.462010 99.526066 62.796284 0.24265752051095138 0.2598475229546153 0.21163418285767038 0.28439700602671786 0.2162986939022897 0.20023389068131223
5 0.0001 0.672756 100.000000 92.030973 100.000000 87.363869 0.913354 0.0 0.0 0.0 ... 100.000000 93.950637 100.000000 90.486867 0.2782585505749827 0.29664621933526747 0.2298513717483986 0.31238377515420235 0.248831860960126 0.21675760973268016
6 0.0005 NaN 100.000000 100.000000 100.000000 100.000000 0.910521 0.0 0.0 0.0 ... 100.000000 99.983868 100.000000 99.983985 0.29449662651546876 0.30750704920451805 0.24369163268190178 0.33535281397087024 0.26575349005736537 0.23706711609230102
7 0.001 NaN 100.000000 100.000000 100.000000 100.000000 0.908683 0.0 0.0 0.0 ... 100.000000 100.000000 100.000000 100.000000 0.2940877706231391 0.30689114259362715 0.24373089959003352 0.3353392198202194 0.26477229982899697 0.23722332291693782

8 rows × 22 columns

Plot shrinkage figure

Hide source
saveas="shrinkage_analysis_trace_plots_beta"
fig, ax = plt.subplots(
    4,4, 
    figsize=[6.4,7],
    sharex=True, 
    gridspec_kw={
        'width_ratios': [1,0.35,1,1]
    }
)

lasso_cmap=plt.get_cmap("Set2").colors
lasso_shrinkage = sorted(models.scale_coeff_lasso_shift.unique())
lasso_shrinkage_cmap = dict(zip(lasso_shrinkage, lasso_cmap))

mutations_to_highlight = ["D142L", "A419S", "A570D", "K854N", "T1027I"]
mutations_cmap=plt.get_cmap("Accent").colors
mutations_cmap = dict(zip(mutations_to_highlight, mutations_cmap))

condition_col = {
    "Omicron_BA1" : 0,
    "Delta" : 2,
    "Omicron_BA2" : 3
}

replicate_line_style = {
    0 : "-",
    1 : "--"
}

replicate_marker = {
    0 : "o",
    1 : "D"
}

sense_colors = {
    "nonsynonymous" : "darkgrey",
    "stop" : "red"
}

sense_lw = {
    "nonsynonymous" : 0.5,
    "stop" : 0.1
}

sense_alpha = {
    "nonsynonymous" : 0.1,
    "stop" : 0.5
}

model_choice = chosen_lasso_strength
model_line_kwargs = {
    "linewidth" : 15,
    "color" : "grey",
    "alpha" : 0.1
}

for (condition, replicate), df in tall_mut_df.groupby(["condition", "replicate"]):
    row = 0
    iter_ax = ax[row, condition_col[condition]]
    sns.despine(ax=iter_ax)
    
    # plot nonsynonymous, non validated
    for mut, trace_df in df.query(
        "sense == 'nonsynonymous' & not mutation.isin(@mutations_to_highlight)"
    ).groupby("mutation"):
        iter_ax.plot(
            trace_df.scale_coeff_lasso, 
            trace_df.S,
            linestyle=replicate_line_style[replicate],
            linewidth=sense_alpha['nonsynonymous'],
            alpha=sense_alpha['nonsynonymous'],
            color="lightgrey"
        )
    
    # plot stop traces
    for mut, trace_df in df.query("sense == 'stop'").groupby("mutation"):

        iter_ax.plot(
            trace_df.scale_coeff_lasso, 
            trace_df.S,
            linestyle=replicate_line_style[replicate],
            linewidth=sense_lw['stop'],
            alpha=sense_alpha['stop'],
            color=sense_colors['stop']
        )

    # plot highlighted muts
    for mut, trace_df in df.query(
        "mutation.isin(@mutations_to_highlight)"
    ).groupby("mutation"):
        iter_ax.plot(
            trace_df.scale_coeff_lasso, 
            trace_df.S,
            linestyle=replicate_line_style[replicate],
            linewidth=2,
            alpha=1.0,
            color=mutations_cmap[mut]
        )
    iter_ax.axvline(str(model_choice), **model_line_kwargs)
    
        
    # Plot sparsity    
    row = 1
    iter_ax = ax[row, condition_col[condition]]
    sns.despine(ax=iter_ax)
    
    for sense in ["nonsynonymous", "stop"]:
        iter_ax.plot(
            lasso_sparsity_loss["lasso"],
            lasso_sparsity_loss[f"{condition}-{replicate}-{sense}-sparsity"],
            linestyle=replicate_line_style[replicate],
            linewidth=2,
            alpha=0.5,
            color=sense_colors[sense]
        )
        
        iter_ax.scatter(
            lasso_sparsity_loss["lasso"],
            lasso_sparsity_loss[f"{condition}-{replicate}-{sense}-sparsity"],
            marker=replicate_marker[replicate],
            alpha=0.5,
            color="black"
        )
    iter_ax.axvline(str(model_choice), **model_line_kwargs)
        
    # Plot Loss  
    row = 2
    iter_ax = ax[row, condition_col[condition]]
    sns.despine(ax=iter_ax)
    
    
    iter_ax.plot(
        lasso_sparsity_loss["lasso"],
        lasso_sparsity_loss[f"{condition}-{replicate}-loss"],
        linestyle=replicate_line_style[replicate],
        linewidth=2,
        alpha=1.0,
        color="darkgrey"
    )

    iter_ax.scatter(
        lasso_sparsity_loss["lasso"],
        lasso_sparsity_loss[f"{condition}-{replicate}-loss"],
        marker=replicate_marker[replicate],
        alpha=0.7,
        color="black"
    )
    iter_ax.axvline(str(model_choice), **model_line_kwargs)
    
    # Plot Correlation  
    row = 3
    iter_ax = ax[row, condition_col[condition]]
    sns.despine(ax=iter_ax)
    
    iter_ax.plot(
        lasso_sparsity_loss["lasso"],
        lasso_sparsity_loss[f"{condition}-replicate-correlation"],
        linestyle="-",
        linewidth=2,
        alpha=1.0,
        color="darkgrey"
    )
    
    iter_ax.scatter(
        lasso_sparsity_loss["lasso"],
        lasso_sparsity_loss[f"{condition}-replicate-correlation"],
        marker="X",
        alpha=0.7,
        color="black"
    )
    iter_ax.axvline(str(model_choice), **model_line_kwargs)



ax[0, 0].set_title(r"BA.1", size=11)
ax[0, 2].set_title("Delta", size=11)
ax[0, 3].set_title("BA.2", size=11)

ax[0, 1].set_visible(False)
ax[1, 1].set_visible(False)
ax[2, 1].set_visible(False)
ax[3, 1].set_visible(False)

ax[0, 0].set_ylabel("mut. effect ($beta_{m}$)", size=11)
ax[1, 0].set_ylabel("sparsity\n(% $beta_{m} = 0$)", size=11)
# ax[2, 0].set_ylabel("condition\n train data loss \n$L_{Huber_{\sigma=1}}^{d}$ w/o penalty", size=11)
# ax[2, 0].set_ylabel("condition\n training data loss", size=11)
ax[2, 0].set_ylabel("Huber loss\nw/o penalty", size=11)
# ax[3, 0].set_ylabel("$r(beta_{m}^{rep1},beta_{m}^{rep2})$", size=11)
ax[3, 0].set_ylabel("replicate\nmut. effect\ncorrelation", size=11)

ax[0, 2].set_ylabel("shift $(\Delta_{d,m})$", size=11)
ax[1, 2].set_ylabel("sparsity\n(% $\Delta_{d,m} = 0$)", size=11)
# ax[2, 2].set_ylabel("Huber loss\n w/o penalty", size=11)
ax[2, 2].set_ylabel("Huber loss\nw/o penalty", size=11)
# ax[2, 0].set_ylabel("condition\n training data loss", size=11)

# ax[3, 2].set_ylabel("$r(\Delta_{d,m}^{rep1},\Delta_{d,m}^{rep2})$", size=11)
ax[3, 2].set_ylabel("replicate shift\ncorrelation", size=11)

# TODO legends
black_line = mlines.Line2D([], [], color='black', linestyle='-',
                          markersize=5, label='rep 1')
black_dashed = mlines.Line2D([], [], color='black',linestyle='--',
                          markersize=5, label='rep 2')
lasso_color_handles = [
    mlines.Line2D(
        [], [], 
        color=color, 
        linestyle='-',
        markersize=2,
        linewidth=2,
        label=mut
    )
    for mut, color in mutations_cmap.items()
]

elements = lasso_color_handles + [black_line, black_dashed] 
ax[0,3].legend(handles=elements, bbox_to_anchor = (1, 1), loc='upper left', frameon=False, fontsize=9)

black_circle = mlines.Line2D([], [], color='black', marker='o', linestyle='None',
                          markersize=5, label='rep 1')
black_triangle = mlines.Line2D([], [], color='black', marker='D', linestyle='None',
                          markersize=5, label='rep 2')

red_line = mlines.Line2D([], [], color='red', linewidth=2,linestyle='-',markersize=5, label='stop muts')
grey_line = mlines.Line2D([], [], color='grey',linewidth=2, linestyle='-',markersize=5, label='nonsynonymous\nmuts')

elements = [black_circle, black_triangle, red_line, grey_line] #+lasso_color_handles
ax[1, 3].legend(handles=elements, bbox_to_anchor = (1, 1), loc='upper left', frameon=False, fontsize=9)

ax[3,0].set_xticks(ax[3,0].get_xticks(), ax[3,0].get_xticklabels(), rotation=90, ha='center')
ax[3,2].set_xticks(ax[3,2].get_xticks(), ax[3,2].get_xticklabels(), rotation=90, ha='center')
ax[3,3].set_xticks(ax[3,3].get_xticks(), ax[3,3].get_xticklabels(), rotation=90, ha='center')

ax[3,2].set_xlabel(f"lasso regularization strength ($\lambda$)")
ax[3,2].xaxis.set_label_coords(0.4, -0.6)

ax[0,0].text(
    -0.55, 1.00, 
    f"A", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.55, 1.00, 
    f"B", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)
ax[2,0].text(
    -0.55, 1.00, 
    f"C", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[2,0].transAxes
)
ax[3,0].text(
    -0.55, 1.00, 
    f"D", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[3,0].transAxes
)

ax[0, 2].set_yticks([-2.5, 0, 2.5], [-2.5, 0, 2.5])
ax[0, 2].set_ylim([-3.0, 4.5])
ax[0, 3].set_yticks([-2.5, 0, 2.5], [-2.5, 0, 2.5])
ax[0, 3].set_ylim([-3.0, 4.5])
ax[0, 3].yaxis.set_tick_params(labelleft=False)

ax[1, 2].set_yticks([0, 50, 100], [0, 50, 100])
ax[1, 2].set_ylim([-5, 105])
ax[1, 3].set_yticks([0, 50, 100], [0, 50, 100])
ax[1, 3].set_ylim([-5, 105])
ax[1, 3].yaxis.set_tick_params(labelleft=False)

ax[2, 2].set_yticks([.1, .2, .3], [.1, .2, .3])
ax[2, 2].set_ylim([.15, .35])
ax[2, 3].set_yticks([.1, .2, .3], [.1, .2, .3])
ax[2, 3].set_ylim([.15, .35])
ax[2, 3].yaxis.set_tick_params(labelleft=False)

ax[3, 2].set_yticks([.55, .65, .75], [.55, .65, .75])
ax[3, 2].set_ylim([.51, .78])
ax[3, 3].set_yticks([.55, .65, .75], [.55, .65, .75])
ax[3, 3].set_ylim([.51, .78])
ax[3, 3].yaxis.set_tick_params(labelleft=False)

ax[3, 0].set_yticks([.75, .85, .95], [.75, .85, .95])

# plt.tight_layout()
fig.savefig(f"{OUTDIR}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{OUTDIR}/{saveas}.png",bbox_inches='tight')
plt.show()
_images/39e9fdbb9387b592b6a4b278bb557de2bc904a535347bd5f4406421ace8b1882.png

Global epistasis fits

Here, we take a look at the fit of the sigmoidal global epistasis function (at the chosen lasso coefficient of 5e-5) to the data.

For each replicate at the chosen lasso strength, we get the training data predictions using model.get_variants_df, and use model.get_condition_params paried with model.model_components for visualizing the global epistasis function with the current model parameters. See the function docs strings for the relevant details of each.

chosen_replicate_models = models.query("scale_coeff_lasso_shift == @chosen_lasso_strength")
replicate_data = {}
for row_idx, replicate_row in chosen_replicate_models.iterrows():
    model = replicate_row[model_name]

    # get training data variants and their respective 
    df = model.get_variants_df(phenotype_as_effect=False)

    # find the low/high bound of the training data and use those to make
    # global epistasis predictions across the range for plotting
    xlb, xub = [-1, 1] + np.quantile(df.predicted_latent, [0.05, 1.0])
    additive_model_grid = np.linspace(xlb, xub, num=1000)

    # make predictions on hypothetical data points between lower, and upper bound
    current_params = model.get_condition_params(model.data.reference)
    latent_preds = model.model_components["g"](current_params["theta"], additive_model_grid)
    shape = (additive_model_grid, latent_preds)   

    # save and organize the data for plotting
    replicate_data[replicate_row.replicate] = {
        "variants_df" : df,
        "wildtype_df" : model.wildtype_df,
        "epistasis_shape" : shape,
        "condition_colors" : model.data.condition_colors
    }

Plot the observed functional scores of a random sample of all variants (20%), as function of both latent phenotype prediction (top), and functional score phenotype prediction (bottom).

Hide source
saveas="global_epistasis_and_prediction_correlations"
fig, ax = plt.subplots(2,2, figsize=[6.4,6], sharey='row')    

row=0
for replicate, data in replicate_data.items():

    iter_ax = ax[row, replicate]
    sns.scatterplot(
        data=data["variants_df"].sample(frac=0.2),
        x="predicted_latent",
        y=f"func_score",
        hue="condition",
        palette=model.data.condition_colors,
        ax=iter_ax,
        legend=False,
        size=5,
        alpha=0.1,
        lw=3
    )
    
    for condition, values in data["wildtype_df"].iterrows():
        iter_ax.axvline(
            values.predicted_latent,
            label=condition,
            c=model.data.condition_colors[condition],
            lw=2,
        )
    
    iter_ax.plot(*data["epistasis_shape"], color="k", lw=2)
    
    xlb, xub = [-1, 1] + np.quantile(data["variants_df"].predicted_latent, [0.05, 1.0])
    ylb, yub = [-1, 1] + np.quantile(data["variants_df"].func_score, [0.05, 1.0])
    iter_ax.set_xlim([xlb, xub])
    iter_ax.set_ylim([ylb, yub])
    iter_ax.set_title(f"replicate {replicate+1}")
    iter_ax.set_ylabel("observed\nfunctional score")
    iter_ax.set_xlabel("predicted latent phenotype ($\phi$)")

row=1
for replicate, data in replicate_data.items():

    iter_ax = ax[row, replicate]
    sns.scatterplot(
        data=data["variants_df"].sample(frac=0.1),
        x="predicted_func_score",
        y=f"func_score",
        hue="condition",
        palette=model.data.condition_colors,
        ax=iter_ax,
        legend=False,
        size=5,
        alpha=0.1
    )
    
    for condition, values in data["wildtype_df"].iterrows():
        iter_ax.axvline(
            values.predicted_latent,
            label=condition,
            c=model.data.condition_colors[condition],
            lw=2,
        )
    
    iter_ax.set_ylabel("observed\nfunctional score")
    iter_ax.set_xlabel("predicted functional score")

    start_y = 0.9
    for c, cdf in data["variants_df"].groupby("condition"):
        r = pearsonr(
            cdf["predicted_func_score"],
            cdf["func_score"]
        )[0]
        iter_ax.annotate(
            f"$r = {r:.2f}$",
            (0.1, start_y),
            xycoords="axes fraction",
            fontsize=12,
            c=model.data.condition_colors[c],
        )
        start_y += -0.1


elements = [
    mlines.Line2D([], [], color=color, marker='o', linestyle='None',markersize=5, label=condition)
    for condition, color in replicate_data[0]["condition_colors"].items()
]


ax[0, 0].legend(
    handles=elements, 
    bbox_to_anchor = (0., .99), 
    loc='upper left', 
    frameon=False, 
    fontsize=9
)
    
    
plt.tight_layout()
fig.subplots_adjust(wspace=0.05)

ax[0,0].text(
    -0.1, 1.00, 
    f"A", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.00, 
    f"B", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)


fig.savefig(f"{OUTDIR}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{OUTDIR}/{saveas}.png",bbox_inches='tight')
plt.show()
_images/b5d3180354d80f3fb6f5ee5718f83c271a03b4a6e9ff53a9fd790b8552a743cf.png

Shifted mutations (interactive altair chart)

The easiest way to view shifted mutations is to create an interactive altair chart using multidms.plot.mut_shift_plot. This function can take a single model, or a collection of models in a dictionary if you want to visualize the aggregated (mean) results of shared mutations between models. Toggle the drop down for the cell below to see details on using this function.

help(multidms.plot.mut_shift_plot)
Hide output
Help on function mut_shift_plot in module multidms.plot:

mut_shift_plot(fit_data, biochem_order_aas=True, times_seen_threshold=3, include_beta=True, **kwargs)
    Make plot of mutation escape values for one or more replicate fits.
    You may either pass a single `multidms.Data` object for
    visualizing a single set of parameters or a collection of replicate
    fits in the form of a dictionary where key's are the replicate
    name (i.e. rep1, rep2) and values are the respective model objects.
    
    Parameters
    ----------
    fit_data : multidms.Data or dict
        Either a single `multidms.Data` object or a dictionary of
        replicate fits where the keys are the replicate names and the values
        are the respective model objects.
    biochem_order_aas : bool
        Biochemically order amino-acid alphabet :attr:`PolyclonalCollection.alphabet`
        by passing it through :func:`polyclonal.alphabets.biochem_order_aas`.
    times_seen_threshold : int
        Set a threshold for the number of genetic backgrounds each mutant must be seen
        within each condition in order to be used in the visualization.
    include_beta : bool
        If True, include the beta values as another category in the figure.
        If False, only include beta's in the tooltip.
    **kwargs
        Keyword args for :func:`polyclonal.plot.lineplot_and_heatmap`
    
    Returns
    -------
    altair.Chart
        Interactive heat maps.

Here, we create the interactive chart by feeding the function a dictionary containing the two replicate models, and specifying:

  1. times_seen_threshold = 1, meaning for a mutation to be included, it must be seen at least once in every condition

  2. inlcude_beta = False, we only wish to see the shifted parameters visualized, not the respective effect (beta) parameters. (note that respective effect values will be added as a tooltip when hovering over any shift mutation).

To view the chart, toggle the output of the cell below.

chart = multidms.plot.mut_shift_plot(
    {
        f"rep_{row.replicate}":row[model_name] 
        for idx, row in chosen_replicate_models.iterrows()
    },
    include_beta = False,
    times_seen_threshold=1
)
chart.save(f"{OUTDIR}/interactive_chart_wo_beta.html")
chart
Hide output