model_collection¶
Contains the ModelCollection
class, which takes a collection of models
and merges the results for comparison and visualization.
- multidms.model_collection.fit_one_model(dataset, epistatic_model='Sigmoid', output_activation='Identity', init_theta_scale=6.5, init_theta_bias=-3.5, init_beta_variance=1.0, n_hidden_units=5, lower_bound=None, PRNGKey=0, verbose=False, **kwargs)¶
Fit a multidms model to a dataset. This is a wrapper around the multidms fit method that allows for easy specification of the fit parameters. This method is helpful for comparing and organizing multiple fits.
- Parameters:
dataset (
multidms.Data
) – The dataset to fit to. For bookkeeping and downstream analysis, the name of the dataset (Data.name) is saved in the fit attributes that are returned.epistatic_model (str, optional) – The epistatic model to use. The default is “Identity”.
output_activation (str, optional) – The output activation function to use. The default is “Identity”.
init_theta_scale (float, optional) – The scale to use for initializing the model parameters. The default is 6.5.
init_theta_bias (float, optional) – The bias to use for initializing the model parameters. The default is -3.5.
init_beta_variance (float, optional) – The variance to use for initializing the model’s beta parameters from a normal distribution. The default is 1.0.
n_hidden_units (int, optional) – The number of hidden units to use in the neural network model. The default is 5.
lower_bound (float, optional) – The lower bound for use with the softplus activation function. The default is None, but must be specified if using the softplus activation.
PRNGKey (int, optional) – The PRNGKey to use to initialize model parameters. The default is 0.
verbose (bool, optional) – Whether to print out information about the fit to stdout. The default is False.
**kwargs (dict) – Additional keyword arguments to pass to the multidms.Model.fit method.
- Returns:
fit_series – A series containing reference to the fit multidms.Model object and the associated parameters used for the fit. These consist mostly of the keyword arguments passed to this function, less “verbose”, and with the addition of: 1. “model” - the fit multidms.Model object reference, 2. “dataset_name” which will simply be the name associated with the Data object used for training (note that the multidms.Data object itself is always accessible via the Model.data attribute). 3. “step_loss” which is a numpy array of the loss at the end of each training epoch.
- Return type:
- multidms.model_collection.stack_fit_models(fit_models_list)¶
given a list of pd.Series objects returned by fit_one_model, stack them into a single pd.DataFrame
- multidms.model_collection.fit_models(params, n_threads=-1, failures='error')¶
Fit collection of
multidms.model.Model
models.Enables fitting of multiple models simultaneously using multiple threads. Most commonly, this function is used to fit a set of models across combinations of replicate training datasets, and lasso coefficients for model selection and evaluation. The returned dataframe is meant to be passed into the
multidms.model_collection.ModelCollection
class for comparison and visualization.- Parameters:
params (dict) – Dictionary which defines the parameter space of all models you wish to run. Each value in the dictionary must be a list of values, even in the case of singletons. This function will compute all combinations of the parameter space and pass each combination to
multidms.utils.fit_one_model()
to be run in parallel, thus only key-value pairs which match the kwargs are allowed. See the docstring ofmultidms.model_collection.fit_one_model()
for a description of the allowed parameters.n_threads (int) – Number of threads (CPUs, cores) to use for fitting. Set to -1 to use all CPUs available.
failures ({"error", "tolerate"}) – What if fitting fails for a model? If “error” then raise an error, if “ignore” then just return None for models that failed optimization.
- Returns:
Number of models that fit successfully, number of models that failed, and a dataframe which contains a row for each of the multidms.Model object references along with the parameters each was fit with for convenience. The dataframe is ultimately meant to be passed into the ModelCollection class. for comparison and visualization.
- Return type:
(n_fit, n_failed, fit_models)
- class multidms.model_collection.ModelCollection(fit_models)¶
Bases:
object
A class for the comparison and visualization of multiple multidms.Model fits. The respective collection of training datasets for each fit must share the same reference sequence and conditions. Additionally, the inferred site maps must agree upon condition wildtypes for all shared sites.
The utility function multidms.model_collection.fit_models is used to fit the collection of models, and the resulting dataframe is passed to the constructor of this class.
- Parameters:
fit_models (
pandas.DataFrame
) – A dataframe containing the fit attributes and pickled model objects as returned by multidms.model_collection.fit_models.
- property reference: str¶
The reference conditions (shared by each fitting dataset) used for fitting.
The mutations shared by each fitting dataset.
- split_apply_combine_muts(groupby=('dataset_name', 'scale_coeff_lasso_shift'), aggregate_func='mean', inner_merge_dataset_muts=True, query=None, **kwargs)¶
wrapper to split-apply-combine the set of mutational dataframes harbored by each of the fits in the collection.
Here, we group the collection of fits using attributes (columns in
ModelCollection.fit_models
) specified using thegroupby
parameter. Each of the individual fits within a groups may then be filtered via**kwargs
, and aggregated viaaggregate_func
, before the function stacks all the groups back together in a tall style dataframe. The resulting dataframe will have a multiindex with the mutation and the groupby attributes.- Parameters:
groupby (str or tuple of str or None, optional) – The attributes to group the fits by. If None, then group by all attributes except for the model, data, and step_loss attributes. The default is (“dataset_name”, “scale_coeff_lasso_shift”).
aggregate_func (str or callable, optional) – The function to aggregate the mutational dataframes within each group. The default is “mean”.
inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.
query (str, optional) – The pandas query to apply to the ModelCollection.fit_models dataframe before splitting. The default is None.
**kwargs (dict) – Keyword arguments to pass to the
multidms.Model.get_mutations_df()
method (“phenotype_as_effect”, and “times_seen_threshold”) see the method docstring for details.
- Returns:
A dataframe containing the aggregated mutational parameter values
- Return type:
- add_validation_loss(test_data, overwrite=False)¶
Add validation loss to the fit collection dataframe.
- Parameters:
test_data (pd.DataFrame or dict(str, pd.DataFrame)) – The testing dataframe to compute validation loss with respect to, must have columns “aa_substitutitions”, “condition”, and “func_score”. If a dictionary is passed, there should be a key for each unique dataset_name factor in the self.fit_models dataframe - with the value being the respective testing dataframe.
overwrite (bool, optional) – Whether to overwrite the validation_loss column if it already exists. The default is False.
- Returns:
The self.fit_models dataframe with the validation loss added.
- Return type:
pd.DataFrame
- get_conditional_loss_df(query=None)¶
return a long form dataframe with columns “dataset_name”, “scale_coeff_lasso_shift”, “split” (“training” or “validation”), “loss” (actual value), and “condition”.
- Parameters:
query (str, optional) – The query to apply to the fit_models dataframe before formatting the loss dataframe. The default is None.
- convergence_trajectory_df(query=None, id_vars=('dataset_name', 'scale_coeff_lasso_shift'))¶
Combine the converence trajectory dataframes of all fits in the queried collection.
- mut_param_heatmap(query=None, mut_param='shift', aggregate_func='mean', inner_merge_dataset_muts=True, times_seen_threshold=0, phenotype_as_effect=True, **kwargs)¶
Create lineplot and heatmap altair chart across replicate datasets. This function optionally applies a given pandas.query on the fit_models dataframe that should result in a subset of fit’s which make sense to aggregate mutational data across, e.g. replicate datasets. It then computes the mean or median mutational parameter value (“beta”, “shift”, or “predicted_func_score”) between the remaining fits. and creates an interactive altair chart.
Note that this will throw an error if the queried fits have more than one unique hyper-parameter besides “dataset_name”.
- Parameters:
query (str) – The query to apply to the fit_models dataframe. This should be used to subset the fits to only those which make sense to aggregate mutational data across, e.g. replicate datasets. For example, if you have a collection of fits with different epistatic models, you may want to query for only those fits with the same epistatic model. e.g. query=”epistatic_model == ‘Sigmoid’”. For more on the query syntax, see the pandas.query documentation.
mut_param (str, optional) – The mutational parameter to plot. The default is “shift”. Must be one of “shift”, “predicted_func_score”, or “beta”.
aggregate_func (str, optional) – The function to aggregate the mutational parameter values between dataset fits. The default is “mean”.
inner_merge_dataset_muts (bool, optional) – Whether to toss mutations which are _not_ shared across all datasets before aggregation of group mutation parameter values. The default is True.
times_seen_threshold (int, optional) – The minimum number of times a mutation must be seen across all conditions within a single fit to be included in the aggregation. The default is 0.
phenotype_as_effect (bool, optional) – Passed to Model.get_mutations_df(), Only applies if mut_param=”predicted_func_score”.
**kwargs (dict) – Keyword arguments to pass to
multidms.plot._lineplot_and_heatmap()
.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file.
- Return type:
altair.Chart
- mut_param_traceplot(mutations, mut_param='shift', x='scale_coeff_lasso_shift', width_scalar=100, height_scalar=100, **kwargs)¶
visualize mutation parameter values across the lasso penalty weights (by default) of a given subset of the mutations in the form of an altair.FacetChart. This is useful when you would like to confirm that a reported mutational parameter value carries through across the individual fits.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file.
- Return type:
altair.Chart
- shift_sparsity(x='scale_coeff_lasso_shift', width_scalar=100, height_scalar=100, return_data=False, **kwargs)¶
Visualize shift parameter set sparsity across the lasso penalty weights (by default) in the form of an altair.FacetChart. We will group the mutations according to their status as either a a “stop” (e.g. A15*), or “nonsynonymous” (e.g. A15G) mutation before calculating the sparsity. This is because in a way, mutations to stop codons act as a False positive rate, as we expect their mutational effect to be equally deleterious in all experiments, and thus have a shift parameter value of zero.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.
- Return type:
altair.Chart or Tuple(pd.DataFrame, altair.Chart)
- mut_param_dataset_correlation(x='scale_coeff_lasso_shift', width_scalar=200, height=200, return_data=False, r=2, **kwargs)¶
Visualize the correlation between replicate datasets across the lasso penalty weights (by default) in the form of an altair.FacetChart. We compute correlation of mutation parameters accross each pair of datasets in the collection.
- Parameters:
x (str, optional) – The parameter to plot on the x-axis. The default is “scale_coeff_lasso_shift”.
width_scalar (int, optional) – The width of the chart. The default is 150.
height (int, optional) – The height of the chart. The default is 200.
return_data (bool, optional) – Whether to return the underlying data. The default is False.
r (int, optional) – The exponential of the correlation coefficient reported. May be either 1 for pearson, 2 for coefficient of determination (r-squared), The default is 2.
**kwargs (dict) – The keyword arguments to pass to the
multidms.model_collection.ModelCollection.split_apply_combine_muts()
method. See the method docstring for details.
- Returns:
A chart object which can be displayed in a jupyter notebook or saved to a file. If return_data=True, then a tuple containing the chart and the underlying data will be returned.
- Return type:
altair.Chart or Tuple(altair.Chart, pd.DataFrame)