[docs]defsubnet_conv_func(kernel_size:int,hidden_ratio:float)->Callable:"""Subnet Convolutional Function. Callable class or function ``f``, called as ``f(channels_in, channels_out)`` and should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. Args: kernel_size (int): Kernel Size hidden_ratio (float): Hidden ratio to compute number of hidden channels. Returns: Callable: Sequential for the subnet constructor. """defsubnet_conv(in_channels:int,out_channels:int)->nn.Sequential:hidden_channels=int(in_channels*hidden_ratio)returnnn.Sequential(nn.Conv2d(in_channels,hidden_channels,kernel_size,padding="same"),nn.ReLU(),nn.Conv2d(hidden_channels,out_channels,kernel_size,padding="same"),)returnsubnet_conv
[docs]defcreate_fast_flow_block(input_dimensions:List[int],conv3x3_only:bool,hidden_ratio:float,flow_steps:int,clamp:float=2.0,)->SequenceINN:"""Create NF Fast Flow Block. This is to create Normalizing Flow (NF) Fast Flow model block based on Figure 2 and Section 3.3 in the paper. Args: input_dimensions (List[int]): Input dimensions (Channel, Height, Width) conv3x3_only (bool): Boolean whether to use conv3x3 only or conv3x3 and conv1x1. hidden_ratio (float): Ratio for the hidden layer channels. flow_steps (int): Flow steps. clamp (float, optional): Clamp. Defaults to 2.0. Returns: SequenceINN: FastFlow Block. """nodes=SequenceINN(*input_dimensions)foriinrange(flow_steps):ifi%2==1andnotconv3x3_only:kernel_size=1else:kernel_size=3nodes.append(AllInOneBlock,subnet_constructor=subnet_conv_func(kernel_size,hidden_ratio),affine_clamping=clamp,permute_soft=False,)returnnodes
[docs]classFastflowModel(nn.Module):"""FastFlow. Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows. Args: input_size (Tuple[int, int]): Model input size. backbone (str): Backbone CNN network flow_steps (int): Flow steps. conv3x3_only (bool, optinoal): Use only conv3x3 in fast_flow model. Defaults to False. hidden_ratio (float, optional): Ratio to calculate hidden var channels. Defaults to 1.0. Raises: ValueError: When the backbone is not supported. """def__init__(self,input_size:Tuple[int,int],backbone:str,flow_steps:int,conv3x3_only:bool=False,hidden_ratio:float=1.0,)->None:super().__init__()self.input_size=input_sizeifbackbonein["cait_m48_448","deit_base_distilled_patch16_384"]:self.feature_extractor=timm.create_model(backbone,pretrained=True)channels=[768]scales=[16]elifbackbonein["resnet18","wide_resnet50_2"]:self.feature_extractor=timm.create_model(backbone,pretrained=True,features_only=True,out_indices=[1,2,3],)channels=self.feature_extractor.feature_info.channels()scales=self.feature_extractor.feature_info.reduction()# for transformers, use their pretrained norm w/o grad# for resnets, self.norms are trainable LayerNormself.norms=nn.ModuleList()forchannel,scaleinzip(channels,scales):self.norms.append(nn.LayerNorm([channel,int(input_size[0]/scale),int(input_size[1]/scale)],elementwise_affine=True,))else:raiseValueError(f"Backbone {backbone} is not supported. List of available backbones are ""[cait_m48_448, deit_base_distilled_patch16_384, resnet18, wide_resnet50_2].")forparameterinself.feature_extractor.parameters():parameter.requires_grad=Falseself.fast_flow_blocks=nn.ModuleList()forchannel,scaleinzip(channels,scales):self.fast_flow_blocks.append(create_fast_flow_block(input_dimensions=[channel,int(input_size[0]/scale),int(input_size[1]/scale)],conv3x3_only=conv3x3_only,hidden_ratio=hidden_ratio,flow_steps=flow_steps,))self.anomaly_map_generator=AnomalyMapGenerator(input_size=input_size)
[docs]defforward(self,input_tensor:Tensor)->Union[Tuple[List[Tensor],List[Tensor]],Tensor]:"""Forward-Pass the input to the FastFlow Model. Args: input_tensor (Tensor): Input tensor. Returns: Union[Tuple[Tensor, Tensor], Tensor]: During training, return (hidden_variables, log-of-the-jacobian-determinants). During the validation/test, return the anomaly map. """return_val:Union[Tuple[List[Tensor],List[Tensor]],Tensor]self.feature_extractor.eval()ifisinstance(self.feature_extractor,VisionTransformer):features=self._get_vit_features(input_tensor)elifisinstance(self.feature_extractor,Cait):features=self._get_cait_features(input_tensor)else:features=self._get_cnn_features(input_tensor)# Compute the hidden variable f: X -> Z and log-likelihood of the jacobian# (See Section 3.3 in the paper.)# NOTE: output variable has z, and jacobian tuple for each fast-flow blocks.hidden_variables:List[Tensor]=[]log_jacobians:List[Tensor]=[]forfast_flow_block,featureinzip(self.fast_flow_blocks,features):hidden_variable,log_jacobian=fast_flow_block(feature)hidden_variables.append(hidden_variable)log_jacobians.append(log_jacobian)return_val=(hidden_variables,log_jacobians)ifnotself.training:return_val=self.anomaly_map_generator(hidden_variables)returnreturn_val
[docs]def_get_cnn_features(self,input_tensor:Tensor)->List[Tensor]:"""Get CNN-based features. Args: input_tensor (Tensor): Input Tensor. Returns: List[Tensor]: List of features. """features=self.feature_extractor(input_tensor)features=[self.norms[i](feature)fori,featureinenumerate(features)]returnfeatures
[docs]def_get_cait_features(self,input_tensor:Tensor)->List[Tensor]:"""Get Class-Attention-Image-Transformers (CaiT) features. Args: input_tensor (Tensor): Input Tensor. Returns: List[Tensor]: List of features. """feature=self.feature_extractor.patch_embed(input_tensor)feature=feature+self.feature_extractor.pos_embedfeature=self.feature_extractor.pos_drop(feature)foriinrange(41):# paper Table 6. Block Index = 40feature=self.feature_extractor.blocks[i](feature)batch_size,_,num_channels=feature.shapefeature=self.feature_extractor.norm(feature)feature=feature.permute(0,2,1)feature=feature.reshape(batch_size,num_channels,self.input_size[0]//16,self.input_size[1]//16)features=[feature]returnfeatures
[docs]def_get_vit_features(self,input_tensor:Tensor)->List[Tensor]:"""Get Vision Transformers (ViT) features. Args: input_tensor (Tensor): Input Tensor. Returns: List[Tensor]: List of features. """feature=self.feature_extractor.patch_embed(input_tensor)cls_token=self.feature_extractor.cls_token.expand(feature.shape[0],-1,-1)ifself.feature_extractor.dist_tokenisNone:feature=torch.cat((cls_token,feature),dim=1)else:feature=torch.cat((cls_token,self.feature_extractor.dist_token.expand(feature.shape[0],-1,-1),feature,),dim=1,)feature=self.feature_extractor.pos_drop(feature+self.feature_extractor.pos_embed)foriinrange(8):# paper Table 6. Block Index = 7feature=self.feature_extractor.blocks[i](feature)feature=self.feature_extractor.norm(feature)feature=feature[:,2:,:]batch_size,_,num_channels=feature.shapefeature=feature.permute(0,2,1)feature=feature.reshape(batch_size,num_channels,self.input_size[0]//16,self.input_size[1]//16)features=[feature]returnfeatures