ncalab.training.distillation_trainer
Classes
Module Contents
- class ncalab.training.distillation_trainer.DistillationNCATrainer(nca: ncalab.models.basicNCA.BasicNCAModel, teacher: torch.nn.Module, model_path: pathlib.Path | pathlib.PosixPath, gradient_clipping: bool = False, steps_range: tuple = (90, 110), steps_validation: int = 100, lr: float | None = None, lr_gamma: float = 0.99, adam_betas=(0.4, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adamw', pool: ncalab.training.pool.Pool | None = None)
- Parameters:
nca (ncalab.BasicNCAModel) – NCA model instance to train.
model_path (Path | PosixPath, optional) – Path to saved models. If None, models are not saved, defaults to None.
gradient_clipping (bool, optional) – Whether to clip gradients, defaults to False.
steps_range (tuple, optional) – Inclusive range of NCA time steps, randomized in each forward pass, defaults to (90, 110).
steps_validation (int, optional) – Number of steps to use during validation, defaults to 100.
lr (float, optional) – Initial learning rate, defaults to 16e-4.
lr_gamma (float, optional) – Exponential learning rate decay, defaults to 0.9999.
adam_betas (tuple, optional) – Beta values for Adam optimizer, defaults to (0.9, 0.95).
batch_repeat – How often each batch will be duplicated, dfaults to 2.
max_epochs – Maximum number of epochs in training, defaults to 200.
optimizer_method (str, optional) – Optimization method, defaults to ‘adam’.
pool (ncalab.Pool) – Sample pool object.
- nca
- teacher
- model_path
- gradient_clipping = False
- steps_range = (90, 110)
- steps_validation = 100
- lr_gamma = 0.99
- adam_betas = (0.4, 0.95)
- batch_repeat = 2
- max_epochs = 200
- optimizer_method = 'adamw'
- pool = None
- info() str
Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes.
- Returns [str]:
Markdown-formatted info string.
- _train_iteration(x: torch.Tensor, y: torch.Tensor, steps: int, optimizer: torch.optim.Optimizer, total_batch_iterations: int, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None) Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]]
Run a single training iteration.
- Parameters:
x – Input training images.
y – Input training labels.
steps – Number of NCA inference time steps.
optimizer – Optimizer.
total_batch_iterations (int) – Total training batch iterations
summary_writer (SummaryWriter, optional) – Tensorboard SummaryWriter
- Returns:
Predicted image.
- Return type:
Tuple[Prediction, Dict[str, torch.Tensor]]
- train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader | None = None, dataloader_test: torch.utils.data.DataLoader | None = None, save_every: int | None = None, summary_writer: torch.utils.tensorboard.SummaryWriter | None = None, plot_function: ncalab.visualization.Visual | None = None, earlystopping: ncalab.training.earlystopping.EarlyStopping | None = None) ncalab.training.traininghistory.TrainingHistory
Execute basic NCA training loop with a single function call.
- Parameters:
[DataLoader] (dataloader_val) – Training DataLoader
[DataLoader] – Validation DataLoader
[int] (save_every) – How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model’s default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
- Returns [TrainingHistory]:
TrainingHistory object.