[docs]classAttentionModule(nn.Module):"""Squeeze and excitation block that acts as the attention module in SSPCAB. Args: channels (int): Number of input channels. reduction_ratio (int): Reduction ratio of the attention module. """def__init__(self,in_channels:int,reduction_ratio:int=8):super().__init__()out_channels=in_channels//reduction_ratioself.fc1=nn.Linear(in_channels,out_channels)self.fc2=nn.Linear(out_channels,in_channels)
[docs]defforward(self,inputs:Tensor)->Tensor:"""Forward pass through the attention module."""# reduce feature map to 1d vector through global average poolingavg_pooled=inputs.mean(dim=(2,3))# squeeze and exciteact=self.fc1(avg_pooled)act=F.relu(act)act=self.fc2(act)act=F.sigmoid(act)# multiply with inputse_out=inputs*act.view(act.shape[0],act.shape[1],1,1)returnse_out
[docs]classSSPCAB(nn.Module):"""SSPCAB block. Args: in_channels (int): Number of input channels. kernel_size (int): Size of the receptive fields of the masked convolution kernel. dilation (int): Dilation factor of the masked convolution kernel. reduction_ratio (int): Reduction ratio of the attention module. """def__init__(self,in_channels:int,kernel_size:int=1,dilation:int=1,reduction_ratio:int=8):super().__init__()self.pad=kernel_size+dilationself.crop=2*(kernel_size+dilation)self.masked_conv1=nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size)self.masked_conv2=nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size)self.masked_conv3=nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size)self.masked_conv4=nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size)self.attention_module=AttentionModule(in_channels=in_channels,reduction_ratio=reduction_ratio)
[docs]defforward(self,inputs:Tensor)->Tensor:"""Forward pass through the SSPCAB block."""# compute masked convolutionpadded=F.pad(inputs,(self.pad,)*4)masked_out=torch.zeros_like(inputs)masked_out+=self.masked_conv1(padded[...,:-self.crop,:-self.crop])masked_out+=self.masked_conv2(padded[...,:-self.crop,self.crop:])masked_out+=self.masked_conv3(padded[...,self.crop:,:-self.crop])masked_out+=self.masked_conv4(padded[...,self.crop:,self.crop:])# apply channel attention modulesspcab_out=self.attention_module(masked_out)returnsspcab_out