biophysical¶
Defines functions for jointly
modeling global epistasis biophysical models -
as well as the (“private”) objective functions required
for parameter optimization in multidms.model
.
multidms.Model Objects are defined with parameters which take references to functions (such as the ones defined here) as arguments. The object then handles the:
parameter initialization/bookeeping
composition of the provided functions using functools.partial
the subsequent jit-compilation, using jax.jit, on the composed model and objective functions
as well as the optimizations of the the model parameters.
This allows for the user/developer to define their own model components
for each of the latent phenotype, global epistasis, and output activation functions.
This module is kept separate from the multidms.model
primarily for the sake of readability and documentation.
This may change in the future.
Note
In order to make use of the jax library, all components must be written in a functional style. This means that all functions must be written to take in all parameters as arguments and return a value with no side effects. See the jax documentation for more details.
- multidms.biophysical.additive_model(d_params: dict, X_d: array)¶
Model for predicting latent phenotype of a set of binary encoded variants \(v\) for a given condition, \(d\) and the corresponding beta ( \(\beta\) ), shift ( \(\Delta_{d, m}\) ), and latent offset ( \(\beta_0, \alpha_d\) ) parameters.
\[\phi_d(v) = \beta_0 + \alpha_d + \sum_{m \in v} (\beta_{m} + \Delta_{d, m})\]- Parameters:
d_params (dict) – Dictionary of model defining parameters as jax arrays. note that shape of the parameters must be compatible with the input data.
X_d (array-like) – Binary encoded mutations for a given set of variants from condition, \(d\).
- Returns:
Predicted latent phenotypes for each row in
X_d
- Return type:
jnp.array
- multidms.biophysical.sigmoidal_global_epistasis(theta: dict, z_d: array)¶
A flexible sigmoid function for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores using the scaling parameters \(\theta_{\text{scale}}\) and \(\theta_{\text{bias}}\) such that:
\[g(z) = \frac{\theta_{\text{scale}}}{1 + e^{-z}} + \theta_{\text{bias}}\]Note
this function is independent from the experimental condition from which a variant is observed.
- Parameters:
theta (dict) – Dictionary of model defining parameters as jax arrays.
z_d (jnp.array) – Latent phenotype values for a given set of variants
- Returns:
Predicted functional scores for each latent phenotype in
z_d
.- Return type:
jnp.array
- multidms.biophysical.softplus_global_epistasis(theta: dict, z_d: array)¶
A flexible softplus function for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores such that
\[g(z) = -\theta_\text{scale}\log\left(1+e^{-z}\right) + \theta_\text{bias}\]Note
This function has no natural lower bound, thus, it is recommended you use this model in conjuction with an output activation such as
softplus_activation()
- Parameters:
theta (dict) – Dictionary of model defining parameters as jax arrays.
z_d (jnp.array) – Latent phenotype values for a given set of variants
- Returns:
Predicted functional scores for each latent phenotype in
z_d
.- Return type:
jnp.array
- multidms.biophysical.nn_global_epistasis(theta: dict, z_d: array)¶
A single-layer neural network for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores.
For this option, the user defines a number of units in the singular hidden layer of the model. For each hidden unit, we introduce three parameters (two weights and a bias) to be inferred. All weights are clipped at zero to maintain assumptions of monotonicity in the resulting epistasis function shape. The network applies a sigmoid activation to each internal unit before a final transformation and addition of a constant gives us our predicted functional score.
More concretely, given latent phenotype, \(\phi_d(v) = z\), let
\[g(z) = b^{o}+ \sum_{i}^{n} \frac{w^{o}_{i}}{1 + e^{w^{l}_{i}*z + b^{l}_{i}}}\]Where:
\(n\) is the number of units in the hidden layer.
\(w^{l}_{i}\) and \(w^{o}_{i}\) are free parameters representing latent and output tranformations, respectively, associated with unit i in the hidden layer of the network.
\(b^{l}_{i}\) is a free parameter, as an added bias term to unit i.
\(b^{o}\) is a constant, singular free parameter.
Note
This is an advanced feature and we advise against its use unless the other options are not sufficiently parameterized for particularly complex experimental conditions.
- Parameters:
theta (dict) – Dictionary of model defining parameters as jax arrays.
z_d (jnp.array) – Latent phenotype values for a given set of variants
- Returns:
Predicted functional scores for each latent phenotype in
z_d
.- Return type:
jnp.array
- multidms.biophysical.identity_activation(d_params, act, **kwargs)¶
Identity function \(f(x)=x\). Mostly a ghost function which helps compose the model when you don’t want to use any final output activation e.g. you don’t have a pre-defined lower bound.
- multidms.biophysical.softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwargs)¶
A modified softplus that hinges at given lower bound. The rate of change at the hinge is defined by ‘hinge_scale’.
In essence, this is a modified _softplus_ activation, (\(\text{softplus}(x)=\log(1 + e^{x})\)) with a lower bound at \(l + \gamma_{h}\), as well as a ramping coefficient, \(\lambda_{\text{sp}}\).
Concretely, if we let \(z' = g(\phi_d(v))\), then the predicted functional score of our model is given by:
\[t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l\]Functionally speaking, this truncates scores below a lower bound, while leaving scores above (mostly) unaltered. There is a small range of input values where the function smoothly transitions between a flat regime (where data is truncated) and a linear regime (where data is not truncated).
Note
This is derived from https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html
- Parameters:
d_params (dict) – Dictionary of model defining parameters as jax arrays.
act (jnp.array) – Activations to apply the softplus function to.
lower_bound (float) – Lower bound to hinge the softplus function at.
hinge_scale (float) – Rate of change at the hinge point.
kwargs (dict) – Additional keyword arguments to pass to the biophysical model function
- Returns:
Transformed activations.
- Return type:
jnp.array
- multidms.biophysical.proximal_box_constraints(params, hyperparameters, *args, **kwargs)¶
Proximal operator for box constraints for single condition models.
Note that *args, and **kwargs are placeholders for additional arguments that may be passed to this function by the optimizer.
- multidms.biophysical.proximal_objective(Dop, params, hyperparameters, scaling=1.0)¶
ADMM generalized lasso optimization.
- multidms.biophysical.smooth_objective(f, params, data, scale_coeff_ridge_beta=0.0, scale_coeff_ridge_ge_scale=0.0, scale_coeff_ridge_ge_bias=0.0, huber_scale=1, **kwargs)¶
Cost (Objective) function summed across all conditions
- Parameters:
f (function) – Biophysical model function
params (dict) – Dictionary of parameters to optimize
data (tuple) – Tuple of (X, y) data where each are dictionaries keyed by condition, return the respective binarymap and the row associated target functional scores
huber_scale (float) – Scale parameter for Huber loss function
scale_coeff_ridge_beta (float) – Ridge penalty coefficient for shift parameters
scale_coeff_ridge_ge_scale (float) – Ridge penalty coefficient for global epistasis scale parameter
scale_coeff_ridge_ge_bias (float) – Ridge penalty coefficient for global epistasis bias parameter
kwargs (dict) – Additional keyword arguments to pass to the biophysical model function
- Returns:
loss – Summed loss across all conditions.
- Return type: