ncalab.training.trainer
Classes
Trainer class for any model subclassing BasicNCA. |
Module Contents
- class ncalab.training.trainer.BasicNCATrainer(nca: ncalab.models.basicNCA.AbstractNCAModel, model_path: pathlib.Path | pathlib.PosixPath | None, gradient_clipping: bool = False, lr: float | None = None, lr_gamma: float = 0.99, adam_betas=(0.9, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adam', pool: ncalab.training.pool.Pool | None = None, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None)
Trainer class for any model subclassing BasicNCA.
- Parameters:
nca (ncalab.AbstractNCAModel) – NCA model instance to train.
model_path (Path | PosixPath, optional) – Path to saved models. If None, models are not saved, defaults to None.
gradient_clipping (bool, optional) – Whether to clip gradients, defaults to False.
lr (float, optional) – Initial learning rate, defaults to 16e-4.
lr_gamma (float, optional) – Exponential learning rate decay, defaults to 0.9999.
adam_betas (tuple, optional) – Beta values for Adam optimizer, defaults to (0.9, 0.95).
batch_repeat – How often each batch will be duplicated, dfaults to 2.
max_epochs – Maximum number of epochs in training, defaults to 200.
optimizer_method (str, optional) – Optimization method, defaults to ‘adam’.
pool (ncalab.Pool) – Sample pool object.
- nca
- model_path
- gradient_clipping = False
- lr_gamma = 0.99
- adam_betas = (0.9, 0.95)
- batch_repeat = 2
- max_epochs = 200
- optimizer_method = 'adam'
- pool = None
- lr_scheduler = None
- info() str
Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes.
- Returns [str]:
Markdown-formatted info string.
- _train_iteration(x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, head_optimizer: torch.optim.Optimizer | None, total_batch_iterations: int, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None) Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]]
Run a single training iteration.
- Parameters:
x – Input training images.
y – Input training labels.
steps – Number of NCA inference time steps.
optimizer – Optimizer.
total_batch_iterations (int) – Total training batch iterations
summary_writer (SummaryWriter, optional) – Tensorboard SummaryWriter
- Returns:
Predicted image.
- Return type:
Tuple[Prediction, Dict[str, torch.Tensor]]
- train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None, dataloader_test: torch.utils.data.DataLoader | None = None, save_every: int | None = None, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None, plot_function: ncalab.visualization.Visual | None = None, earlystopping: ncalab.training.earlystopping.EarlyStopping | None = None) ncalab.training.traininghistory.TrainingHistory
Execute basic NCA training loop with a single function call.
- Parameters:
[DataLoader] (dataloader_val) – Training DataLoader
[DataLoader] – Validation DataLoader
[int] (save_every) – How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model’s default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
- Returns [TrainingHistory]:
TrainingHistory object.