Skip to content

zamba.pytorch.finetuning

Classes

BackboneFinetuning (BackboneFinetuning)

Derived from PTL's built-in BackboneFinetuning, but during the backbone freeze phase, choose whether to freeze batch norm layers, even if train_bn is True (i.e., even if we train them during the backbone unfreeze phase).

Finetune a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters:

Name Type Description Default
unfreeze_backbone_at_epoch

Epoch at which the backbone will be unfreezed.

required
lambda_func

Scheduling function for increasing backbone learning rate.

required
backbone_initial_ratio_lr

Used to scale down the backbone learning rate compared to rest of model

required
backbone_initial_lr

Optional, Inital learning rate for the backbone. By default, we will use current_learning / backbone_initial_ratio_lr

required
should_align

Wheter to align with current learning rate when backbone learning reaches it.

required
initial_denom_lr

When unfreezing the backbone, the intial learning rate will current_learning_rate / initial_denom_lr.

required
train_bn

Wheter to make Batch Normalization trainable.

required
verbose

Display current learning rate for model and backbone

required
round

Precision for displaying learning rate

required

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])

Attributes

state_key: str inherited property readonly

Identifier for the state of the callback.

Used to store and retrieve a callback's state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

Methods

__init__(self, *args, *, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs) special
Source code in zamba/pytorch/finetuning.py
def __init__(
    self, *args, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs
):
    if multiplier is not None:
        kwargs["lambda_func"] = multiplier_factory(multiplier)
    super().__init__(*args, **kwargs)
    # choose whether to train batch norm layers prior to finetuning phase
    self.pre_train_bn = pre_train_bn
filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List inherited

This function is used to exclude any parameter which already exists in this optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer used for parameter exclusion

required
params Iterable

Iterable of parameters used to check against the provided optimizer

required

Returns:

Type Description
List

List of parameters not contained in this optimizer param groups

Source code in zamba/pytorch/finetuning.py
@staticmethod
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
    """This function is used to exclude any parameter which already exists in this optimizer.

    Args:
        optimizer: Optimizer used for parameter exclusion
        params: Iterable of parameters used to check against the provided optimizer

    Returns:
        List of parameters not contained in this optimizer param groups
    """
    out_params = []
    removed_params = []
    for param in params:
        if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]):
            out_params.append(param)
        else:
            removed_params.append(param)

    if removed_params:
        rank_zero_warn(
            "The provided params to be frozen already exist within another group of this optimizer."
            " Those parameters will be skipped.\n"
            "HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
            f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
            UserWarning,
        )
    return out_params
filter_params(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True) -> Generator inherited

Yields the requires_grad parameters of a given module or list of modules.

Parameters:

Name Type Description Default
modules Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]

A given module or an iterable of modules

required
train_bn bool

Whether to train BatchNorm module

True
requires_grad bool

Whether to create a generator for trainable or non-trainable parameters.

True

Returns:

Type Description
Generator

Generator

Source code in zamba/pytorch/finetuning.py
@staticmethod
def filter_params(
    modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True
) -> Generator:
    """Yields the `requires_grad` parameters of a given module or list of modules.

    Args:
        modules: A given module or an iterable of modules
        train_bn: Whether to train BatchNorm module
        requires_grad: Whether to create a generator for trainable or non-trainable parameters.
    Returns:
        Generator
    """
    modules = BaseFinetuning.flatten_modules(modules)
    for mod in modules:
        if isinstance(mod, _BatchNorm) and not train_bn:
            continue
        # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
        for param in mod.parameters(recurse=False):
            if param.requires_grad == requires_grad:
                yield param
finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int) -> None inherited

Called when the epoch begins.

Source code in zamba/pytorch/finetuning.py
def finetune_function(
    self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
    """Called when the epoch begins."""
    if epoch == self.unfreeze_backbone_at_epoch:
        current_lr = optimizer.param_groups[0]["lr"]
        initial_backbone_lr = (
            self.backbone_initial_lr
            if self.backbone_initial_lr is not None
            else current_lr * self.backbone_initial_ratio_lr
        )
        self.previous_backbone_lr = initial_backbone_lr
        self.unfreeze_and_add_param_group(
            pl_module.backbone,
            optimizer,
            initial_backbone_lr,
            train_bn=self.train_bn,
            initial_denom_lr=self.initial_denom_lr,
        )
        if self.verbose:
            log.info(
                f"Current lr: {round(current_lr, self.rounding)}, "
                f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
            )

    elif epoch > self.unfreeze_backbone_at_epoch:
        current_lr = optimizer.param_groups[0]["lr"]
        next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
        next_current_backbone_lr = (
            current_lr
            if (self.should_align and next_current_backbone_lr > current_lr)
            else next_current_backbone_lr
        )
        optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
        self.previous_backbone_lr = next_current_backbone_lr
        if self.verbose:
            log.info(
                f"Current lr: {round(current_lr, self.rounding)}, "
                f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
            )
flatten_modules(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]) -> List[torch.nn.modules.module.Module] inherited

This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.

Parameters:

Name Type Description Default
modules Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]

A given module or an iterable of modules

required

Returns:

Type Description
List[torch.nn.modules.module.Module]

List of modules

Source code in zamba/pytorch/finetuning.py
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
    """This function is used to flatten a module or an iterable of modules into a list of its leaf modules
    (modules with no children) and parent modules that have parameters directly themselves.

    Args:
        modules: A given module or an iterable of modules

    Returns:
        List of modules
    """
    if isinstance(modules, ModuleDict):
        modules = modules.values()

    if isinstance(modules, Iterable):
        _modules = []
        for m in modules:
            _modules.extend(BaseFinetuning.flatten_modules(m))

    else:
        _modules = modules.modules()

    # Capture all leaf modules as well as parent modules that have parameters directly themsleves
    return [m for m in _modules if not list(m.children()) or m._parameters]
freeze(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], train_bn: bool = True) -> None inherited

Freezes the parameters of the provided modules.

Parameters:

Name Type Description Default
modules Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]

A given module or an iterable of modules

required
train_bn bool

If True, leave the BatchNorm layers in training mode

True

Returns:

Type Description
None

None

Source code in zamba/pytorch/finetuning.py
@staticmethod
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
    """Freezes the parameters of the provided modules.

    Args:
        modules: A given module or an iterable of modules
        train_bn: If True, leave the BatchNorm layers in training mode

    Returns:
        None
    """
    modules = BaseFinetuning.flatten_modules(modules)
    for mod in modules:
        if isinstance(mod, _BatchNorm) and train_bn:
            BaseFinetuning.make_trainable(mod)
        else:
            # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
            for param in mod.parameters(recurse=False):
                param.requires_grad = False
freeze_before_training(self, pl_module: pl.LightningModule)

Override to add your freeze logic.

Source code in zamba/pytorch/finetuning.py
def freeze_before_training(self, pl_module: "pl.LightningModule"):
    self.freeze(pl_module.backbone, train_bn=self.pre_train_bn)
make_trainable(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]) -> None inherited

Unfreezes the parameters of the provided modules.

Parameters:

Name Type Description Default
modules Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]

A given module or an iterable of modules

required
Source code in zamba/pytorch/finetuning.py
@staticmethod
def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None:
    """Unfreezes the parameters of the provided modules.

    Args:
        modules: A given module or an iterable of modules
    """
    modules = BaseFinetuning.flatten_modules(modules)
    for module in modules:
        # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
        for param in module.parameters(recurse=False):
            param.requires_grad = True
on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called after loss.backward() and before optimizers are stepped.

Source code in zamba/pytorch/finetuning.py
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called after ``loss.backward()`` and before optimizers are stepped."""
    pass
on_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the training batch ends.

Source code in zamba/pytorch/finetuning.py
def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the training batch ends."""
    pass
on_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the training batch begins.

Source code in zamba/pytorch/finetuning.py
def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the training batch begins."""
    pass
on_before_accelerator_backend_setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule) inherited

Called before accelerator is being setup.

Source code in zamba/pytorch/finetuning.py
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
    self.freeze_before_training(pl_module)
on_before_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule, loss: Tensor) -> None inherited

Called before loss.backward().

Source code in zamba/pytorch/finetuning.py
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
    """Called before ``loss.backward()``."""
    pass
on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer, opt_idx: int) -> None inherited

Called before optimizer.step().

Source code in zamba/pytorch/finetuning.py
def on_before_optimizer_step(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
) -> None:
    """Called before ``optimizer.step()``."""
    pass
on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer) -> None inherited

Called before optimizer.zero_grad().

Source code in zamba/pytorch/finetuning.py
def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
    """Called before ``optimizer.zero_grad()``."""
    pass
on_configure_sharded_model(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called before configure sharded model.

Source code in zamba/pytorch/finetuning.py
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called before configure sharded model."""
on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when either of train/val/test epoch ends.

Source code in zamba/pytorch/finetuning.py
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when either of train/val/test epoch ends."""
    pass
on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when either of train/val/test epoch begins.

Source code in zamba/pytorch/finetuning.py
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when either of train/val/test epoch begins."""
    pass
on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None inherited

Called when any trainer execution is interrupted by an exception.

Source code in zamba/pytorch/finetuning.py
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
    """Called when any trainer execution is interrupted by an exception."""
    pass
on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when fit ends.

Source code in zamba/pytorch/finetuning.py
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when fit ends."""
    pass
on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited
Source code in zamba/pytorch/finetuning.py
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """
    Raises:
        MisconfigurationException:
            If LightningModule has no nn.Module `backbone` attribute.
    """
    if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
        return super().on_fit_start(trainer, pl_module)
    raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
on_init_end(self, trainer: pl.Trainer) -> None inherited

Called when the trainer initialization ends, model has not yet been set.

Source code in zamba/pytorch/finetuning.py
def on_init_end(self, trainer: "pl.Trainer") -> None:
    """Called when the trainer initialization ends, model has not yet been set."""
    pass
on_init_start(self, trainer: pl.Trainer) -> None inherited

Called when the trainer initialization begins, model has not yet been set.

Source code in zamba/pytorch/finetuning.py
def on_init_start(self, trainer: "pl.Trainer") -> None:
    """Called when the trainer initialization begins, model has not yet been set."""
    pass
on_keyboard_interrupt(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

.. deprecated:: v1.5 This callback hook was deprecated in v1.5 in favor of on_exception and will be removed in v1.7.

Called when any trainer execution is interrupted by KeyboardInterrupt.

Source code in zamba/pytorch/finetuning.py
def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    r"""
    .. deprecated:: v1.5
        This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.

    Called when any trainer execution is interrupted by KeyboardInterrupt.
    """
    pass
on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: Dict[int, List[Dict[str, Any]]]) -> None inherited

Called when loading a model checkpoint, use to reload state.

Parameters:

Name Type Description Default
trainer pl.Trainer

the current :class:~pytorch_lightning.trainer.Trainer instance.

required
pl_module pl.LightningModule

the current :class:~pytorch_lightning.core.lightning.LightningModule instance.

required
callback_state Dict[int, List[Dict[str, Any]]]

the callback state returned by on_save_checkpoint.

required

!!! note The on_load_checkpoint won't be called with an undefined state. If your on_load_checkpoint hook behavior doesn't rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

Source code in zamba/pytorch/finetuning.py
def on_load_checkpoint(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[int, List[Dict[str, Any]]]
) -> None:
    self.previous_backbone_lr = callback_state["previous_backbone_lr"]
    super().on_load_checkpoint(trainer, pl_module, callback_state["internal_optimizer_metadata"])
on_predict_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the predict batch ends.

Source code in zamba/pytorch/finetuning.py
def on_predict_batch_end(
    self,
    trainer: "pl.Trainer",
    pl_module: "pl.LightningModule",
    outputs: Any,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Called when the predict batch ends."""
    pass
on_predict_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the predict batch begins.

Source code in zamba/pytorch/finetuning.py
def on_predict_batch_start(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
    """Called when the predict batch begins."""
    pass
on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when predict ends.

Source code in zamba/pytorch/finetuning.py
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when predict ends."""
    pass
on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: List[Any]) -> None inherited

Called when the predict epoch ends.

Source code in zamba/pytorch/finetuning.py
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None:
    """Called when the predict epoch ends."""
    pass
on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the predict epoch begins.

Source code in zamba/pytorch/finetuning.py
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the predict epoch begins."""
    pass
on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the predict begins.

Source code in zamba/pytorch/finetuning.py
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the predict begins."""
    pass
on_pretrain_routine_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the pretrain routine ends.

Source code in zamba/pytorch/finetuning.py
def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the pretrain routine ends."""
    pass
on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the pretrain routine begins.

Source code in zamba/pytorch/finetuning.py
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the pretrain routine begins."""
    pass
on_sanity_check_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the validation sanity check ends.

Source code in zamba/pytorch/finetuning.py
def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the validation sanity check ends."""
    pass
on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the validation sanity check starts.

Source code in zamba/pytorch/finetuning.py
def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the validation sanity check starts."""
    pass
on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]) -> Dict[str, Any] inherited

Called when saving a model checkpoint, use to persist state.

Parameters:

Name Type Description Default
trainer pl.Trainer

the current :class:~pytorch_lightning.trainer.Trainer instance.

required
pl_module pl.LightningModule

the current :class:~pytorch_lightning.core.lightning.LightningModule instance.

required
checkpoint Dict[str, Any]

the checkpoint dictionary that will be saved.

required

Returns:

Type Description
Dict[str, Any]

The callback state.

Source code in zamba/pytorch/finetuning.py
def on_save_checkpoint(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
    return {
        "internal_optimizer_metadata": self._internal_optimizer_metadata,
        "previous_backbone_lr": self.previous_backbone_lr,
    }
on_test_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the test batch ends.

Source code in zamba/pytorch/finetuning.py
def on_test_batch_end(
    self,
    trainer: "pl.Trainer",
    pl_module: "pl.LightningModule",
    outputs: Optional[STEP_OUTPUT],
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Called when the test batch ends."""
    pass
on_test_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the test batch begins.

Source code in zamba/pytorch/finetuning.py
def on_test_batch_start(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
    """Called when the test batch begins."""
    pass
on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the test ends.

Source code in zamba/pytorch/finetuning.py
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the test ends."""
    pass
on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the test epoch ends.

Source code in zamba/pytorch/finetuning.py
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the test epoch ends."""
    pass
on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the test epoch begins.

Source code in zamba/pytorch/finetuning.py
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the test epoch begins."""
    pass
on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the test begins.

Source code in zamba/pytorch/finetuning.py
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the test begins."""
    pass
on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None inherited

Called when the train batch ends.

Source code in zamba/pytorch/finetuning.py
def on_train_batch_end(
    self,
    trainer: "pl.Trainer",
    pl_module: "pl.LightningModule",
    outputs: STEP_OUTPUT,
    batch: Any,
    batch_idx: int,
    unused: Optional[int] = 0,
) -> None:
    """Called when the train batch ends."""
    pass
on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None inherited

Called when the train batch begins.

Source code in zamba/pytorch/finetuning.py
def on_train_batch_start(
    self,
    trainer: "pl.Trainer",
    pl_module: "pl.LightningModule",
    batch: Any,
    batch_idx: int,
    unused: Optional[int] = 0,
) -> None:
    """Called when the train batch begins."""
    pass
on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the train ends.

Source code in zamba/pytorch/finetuning.py
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the train ends."""
    pass
on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR
  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
Source code in zamba/pytorch/finetuning.py
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the train epoch ends.

    To access all batch outputs at the end of the epoch, either:

    1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
    2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
    """
    pass
on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the epoch begins.

Source code in zamba/pytorch/finetuning.py
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the epoch begins."""
    # import is here to avoid circular imports
    from pytorch_lightning.loops.utilities import _get_active_optimizers

    for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
        num_param_groups = len(optimizer.param_groups)
        self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
        current_param_groups = optimizer.param_groups
        self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the train begins.

Source code in zamba/pytorch/finetuning.py
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the train begins."""
    pass
on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Union[torch.Tensor, Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the validation batch ends.

Source code in zamba/pytorch/finetuning.py
def on_validation_batch_end(
    self,
    trainer: "pl.Trainer",
    pl_module: "pl.LightningModule",
    outputs: Optional[STEP_OUTPUT],
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Called when the validation batch ends."""
    pass
on_validation_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int) -> None inherited

Called when the validation batch begins.

Source code in zamba/pytorch/finetuning.py
def on_validation_batch_start(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
    """Called when the validation batch begins."""
    pass
on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the validation loop ends.

Source code in zamba/pytorch/finetuning.py
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the validation loop ends."""
    pass
on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the val epoch ends.

Source code in zamba/pytorch/finetuning.py
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the val epoch ends."""
    pass
on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the val epoch begins.

Source code in zamba/pytorch/finetuning.py
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the val epoch begins."""
    pass
on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None inherited

Called when the validation loop begins.

Source code in zamba/pytorch/finetuning.py
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    """Called when the validation loop begins."""
    pass
setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None inherited

Called when fit, validate, test, predict, or tune begins.

Source code in zamba/pytorch/finetuning.py
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
    """Called when fit, validate, test, predict, or tune begins."""
    pass
teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None inherited

Called when fit, validate, test, predict, or tune ends.

Source code in zamba/pytorch/finetuning.py
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
    """Called when fit, validate, test, predict, or tune ends."""
    pass
unfreeze_and_add_param_group(modules: Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]], optimizer: Optimizer, lr: Optional[float] = None, initial_denom_lr: float = 10.0, train_bn: bool = True) -> None inherited

Unfreezes a module and adds its parameters to an optimizer.

Parameters:

Name Type Description Default
modules Union[torch.nn.modules.module.Module, Iterable[Union[torch.nn.modules.module.Module, Iterable]]]

A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.

required
optimizer Optimizer

The provided optimizer will receive new parameters and will add them to add_param_group

required
lr Optional[float]

Learning rate for the new param group.

None
initial_denom_lr float

If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.

10.0
train_bn bool

Whether to train the BatchNormalization layers.

True
Source code in zamba/pytorch/finetuning.py
@staticmethod
def unfreeze_and_add_param_group(
    modules: Union[Module, Iterable[Union[Module, Iterable]]],
    optimizer: Optimizer,
    lr: Optional[float] = None,
    initial_denom_lr: float = 10.0,
    train_bn: bool = True,
) -> None:
    """Unfreezes a module and adds its parameters to an optimizer.

    Args:
        modules: A module or iterable of modules to unfreeze.
            Their parameters will be added to an optimizer as a new param group.
        optimizer: The provided optimizer will receive new parameters and will add them to
            `add_param_group`
        lr: Learning rate for the new param group.
        initial_denom_lr: If no lr is provided, the learning from the first param group will be used
            and divided by `initial_denom_lr`.
        train_bn: Whether to train the BatchNormalization layers.
    """
    BaseFinetuning.make_trainable(modules)
    params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
    denom_lr = initial_denom_lr if lr is None else 1.0
    params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True)
    params = BaseFinetuning.filter_on_optimizer(optimizer, params)
    if params:
        optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})

Functions

multiplier_factory(rate: float)

Returns a function that returns a constant value for use in computing a constant learning rate multiplier.

Parameters:

Name Type Description Default
rate float

Constant multiplier.

required
Source code in zamba/pytorch/finetuning.py
def multiplier_factory(rate: float):
    """Returns a function that returns a constant value for use in computing a constant learning
    rate multiplier.

    Args:
        rate (float): Constant multiplier.
    """

    def multiplier(*args, **kwargs):
        return rate

    return multiplier