ncalab.search ============= .. py:module:: ncalab.search Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/ncalab/search/search/index Classes ------- .. autoapisummary:: ncalab.search.BasicNCATrainer ncalab.search.ParameterSet ncalab.search.ParameterSearch Package Contents ---------------- .. py:class:: BasicNCATrainer(nca: ncalab.models.basicNCA.BasicNCAModel, model_path: Optional[pathlib.Path | pathlib.PosixPath], gradient_clipping: bool = False, lr: Optional[float] = None, lr_gamma: float = 0.99, adam_betas=(0.9, 0.95), batch_repeat: int = 2, max_epochs: int = 200, optimizer_method: str = 'adam', pool: Optional[ncalab.training.pool.Pool] = None) Trainer class for any model subclassing BasicNCA. :param nca: NCA model instance to train. :type nca: ncalab.BasicNCAModel :param model_path: Path to saved models. If None, models are not saved, defaults to None. :type model_path: Path | PosixPath, optional :param gradient_clipping: Whether to clip gradients, defaults to False. :type gradient_clipping: bool, optional :param lr: Initial learning rate, defaults to 16e-4. :type lr: float, optional :param lr_gamma: Exponential learning rate decay, defaults to 0.9999. :type lr_gamma: float, optional :param adam_betas: Beta values for Adam optimizer, defaults to (0.9, 0.95). :type adam_betas: tuple, optional :param batch_repeat: How often each batch will be duplicated, dfaults to 2. :param max_epochs: Maximum number of epochs in training, defaults to 200. :param optimizer_method: Optimization method, defaults to 'adam'. :type optimizer_method: str, optional :param pool: Sample pool object. :type pool: ncalab.Pool .. py:attribute:: nca .. py:attribute:: model_path .. py:attribute:: gradient_clipping :value: False .. py:attribute:: lr_gamma :value: 0.99 .. py:attribute:: adam_betas :value: (0.9, 0.95) .. py:attribute:: batch_repeat :value: 2 .. py:attribute:: max_epochs :value: 200 .. py:attribute:: optimizer_method :value: 'adam' .. py:attribute:: pool :value: None .. py:method:: info() -> str Shows a markdown-formatted info string with training parameters. Useful for showing info on tensorboard to keep track of parameter changes. :returns [str]: Markdown-formatted info string. .. py:method:: _train_iteration(x: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, total_batch_iterations: int, summary_writer: Optional[torch.utils.tensorboard.SummaryWriter] = None) -> Tuple[ncalab.prediction.Prediction, Dict[str, torch.Tensor]] Run a single training iteration. :param x: Input training images. :param y: Input training labels. :param steps: Number of NCA inference time steps. :param optimizer: Optimizer. :param total_batch_iterations: Total training batch iterations :type total_batch_iterations: int :param summary_writer: Tensorboard SummaryWriter :type summary_writer: SummaryWriter, optional :returns: Predicted image. :rtype: Tuple[Prediction, Dict[str, torch.Tensor]] .. py:method:: train(dataloader_train: torch.utils.data.DataLoader, dataloader_val: Optional[torch.utils.data.DataLoader] = None, dataloader_test: Optional[torch.utils.data.DataLoader] = None, save_every: Optional[int] = None, summary_writer: Optional[torch.utils.tensorboard.SummaryWriter] = None, plot_function: Optional[ncalab.visualization.Visual] = None, earlystopping: Optional[ncalab.training.earlystopping.EarlyStopping] = None) -> ncalab.training.traininghistory.TrainingHistory Execute basic NCA training loop with a single function call. :param dataloader_train [DataLoader]: Training DataLoader :param dataloader_val [DataLoader]: Validation DataLoader :param save_every [int]: How often to save model state (in epochs). Useful for very small datasets, like growing lizard. :param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None. :param plot_function: Plot function override. If None, use model's default. Defaults to None. :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None. :returns [TrainingHistory]: TrainingHistory object. .. py:class:: ParameterSet(**kwargs) .. py:attribute:: params .. py:attribute:: mutable .. py:attribute:: combinations :value: [] .. py:attribute:: index :value: 0 .. py:method:: is_mutable(key) .. py:method:: info() .. py:method:: next() -> Dict[str, Any] .. py:method:: num_combinations() .. py:method:: __len__() .. py:method:: __next__() .. py:method:: __iter__() .. py:class:: ParameterSearch(device, model_class, model_params: ParameterSet, trainer_params: ParameterSet) .. py:attribute:: device .. py:attribute:: model_class .. py:attribute:: model_params .. py:attribute:: trainer_params .. py:method:: info() -> str Generate information string with a summary of the search to run. .. py:method:: search(dataloader_train: torch.utils.data.DataLoader, dataloader_val: Optional[torch.utils.data.DataLoader] = None) Run search. :param dataloader_train [DataLoader]: Training DataLoader. :param dataloader_val [DataLoader]: Validation DataLoader. Defaults to None. .. py:method:: __call__(*args, **kwargs) Shorthand for running the search.