Source code for anomalib.models.draem.torch_model

"""PyTorch model for the DRAEM model implementation."""

# Original Code
# Copyright (c) 2021 VitjanZ
# https://github.com/VitjanZ/DRAEM.
# SPDX-License-Identifier: MIT
#
# Modified
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Tuple, Union

import torch
from torch import Tensor, nn

from anomalib.models.components.layers import SSPCAB


[docs]class DraemModel(nn.Module): """DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks.""" def __init__(self, sspcab: bool = False): super().__init__() self.reconstructive_subnetwork = ReconstructiveSubNetwork(sspcab=sspcab) self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2)
[docs] def forward(self, batch: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the reconstruction and anomaly mask from an input image. Args: x (Tensor): batch of input images Returns: Predicted confidence values of the anomaly mask. During training the reconstructed input images are returned as well. """ reconstruction = self.reconstructive_subnetwork(batch) concatenated_inputs = torch.cat([batch, reconstruction], axis=1) prediction = self.discriminative_subnetwork(concatenated_inputs) if self.training: return reconstruction, prediction return torch.softmax(prediction, dim=1)[:, 1, ...]
[docs]class ReconstructiveSubNetwork(nn.Module): """Autoencoder model that encodes and reconstructs the input image. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. base_width (int): Base dimensionality of the layers of the autoencoder. """ def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128, sspcab: bool = False): super().__init__() self.encoder = EncoderReconstructive(in_channels, base_width, sspcab=sspcab) self.decoder = DecoderReconstructive(base_width, out_channels=out_channels)
[docs] def forward(self, batch: Tensor) -> Tensor: """Encode and reconstruct the input images. Args: batch (Tensor): Batch of input images Returns: Batch of reconstructed images. """ encoded = self.encoder(batch) decoded = self.decoder(encoded) return decoded
[docs]class DiscriminativeSubNetwork(nn.Module): """Discriminative model that predicts the anomaly mask from the original image and its reconstruction. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. base_width (int): Base dimensionality of the layers of the autoencoder. """ def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width: int = 64): super().__init__() self.encoder_segment = EncoderDiscriminative(in_channels, base_width) self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels)
[docs] def forward(self, batch: Tensor) -> Tensor: """Generate the predicted anomaly masks for a batch of input images. Args: batch (Tensor): Batch of inputs consisting of the concatenation of the original images and their reconstructions. Returns: Activations of the output layer corresponding to the normal and anomalous class scores on the pixel level. """ act1, act2, act3, act4, act5, act6 = self.encoder_segment(batch) segmentation = self.decoder_segment(act1, act2, act3, act4, act5, act6) return segmentation
[docs]class EncoderDiscriminative(nn.Module): """Encoder part of the discriminator network. Args: in_channels (int): Number of input channels. base_width (int): Base dimensionality of the layers of the autoencoder. """ def __init__(self, in_channels: int, base_width: int): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.mp1 = nn.Sequential(nn.MaxPool2d(2)) self.block2 = nn.Sequential( nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) self.mp2 = nn.Sequential(nn.MaxPool2d(2)) self.block3 = nn.Sequential( nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.mp3 = nn.Sequential(nn.MaxPool2d(2)) self.block4 = nn.Sequential( nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.mp4 = nn.Sequential(nn.MaxPool2d(2)) self.block5 = nn.Sequential( nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.mp5 = nn.Sequential(nn.MaxPool2d(2)) self.block6 = nn.Sequential( nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), )
[docs] def forward(self, batch: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Convert the inputs to the salient space by running them through the encoder network. Args: batch (Tensor): Batch of inputs consisting of the concatenation of the original images and their reconstructions. Returns: Computed feature maps for each of the layers in the encoder sub network. """ act1 = self.block1(batch) mp1 = self.mp1(act1) act2 = self.block2(mp1) mp2 = self.mp3(act2) act3 = self.block3(mp2) mp3 = self.mp3(act3) act4 = self.block4(mp3) mp4 = self.mp4(act4) act5 = self.block5(mp4) mp5 = self.mp5(act5) act6 = self.block6(mp5) return act1, act2, act3, act4, act5, act6
[docs]class DecoderDiscriminative(nn.Module): """Decoder part of the discriminator network. Args: base_width (int): Base dimensionality of the layers of the autoencoder. out_channels (int): Number of output channels. """ def __init__(self, base_width: int, out_channels: int = 1): super().__init__() self.up_b = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.db_b = nn.Sequential( nn.Conv2d(base_width * (8 + 8), base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.up1 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.db1 = nn.Sequential( nn.Conv2d(base_width * (4 + 8), base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.up2 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) self.db2 = nn.Sequential( nn.Conv2d(base_width * (2 + 4), base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) self.up3 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.db3 = nn.Sequential( nn.Conv2d(base_width * (2 + 1), base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.up4 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.db4 = nn.Sequential( nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1))
[docs] def forward(self, act1: Tensor, act2: Tensor, act3: Tensor, act4: Tensor, act5: Tensor, act6: Tensor) -> Tensor: """Computes predicted anomaly class scores from the intermediate outputs of the encoder sub network. Args: act1 (Tensor): Encoder activations of the first block of convolutional layers. act2 (Tensor): Encoder activations of the second block of convolutional layers. act3 (Tensor): Encoder activations of the third block of convolutional layers. act4 (Tensor): Encoder activations of the fourth block of convolutional layers. act5 (Tensor): Encoder activations of the fifth block of convolutional layers. act6 (Tensor): Encoder activations of the sixth block of convolutional layers. Returns: Predicted anomaly class scores per pixel. """ up_b = self.up_b(act6) cat_b = torch.cat((up_b, act5), dim=1) db_b = self.db_b(cat_b) up1 = self.up1(db_b) cat1 = torch.cat((up1, act4), dim=1) db1 = self.db1(cat1) up2 = self.up2(db1) cat2 = torch.cat((up2, act3), dim=1) db2 = self.db2(cat2) up3 = self.up3(db2) cat3 = torch.cat((up3, act2), dim=1) db3 = self.db3(cat3) up4 = self.up4(db3) cat4 = torch.cat((up4, act1), dim=1) db4 = self.db4(cat4) out = self.fin_out(db4) return out
[docs]class EncoderReconstructive(nn.Module): """Encoder part of the reconstructive network. Args: in_channels (int): Number of input channels. base_width (int): Base dimensionality of the layers of the autoencoder. """ def __init__(self, in_channels: int, base_width: int, sspcab: bool = False): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.mp1 = nn.Sequential(nn.MaxPool2d(2)) self.block2 = nn.Sequential( nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) self.mp2 = nn.Sequential(nn.MaxPool2d(2)) self.block3 = nn.Sequential( nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.mp3 = nn.Sequential(nn.MaxPool2d(2)) self.block4 = nn.Sequential( nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.mp4 = nn.Sequential(nn.MaxPool2d(2)) if sspcab: self.block5 = SSPCAB(base_width * 8) else: self.block5 = nn.Sequential( nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), )
[docs] def forward(self, batch: Tensor) -> Tensor: """Encode a batch of input images to the salient space. Args: batch (Tensor): Batch of input images. Returns: Feature maps extracted from the bottleneck layer. """ act1 = self.block1(batch) mp1 = self.mp1(act1) act2 = self.block2(mp1) mp2 = self.mp3(act2) act3 = self.block3(mp2) mp3 = self.mp3(act3) act4 = self.block4(mp3) mp4 = self.mp4(act4) act5 = self.block5(mp4) return act5
[docs]class DecoderReconstructive(nn.Module): """Decoder part of the reconstructive network. Args: base_width (int): Base dimensionality of the layers of the autoencoder. out_channels (int): Number of output channels. """ def __init__(self, base_width: int, out_channels: int = 1): super().__init__() self.up1 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), ) self.db1 = nn.Sequential( nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 8), nn.ReLU(inplace=True), nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.up2 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), ) self.db2 = nn.Sequential( nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 4), nn.ReLU(inplace=True), nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) self.up3 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), ) # cat with base*1 self.db3 = nn.Sequential( nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 2), nn.ReLU(inplace=True), nn.Conv2d(base_width * 2, base_width * 1, kernel_size=3, padding=1), nn.BatchNorm2d(base_width * 1), nn.ReLU(inplace=True), ) self.up4 = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.db4 = nn.Sequential( nn.Conv2d(base_width * 1, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), nn.BatchNorm2d(base_width), nn.ReLU(inplace=True), ) self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1))
[docs] def forward(self, act5: Tensor) -> Tensor: """Reconstruct the image from the activations of the bottleneck layer. Args: act5 (Tensor): Activations of the bottleneck layer. Returns: Batch of reconstructed images. """ up1 = self.up1(act5) db1 = self.db1(up1) up2 = self.up2(db1) db2 = self.db2(up2) up3 = self.up3(db2) db3 = self.db3(up3) up4 = self.up4(db3) db4 = self.db4(up4) out = self.fin_out(db4) return out