ncalab.training
Submodules
Classes
Early stopping helper class. |
|
Trainer class for any model subclassing BasicNCA. |
|
Stores data about the training progress. Populated during training |
|
Helper class, storing a training / validation data split to generate |
|
Stores a k-fold cross-validation split. |
|
Abstract base class for NCA models. |
|
Stores the result of an NCA prediction, including the number of steps it took. |
|
Base class for tensorboard visuals. |
|
Early stopping helper class. |
|
Sample pool that retains previous predictions. Also applies damaging patterns to |
|
Stores data about the training progress. Populated during training |
|
Trainer class for any model subclassing BasicNCA. |
|
Abstract base class for NCA models. |
|
Encodes last status of the training. |
|
Stores data about the training progress. Populated during training |
Functions
|
Interpret a range parameter that is passed for NCA timesteps. |
|
Pads the BCWH input tensor along its channel dimension to match the expected number of |
|
Panics if x is None, otherwise returns x. |
Package Contents
- class ncalab.training.EarlyStopping(patience: int, min_delta: float = 1e-06)
Early stopping helper class. Helps to stop the training if no change in validation metrics is observed.
- Parameters:
patience (int) – Steps to wait until stopping the training.
min_delta (float) – Minimum deviation until counter is reset, defaults to 1e-6.
- patience
- min_delta = 1e-06
- best_accuracy = 0.0
- counter = 0
- done() bool
Checks whether the training can be stopped.
Needs to be queried in training loop, once per epoch.
- Returns:
Whether to stop the training or not.
- Return type:
bool
- step(accuracy: float)
Increases internal counter if accuracy doesn’t improve, otherwise resets the counter.
Needs to be called in training loop, once per epoch.
- Parameters:
accuracy (float) – Validation accuracy.
- class ncalab.training.BasicNCATrainer(nca: ncalab.models.basicNCA.BasicNCAModel, model_path: pathlib.Path | pathlib.PosixPath | None, gradient_clipping: bool = False, lr: float | None = None, lr_gamma: float = 0.99, adam_betas=(0.9, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adam', pool: ncalab.training.pool.Pool | None = None)
Trainer class for any model subclassing BasicNCA.
- Parameters:
nca (ncalab.BasicNCAModel) – NCA model instance to train.
model_path (Path | PosixPath, optional) – Path to saved models. If None, models are not saved, defaults to None.
gradient_clipping (bool, optional) – Whether to clip gradients, defaults to False.
lr (float, optional) – Initial learning rate, defaults to 16e-4.
lr_gamma (float, optional) – Exponential learning rate decay, defaults to 0.9999.
adam_betas (tuple, optional) – Beta values for Adam optimizer, defaults to (0.9, 0.95).
batch_repeat – How often each batch will be duplicated, dfaults to 2.
max_epochs – Maximum number of epochs in training, defaults to 200.
optimizer_method (str, optional) – Optimization method, defaults to ‘adam’.
pool (ncalab.Pool) – Sample pool object.
- nca
- model_path
- gradient_clipping = False
- lr_gamma = 0.99
- adam_betas = (0.9, 0.95)
- batch_repeat = 2
- max_epochs = 200
- optimizer_method = 'adam'
- pool = None
- info() str
Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes.
- Returns [str]:
Markdown-formatted info string.
- _train_iteration(x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, total_batch_iterations: int, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None) Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]]
Run a single training iteration.
- Parameters:
x – Input training images.
y – Input training labels.
steps – Number of NCA inference time steps.
optimizer – Optimizer.
total_batch_iterations (int) – Total training batch iterations
summary_writer (SummaryWriter, optional) – Tensorboard SummaryWriter
- Returns:
Predicted image.
- Return type:
Tuple[Prediction, Dict[str, torch.Tensor]]
- train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None, dataloader_test: torch.utils.data.DataLoader | None = None, save_every: int | None = None, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None, plot_function: ncalab.visualization.Visual | None = None, earlystopping: ncalab.training.earlystopping.EarlyStopping | None = None) ncalab.training.traininghistory.TrainingHistory
Execute basic NCA training loop with a single function call.
- Parameters:
[DataLoader] (dataloader_val) – Training DataLoader
[DataLoader] – Validation DataLoader
[int] (save_every) – How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model’s default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
- Returns [TrainingHistory]:
TrainingHistory object.
- class ncalab.training.TrainingHistory(path: pathlib.Path | pathlib.PosixPath | None, metrics: Dict[str, float], current_epoch: int, current_model: ncalab.models.BasicNCAModel, best_accuracy: float = 0, best_epoch: int = 0, best_model: ncalab.models.BasicNCAModel | None = None, verbose: bool = True)
Stores data about the training progress. Populated during training with ncalab.training.BasicNCATrainer.
- Parameters:
path (Optional[Path | PosixPath]) – Save and load path.
metrics (Dict[str, float]) – Dict of validation metrics
current_epoch (int) – Current training epoch.
current_model (BasicNCAModel) – Currently trained model.
best_accuracy (float, optional) – Best validation accuracy, defaults to 0
best_epoch (int, optional) – Epoch of best validation accuracy, defaults to 0
best_model (Optional[BasicNCAModel], optional) – Model with best validation accuracy, defaults to None
verbose (bool, optional) – Whether to print updates of validation accuracy, defaults to True
- path
- metrics
- current_epoch
- current_model
- best_accuracy = 0
- best_epoch = 0
- best_model = None
- verbose = True
- created_timestamp
- modified_timestamp
- update(epoch: int, model: ncalab.models.BasicNCAModel, accuracy: float, overwrite: bool = False)
Populates history with current iteration’s values.
Automatically recognizes changes in accuracy.
- Parameters:
epoch (int) – Current epoch
model (BasicNCAModel) – Current model
accuracy (float) – Current accuracy, based on model’s validation metric
overwrite (bool, optional) – Whether to overwrite best accuracy even with no improvement, defaults to False
- save()
Saves history and model checkpoint.
- to_dict() Dict
Return dict of recorded values
- Returns:
Dict of recorded values
- Return type:
Dict
- class ncalab.training.TrainValRecord(train: List[str], val: List[str])
Helper class, storing a training / validation data split to generate respective DataLoader objects.
- Parameters:
train (List[str]) – List of training image file paths
val (List[str]) – List of validation image file paths
- train
- val
- dataloaders(DatasetType: Type, path: pathlib.Path | pathlib.PosixPath, transform=None, batch_sizes=None)
Generate a pair of training and validation DataLoader objects, based on a given DataSet subtype.
- class ncalab.training.SplitDefinition
Stores a k-fold cross-validation split.
- folds = []
- dataloader_test = None
- static read(path: pathlib.PosixPath) SplitDefinition
Reads json files with split definitions, similar to those created by nnUNet.
Format is like
[ { "train": [ "filename0", "filename1",... ] "val": [ "filename2", "filename3",... ] }, { ... } ]
- Parameters:
path – Path to JSON file containing split definition.
- Returns:
SplitDefinition object
- Return type:
- __len__() int
- __getitem__(idx) TrainValRecord
- class ncalab.training.KFoldCrossValidationTrainer(trainer: ncalab.training.trainer.BasicNCATrainer, split: SplitDefinition)
- Parameters:
[BasicNCATrainer] (trainer) – BasicNCATrainer, to train each individual fold.
[SplitDefinition] (split) – Definition of the split used for k-fold cross-training.
- trainer
- model_prototype
- model_name
- split
- train(DatasetType: Type, datapath: pathlib.Path | pathlib.PosixPath, transform, batch_sizes: None | Dict = None, save_every: int | None = None) List[ncalab.training.traininghistory.TrainingHistory]
Run training loop with a single function call.
- Parameters:
[Type] (DatasetType) – Type of dataset class to use.
[Path] (datapath) – _description_
transform – Data transform, e.g. initialized via Albumentations.
batch_sizes – Dict of batch sizes per set, e.g. {“train”: 8, “val”: 16}. Defaults to None.
[int] (save_every) – _description_. Defaults to None.
plot_function – Plot function override. If None, use model’s default. Defaults to None.
- Returns [List[TrainingHistory]]:
List of TrainingHistory objects, one per fold.
- class ncalab.training.BasicNCAModel(device: torch.device, num_image_channels: int, num_hidden_channels: int, num_output_channels: int, plot_function: ncalab.visualization.Visual | None = None, validation_metric: str | None = None, fire_rate: float = 0.5, hidden_size: int = 128, use_alive_mask: bool = False, immutable_image_channels: bool = True, num_learned_filters: int = 2, filter_padding: str = 'reflect', use_laplace: bool = False, kernel_size: int = 3, pad_noise: bool = False, use_temporal_encoding: bool = False, rule_type: type[ncalab.models.basicNCA.basicNCArule.BasicNCARule] = BasicNCARule, training_timesteps: int | Tuple[int, int] = 100, inference_timesteps: int | Tuple[int, int] = 100)
Bases:
torch.nn.ModuleAbstract base class for NCA models.
BasicNCAModel is a composition of an NCA backbone model (called “rule”), and an (optional) head module for downstream tasks.
- Parameters:
device – Pytorch device descriptor.
num_image_channels – Number of channels reserved for input image.
num_hidden_channels – Number of hidden channels (communication channels).
num_output_channels – Number of output channels.
fire_rate – Fire rate for stochastic weight update. Defaults to 0.5.
hidden_size – Number of neurons in hidden layer. Defaults to 128.
use_alive_mask – Whether to use alive masking (channel 3) during training. Defaults to False.
immutable_image_channels – If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True.
num_learned_filters – Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
filter_padding – Padding type to use. Might affect reliance on spatial cues. Defaults to “circular”.
use_laplace – Whether to use Laplace filter (only if num_learned_filters == 0)
kernel_size – Filter kernel size (only for learned filters)
pad_noise – Whether to pad input image tensor with noise in hidden / output channels
- device
- num_image_channels
- num_output_channels
- num_channels
- fire_rate = 0.5
- use_alive_mask = False
- immutable_image_channels = True
- num_learned_filters = 2
- use_laplace = False
- kernel_size = 3
- filter_padding = 'reflect'
- pad_noise = False
- use_temporal_encoding = False
- plot_function = None
- validation_metric = None
- training_timesteps = 100
- inference_timesteps = 100
- perception
- input_vector_size
- rule_type
- rule
- head: ncalab.models.basicNCA.basicNCAhead.BasicNCAHead | None = None
- _define_rule() ncalab.models.basicNCA.basicNCArule.BasicNCARule
- prepare_input(x: torch.Tensor) torch.Tensor
Preprocess input. Intended to be overwritten by subclass, if preprocessing is necessary.
- Parameters:
[torch.Tensor] (x) – Input tensor to preprocess.
- Returns:
Processed tensor.
- _alive(x)
- _update(x: torch.Tensor, step: int) torch.Tensor
Compute residual cell update.
- Parameters:
[torch.Tensor] (x) – Input tensor, BCWH
[int] (step) – Current timestep, required for computing temporal encoding.
- Returns:
Residual cell update, BCWH.
- _forward_step(x: torch.Tensor, step: int)
- forward(x: torch.Tensor, steps: int = 1) ncalab.prediction.Prediction
- Parameters:
[torch.Tensor] (x) – Input image, padded along the channel dimension, BCWH.
[int] (steps) – Time steps in forward pass.
- Returns [Prediction]:
Prediction object.
- loss(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, torch.Tensor]
Compute loss. Needs to be overloaded by any subclass. Please note that the returned dict needs to hold “total” key in which the total loss is stored, which is typically a weighted sum of other losses. The total loss is backpropagated, whereas the other losses are sent to tensorboard.
- Parameters:
[torch.Tensor] (label) – Input image, BCWH.
[torch.Tensor] – Ground truth, BCWH.
- Returns:
Dictionary of identifiers mapped to computed losses.
- finetune(freeze_head: bool = False)
Prepare model for fine tuning by freezing everything except the final layer, and setting to “train” mode.
- Param:
freeze_head
- metrics(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, float]
Return dict of standard evaluation metrics, given a prediction and corresponding ground truth label.
- Parameters:
pred (Prediction) – Prediction.
label (torch.Tensor) – Ground truth label.
- Returns:
Dict of metrics, mapped by their names.
- Return type:
Dict[str, float]
- predict(image: torch.Tensor, steps: int = 100) ncalab.prediction.Prediction
Make an NCA prediction, performing multiple forward passes to yield a final result.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
steps (int) – Time steps
- Returns:
Prediction object.
- Return type:
- record(image: torch.Tensor, steps: int | None = None) List[ncalab.prediction.Prediction]
Record predictions for all time steps and return the resulting sequence of predictions.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
- Returns:
List of Prediction objects.
- Return type:
List[Prediction]
- validate(image: torch.Tensor, label: torch.Tensor, steps: int | None = None) Tuple[Dict[str, float], ncalab.prediction.Prediction] | None
Make a prediction on an image of the validation set and return metrics computed with respect to a labelled validation image.
- Parameters:
[torch.Tensor] (label) – Input image, BCWH
[torch.Tensor] – Ground truth label
[int] (steps) – Inference steps
- Returns [Tuple[float, Prediction]]:
Validation metric, predicted image BCWH
- _to_dict() Dict[str, Any]
- to_dict() Dict[str, Any]
- num_trainable_parameters() int
Returns the number of trainable model parameters.
- Returns:
Number of trainable parameters.
- Return type:
int
- class ncalab.training.Prediction(model, steps: int, output_image: torch.Tensor, head_prediction: torch.Tensor | None = None)
Stores the result of an NCA prediction, including the number of steps it took.
Sequences are typically stored by BasicNCAModel’s “record” function, and are returned as a list of Prediction objects.
Constructor is typically not called explicitly. Rather, the forward pass of BasicNCAModel (and its subclasses) is responsible for filling its attributes.
- Parameters:
model (ncalab.BasicNCAModel) – Reference to model used for prediction.
steps (int) – Number of steps taken for the prediction.
output_image (torch.Tensor) – Output image tensor.
- model
- steps
- output_image
- _output_array: numpy.ndarray | None = None
- head_prediction = None
- _head_prediction_array: numpy.ndarray | None = None
- property image_channels: torch.Tensor
Convenience property to access the image channels as a Tensor.
- Returns:
BCWH Tensor
- Return type:
torch.Tensor
Convenience property to access the hidden channels as a Tensor.
- Returns:
BCWH Tensor
- Return type:
torch.Tensor
- property output_channels: torch.Tensor
Convenience property to access the output channels as a Tensor.
- Returns:
BCWH Tensor
- Return type:
torch.Tensor
- property output_array: numpy.ndarray
Convenience property to access the whole output image in the format of a numpy array. Brings the entire tensor to CPU on demand, and only at the first call.
- Returns:
Numpy array in BCWH format
- Return type:
np.ndarray
- property image_channels_np: numpy.ndarray
Convenience property to access the output image channels in the format of a numpy array. Brings the entire tensor to CPU on demand, and only at the first call.
- Returns:
Numpy array in BCWH format
- Return type:
np.ndarray
Convenience property to access the hidden image channels in the format of a numpy array. Brings the entire tensor to CPU on demand, and only at the first call.
- Returns:
Numpy array in BCWH format
- Return type:
np.ndarray
- property output_channels_np: numpy.ndarray
Convenience property to access the image’s output channels in the format of a numpy array. Brings the entire tensor to CPU on demand, and only at the first call.
- Returns:
Numpy array in BCWH format
- Return type:
np.ndarray
- property head_prediction_array: numpy.ndarray | None
- ncalab.training.intepret_range_parameter(x: int | Tuple[int, int]) int
Interpret a range parameter that is passed for NCA timesteps.
If the parameter is a single int, just return it as is. If the parameter is a two-valued tuple, interpret it as a [min,max) and randomly sample from that range.
- Parameters:
x (int | Tuple[int, int]) – _description_
- Raises:
TypeError – If something else than an int or a tuple was passed.
- Returns:
_description_
- Return type:
int
- ncalab.training.pad_input(x: torch.Tensor, nca: ncalab.models.BasicNCAModel, noise: bool = True, mean: float = 0.5, std: float = 0.225) torch.Tensor
Pads the BCWH input tensor along its channel dimension to match the expected number of channels required by the NCA model. Pads with either Gaussian noise (parameterized by mean and std) or zeros, depending on the “noise” parameter.
- Parameters:
x (torch.Tensor) – Input image tensor, BCWH.
nca (ncalab.BasicNCAModel) – NCA model definition.
noise (bool, optional) – Whether to pad with noise. Otherwise zeros, defaults to True.
mean (float, optional) – Mean (mu) of Gaussian noise distribution, defaults to 0.5.
std (float, optional) – Standard deviation (sigma) of Gaussian noise distribution, defaults to 0.225.
- Returns:
Input tensor, BCWH, padded along the channel dimension.
- Return type:
torch.Tensor
- ncalab.training.unwrap(x: Any)
Panics if x is None, otherwise returns x.
This is a useful shorthand for cases such as
x = unwrap(some_object).do_something()in which we are 99% certain that some_object is not None and want to avoid a mypy complaint.- Parameters:
x (Any) – Any kind of object.
- Raises:
RuntimeError – If x is None.
- Returns:
Just passes through the input x if it is not None.
- class ncalab.training.Visual
Base class for tensorboard visuals.
- show(model, image: numpy.ndarray, prediction: ncalab.prediction.Prediction, label: numpy.ndarray) matplotlib.figure.Figure
- class ncalab.training.EarlyStopping(patience: int, min_delta: float = 1e-06)
Early stopping helper class. Helps to stop the training if no change in validation metrics is observed.
- Parameters:
patience (int) – Steps to wait until stopping the training.
min_delta (float) – Minimum deviation until counter is reset, defaults to 1e-6.
- patience
- min_delta = 1e-06
- best_accuracy = 0.0
- counter = 0
- done() bool
Checks whether the training can be stopped.
Needs to be queried in training loop, once per epoch.
- Returns:
Whether to stop the training or not.
- Return type:
bool
- step(accuracy: float)
Increases internal counter if accuracy doesn’t improve, otherwise resets the counter.
Needs to be called in training loop, once per epoch.
- Parameters:
accuracy (float) – Validation accuracy.
- class ncalab.training.Pool(n_seed: int = 1, damage: bool = False, p_damage: float = 0.2)
Sample pool that retains previous predictions. Also applies damaging patterns to images to increase the robustness of the trained NCA.
- Parameters:
n_seed (int, optional) – How many seed images to retain, defaults to 1
damage (bool, optional) – Whether to apply damaging patterns, defaults to False
p_damage (float, optional) – Probability at which a damaging pattern is applied, defaults to 0.2
- n_seed = 1
- damage = False
- batch: torch.Tensor | None = None
- p_damage = 0.2
- update(batch: torch.Tensor)
- Parameters:
batch – BCWH
- sample(seed: torch.Tensor) torch.Tensor
- Parameters:
seed – BCWH
- Returns:
BCWH
- class ncalab.training.TrainingHistory(path: pathlib.Path | pathlib.PosixPath | None, metrics: Dict[str, float], current_epoch: int, current_model: ncalab.models.BasicNCAModel, best_accuracy: float = 0, best_epoch: int = 0, best_model: ncalab.models.BasicNCAModel | None = None, verbose: bool = True)
Stores data about the training progress. Populated during training with ncalab.training.BasicNCATrainer.
- Parameters:
path (Optional[Path | PosixPath]) – Save and load path.
metrics (Dict[str, float]) – Dict of validation metrics
current_epoch (int) – Current training epoch.
current_model (BasicNCAModel) – Currently trained model.
best_accuracy (float, optional) – Best validation accuracy, defaults to 0
best_epoch (int, optional) – Epoch of best validation accuracy, defaults to 0
best_model (Optional[BasicNCAModel], optional) – Model with best validation accuracy, defaults to None
verbose (bool, optional) – Whether to print updates of validation accuracy, defaults to True
- path
- metrics
- current_epoch
- current_model
- best_accuracy = 0
- best_epoch = 0
- best_model = None
- verbose = True
- created_timestamp
- modified_timestamp
- update(epoch: int, model: ncalab.models.BasicNCAModel, accuracy: float, overwrite: bool = False)
Populates history with current iteration’s values.
Automatically recognizes changes in accuracy.
- Parameters:
epoch (int) – Current epoch
model (BasicNCAModel) – Current model
accuracy (float) – Current accuracy, based on model’s validation metric
overwrite (bool, optional) – Whether to overwrite best accuracy even with no improvement, defaults to False
- save()
Saves history and model checkpoint.
- to_dict() Dict
Return dict of recorded values
- Returns:
Dict of recorded values
- Return type:
Dict
- class ncalab.training.BasicNCATrainer(nca: ncalab.models.basicNCA.BasicNCAModel, model_path: pathlib.Path | pathlib.PosixPath | None, gradient_clipping: bool = False, lr: float | None = None, lr_gamma: float = 0.99, adam_betas=(0.9, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adam', pool: ncalab.training.pool.Pool | None = None)
Trainer class for any model subclassing BasicNCA.
- Parameters:
nca (ncalab.BasicNCAModel) – NCA model instance to train.
model_path (Path | PosixPath, optional) – Path to saved models. If None, models are not saved, defaults to None.
gradient_clipping (bool, optional) – Whether to clip gradients, defaults to False.
lr (float, optional) – Initial learning rate, defaults to 16e-4.
lr_gamma (float, optional) – Exponential learning rate decay, defaults to 0.9999.
adam_betas (tuple, optional) – Beta values for Adam optimizer, defaults to (0.9, 0.95).
batch_repeat – How often each batch will be duplicated, dfaults to 2.
max_epochs – Maximum number of epochs in training, defaults to 200.
optimizer_method (str, optional) – Optimization method, defaults to ‘adam’.
pool (ncalab.Pool) – Sample pool object.
- nca
- model_path
- gradient_clipping = False
- lr_gamma = 0.99
- adam_betas = (0.9, 0.95)
- batch_repeat = 2
- max_epochs = 200
- optimizer_method = 'adam'
- pool = None
- info() str
Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes.
- Returns [str]:
Markdown-formatted info string.
- _train_iteration(x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, total_batch_iterations: int, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None) Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]]
Run a single training iteration.
- Parameters:
x – Input training images.
y – Input training labels.
steps – Number of NCA inference time steps.
optimizer – Optimizer.
total_batch_iterations (int) – Total training batch iterations
summary_writer (SummaryWriter, optional) – Tensorboard SummaryWriter
- Returns:
Predicted image.
- Return type:
Tuple[Prediction, Dict[str, torch.Tensor]]
- train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None, dataloader_test: torch.utils.data.DataLoader | None = None, save_every: int | None = None, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None, plot_function: ncalab.visualization.Visual | None = None, earlystopping: ncalab.training.earlystopping.EarlyStopping | None = None) ncalab.training.traininghistory.TrainingHistory
Execute basic NCA training loop with a single function call.
- Parameters:
[DataLoader] (dataloader_val) – Training DataLoader
[DataLoader] – Validation DataLoader
[int] (save_every) – How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model’s default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
- Returns [TrainingHistory]:
TrainingHistory object.
- class ncalab.training.BasicNCAModel(device: torch.device, num_image_channels: int, num_hidden_channels: int, num_output_channels: int, plot_function: ncalab.visualization.Visual | None = None, validation_metric: str | None = None, fire_rate: float = 0.5, hidden_size: int = 128, use_alive_mask: bool = False, immutable_image_channels: bool = True, num_learned_filters: int = 2, filter_padding: str = 'reflect', use_laplace: bool = False, kernel_size: int = 3, pad_noise: bool = False, use_temporal_encoding: bool = False, rule_type: type[ncalab.models.basicNCA.basicNCArule.BasicNCARule] = BasicNCARule, training_timesteps: int | Tuple[int, int] = 100, inference_timesteps: int | Tuple[int, int] = 100)
Bases:
torch.nn.ModuleAbstract base class for NCA models.
BasicNCAModel is a composition of an NCA backbone model (called “rule”), and an (optional) head module for downstream tasks.
- Parameters:
device – Pytorch device descriptor.
num_image_channels – Number of channels reserved for input image.
num_hidden_channels – Number of hidden channels (communication channels).
num_output_channels – Number of output channels.
fire_rate – Fire rate for stochastic weight update. Defaults to 0.5.
hidden_size – Number of neurons in hidden layer. Defaults to 128.
use_alive_mask – Whether to use alive masking (channel 3) during training. Defaults to False.
immutable_image_channels – If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True.
num_learned_filters – Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
filter_padding – Padding type to use. Might affect reliance on spatial cues. Defaults to “circular”.
use_laplace – Whether to use Laplace filter (only if num_learned_filters == 0)
kernel_size – Filter kernel size (only for learned filters)
pad_noise – Whether to pad input image tensor with noise in hidden / output channels
- device
- num_image_channels
- num_hidden_channels
- num_output_channels
- num_channels
- fire_rate = 0.5
- hidden_size = 128
- use_alive_mask = False
- immutable_image_channels = True
- num_learned_filters = 2
- use_laplace = False
- kernel_size = 3
- filter_padding = 'reflect'
- pad_noise = False
- use_temporal_encoding = False
- plot_function = None
- validation_metric = None
- training_timesteps = 100
- inference_timesteps = 100
- perception
- input_vector_size
- rule_type
- rule
- head: ncalab.models.basicNCA.basicNCAhead.BasicNCAHead | None = None
- _define_rule() ncalab.models.basicNCA.basicNCArule.BasicNCARule
- prepare_input(x: torch.Tensor) torch.Tensor
Preprocess input. Intended to be overwritten by subclass, if preprocessing is necessary.
- Parameters:
[torch.Tensor] (x) – Input tensor to preprocess.
- Returns:
Processed tensor.
- _alive(x)
- _update(x: torch.Tensor, step: int) torch.Tensor
Compute residual cell update.
- Parameters:
[torch.Tensor] (x) – Input tensor, BCWH
[int] (step) – Current timestep, required for computing temporal encoding.
- Returns:
Residual cell update, BCWH.
- _forward_step(x: torch.Tensor, step: int)
- forward(x: torch.Tensor, steps: int = 1) ncalab.prediction.Prediction
- Parameters:
[torch.Tensor] (x) – Input image, padded along the channel dimension, BCWH.
[int] (steps) – Time steps in forward pass.
- Returns [Prediction]:
Prediction object.
- loss(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, torch.Tensor]
Compute loss. Needs to be overloaded by any subclass. Please note that the returned dict needs to hold “total” key in which the total loss is stored, which is typically a weighted sum of other losses. The total loss is backpropagated, whereas the other losses are sent to tensorboard.
- Parameters:
[torch.Tensor] (label) – Input image, BCWH.
[torch.Tensor] – Ground truth, BCWH.
- Returns:
Dictionary of identifiers mapped to computed losses.
- finetune(freeze_head: bool = False)
Prepare model for fine tuning by freezing everything except the final layer, and setting to “train” mode.
- Param:
freeze_head
- metrics(pred: ncalab.prediction.Prediction, label: torch.Tensor) Dict[str, float]
Return dict of standard evaluation metrics, given a prediction and corresponding ground truth label.
- Parameters:
pred (Prediction) – Prediction.
label (torch.Tensor) – Ground truth label.
- Returns:
Dict of metrics, mapped by their names.
- Return type:
Dict[str, float]
- predict(image: torch.Tensor, steps: int = 100) ncalab.prediction.Prediction
Make an NCA prediction, performing multiple forward passes to yield a final result.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
steps (int) – Time steps
- Returns:
Prediction object.
- Return type:
- record(image: torch.Tensor, steps: int | None = None) List[ncalab.prediction.Prediction]
Record predictions for all time steps and return the resulting sequence of predictions.
- Parameters:
image (torch.Tensor) – Input image, BCWH.
- Returns:
List of Prediction objects.
- Return type:
List[Prediction]
- validate(image: torch.Tensor, label: torch.Tensor, steps: int | None = None) Tuple[Dict[str, float], ncalab.prediction.Prediction] | None
Make a prediction on an image of the validation set and return metrics computed with respect to a labelled validation image.
- Parameters:
[torch.Tensor] (label) – Input image, BCWH
[torch.Tensor] – Ground truth label
[int] (steps) – Inference steps
- Returns [Tuple[float, Prediction]]:
Validation metric, predicted image BCWH
- _to_dict() Dict[str, Any]
- to_dict() Dict[str, Any]
- num_trainable_parameters() int
Returns the number of trainable model parameters.
- Returns:
Number of trainable parameters.
- Return type:
int
- class ncalab.training.TrainingStatus(*args, **kwds)
Bases:
enum.EnumEncodes last status of the training.
- STATUS_NONE = 0
- STATUS_RUNNING = 1
- STATUS_DONE = 2
- class ncalab.training.TrainingHistory(path: pathlib.Path | pathlib.PosixPath | None, metrics: Dict[str, float], current_epoch: int, current_model: ncalab.models.BasicNCAModel, best_accuracy: float = 0, best_epoch: int = 0, best_model: ncalab.models.BasicNCAModel | None = None, verbose: bool = True)
Stores data about the training progress. Populated during training with ncalab.training.BasicNCATrainer.
- Parameters:
path (Optional[Path | PosixPath]) – Save and load path.
metrics (Dict[str, float]) – Dict of validation metrics
current_epoch (int) – Current training epoch.
current_model (BasicNCAModel) – Currently trained model.
best_accuracy (float, optional) – Best validation accuracy, defaults to 0
best_epoch (int, optional) – Epoch of best validation accuracy, defaults to 0
best_model (Optional[BasicNCAModel], optional) – Model with best validation accuracy, defaults to None
verbose (bool, optional) – Whether to print updates of validation accuracy, defaults to True
- path
- metrics
- current_epoch
- current_model
- best_accuracy = 0
- best_epoch = 0
- best_model = None
- verbose = True
- created_timestamp
- modified_timestamp
- update(epoch: int, model: ncalab.models.BasicNCAModel, accuracy: float, overwrite: bool = False)
Populates history with current iteration’s values.
Automatically recognizes changes in accuracy.
- Parameters:
epoch (int) – Current epoch
model (BasicNCAModel) – Current model
accuracy (float) – Current accuracy, based on model’s validation metric
overwrite (bool, optional) – Whether to overwrite best accuracy even with no improvement, defaults to False
- save()
Saves history and model checkpoint.
- to_dict() Dict
Return dict of recorded values
- Returns:
Dict of recorded values
- Return type:
Dict