ncalab.search

Submodules

Classes

BasicNCATrainer

Trainer class for any model subclassing BasicNCA.

ParameterSet

ParameterSearch

Package Contents

class ncalab.search.BasicNCATrainer(nca: ncalab.models.basicNCA.AbstractNCAModel, model_path: pathlib.Path | pathlib.PosixPath | None, gradient_clipping: bool = False, lr: float | None = None, lr_gamma: float = 0.99, adam_betas=(0.9, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adam', pool: ncalab.training.pool.Pool | None = None, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None)

Trainer class for any model subclassing BasicNCA.

Parameters:
  • nca (ncalab.AbstractNCAModel) – NCA model instance to train.

  • model_path (Path | PosixPath, optional) – Path to saved models. If None, models are not saved, defaults to None.

  • gradient_clipping (bool, optional) – Whether to clip gradients, defaults to False.

  • lr (float, optional) – Initial learning rate, defaults to 16e-4.

  • lr_gamma (float, optional) – Exponential learning rate decay, defaults to 0.9999.

  • adam_betas (tuple, optional) – Beta values for Adam optimizer, defaults to (0.9, 0.95).

  • batch_repeat – How often each batch will be duplicated, dfaults to 2.

  • max_epochs – Maximum number of epochs in training, defaults to 200.

  • optimizer_method (str, optional) – Optimization method, defaults to ‘adam’.

  • pool (ncalab.Pool) – Sample pool object.

nca
model_path
gradient_clipping = False
lr_gamma = 0.99
adam_betas = (0.9, 0.95)
batch_repeat = 2
max_epochs = 200
optimizer_method = 'adam'
pool = None
lr_scheduler = None
info() str

Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes.

Returns [str]:

Markdown-formatted info string.

_train_iteration(x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, head_optimizer: torch.optim.Optimizer | None, total_batch_iterations: int, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None) Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]]

Run a single training iteration.

Parameters:
  • x – Input training images.

  • y – Input training labels.

  • steps – Number of NCA inference time steps.

  • optimizer – Optimizer.

  • total_batch_iterations (int) – Total training batch iterations

  • summary_writer (SummaryWriter, optional) – Tensorboard SummaryWriter

Returns:

Predicted image.

Return type:

Tuple[Prediction, Dict[str, torch.Tensor]]

train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None, dataloader_test: torch.utils.data.DataLoader | None = None, save_every: int | None = None, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None, plot_function: ncalab.visualization.Visual | None = None, earlystopping: ncalab.training.earlystopping.EarlyStopping | None = None) ncalab.training.traininghistory.TrainingHistory

Execute basic NCA training loop with a single function call.

Parameters:
  • [DataLoader] (dataloader_val) – Training DataLoader

  • [DataLoader] – Validation DataLoader

  • [int] (save_every) – How often to save model state (in epochs). Useful for very small datasets, like growing lizard.

:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model’s default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.

Returns [TrainingHistory]:

TrainingHistory object.

class ncalab.search.ParameterSet(**kwargs)
params
mutable
combinations = []
index = 0
is_mutable(key)
info()
next() Dict[str, Any]
num_combinations()
__len__()
__next__()
__iter__()
class ncalab.search.ParameterSearch(device, model_class, model_params: ParameterSet, trainer_params: ParameterSet)
device
model_class
model_params
trainer_params
info() str

Generate information string with a summary of the search to run.

search(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None)

Run search.

Parameters:
  • [DataLoader] (dataloader_val) – Training DataLoader.

  • [DataLoader] – Validation DataLoader. Defaults to None.

__call__(*args, **kwargs)

Shorthand for running the search.