Source code for torchwebio.models.image.imageclassificationmodel

import urllib.request
from typing import List, Optional, Tuple

import timm
from PIL import Image
from timm.data import create_transform, resolve_data_config
from torch import Tensor, nn, no_grad, topk
from torchvision.transforms import Compose


[docs]class ImageClassificationModel: """ An adapter class to expose all Image classification functionality for models that follow the interface of pytorch-image-models (timm). In a nutshell, this class exposes the following functionality : pre-process, forward and post-process. The process method does all three at a go, and is pretty much the only method that needs to be called externally. Additionally, there are some helper static methods that are grouped here only for encapsulation purposes """ def __init__( self, model_name: Optional[str] = "", number_of_results: Optional[int] = 5, model: Optional[nn.Module] = None, category_list: Optional[str] = None, ): """Initialisation of an image classification model Parameters ---------- model_name : Optional[str], optional timm model string to load. Only required if the model parameter is not None, by default "" number_of_results : int, optional Number of results to return. K in topK if this is treated as a retrieval problem, by default 5 model : nn.Module, optional Pytorch image classification model. If this is passed, model_name is ignored, by default None category_list : str, optional URL to pull down the category list from, by default None """ self._number_of_results = number_of_results if not model: self._model = self._get_model(model_name) else: self._model = model self._transformation = self._prepare_transformation() if category_list: self._categories = self._get_categories(category_list) else: self._categories = self._get_categories() @staticmethod def _get_model(model_name: str) -> nn.Module: """Pull down model from TIMM's model hub, in eval mode Parameters ---------- model_name : str model_name as accepted by `timm.create_model(..)` Returns ------- nn.Module Pytorch model in evaluation mode """ model = timm.create_model(model_name, pretrained=True) model.eval() return model def _prepare_transformation(self) -> Compose: """Prepare transformations for the corresponding model Returns ------- Compose List of transformations to apply to the image, as a pytorch transformers compose object """ config = resolve_data_config({}, model=self._model) return create_transform(**config) @staticmethod def _get_categories( category_list="https://raw.githubusercontent.com/" "pytorch/hub/master/imagenet_classes.txt", ): """Get category lists from a remote URL. Parameters ---------- category_list : str, optional URL for caegory lists. Imagenet by default, by default "https://raw.githubusercontent.com/" "pytorch/hub/master/imagenet_classes.txt" Returns ------- List[str] Ordered list of categories where the index is the category index. """ filename, _ = urllib.request.urlretrieve(category_list) with open(filename, "r") as f: categories = [s.strip() for s in f.readlines()] return categories
[docs] def pre_process(self, img: Image) -> Tensor: """Preprocess the image. Basically processes the image through the transformations. Assumes batch size = 1 Parameters ---------- img : Image Image to pre-process. Returns ------- Tensor pre-processed image as a tensor """ return self._transformation(img).unsqueeze(0)
[docs] def post_process(self, probabilities: Tensor) -> List[Tuple[str, float]]: """Post process the soft-max output of the model. Extracts top-K results from the last layer, and maps them to the corresponding categories. Parameters ---------- probabilities : Tensor Tensor of probabilities (output of the network) Returns ------- List[Tuple[str, float]] List of pairs of Category label and category score """ top5_prob, top5_catid = topk(probabilities, self._number_of_results) return [ (self._categories[top5_catid[i]], top5_prob[i].item()) for i in range(top5_prob.size(0)) ]
[docs] def forward(self, img_tensor: Tensor) -> Tensor: """Forward pass of the pytorch model Parameters ---------- img_tensor : Tensor input image after pre-processing. Returns ------- Tensor Output probabilities from the network """ with no_grad(): out = self._model(img_tensor) return nn.functional.softmax(out[0], dim=0)
[docs] def process_img(self, img) -> List[Tuple[str, float]]: """Processes a single image. Calls pre_process, forward and post_process in succession Parameters ---------- img : PIL.Image Input image as a PIL object Returns ------- List[Tuple[str, float]] Output scores and categories """ img_tensor = self.pre_process(img) probabilities = self.forward(img_tensor) return self.post_process(probabilities)
# TODO: Verify whether this is required. MODEL_NAMES = [ {"name": "Adversarial Inception v3", "model_name": "adv_inception_v3"}, {"name": "AdvProp (EfficientNet)", "model_name": "tf_efficientnet_b0_ap"}, {"name": "Big Transfer (BiT)", "model_name": "resnetv2_101x1_bitm"}, {"name": "CSP-DarkNet", "model_name": "cspdarknet53"}, {"name": "CSP-ResNet", "model_name": "cspresnet50"}, {"name": "CSP-ResNeXt", "model_name": "cspresnext50"}, {"name": "DenseNet", "model_name": "densenet121"}, ]