biophysical

Defines functions for jointly modeling global epistasis biophysical models - as well as the (“private”) objective functions required for parameter optimization in multidms.model.

multidms.Model Objects are defined with parameters which take references to functions (such as the ones defined here) as arguments. The object then handles the:

  • parameter initialization/bookeeping

  • composition of the provided functions using functools.partial

  • the subsequent jit-compilation, using jax.jit, on the composed model and objective functions

  • as well as the optimizations of the the model parameters.

This allows for the user/developer to define their own model components for each of the latent phenotype, global epistasis, and output activation functions. This module is kept separate from the multidms.model primarily for the sake of readability and documentation. This may change in the future.

Note

In order to make use of the jax library, all components must be written in a functional style. This means that all functions must be written to take in all parameters as arguments and return a value with no side effects. See the jax documentation for more details.

multidms.biophysical.additive_model(d_params: dict, X_d: array)

Model for predicting latent phenotype of a set of binary encoded variants \(v\) for a given condition, \(d\) and the corresponding beta ( \(\beta\) ), shift ( \(\Delta_{d, m}\) ), and latent offset ( \(\beta_0, \alpha_d\) ) parameters.

\[\phi_d(v) = \beta_0 + \alpha_d + \sum_{m \in v} (\beta_{m} + \Delta_{d, m})\]
Parameters:
  • d_params (dict) – Dictionary of model defining parameters as jax arrays. note that shape of the parameters must be compatible with the input data.

  • X_d (array-like) – Binary encoded mutations for a given set of variants from condition, \(d\).

Returns:

Predicted latent phenotypes for each row in X_d

Return type:

jnp.array

multidms.biophysical.sigmoidal_global_epistasis(theta: dict, z_d: array)

A flexible sigmoid function for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores using the scaling parameters \(\theta_{\text{scale}}\) and \(\theta_{\text{bias}}\) such that:

\[g(z) = \frac{\theta_{\text{scale}}}{1 + e^{-z}} + \theta_{\text{bias}}\]

Note

this function is independent from the experimental condition from which a variant is observed.

Parameters:
  • theta (dict) – Dictionary of model defining parameters as jax arrays.

  • z_d (jnp.array) – Latent phenotype values for a given set of variants

Returns:

Predicted functional scores for each latent phenotype in z_d.

Return type:

jnp.array

multidms.biophysical.softplus_global_epistasis(theta: dict, z_d: array)

A flexible softplus function for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores such that

\[g(z) = -\theta_\text{scale}\log\left(1+e^{-z}\right) + \theta_\text{bias}\]

Note

This function has no natural lower bound, thus, it is recommended you use this model in conjuction with an output activation such as softplus_activation()

Parameters:
  • theta (dict) – Dictionary of model defining parameters as jax arrays.

  • z_d (jnp.array) – Latent phenotype values for a given set of variants

Returns:

Predicted functional scores for each latent phenotype in z_d.

Return type:

jnp.array

multidms.biophysical.nn_global_epistasis(theta: dict, z_d: array)

A single-layer neural network for modeling global epistasis. This function takes a set of latent phenotype values, \(z_d\) and computes the predicted functional scores.

For this option, the user defines a number of units in the singular hidden layer of the model. For each hidden unit, we introduce three parameters (two weights and a bias) to be inferred. All weights are clipped at zero to maintain assumptions of monotonicity in the resulting epistasis function shape. The network applies a sigmoid activation to each internal unit before a final transformation and addition of a constant gives us our predicted functional score.

More concretely, given latent phenotype, \(\phi_d(v) = z\), let

\[g(z) = b^{o}+ \sum_{i}^{n} \frac{w^{o}_{i}}{1 + e^{w^{l}_{i}*z + b^{l}_{i}}}\]

Where:

  • \(n\) is the number of units in the hidden layer.

  • \(w^{l}_{i}\) and \(w^{o}_{i}\) are free parameters representing latent and output tranformations, respectively, associated with unit i in the hidden layer of the network.

  • \(b^{l}_{i}\) is a free parameter, as an added bias term to unit i.

  • \(b^{o}\) is a constant, singular free parameter.

Note

This is an advanced feature and we advise against its use unless the other options are not sufficiently parameterized for particularly complex experimental conditions.

Parameters:
  • theta (dict) – Dictionary of model defining parameters as jax arrays.

  • z_d (jnp.array) – Latent phenotype values for a given set of variants

Returns:

Predicted functional scores for each latent phenotype in z_d.

Return type:

jnp.array

multidms.biophysical.identity_activation(d_params, act, **kwargs)

Identity function \(f(x)=x\). Mostly a ghost function which helps compose the model when you don’t want to use any final output activation e.g. you don’t have a pre-defined lower bound.

multidms.biophysical.softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwargs)

A modified softplus that hinges at given lower bound. The rate of change at the hinge is defined by ‘hinge_scale’.

In essence, this is a modified _softplus_ activation, (\(\text{softplus}(x)=\log(1 + e^{x})\)) with a lower bound at \(l + \gamma_{h}\), as well as a ramping coefficient, \(\lambda_{\text{sp}}\).

Concretely, if we let \(z' = g(\phi_d(v))\), then the predicted functional score of our model is given by:

\[t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l\]

Functionally speaking, this truncates scores below a lower bound, while leaving scores above (mostly) unaltered. There is a small range of input values where the function smoothly transitions between a flat regime (where data is truncated) and a linear regime (where data is not truncated).

Parameters:
  • d_params (dict) – Dictionary of model defining parameters as jax arrays.

  • act (jnp.array) – Activations to apply the softplus function to.

  • lower_bound (float) – Lower bound to hinge the softplus function at.

  • hinge_scale (float) – Rate of change at the hinge point.

  • kwargs (dict) – Additional keyword arguments to pass to the biophysical model function

Returns:

Transformed activations.

Return type:

jnp.array

multidms.biophysical.proximal_box_constraints(params, hyperparameters, *args, **kwargs)

Proximal operator for box constraints for single condition models.

Note that *args, and **kwargs are placeholders for additional arguments that may be passed to this function by the optimizer.

multidms.biophysical.proximal_objective(Dop, params, hyperparameters, scaling=1.0)

ADMM generalized lasso optimization.

multidms.biophysical.smooth_objective(f, params, data, scale_coeff_ridge_beta=0.0, scale_coeff_ridge_ge_scale=0.0, scale_coeff_ridge_ge_bias=0.0, huber_scale=1, **kwargs)

Cost (Objective) function summed across all conditions

Parameters:
  • f (function) – Biophysical model function

  • params (dict) – Dictionary of parameters to optimize

  • data (tuple) – Tuple of (X, y) data where each are dictionaries keyed by condition, return the respective binarymap and the row associated target functional scores

  • huber_scale (float) – Scale parameter for Huber loss function

  • scale_coeff_ridge_beta (float) – Ridge penalty coefficient for shift parameters

  • scale_coeff_ridge_ge_scale (float) – Ridge penalty coefficient for global epistasis scale parameter

  • scale_coeff_ridge_ge_bias (float) – Ridge penalty coefficient for global epistasis bias parameter

  • kwargs (dict) – Additional keyword arguments to pass to the biophysical model function

Returns:

loss – Summed loss across all conditions.

Return type:

float