Spike Analysis

This notebook contains the entire empirical analysis on the three spike homologs as seen in our manuscript Jointly modeling deep mutational scans identifies shifted mutational effects among SARS-CoV-2 spike homologs.

Computational platform environment

This section shows the attributes of the machine which ran this notebook, as well as imports the necessary dependencies.

Operating system

Hide content
! grep -E '^(VERSION|NAME)=' /etc/os-release
NAME="Ubuntu"
VERSION="18.04.6 LTS (Bionic Beaver)"

Hardware (Processors and RAM)

Hide content
! lshw -class memory -class processor
WARNING: you should run this program as super-user.
PCI (sysfs)  
  *-memory                  
       description: System memory
       physical id: 0
       size: 996GiB
  *-cpu
       product: AMD EPYC 75F3 32-Core Processor
       vendor: Advanced Micro Devices [AMD]
       physical id: 1
       bus info: cpu@0
       size: 3855MHz
       width: 64 bits
       capabilities: fpu fpu_exception wp vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp x86-64 constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca cpufreq
WARNING: output may be incomplete or inaccurate, you should run this program as super-user.

GPU’s

Hide content
%env CUDA_VISIBLE_DEVICES=0
! nvidia-smi -L
env: CUDA_VISIBLE_DEVICES=0
GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-414cb1bd-372a-4926-b140-b734687c927f)
GPU 1: NVIDIA A100 80GB PCIe (UUID: GPU-e54c2054-5be3-ebd0-e22e-b98441ec664f)
Hide content
# built-in libraries
import os
import sys
from itertools import combinations
from collections import defaultdict
import time
import pprint
import copy
import pickle
from functools import reduce

# external dependencies
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr
from matplotlib.lines import Line2D
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
from matplotlib.transforms import (
    Bbox, TransformedBbox, blended_transform_factory)
from mpl_toolkits.axes_grid1.inset_locator import (
    BboxPatch, BboxConnector, BboxConnectorPatch)
import matplotlib.patches as patches
import matplotlib.colors as colors
import numpy as np
import scipy
from tqdm.notebook import tqdm
import jax
import jax.numpy as jnp
import shutil
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.PDBList import PDBList
from Bio.PDB.DSSP import DSSP
from Bio import SeqIO
import multidms
%matplotlib inline

This was notebook was run with the following multidms version.

multidms.__version__
'0.4.0'

set papermill parameters which define notebook behavior

output_dir = 'results/spike_analysis'
# define the fitting parameters
scale_coeff_lasso_shift = [0.0, 5.00e-6, 1.00e-05, 2.00e-05, 4.00e-05, 8.00e-05, 1.60e-04, 3.20e-04, 6.40e-04] # the sweep of lasso coefficient params
alpha_d=True
scale_coeff_ridge_alpha_d=1e-3
num_training_steps = 30
iterations_per_step = 1000 # default 20000
# init_beta_naught = 5.0 # We've found that we need to start with a higher beta_naught to get the model to converge correctly,
scale_coeff_ridge_beta = 0.0 # the sweep of ridge coefficient params
train_frac = 0.8 # fraction of data to use for cross validation training.
lasso_choice = 4.00e-05 # the lasso coefficient to use for the final model

Set some global configurations for plotting.

Hide content
if not os.path.exists(output_dir): os.mkdir(output_dir)

rc_kwargs = {
    'legend.frameon': False,
    "font.size" : 11,
    "font.weight" : "normal"
}

plt.rcParams.update(**rc_kwargs)

Input Data

Load and organize the funcational score dms data.

We begin with 16 individual sets of barcoded variants and their associated functional scores pre-computed. Each set derives from a single DMS experiment using one of Delta, Omicron BA.1, or Omicron BA.2 as the experimental wildtype. First, we parse the filenames to get experimental attributes tied in with the individual datasets as nested pd.DataFrames

Hide source
func_score_data = pd.DataFrame()

for homolog in ["Delta", "Omicron_BA1", "Omicron_BA2"]:
    
    # functional scores
    func_sel = (
        pd.read_csv(f"data/{homolog}/functional_selections.csv")
        .assign(
            filename = lambda x: f"data/{homolog}/" + 
            x.library + "_" + 
            x.preselection_sample + 
            "_vs_" + x.postselection_sample + 
            "_func_scores.csv"
        )
        .assign(
            func_sel_scores_df = lambda x: x.filename.apply(
                lambda f: pd.read_csv(f)
            )   
        )
        .assign(
            len_func_sel_scores_df = lambda x: x.func_sel_scores_df.apply(
                lambda x: len(x)
            )
        )
        .assign(homolog = homolog)
    )
    func_score_data = pd.concat([func_score_data, func_sel]).reset_index(drop=True)

# Add a column that gives a unique ID to each homolog/DMS experiment
func_score_data['condition'] = func_score_data.apply(
    lambda row: f"{row['homolog']}-{row['library']}".replace('-Lib',''),
    axis=1
)
func_score_data[['library', 'replicate', 'filename', 'condition']]
library replicate filename condition
0 Lib-1 1 data/Delta/Lib-1_2021-10-28_thaw-1_VSVG_contro... Delta-1
1 Lib-1 2 data/Delta/Lib-1_2021-10-28_thaw-1_VSVG_contro... Delta-1
2 Lib-3 1 data/Delta/Lib-3_2021-10-28_thaw-1_VSVG_contro... Delta-3
3 Lib-3 2 data/Delta/Lib-3_2021-10-28_thaw-1_VSVG_contro... Delta-3
4 Lib-4 1 data/Delta/Lib-4_2021-10-28_thaw-1_VSVG_contro... Delta-4
5 Lib-4 2 data/Delta/Lib-4_2021-10-28_thaw-1_VSVG_contro... Delta-4
6 Lib-2 1 data/Delta/Lib-2_2021-10-28_thaw-1_VSVG_contro... Delta-2
7 Lib-2 2 data/Delta/Lib-2_2021-10-28_thaw-1_VSVG_contro... Delta-2
8 Lib-1 1 data/Omicron_BA1/Lib-1_2022-03-25_thaw-1_VSVG_... Omicron_BA1-1
9 Lib-1 2 data/Omicron_BA1/Lib-1_2022-03-25_thaw-1_VSVG_... Omicron_BA1-1
10 Lib-2 1 data/Omicron_BA1/Lib-2_2022-06-22_thaw-1_VSVG_... Omicron_BA1-2
11 Lib-3 1 data/Omicron_BA1/Lib-3_2022-06-22_thaw-1_VSVG_... Omicron_BA1-3
12 Lib-1 1 data/Omicron_BA2/Lib-1_2022-10-22_thaw-1_VSVG_... Omicron_BA2-1
13 Lib-2 1 data/Omicron_BA2/Lib-2_2022-10-22_thaw-1_VSVG_... Omicron_BA2-2
14 Lib-1 2 data/Omicron_BA2/Lib-1_2022-10-22_thaw-2_VSVG_... Omicron_BA2-1
15 Lib-2 2 data/Omicron_BA2/Lib-2_2022-10-22_thaw-2_VSVG_... Omicron_BA2-2
Hide source
avail_cond_str = '\n- '.join(list(func_score_data.condition.unique()))
print(f"Available conditions for fitting are:\n- {avail_cond_str}")
Available conditions for fitting are:
- Delta-1
- Delta-3
- Delta-4
- Delta-2
- Omicron_BA1-1
- Omicron_BA1-2
- Omicron_BA1-3
- Omicron_BA2-1
- Omicron_BA2-2

Concatentate each of the individual experiments, keeping track of the library and homolog of each. Output noteable features, for a random sample of 10

Hide source
func_score_df = pd.DataFrame()
for idx, row in tqdm(func_score_data.iterrows(), total=len(func_score_data)):
    mut_df_replicates = row.func_sel_scores_df.assign(
        homolog=row.homolog,
        library = row.library,
        replicate = row.replicate,
        condition=row.condition
    )
    func_score_df = pd.concat([func_score_df, mut_df_replicates])

# rename, sort index, and fill na (wildtype values) with empty strings
func_score_df = (func_score_df
    .rename(
        {"aa_substitutions_reference":"aa_substitutions"}, 
        axis=1
    )
    .reset_index(drop=True)
    .fillna("")
    .sort_values(by="condition")
)
func_score_df[["library", "barcode", "aa_substitutions", "func_score", "condition"]].sample(10, random_state=0)
library barcode aa_substitutions func_score condition
721925 Lib-3 CGTAAAGTTCCAACAA G769R D950F R1107M N1192S -2.1765 Omicron_BA1-3
239549 Lib-4 AATAATTTTCCTACAC -2.0202 Delta-4
816259 Lib-1 GATGATACCAAACTAT K814T L1024I E1207K -2.1526 Omicron_BA2-1
612980 Lib-2 AAATATCCTACAAGAA C738Y A890T H1058Y -9.0995 Omicron_BA1-2
368871 Lib-1 TAATACCGAATCCCCC A893V S939D A1078T -4.1550 Omicron_BA1-1
1115330 Lib-2 GTATACATGTATGATG S71L D1163E S1242N 0.2762 Omicron_BA2-2
410949 Lib-1 GCATTACTACAAATAA N960K 0.6777 Omicron_BA1-1
971589 Lib-1 CAATATAGCATAGAGG R78L 0.1378 Omicron_BA2-1
592643 Lib-2 ACAAGCTTTGCAACAA Y200H 1.3313 Omicron_BA1-2
381265 Lib-1 CTAGTCTCCGACAAAA F347S D627G I850L -6.8325 Omicron_BA1-1

Discard all variants with a pre-selection count of 100.

Hide source
n_pre_threshold = len(func_score_df)
func_score_df.query("pre_count >= 100", inplace=True)
print(f"Of {n_pre_threshold} variants, {n_pre_threshold - len(func_score_df)} had fewer than the threshold of counts before selection, and were filtered out")
Of 1135096 variants, 120164 had fewer than the threshold of counts before selection, and were filtered out

We only require a functional score, aa substitutions, and condition column for instatiating the multidms.Data object. drop the rest.

Hide source
required_cols = ['func_score', 'aa_substitutions', 'condition']
func_score_df.drop([c for c in func_score_df if c not in required_cols], axis=1, inplace=True)

Remove all variants with string-suffixed sites (indels) and stop codon wildtypes.

Hide source
stop_wt_vars = []
non_numeric_sites = []
for idx, row in tqdm(func_score_df.iterrows(), total=len(func_score_df)):
    for sub in row["aa_substitutions"].split():
        if sub[0] == "*":
            stop_wt_vars.append(idx)
        if not sub[-2].isnumeric():
            non_numeric_sites.append(idx)

to_drop = set.union(set(stop_wt_vars), set(non_numeric_sites))
func_score_df.drop(to_drop, inplace=True)

We clip all functional scores at a lower bound of -3.5, and an upper bound of 2.5.

n_below_clip = len(func_score_df.query(f"func_score < -3.5"))
n_above_clip = len(func_score_df.query(f"func_score > 2.5"))
print(f"There are {n_below_clip} variants below the clip theshold, and {n_above_clip} above.")
func_score_df = func_score_df.assign(
    func_score = func_score_df.func_score.clip(-3.5, 2.5)
)
There are 143177 variants below the clip theshold, and 345 above.

Choose two representative biological replicates for each of the three homologs.

experiment_conditions = ["Delta", "Omicron_BA1", "Omicron_BA2"]
replicate_1_experiments = ["Delta-2", "Omicron_BA1-2", "Omicron_BA2-1"]
replicate_2_experiments = ["Delta-4", "Omicron_BA1-3", "Omicron_BA2-2"]

Organize the two replicates and annotate replicates “1” and “2”. These each represent a distinct training set such that we may train replicate models and compare their results. Output a random sample of 10 variants.

Hide source
func_score_df = pd.concat(
    [
        (
            func_score_df
            .query("condition in @replicate_1_experiments")
            .replace(dict(zip(replicate_1_experiments, experiment_conditions)))
            .assign(replicate=1)
        ),
        (
            func_score_df
            .query("condition in @replicate_2_experiments")
            .replace(dict(zip(replicate_2_experiments, experiment_conditions)))
            .assign(replicate=2)
        )
    ]
)
func_score_df = func_score_df.assign(
    n_subs = [
        len(aa_subs.split()) 
        for aa_subs in func_score_df.aa_substitutions
    ]
)
func_score_df.sample(10)
func_score aa_substitutions condition replicate n_subs
623435 -2.5084 K764D Omicron_BA1 1 1
1064675 -3.5000 E324K F375R V1033A Omicron_BA2 2 3
1073643 -3.5000 I584P G891R V1068F Omicron_BA2 2 3
211227 -1.1056 D936G Delta 2 1
808624 -0.0815 Omicron_BA2 1 0
967641 -1.5393 T76K T284I V1176F Omicron_BA2 1 3
675886 -1.1126 I105V W152C D178H K764V S803T Omicron_BA1 2 5
853582 -3.5000 T76A R158S A771P Omicron_BA2 1 3
662082 0.3180 Omicron_BA1 2 0
764857 -3.5000 Y265F T307N T678F K764N Y873S Omicron_BA1 2 5
func_score_df.to_csv(f"{output_dir}/training_functional_scores.csv", index=False)

Variant barcode and mutation background stats

In this section we briedly query and visualize charictaristics of the replicate training sets.

Get the expected number substitutions per variant for each condition replicate.

Hide source
for group, group_df in func_score_df.groupby(["condition", "replicate"]):
    print(f"{group[0]} - rep {group[1]} has {round(group_df.n_subs.mean(), 5)} subs per variant, on average")
Delta - rep 1 has 2.18671 subs per variant, on average
Delta - rep 2 has 2.29472 subs per variant, on average
Omicron_BA1 - rep 1 has 1.802 subs per variant, on average
Omicron_BA1 - rep 2 has 1.75802 subs per variant, on average
Omicron_BA2 - rep 1 has 2.31117 subs per variant, on average
Omicron_BA2 - rep 2 has 2.32827 subs per variant, on average

Get the number of unique mutations seen in each condition replicate.

Hide source
for group, group_df in func_score_df.groupby(["condition", "replicate"]):
    print(f"{group[0]} - rep {group[1]} has {len(group_df.aa_substitutions.unique())}")  
Delta - rep 1 has 28515
Delta - rep 2 has 29158
Omicron_BA1 - rep 1 has 70597
Omicron_BA1 - rep 2 has 62129
Omicron_BA2 - rep 1 has 60397
Omicron_BA2 - rep 2 has 57719

Visualize the distribution of barcodes per variant, as well as the distribution of unique backgrounds per mutation.

import warnings
warnings.simplefilter("ignore")
Hide source
saveas = f"raw_data_summary_barcodes_backgrounds_hist"
logscale=False
fig, ax = plt.subplots(2,3, sharex="row", sharey="row", figsize=[6.4, 5.5])

condition_title = {
    "Delta":"Delta",
    "Omicron_BA1" : "BA.1",
    "Omicron_BA2" : "BA.2"
}

row = 0
for col, (condition, condition_df) in enumerate(func_score_df.groupby("condition")):
    iter_ax = ax[row, col]
    
    mut_df_replicates = condition_df.query("aa_substitutions != ''")
    mut_df_replicates = mut_df_replicates.assign(
        num_muts = [
            len(aa_subs.split())
            for aa_subs in mut_df_replicates.aa_substitutions
        ]
    )
    
    sns.histplot(mut_df_replicates.query("num_muts <= 10"), x="num_muts", ax=iter_ax, hue="replicate", discrete=True)
    for rep, rep_df in mut_df_replicates.groupby("replicate"):
        mean = rep_df['num_muts'].mean()
        iter_ax.axvline(mean, linestyle=("-" if rep == 1 else "--"))
    
    if logscale: iter_ax.set_yscale('log')
    if col != 2: 
        iter_ax.get_legend().remove()
    n_rep1 = len(mut_df_replicates.query("replicate == 1"))//1000
    n_rep2 = len(mut_df_replicates.query("replicate == 2"))//1000
    iter_ax.text(
        0.1, 1.1, 
        f"$N={n_rep1}K, {n_rep2}K$", 
        ha="left", va="top", 
        transform=iter_ax.transAxes
    )
    xscale = "number of amino-acid substitutions per variant" if col == 1 else ""
    iter_ax.set_xlabel(xscale)
    
    ylabel = f"variant counts" if col == 0 else ""
    iter_ax.set_ylabel(ylabel)
    iter_ax.set_xticks(
        [i+1 for i in range(10)],
        labels=[i+1 for i in range(10)], 
        ha="center",
        size=7,
        rotation=0
    )
    sns.despine(ax=iter_ax)
    iter_ax.set_title(condition_title[condition], y=1.15)

row = 1
collapsed_bc_df = func_score_df.groupby(
    ["replicate", "condition", "aa_substitutions"]
).aggregate("mean").reset_index()
for col, (condition, condition_df) in enumerate(collapsed_bc_df.groupby("condition")):
    iter_ax = ax[row, col]
    mut_df_replicates = pd.DataFrame()
    for rep, rep_df in condition_df.groupby("replicate"):
        
        times_seen = (
            rep_df["aa_substitutions"].str.split().explode().value_counts()
        )
        if (times_seen == times_seen.astype(int)).all():
            times_seen = times_seen.astype(int)
        times_seen = pd.DataFrame(times_seen)
        times_seen.index.name = f"mutation"
        mut_df_replicates = pd.concat([mut_df_replicates, times_seen.assign(replicate=rep).reset_index()])

    sns.histplot(
        mut_df_replicates.query("count <= 50"), 
        x="count", 
        ax=iter_ax, 
        element='step', 
        hue="replicate", 
        discrete=True
    )
    
    for rep, rep_df in mut_df_replicates.groupby("replicate"):
        mean = rep_df['count'].mean()
        iter_ax.axvline(mean, linestyle=("-" if rep == 1 else "--"))
        
    iter_ax.get_legend().remove()
    n_rep1 = len(mut_df_replicates.query("replicate == 1"))
    n_rep2 = len(mut_df_replicates.query("replicate == 2"))
    iter_ax.text(
        0.1, 1.1, 
        f"$N={n_rep1}, {n_rep2}$", 
        ha="left", va="top", 
        transform=iter_ax.transAxes
    )
    
    xscale = "number of variant backgrounds \nfor a given amino-acid substitution" if col == 1 else ""
    iter_ax.set_xlabel(xscale)
    
    ylabel = f"mutation counts" if col == 0 else ""
    iter_ax.set_ylabel(ylabel)
    
    xticks = [i for i in range(0, 51) if i % 5 == 0]
    iter_ax.set_xticks(
        xticks,
        labels=xticks, 
        ha="center",
        size=7,
        rotation=0
    )
    
    sns.despine(ax=iter_ax)

plt.tight_layout()

ax[0,0].text(
    -0.1, 1.06, 
    f"A", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.06, 
    f"B", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)

fig.subplots_adjust(hspace=.6)
fig.savefig(f"{output_dir}/{saveas}.pdf")
fig.savefig(f"{output_dir}/{saveas}.png")
plt.show()
_images/68f87b1b53b8a656f20add61e71f54ce43b1dfbdb0a6d0a554a96539457bcb35.png

Plot the correlation of variant functional scores (averaged across barcodes) between replicates in each condition, as well as the full distribution of functional scores.

Hide source
saveas = "replicate_functional_score_correlation_scatter"
pal = sns.color_palette('tab20')

fig, ax = plt.subplots(2,3, sharex="row", sharey="row", figsize=[6.4, 5.3])
collapsed_bc_df = func_score_df.groupby(
    ["replicate", "condition", "aa_substitutions"]
).aggregate("mean").reset_index()
collapsed_bc_df = collapsed_bc_df.assign(
    is_stop=[True if "*" in aasubs else False for aasubs in collapsed_bc_df.aa_substitutions]
)

is_stop_alpha_dict = {
    True : 0.5,
    False : 0.2
}

lim = [-3.8, 2.8]
ticks = np.linspace(-3, 2, 6)
for col, (condition, condition_df) in enumerate(collapsed_bc_df.groupby("condition")):
    
    row = 0
    iter_ax = ax[row, col]
    
    mut_df_replicates = reduce(
        lambda left, right: pd.merge(
            left, right, left_index=True, right_index=True, how="inner"
        ),
        [
            rep_df.rename({"func_score":f"rep_{rep}_func_score"}, axis=1).set_index("aa_substitutions")
            for rep, rep_df in condition_df.groupby("replicate") 
        ],
    )
    
    mut_df_replicates = mut_df_replicates.assign(
        is_stop=[True if "*" in aasubs else False for aasubs in mut_df_replicates.index.values]
    )
    mut_df_replicates = mut_df_replicates.assign(
        n_subs=[len(aasubs.split()) for aasubs in mut_df_replicates.index.values]
    )
    
    # alpha = [is_stop_alpha_dict[istp] for istp in mut_df_replicates.is_stop]
    for istp, color in zip([False, True], ["darkgrey", "red"]):
        sns.scatterplot(
            mut_df_replicates.query("is_stop == @istp"), 
            x="rep_1_func_score", 
            y="rep_2_func_score", 
            ax =iter_ax,
            c=color,
            alpha=is_stop_alpha_dict[istp],
            legend=False
        )
    
    iter_ax.plot([-3.5, 2.5], [-3.5, 2.5], "--", lw=2, c="royalblue")
    
    iter_ax.set_ylim(lim)
    iter_ax.set_xlim(lim)
    if col == 0:
        iter_ax.set_yticks(ticks, labels=ticks)
    iter_ax.set_xticks(ticks, labels=ticks, rotation=90)
    
    corr = pearsonr(mut_df_replicates["rep_1_func_score"], mut_df_replicates["rep_2_func_score"])[0]
    iter_ax.annotate(
        f"$r = {corr:.2f}$", 
        (0.1, 0.9), 
        xycoords="axes fraction", 
        fontsize=12
    )
    iter_ax.set_title(condition)
    # iter_ax.get_legend().remove()
    sns.despine(ax=iter_ax)
    
    row = 1
    iter_ax = ax[row, col]
    sns.violinplot(
        condition_df,
        x="is_stop",
        y="func_score",
        hue="replicate",
        split=True,
        gap=.1, inner="quart",
        palette=["0.5", "0.75"],
        ax=iter_ax
    )
    
    sns.despine(ax=iter_ax)
    if col != 2:
        iter_ax.get_legend().remove()
    else:
        iter_ax.legend(bbox_to_anchor = (1.25, 1.05), title="replicate")
    if col == 0:
        iter_ax.set_yticks(ticks, labels=ticks)

ax[0,0].set_xlabel("")
ax[0,0].set_ylabel("replicate 2 \n functional score")

ax[0,1].set_xlabel("replicate 1 functional score")
ax[0,1].set_title("BA.1")
ax[0,2].set_xlabel("")
ax[0,2].set_title("BA.2")

ax[1,0].set_xlabel("")
ax[1,0].set_ylabel("functional score")

ax[1,1].set_xlabel("variants contain stop codon mutations")
ax[1,2].set_xlabel("")
ax[1,2].set_ylabel("")
ax[1,1].set_ylabel("")

ax[0,0].text(
    -0.1, 1.06, 
    f"A", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.06, 
    f"B", 
    ha="right", va="bottom", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)

# fig.suptitle("Variant Functional Score \nReplicate Correlation")
plt.tight_layout()
fig.subplots_adjust(wspace=0.08, hspace = 0.5)
fig.savefig(f"{output_dir}/{saveas}.pdf")
fig.savefig(f"{output_dir}/{saveas}.png")
plt.show()
_images/5696350b9975066deb25bd8895aa3d3b08202a2158c14d19e745e73b44b459d3.png

Encode data for fitting

Next, we use the multidms.Data class to prep our data for fitting.

Instantiate an object for each of our two replicate training sets, and append them to a list

datasets = []
for res, fsdf in func_score_df.groupby("replicate"):       

    start = time.time()

    # instantiate data object
    data = multidms.Data(
        fsdf,
        collapse_identical_variants="mean",       # take the average variant func score across barcode replicates
        alphabet=multidms.AAS_WITHSTOP_WITHGAP,   # 
        reference="Omicron_BA1",
        assert_site_integrity=False,
        verbose=True,
        nb_workers=8,
        name=f"rep-{res}" 
    )

    end = time.time()
    prep_time = round(end-start)
    print(f"Finished, time: {prep_time}")

    datasets.append(data)
Hide output
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [144, 143, 69, 145, 70, 211, 25, 26, 24, 157, 158, 898],
dropping: 10983 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 2041 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
Finished, time: 37
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [145, 70, 144, 143, 69, 211, 422, 26, 24, 25, 157, 158],
dropping: 10129 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 1873 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
Finished, time: 35

Change the datasets colors - the only mutible thing about Data Objects.

Hide content
# condition colors must be hex-encoded,
palette = [
    '#F97306',
    '#BFBFBF',
    '#9400D3'
]

conditions = [
    "Delta",
    "Omicron_BA1",
    "Omicron_BA2"
]

cc = {con:col for con, col in zip(conditions, palette)}
for data in datasets:
    data.condition_colors = cc

Fit Models (Shrinkage Analysis)

For each replicate training set, initialize and fit a set models with variable lasso strength coefficients. Here, we use the multidms.utils.fit_wrapper function to handle the instatiation, and parameter fitting. Expand the cell output below to view the function docstring.

fit_params = {
    'scale_coeff_lasso_shift': scale_coeff_lasso_shift,
    'alpha_d': [alpha_d],
    'scale_coeff_ridge_alpha_d':[scale_coeff_ridge_alpha_d],
    'scale_coeff_ridge_beta': [scale_coeff_ridge_beta],
    'num_training_steps': [num_training_steps],
    'iterations_per_step':[iterations_per_step],
}

fit_params["dataset"] = datasets
_, _, models = multidms.fit_models(fit_params, n_threads=-1)
pickle.dump(models, open(f"{output_dir}/full_models.pkl", "wb"))

The cell above saves the models and their relevant hyper-parameters in a DataFrame to a pickle binary file. Hense, If it’s already been run, and you want to execute the code below without re-fitting, the following cell will load that binary file.

models["replicate"] = models.dataset_name.str.split("-").str[-1].astype(int)
Hide content
# set some global variables
# columns name
model_name = "model"

# the number of times a mutation must be seen in each condition to be included in downstream analysis
times_seen_threshold = 1

# The chosen lasso strength for our final spike model results
chosen_lasso_strength = 4e-5

Plot the model loss over training steps, as provided by fit_wrapper.

Hide source
saveas="convergence_all_lasso_lines"
cmap=plt.get_cmap("tab20")


fig, ax = plt.subplots(1,figsize=[6.4,4.5])
color_idx = -1
for i, (model, model_row) in enumerate(models.iterrows()):
    if i%2 == 0: color_idx += 1

    ax.plot(
        [(1000 * s) for s in range(len(model_row.step_loss))],
        model_row.step_loss,
        c=cmap.colors[color_idx],
        lw=3,
        linestyle="-" if model_row.replicate == 0 else "--",
        label=f"rep: {model_row.replicate} scale_coeff: {model_row.scale_coeff_lasso_shift}"
    )

ticks = range(0, 30001, 5000)
labels = [f"{t//1000}K" for t in ticks]
ax.set_xticks(ticks, labels, rotation=0, ha='center')
ax.set_ylabel("Model Loss (w/o L1 penalty)")
ax.set_xlabel("Optimization iterations")

black_line = mlines.Line2D([], [], color='black', linestyle='-',
                          markersize=5, label='rep 1')
black_dashed = mlines.Line2D([], [], color='black',linestyle='--',
                          markersize=5, label='rep 2')
lasso_color_handles = [
    mlines.Line2D(
        [], [], 
        color=color, 
        linestyle='-',
        markersize=5,
        linewidth=3,
        label="$\lambda$: "+str(lasso)
    )
    for lasso, color in zip(models.scale_coeff_lasso_shift.unique(), cmap.colors)
]

elements = [black_line, black_dashed] + lasso_color_handles
ax.legend(handles=elements, bbox_to_anchor = (1, 1), loc='upper left', frameon=False, fontsize=9)
sns.despine(ax=ax)
ax.set_ylim()
plt.tight_layout()
fig.savefig(f"{output_dir}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{output_dir}/{saveas}.png",bbox_inches='tight')
plt.show()
_images/356cc9e9065bbbbbd02417d0710e27ee4166d1f88863b6ec52bc6682c5bf02d5.png

Model Evaluation and Selection

model_collection = multidms.ModelCollection(models)
chart, sparsity_df = model_collection.shift_sparsity(return_data=True, height_scalar=100) # TODO raise issue to fix height scalar
print(sparsity_df.head())
chart
cache miss - this could take a moment
  dataset_name  scale_coeff_lasso_shift       mut_type    mut_param  sparsity
0        rep-1                 0.000000  nonsynonymous  shift_Delta  0.234082
1        rep-1                 0.000000           stop  shift_Delta  0.247678
2        rep-1                 0.000005  nonsynonymous  shift_Delta  0.312857
3        rep-1                 0.000005           stop  shift_Delta  0.396285
4        rep-1                 0.000010  nonsynonymous  shift_Delta  0.381122
chart, corr_df = model_collection.mut_param_dataset_correlation(width_scalar=200, return_data=True)
print(corr_df.head())
chart
      datasets mut_param  correlation  scale_coeff_lasso_shift
0  rep-1,rep-2      beta     0.837209                 0.000000
0  rep-1,rep-2      beta     0.839758                 0.000005
0  rep-1,rep-2      beta     0.838824                 0.000010
0  rep-1,rep-2      beta     0.838579                 0.000020
0  rep-1,rep-2      beta     0.837910                 0.000040

Cross Validation

train_frac=0.8
seed=4
train, test = [], {}
for replicate, fs_df in func_score_df.groupby("replicate"):

    dataset = fs_df.sample(frac=1, random_state=seed)
    train_split, test_split = dataset.iloc[:int(len(dataset) * train_frac)], dataset.iloc[int(len(dataset) * train_frac):]
    name = f"rep-{replicate}"
    
    train.append(
        multidms.Data(
            train_split,
            collapse_identical_variants="mean",
            reference="Omicron_BA1", 
            alphabet = multidms.AAS_WITHSTOP_WITHGAP, 
            verbose=True,
            name = name
        )
    )

    test[name] = test_split
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [144, 143, 69, 145, 70, 211, 25, 26, 24, 157, 158, 898],
dropping: 9793 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 1732 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
inferring site map for Delta
inferring site map for Omicron_BA1
inferring site map for Omicron_BA2
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [145, 70, 144, 143, 69, 211, 422, 26, 24, 25, 157, 158],
dropping: 9059 variantswhich have mutations at those sites.
invalid non-identical-sites: [371], dropping 1643 variants
Converting mutations for Delta
Converting mutations for Omicron_BA1
is reference, skipping
Converting mutations for Omicron_BA2
fit_params["dataset"] = train 
_, _, models_cv = multidms.model_collection.fit_models(fit_params, n_threads = -1)
mc = multidms.model_collection.ModelCollection(models_cv)
mc.add_validation_loss(test, overwrite=True)
cross_validation_df = mc.get_conditional_loss_df()
cross_validation_df.head()
dataset_name scale_coeff_lasso_shift condition loss split
0 rep-1 0.0 Delta 0.217511 training
1 rep-1 0.000005 Delta 0.222574 training
2 rep-1 0.00001 Delta 0.227321 training
3 rep-1 0.00002 Delta 0.235090 training
4 rep-1 0.00004 Delta 0.249183 training
saveas="shrinkage_analysis_trace_plots_beta"

fig, ax = plt.subplots(3, figsize=[4.5, 7.5], sharex=True)

# replicate correlation
iter_ax = ax[0]
sns.lineplot(
    data=(
        corr_df
        .query("mut_param.str.contains('shift')")
        .rename({"mut_param":"shift params"}, axis=1)
        # .replace({"Data-1":"rep-1", "Data-2":"rep-2"})
        .replace({"shift_Delta":"Delta", "shift_Omicron_BA2":"BA.2"})
        .assign(
            scale_coeff_lasso_shift = [
                f"{l:.1e}" 
                for l in corr_df.query("mut_param.str.contains('shift')").scale_coeff_lasso_shift
            ],
            correlation = lambda x: x.correlation**2
        )
        .reset_index(drop=True)
    ),
    x="scale_coeff_lasso_shift",
    y="correlation",
    style="shift params",
    markers=True,
    ax=iter_ax,
    linewidth=3,
    color="black"
)
iter_ax.set_ylabel("rep1 v. rep2\nshift $(R^2)$")
# move legend outside of plot
iter_ax.legend(
    bbox_to_anchor = (1, 1), 
    loc='upper left', 
    frameon=False
)



# plot loss
iter_ax = ax[1]
sns.lineplot(
    data = (
        cross_validation_df.query("condition=='total'")
        .assign(
            # lasso_strength = [f"{l:.1e}" for l in sparsity_df.scale_coeff_lasso_shift]
            # lasso_strength = lambda x: f"{x.scale_coeff_lasso_shift:.1e}"
            lasso_strength = lambda x: x['scale_coeff_lasso_shift'].apply(lambda y: f'{y:.1e}')
        )
    ),
    x="lasso_strength",
    y="loss",
    ax=iter_ax,
    hue="split",
    style="dataset_name",
    palette={"training":"slategrey", "validation":"#2CA02C"},
    markers=True,
    linewidth=3
)
# move legend outside of plot
iter_ax.legend(
    bbox_to_anchor = (1, 1), 
    loc='upper left', 
    frameon=False
)


# plot sparsity
iter_ax = ax[2]
sns.lineplot(
    data=(
        sparsity_df
        .rename({"dataset_name":"replicate"}, axis=1)
        .rename({"mut_param":"shift params", "mut_type":"mutation type"}, axis=1)
        # .replace({"Data-0":"rep-1", "Data-1":"rep-2"})
        .replace({"nonsynonymous":"nonsynonymous", "stop":"stop"})
        .replace({"shift_Delta":"Delta", "shift_Omicron_BA2":"BA.2"})
        .assign(
            scale_coeff_lasso_shift = [f"{l:.1e}" for l in sparsity_df.scale_coeff_lasso_shift],
            sparsity_percent = lambda x: x.sparsity * 100,
        )
    ),
    x="scale_coeff_lasso_shift",
    y="sparsity_percent",
    hue="mutation type",
    style="replicate",
    palette={"nonsynonymous":"grey", "stop":"#E377C2"},
    markers=True,
    legend=True,
    ax=iter_ax,
    linewidth=3
)
# move legend outside of plot
iter_ax.legend(
    bbox_to_anchor = (1, 1), 
    loc='upper left', 
    frameon=False
)
# rotate x labels
iter_ax.set_xticklabels(
    iter_ax.get_xticklabels(), 
    rotation=90, 
    ha='center'
)
iter_ax.set_ylabel("sparsity\n$(\%\Delta=0)$")
iter_ax.set_xlabel(f"lasso regularization strength ($\lambda$)")

for axes in ax:
    axes.axvline(
        f"{chosen_lasso_strength:.1e}", 
        color="grey",
        linewidth=10,
        alpha=0.35
    )

sns.despine(fig)
plt.tight_layout()
# plt.tight_layout()
fig.savefig(f"{output_dir}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{output_dir}/{saveas}.png",bbox_inches='tight')
# plt.show()

plt.show()
_images/9d405aa82c390b0898bb7f80fdf786a3c9417e11cdea3086019dae423de833c8.png

Global epistasis fits

Here, we take a look at the fit of the sigmoidal global epistasis function (at the chosen lasso coefficient of 5e-5) to the data.

For each replicate at the chosen lasso strength, we get the training data predictions using model.get_variants_df, and use model.get_condition_params paried with model.model_components for visualizing the global epistasis function with the current model parameters. See the function docs strings for the relevant details of each.

chosen_replicate_models = models.query("scale_coeff_lasso_shift == @chosen_lasso_strength")
replicate_data = {}
for row_idx, replicate_row in chosen_replicate_models.iterrows():
    model = replicate_row[model_name]

    # get training data variants and their respective 
    mut_df_replicates = model.get_variants_df(phenotype_as_effect=False)

    # find the low/high bound of the training data and use those to make
    # global epistasis predictions across the range for plotting
    xlb, xub = [-1, 1] + np.quantile(mut_df_replicates.predicted_latent, [0.05, 1.0])
    additive_model_grid = np.linspace(xlb, xub, num=1000)

    # make predictions on hypothetical data points between lower, and upper bound
    current_params = model.get_condition_params(model.data.reference)
    latent_preds = model.model_components["g"](current_params["theta"], additive_model_grid)
    shape = (additive_model_grid, latent_preds)   

    # save and organize the data for plotting
    replicate_data[replicate_row.replicate] = {
        "variants_df" : mut_df_replicates,
        "wildtype_df" : model.wildtype_df,
        "epistasis_shape" : shape,
        "condition_colors" : model.data.condition_colors
    }
replicate_data.keys()
dict_keys([1, 2])

Plot the observed functional scores of a random sample of all variants (20%), as function of both latent phenotype prediction (top), and functional score phenotype prediction (bottom).

Hide source
saveas="global_epistasis_and_prediction_correlations"
fig, ax = plt.subplots(2,2, figsize=[6.4,6], sharey='row')    

row=0
for replicate, data in replicate_data.items():

    iter_ax = ax[row, replicate-1]
    sns.scatterplot(
        data=data["variants_df"].sample(frac=0.2),
        x="predicted_latent",
        y=f"func_score",
        hue="condition",
        palette=model.data.condition_colors,
        ax=iter_ax,
        legend=False,
        size=5,
        alpha=0.2,
        # lw=3
    )
    
    for condition, values in data["wildtype_df"].iterrows():
        iter_ax.axvline(
            values.predicted_latent,
            label=condition,
            c=model.data.condition_colors[condition],
            lw=2,
        )
    
    iter_ax.plot(*data["epistasis_shape"], color="k", lw=2)
    
    xlb, xub = [-1, 1] + np.quantile(data["variants_df"].predicted_latent, [0.05, 1.0])
    ylb, yub = [-1, 1] + np.quantile(data["variants_df"].func_score, [0.05, 1.0])
    iter_ax.set_xlim([xlb, xub])
    iter_ax.set_ylim([ylb, yub])
    iter_ax.set_title(f"replicate {replicate}")
    iter_ax.set_ylabel("observed\nfunctional score")
    iter_ax.set_xlabel("predicted latent phenotype ($\phi$)")

row=1
for replicate, data in replicate_data.items():

    iter_ax = ax[row, replicate-1]
    sns.scatterplot(
        data=data["variants_df"].sample(frac=0.1),
        x="predicted_func_score",
        y=f"func_score",
        hue="condition",
        palette=model.data.condition_colors,
        ax=iter_ax,
        legend=False,
        size=5,
        alpha=0.2
    )
    
    for condition, values in data["wildtype_df"].iterrows():
        iter_ax.axvline(
            values.predicted_latent,
            label=condition,
            c=model.data.condition_colors[condition],
            lw=2,
        )
    
    iter_ax.set_ylabel("observed\nfunctional score")
    iter_ax.set_xlabel("predicted functional score")

    start_y = 0.9
    for c, cdf in data["variants_df"].groupby("condition"):
        r = pearsonr(
            cdf["predicted_func_score"],
            cdf["func_score"]
        )[0]
        iter_ax.annotate(
            f"$r = {r:.2f}$",
            (0.1, start_y),
            xycoords="axes fraction",
            fontsize=12,
            c=model.data.condition_colors[c],
        )
        start_y += -0.1


elements = [
    mlines.Line2D([], [], color=color, marker='o', linestyle='None',markersize=5, label=condition)
    for condition, color in replicate_data[1]["condition_colors"].items()
]


ax[0, 0].legend(
    handles=elements, 
    bbox_to_anchor = (0., .99), 
    loc='upper left', 
    frameon=False, 
    fontsize=9
)
    
    
plt.tight_layout()
fig.subplots_adjust(wspace=0.05)

ax[0,0].text(
    -0.1, 1.00, 
    f"A", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[0,0].transAxes
)
ax[1,0].text(
    -0.1, 1.00, 
    f"B", 
    ha="right", va="center", 
    size=15,
    weight="bold",
    transform=ax[1,0].transAxes
)


fig.savefig(f"{output_dir}/{saveas}.pdf",bbox_inches='tight')
fig.savefig(f"{output_dir}/{saveas}.png",bbox_inches='tight')
plt.show()
_images/77f87217e073e1c47149400799d7915b3d2127045267e1ce0381e3824a3444fb.png

Shifted mutations (interactive altair chart)

The easiest way to view shifted mutations is to create an interactive altair chart using multidms.plot.mut_shift_plot. This function can take a single model, or a collection of models in a dictionary if you want to visualize the aggregated (mean) results of shared mutations between models. Toggle the drop down for the cell below to see details on using this function.

help(multidms.ModelCollection.mut_param_heatmap)
Hide output
Help on function mut_param_heatmap in module multidms.model_collection:

mut_param_heatmap(self, query=None, mut_param='shift', aggregate_func='mean', inner_merge_dataset_muts=True, times_seen_threshold=0, phenotype_as_effect=True, **kwargs)
    Create lineplot and heatmap altair chart
    across replicate datasets.
    This function optionally applies a given `pandas.query`
    on the fit_models dataframe that should result in a subset of
    fit's which make sense to aggregate mutational data across, e.g.
    replicate datasets.
    It then computes the mean or median mutational parameter value
    ("beta", "shift", or "predicted_func_score")
    between the remaining fits. and creates an interactive altair chart.
    
    
    Note that this will throw an error if the queried fits have more
    than one unique hyper-parameter besides "dataset_name".
    
    
    Parameters
    ----------
    query : str
        The query to apply to the fit_models dataframe. This should be
        used to subset the fits to only those which make sense to aggregate
        mutational data across, e.g. replicate datasets.
        For example, if you have a collection of
        fits with different epistatic models, you may want to query
        for only those fits with the same epistatic model. e.g.
        `query="epistatic_model == 'Sigmoid'"`. For more on the query
        syntax, see the
        `pandas.query <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html>`_
        documentation.
    mut_param : str, optional
        The mutational parameter to plot. The default is "shift".
        Must be one of "shift", "predicted_func_score", or "beta".
    aggregate_func : str, optional
        The function to aggregate the mutational parameter values
        between dataset fits. The default is "mean".
    inner_merge_dataset_muts : bool, optional
        Whether to toss mutations which are _not_ shared across all datasets
        before aggregation of group mutation parameter values.
        The default is True.
    times_seen_threshold : int, optional
        The minimum number of times a mutation must be seen across
        all conditions within a single fit to be included in the
        aggregation. The default is 0.
    phenotype_as_effect : bool, optional
        Passed to `Model.get_mutations_df()`,
        Only applies if `mut_param="predicted_func_score"`.
    **kwargs : dict
        Keyword arguments to pass to
        :func:`multidms.plot._lineplot_and_heatmap`.
    
    Returns
    -------
    altair.Chart
        A chart object which can be displayed in a jupyter notebook
        or saved to a file.

Here, we create the interactive chart by feeding the function a dictionary containing the two replicate models, and specifying:

  1. times_seen_threshold = 1, meaning for a mutation to be included, it must be seen at least once in every condition

  2. inlcude_beta = False, we only wish to see the shifted parameters visualized, not the respective effect (beta) parameters. (note that respective effect values will be added as a tooltip when hovering over any shift mutation).

To view the chart, toggle the output of the cell below.

mc = multidms.ModelCollection(models.drop(columns="replicate"))
chart = mc.mut_param_heatmap(query=f"scale_coeff_lasso_shift == {chosen_lasso_strength}", times_seen_threshold=times_seen_threshold)
chart.save(f"{output_dir}/interactive_shift_chart.html")
chart
Hide output
cache miss - this could take a moment