Skip to content

zamba.models.efficientnet_models

Classes

TimeDistributedEfficientNet

Bases: ZambaVideoClassificationLightningModule

Source code in zamba/models/efficientnet_models.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@register_model
class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule):
    _default_model_name = (
        "time_distributed"  # used to look up default configuration for checkpoints
    )

    def __init__(
        self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs
    ):

        super().__init__(**kwargs)

        if finetune_from is None:
            efficientnet = timm.create_model("efficientnetv2_rw_m", pretrained=True)
            efficientnet.classifier = nn.Identity()
        else:
            efficientnet = self.load_from_checkpoint(finetune_from).base.module

        # freeze base layers
        for param in efficientnet.parameters():
            param.requires_grad = False

        num_backbone_final_features = efficientnet.num_features

        self.backbone = torch.nn.ModuleList(
            [
                efficientnet.get_submodule("blocks.5"),
                efficientnet.conv_head,
                efficientnet.bn2,
                efficientnet.global_pool,
            ]
        )

        self.base = TimeDistributed(efficientnet, tdim=1)
        self.classifier = nn.Sequential(
            nn.Linear(num_backbone_final_features, 256),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.Flatten(),
            nn.Linear(64 * num_frames, self.num_classes),
        )

        self.save_hyperparameters("num_frames")

    def forward(self, x):
        self.base.eval()
        x = self.base(x)
        return self.classifier(x)

Attributes

backbone = torch.nn.ModuleList([efficientnet.get_submodule('blocks.5'), efficientnet.conv_head, efficientnet.bn2, efficientnet.global_pool]) instance-attribute
base = TimeDistributed(efficientnet, tdim=1) instance-attribute
classifier = nn.Sequential(nn.Linear(num_backbone_final_features, 256), nn.Dropout(0.2), nn.ReLU(), nn.Linear(256, 64), nn.Flatten(), nn.Linear(64 * num_frames, self.num_classes)) instance-attribute

Functions

__init__(num_frames = 16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs)
Source code in zamba/models/efficientnet_models.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs
):

    super().__init__(**kwargs)

    if finetune_from is None:
        efficientnet = timm.create_model("efficientnetv2_rw_m", pretrained=True)
        efficientnet.classifier = nn.Identity()
    else:
        efficientnet = self.load_from_checkpoint(finetune_from).base.module

    # freeze base layers
    for param in efficientnet.parameters():
        param.requires_grad = False

    num_backbone_final_features = efficientnet.num_features

    self.backbone = torch.nn.ModuleList(
        [
            efficientnet.get_submodule("blocks.5"),
            efficientnet.conv_head,
            efficientnet.bn2,
            efficientnet.global_pool,
        ]
    )

    self.base = TimeDistributed(efficientnet, tdim=1)
    self.classifier = nn.Sequential(
        nn.Linear(num_backbone_final_features, 256),
        nn.Dropout(0.2),
        nn.ReLU(),
        nn.Linear(256, 64),
        nn.Flatten(),
        nn.Linear(64 * num_frames, self.num_classes),
    )

    self.save_hyperparameters("num_frames")
forward(x)
Source code in zamba/models/efficientnet_models.py
58
59
60
61
def forward(self, x):
    self.base.eval()
    x = self.base(x)
    return self.classifier(x)

Functions