Source code for anomalib.data

"""Anomalib Datasets."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
from typing import Union

from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule

from .btech import BTech
from .folder import Folder
from .inference import InferenceDataset
from .mvtec import MVTec

logger = logging.getLogger(__name__)


[docs]def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule: """Get Anomaly Datamodule. Args: config (Union[DictConfig, ListConfig]): Configuration of the anomaly model. Returns: PyTorch Lightning DataModule """ logger.info("Loading the datamodule") datamodule: LightningDataModule if config.dataset.format.lower() == "mvtec": datamodule = MVTec( # TODO: Remove config values. IAAALD-211 root=config.dataset.path, category=config.dataset.category, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), train_batch_size=config.dataset.train_batch_size, test_batch_size=config.dataset.test_batch_size, num_workers=config.dataset.num_workers, seed=config.project.seed, task=config.dataset.task, transform_config_train=config.dataset.transform_config.train, transform_config_val=config.dataset.transform_config.val, create_validation_set=config.dataset.create_validation_set, ) elif config.dataset.format.lower() == "btech": datamodule = BTech( # TODO: Remove config values. IAAALD-211 root=config.dataset.path, category=config.dataset.category, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), train_batch_size=config.dataset.train_batch_size, test_batch_size=config.dataset.test_batch_size, num_workers=config.dataset.num_workers, seed=config.project.seed, task=config.dataset.task, transform_config_train=config.dataset.transform_config.train, transform_config_val=config.dataset.transform_config.val, create_validation_set=config.dataset.create_validation_set, ) elif config.dataset.format.lower() == "folder": datamodule = Folder( root=config.dataset.path, normal_dir=config.dataset.normal_dir, abnormal_dir=config.dataset.abnormal_dir, task=config.dataset.task, normal_test_dir=config.dataset.normal_test_dir, mask_dir=config.dataset.mask, extensions=config.dataset.extensions, split_ratio=config.dataset.split_ratio, seed=config.project.seed, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), train_batch_size=config.dataset.train_batch_size, test_batch_size=config.dataset.test_batch_size, num_workers=config.dataset.num_workers, transform_config_train=config.dataset.transform_config.train, transform_config_val=config.dataset.transform_config.val, create_validation_set=config.dataset.create_validation_set, ) else: raise ValueError( "Unknown dataset! \n" "If you use a custom dataset make sure you initialize it in" "`get_datamodule` in `anomalib.data.__init__.py" ) return datamodule
__all__ = [ "get_datamodule", "BTech", "Folder", "InferenceDataset", "MVTec", ]