Source code for conex.helpers.data

from torch.utils.data import Dataset


[docs]class LocationDataset(Dataset): """ A custom dataset class for data with triple image, location, label nature. Args: dataset (Dataset): An Instance of a dataset pre_transform (Transform): A transformation to apply on images. If given, transformation should return a `(image, location)` tuple. post_transform (Transform): A Transformation that applies on images. Suitable for encodings. location_transform (Transform): A Transformation applies on location data. Suitable for encodings. target_transform (Transform): A Transformation applies on labels. """ def __init__( self, dataset, pre_transform=None, post_transform=None, location_transform=None, target_transform=None, ): self.dataset = dataset self.pre_transform = pre_transform self.post_transform = post_transform self.location_transform = location_transform self.target_transform = target_transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): image, label = self.dataset[idx] location = None if self.pre_transform: image, location = self.pre_transform(image) if self.post_transform: image = self.post_transform(image) if self.location_transform: location = self.location_transform(location) if self.target_transform: label = self.target_transform(label) return image, location, label