torchdms.analysis

A wrapper class for training models.

Classes

Analysis

A wrapper class for training models.

class torchdms.analysis.Analysis(model, model_path, val_data, train_data_list, site_dict=None, batch_size=500, learning_rate=0.005, device='cpu')[source]

A wrapper class for training models.

Todo

much more documentation needed

__init__(model, model_path, val_data, train_data_list, site_dict=None, batch_size=500, learning_rate=0.005, device='cpu')[source]
loss_of_targets_and_prediction(loss_fn, targets, predictions, per_target_loss_decay)[source]

Return loss on the valid predictions, i.e. the ones that are not NaN.

complete_loss(loss_fn, targets, predictions, loss_decays)[source]

Compute our total (across targets) loss with regularization.

Here we compute loss separately for each target, before summing the results. This allows for us to take advantage of the samples which may contain missing information for a subset of the targets.

train(epoch_count, loss_fn, patience=10, min_lr=1e-05, loss_weight_span=None, exp_target=None, beta_rank=None)[source]

Train self.model using all the bells and whistles.

multi_train(independent_start_count, independent_start_epoch_count, epoch_count, loss_fn, patience=10, min_lr=1e-05, loss_weight_span=None, exp_target=None, beta_rank=None)[source]

Do pre-training on self.model using the specified number of independent starts, writing the best pre-trained model to the model path, then fully training it.

simple_train(epoch_count, loss_fn)[source]

Bare-bones training of self.model.

This traning doesn’t even handle nans. If you want that behavior, just use self.loss_of_targets_and_prediction rather than loss_fn directly.

We also cat together all of the data rather than getting gradients on a per-stratum basis. If you don’t want this behavior use self.train_infinite_loaders rather than the train_infinite_loaders defined below.