model¶
Defines Model
objects.
- class multidms.model.Model(data: ~multidms.data.Data, epistatic_model=<function sigmoidal_global_epistasis>, output_activation=<function identity_activation>, PRNGKey=0, lower_bound=None, n_hidden_units=5, init_theta_scale=5.0, init_theta_bias=-5.0, init_beta_variance=0.0, name=None)¶
Bases:
object
Represent one or more DMS experiments to obtain tuned parameters that provide insight into individual mutational effects and conditional shifts of those effects on all non-reference conditions. For more see the biophysical model documentation
- Parameters:
data (multidms.Data) – A reference to the dataset which will define the parameters of the model to be fit.
epistatic_model (<class 'function'>) – A function which will transform the latent effects of mutations into a functional score. See the biophysical model documentation for more.
output_activation (<class 'function'>) – A function which will transform the output of the global epistasis function. Defaults to the identity function (no activation). See the biophysical model documentation
conditional_shifts (bool) – If true (default) initialize and fit the shift parameters for each non-reference condition. See Model Description section for more. Defaults to True.
alpha_d (bool) – If True introduce a latent offset parameter for each condition. See the biophysical docs section for more. Defaults to True.
gamma_corrected (bool) – If true (default), introduce the ‘gamma’ parameter for each non-reference parameter to account for differences between wild type behavior relative to its variants. This is essentially a bias added to the functional scores during fitting. See Model Description section for more. Defaults to False.
PRNGKey (int) – The initial seed key for random parameters assigned to Betas and any other randomly initialized parameters. for more.
init_beta_naught (float) – Initialize the latent offset parameter applied to all conditions. See the biophysical docs section for more.
init_theta_scale (float) – Initialize the scaling parameter \(\theta_{\text{scale}}\) of a two-parameter epistatic model (Sigmoid or Softplus).
init_theta_bias (float) – Initialize the bias parameter \(\theta_{\text{bias}}\) of a two parameter epistatic model (Sigmoid or Softplus).
init_beta_variance (float) – Beta parameters are initialized by sampling from a normal distribution. This parameter specifies the variance of the distribution being sampled.
n_hidden_units (int or None) – If using
multidms.biophysical.nn_global_epistasis()
as the epistatic model, this is the number of hidden units used in the transform.lower_bound (float or None) – If using
multidms.biophysical.softplus_activation()
as the output activation, this is the lower bound of the softplus function.name (str or None) – Name of the Model object. If None, will be assigned a unique name based upon the number of data objects instantiated.
Example
To create a
Model
object, all you need is the respectiveData
object for parameter fitting.>>> import multidms >>> from tests.test_data import data >>> model = multidms.Model(data)
Upon initialization, you will now have access to the underlying data and parameters.
>>> model.data.mutations ('M1E', 'M1W', 'G3P', 'G3R') >>> model.data.conditions ('a', 'b') >>> model.data.reference 'a' >>> model.data.condition_colors {'a': '#0072B2', 'b': '#CC79A7'}
The mutations_df and variants_df may of course also be accessed. First, we set pandas to display all rows and columns.
>>> import pandas as pd >>> pd.set_option('display.max_rows', None) >>> pd.set_option('display.max_columns', None)
>>> model.data.mutations_df mutation wts sites muts times_seen_a times_seen_b 0 M1E M 1 E 1 3 1 M1W M 1 W 1 0 2 G3P G 3 P 1 4 3 G3R G 3 R 1 2
However, if accessed directly through the
Model
object, you will get the same information, along with model/parameter specific features included. These are automatically updated each time you request the property.>>> model.get_mutations_df() wts sites muts times_seen_a times_seen_b beta_a beta_b shift_b \ mutation M1E M 1 E 1 3 0.0 0.0 0.0 M1W M 1 W 1 0 0.0 -0.0 0.0 G3P G 3 P 1 4 -0.0 -0.0 -0.0 G3R G 3 R 1 2 -0.0 0.0 -0.0 predicted_func_score_a predicted_func_score_b mutation M1E 0.0 0.0 M1W 0.0 0.0 G3P 0.0 0.0 G3R 0.0 0.0
Notice the respective single mutation effects (
"beta"
), conditional shifts (shift_d
), and predicted functional score (F_d
) of each mutation in the model are now easily accessible. Similarly, we can take a look at the variants_df for the model,>>> model.get_variants_df() condition aa_substitutions func_score var_wrt_ref predicted_latent \ 0 a M1E 2.0 M1E 0.0 1 a G3R -7.0 G3R 0.0 2 a G3P -0.5 G3P 0.0 3 a M1W 2.3 M1W 0.0 4 b M1E 1.0 G3P M1E 0.0 5 b P3R -5.0 G3R 0.0 6 b P3G 0.4 0.0 7 b M1E P3G 2.7 M1E 0.0 8 b M1E P3R -2.7 G3R M1E 0.0 predicted_func_score 0 0.0 1 0.0 2 0.0 3 0.0 4 0.0 5 0.0 6 0.0 7 0.0 8 0.0
We now have access to the predicted (and gamma corrected) functional scores as predicted by the models current parameters.
So far, these parameters and predictions results from them have not been tuned to the dataset. Let’s take a look at the loss on the training dataset given our initialized parameters
>>> model.loss 2.9370000000000003
Next, we fit the model with some chosen hyperparameters.
>>> model.fit(maxiter=10, lasso_shift=1e-5, warn_unconverged=False) >>> model.loss 0.3483478119356665
The model tunes its parameters in place, and the subsequent call to retrieve the loss reflects our models loss given its updated parameters.
- property model_components: frozendict¶
A frozendict which hold the individual components of the model as well as the objective and forward functions.
- property convergence_trajectory_df¶
The state.error through each training iteration. Currentlty, this is reset each time the fit() method is called
- property loss: float¶
Compute un-penalized model loss on all experimental training data without ridge or lasso penalties included.
- property wildtype_df¶
Get a dataframe indexed by condition wildtype containing the prediction features for each.
- get_variants_df(phenotype_as_effect=True)¶
Training data with model predictions for latent, and functional score phenotypes.
- Parameters:
phenotype_as_effect (bool) – if True, phenotypes (both latent, and func_score) are calculated as the _difference_ between predicted phenotype of a given variant and the respective experimental wildtype prediction. Otherwise, report the unmodified model prediction.
- Returns:
A copy of the training data, self.data.variants_df, with the phenotypes added. Phenotypes are predicted based on the current state of the model.
- Return type:
- get_mutations_df(times_seen_threshold=0, phenotype_as_effect=True, return_split=True)¶
Mutation attributes and phenotypic effects based on the current state of the model.
- Parameters:
times_seen_threshold (int, optional) – Only report mutations that have been seen at least this many times in each condition. Defaults to 0.
phenotype_as_effect (bool, optional) – if True, phenotypes are reported as the difference from the conditional wildtype prediction. Otherwise, report the unmodified model prediction.
return_split (bool, optional) – If True, return the split mutations as separate columns: ‘wts’, ‘sites’, and ‘muts’. Defaults to True.
- Returns:
A copy of the mutations data, self.data.mutations_df, with the mutations column set as the index, and columns with the mutational attributes (e.g. betas, shifts) and conditional functional score effect (e.g. ) added.
The columns are ordered as follows: - beta_a, beta_b, … : the latent effect of the mutation - shift_b, shift_c, … : the conditional shift of the mutation - predicted_func_score_a, predicted_func_score_b, … : the
predicted functional score of the mutation.
- Return type:
- get_df_loss(df, error_if_unknown=False, verbose=False, conditional=False)¶
Get the loss of the model on a given data frame.
- Parameters:
df (pandas.DataFrame) – Data frame containing variants. Requirements are the same as those used to initialize the multidms.Data object - except the indices must be unique.
error_if_unknown (bool) – If some of the substitutions in a variant are not present in the model (not in
AbstractEpistasis.binarymap
) then by default we do not include those variants in the loss calculation. If True, raise an error.verbose (bool) – If True, print the number of valid and invalid variants.
conditional (bool) – If True, return the loss for each condition as a dictionary. If False, return the total loss.
- Returns:
The loss of the model on the given data frame.
- Return type:
- add_phenotypes_to_df(df, substitutions_col='aa_substitutions', condition_col='condition', latent_phenotype_col='predicted_latent', observed_phenotype_col='predicted_func_score', converted_substitutions_col='aa_subs_wrt_ref', overwrite_cols=False, unknown_as_nan=False, phenotype_as_effect=True)¶
Add predicted phenotypes to data frame of variants.
- Parameters:
df (pandas.DataFrame) – Data frame containing variants. Requirements are the same as those used to initialize the multidms.Data object - except the indices must be unique.
substitutions_col (str) – Column in df giving variants as substitution strings with respect to a given variants condition. These will be converted to be with respect to the reference sequence prior to prediction. Defaults to ‘aa_substitutions’.
condition_col (str) – Column in df giving the condition from which a variant was observed. Values must exist in the self.data.conditions and and error will be raised otherwise. Defaults to ‘condition’.
latent_phenotype_col (str) – Column added to df containing predicted latent phenotypes.
observed_phenotype_col (str) – Column added to df containing predicted observed phenotypes.
converted_substitutions_col (str or None) – Columns added to df containing converted substitution strings for non-reference conditions if they do not share a wildtype seq.
overwrite_cols (bool) – If the specified latent or observed phenotype column already exist in df, overwrite it? If False, raise an error.
unknown_as_nan (bool) – If some of the substitutions in a variant are not present in the model (not in
AbstractEpistasis.binarymap
) set the phenotypes to nan (not a number)? If False, raise an error.phenotype_as_effect (bool) – if True, phenotypes (both latent, and func_score) are calculated as the _difference_ between predicted phenotype of a given variant and the respective experimental wildtype prediction. Otherwise, report the unmodified model prediction.
- Returns:
A copy of df with the phenotypes added. Phenotypes are predicted based on the current state of the model.
- Return type:
- mutation_site_summary_df(agg_func='mean', **kwargs)¶
Get all single mutational attributes from self._data updated with all model specific attributes, then aggregate all numerical columns by “sites”
- Parameters:
agg_func (str) – Aggregation function to use on the numerical columns. Defaults to “mean”.
**kwargs – Additional keyword arguments to pass to get_mutations_df.
- Returns:
A summary of the mutation attributes aggregated by site.
- Return type:
- get_condition_params(condition=None)¶
Get the relent parameters for a model prediction
- phenotype_fromsubs(aa_subs, condition=None)¶
take a single string of subs which are not already converted wrt reference, convert them and then make a functional score prediction and return the result.
- latent_fromsubs(aa_subs, condition=None)¶
take a single string of subs which are not already converted wrt reference, convert them and them make a latent prediction and return the result.
- phenotype_frombinary(X, condition=None)¶
Condition specific functional score prediction on X using the biophysical model given current model parameters.
- Parameters:
X (jnp.array) – Binary encoded variants to make predictions on.
condition (str) – Condition to make predictions for. If None, use the reference
- latent_frombinary(X, condition=None)¶
Condition specific latent phenotype prediction on X using the biophysical model given current model parameters.
- Parameters:
X (jnp.array) – Binary encoded variants to make predictions on.
condition (str) – Condition to make predictions for. If None, use the reference
- fit(scale_coeff_lasso_shift=1e-05, tol=0.0001, maxiter=1000, maxls=15, acceleration=True, lock_params={}, admm_niter=50, admm_tau=1.0, warn_unconverged=True, upper_bound_ge_scale='infer', convergence_trajectory_resolution=10, **kwargs)¶
Use jaxopt.ProximalGradiant to optimize the model’s free parameters.
- Parameters:
scale_coeff_lasso_shift (float) – L1 penalty coefficient applied “shift” in beta_d parameters. Defaults to 1e-4. This parameter is used to regularize the shift parameters in the model if there’s more than one condition.
tol (float) – Tolerance for the optimization convergence criteria. Defaults to 1e-4.
maxiter (int) – Maximum number of iterations for the optimization. Defaults to 1000.
maxls (int) – Maximum number of iterations to perform during line search.
acceleration (bool) – If True, use FISTA acceleration. Defaults to True.
lock_params (dict) – Dictionary of parameters, and desired value to constrain them at during optimization. By default, no parameters are locked.
admm_niter (int) – Number of iterations to perform during the ADMM optimization. Defaults to 50. Note that in the case of single-condition models, This is set to zero as the generalized lasso ADMM optimization is not used.
admm_tau (float) – ADMM step size. Defaults to 1.0.
warn_unconverged (bool) – If True, raise a warning if the optimization does not converge. convergence is defined by whether the model tolerance (‘’tol’’) threshold was passed during the optimization process. Defaults to True.
upper_bound_ge_scale (float, None, or 'infer') – The positive upper bound of the theta scale parameter - negative values are not allowed. Passing
None
allows the scale of the sigmoid to be unconstrained. Passing the string literal ‘infer’ results in the scale being set to double the range of the training data. Defaults to ‘infer’.convergence_trajectory_resolution (int) – The resolution of the loss and error trajectory recorded during optimization. Defaults to 100.
**kwargs (dict) – Additional keyword arguments passed to the objective function. See the multidms.biophysical.smooth_objective docstring for details on the other hyperparameters that may be supplied to regularize and otherwise modify the objective function being optimized.
- plot_pred_accuracy(hue=True, show=True, saveas=None, annotate_corr=True, ax=None, r=2, **kwargs)¶
Create a figure which visualizes the correlation between model predicted functional score of all variants in the training with ground truth measurements.
- plot_epistasis(hue=True, show=True, saveas=None, ax=None, sample=1.0, **kwargs)¶
Plot latent predictions against gamma corrected ground truth measurements of all samples in the training set.
- plot_param_hist(param, show=True, saveas=False, times_seen_threshold=0, ax=None, **kwargs)¶
Plot the histogram of a parameter.
- plot_param_heatmap(param, show=True, saveas=False, times_seen_threshold=0, ax=None, **kwargs)¶
Plot the heatmap of a parameters associated with specific sites and substitutions.
- plot_shifts_by_site(condition, show=True, saveas=False, times_seen_threshold=0, agg_func='mean', ax=None, **kwargs)¶
Summarize shift parameter values by associated sites and conditions.
- mut_param_heatmap(mut_param='shift', times_seen_threshold=0, phenotype_as_effect=True, **line_and_heat_kwargs)¶
Wrapper method for visualizing the shift plot. see multidms.plot.mut_shift_plot() for more