Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[Feature request] Compatibility with iterable-style datasets #1237

Open
austinmw opened this issue Mar 16, 2022 · 6 comments
Open

[Feature request] Compatibility with iterable-style datasets #1237

austinmw opened this issue Mar 16, 2022 · 6 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed
Milestone

Comments

@austinmw
Copy link

austinmw commented Mar 16, 2022

馃殌 Feature

I'd like to be able to train iterable-style datasets instead of just map-style datasets.
(a map-style dataset in PyTorch has __getitem__ and __len__, whereas iterable-style datasets only have __iter__)

Motivation

Many image datasets in commercial use cases are very large, and therefore require iterable-style rather than map-style.
(Users may create custom iterable datasets, or use torchdata, webdataset, DALI, etc.)

Pitch

Vision tasks seem to require iterating over the entire dataset and building records prior to training (e.g. ObjectDetectionData). This does not make sense as a required step for large datasets. Say for example you want to compare models on a dataset of 10M images. Requiring iterating over this dataset for potentially several hours before training starts seems like an unnecessary and costly step. Users should be able to begin training online and have each sample from an iterable dataset provide the necessary information.

Lack of this capability in my opinion prevents adoption of vision tasks in this library on large scale image training in commercial settings.

Additional context

lightning-bolts object detectors seem to support this style of dataset already.

Links:
https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/
https://github.com/pytorch/data

@austinmw austinmw added enhancement New feature or request help wanted Extra attention is needed labels Mar 16, 2022
@austinmw
Copy link
Author

austinmw commented Mar 17, 2022

For a little more context, I'll paste below example code for a custom LightningDataModule. This datamodule uses DALI and webdataset format. It works fine using pl_bolts object detectors without modification to the dataloading and with minimal modification to training_step. I'd prefer to use flash detectors over bolts detectors since there's a larger selection though.

import os
import glob
import pickle
import numpy as np
import cv2
import torch
import torchvision.transforms as T
import warnings
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pytorch_lightning as pl
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from nvidia import dali
from nvidia.dali import pipeline_def, types, fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy


# Read label map (dict, like 1: person, 2: car, etc.)
with open('coco_idx2label', 'rb') as f:
    idx2label = pickle.load(f)

# Get urls (.tar file paths)
train_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'train*')))
val_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'val*')))
# For example:
# ['/home/ubuntu/data/coco_shards_dali/train-000000.tar',
#  '/home/ubuntu/data/coco_shards_dali/train-000001.tar',
#  ...
#  '/home/ubuntu/data/coco_shards_dali/train-000031.tar']



class DataModuleClass(pl.LightningDataModule):
    def __init__(self, 
                 idx2label, 
                 train_urls,
                 val_urls=None,
                 batch_size=16,
                 num_workers=os.cpu_count() // torch.cuda.device_count(),
                 mean=[103.530, 116.280, 123.675],
                 std=[57.375, 57.120, 58.395],
                 seed=42):

        #Define required parameters here
        self.idx2label = idx2label        
        self.train_urls = train_urls
        self.val_urls = val_urls
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mean = mean
        self.std = std
        self.seed = seed
        
        self.prepare_data_per_node = False
        self._log_hyperparams = False
    
    def prepare_data(self):
        # Define steps that should be done
        # on only one GPU, like getting data.
        pass
    
    def setup(self, stage=None):
        # Define steps that should be done on 
        # every GPU, like splitting data, applying
        # transform etc.        

        # Create train and val dataloaders
    
        if hasattr(self.trainer, 'local_rank'):
            device_id = self.trainer.local_rank
            shard_id = self.trainer.global_rank
            num_shards = self.trainer.world_size            
        else:
            warnings.warn('DataModule setup called before trainer init, using default device_id, shard_id, num_shards')
            device_id = 0
            shard_id = 0
            num_shards = 1
        
        train_pipe = self._wds_pipeline(urls=self.train_urls, 
                                        batch_size=self.batch_size,
                                        num_threads=self.num_workers,
                                        device='gpu',
                                        device_id=device_id, 
                                        shard_id=shard_id,
                                        num_shards=num_shards,
                                        random_shuffle=True,
                                        seed=self.seed,
                                        train=True)
                
        class LightningWrapper(DALIGenericIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)

            def __next__(self):
                item = super().__next__()
                images = item[0]['images']
                bboxes = item[0]['bboxes']
                labels = item[0]['labels']
                return {'images': images, 'bboxes': bboxes, 'labels': labels}
            
        self.train_loader = LightningWrapper(
            train_pipe,
            ['images', 'bboxes', 'labels'],
            reader_name='Reader',
            last_batch_policy=LastBatchPolicy.PARTIAL,
            auto_reset=True)
        
        if self.val_urls:
            val_pipe = self._wds_pipeline(urls=self.val_urls, 
                                            batch_size=self.batch_size,
                                            num_threads=self.num_workers,
                                            device='gpu',
                                            device_id=device_id, 
                                            shard_id=shard_id,
                                            num_shards=num_shards,
                                            random_shuffle=False,
                                            seed=self.seed,
                                            train=False)

            self.val_loader = LightningWrapper(
                val_pipe,
                ['images', 'bboxes', 'labels'],
                reader_name='Reader',
                last_batch_policy=LastBatchPolicy.PARTIAL,
                auto_reset=True)

    def train_dataloader(self):
        # Return DataLoader for Training Data here
        return self.train_loader

    def val_dataloader(self):
        # Return DataLoader for Validation Data here
        if self.val_urls is not None:
            return self.val_loader

    def _decode_augment(self, images, bboxes, labels, device, seed=0, fp16=True, train=True):
        bboxes = fn.reshape(bboxes, shape=[64,4])

        # Adjust boxes due to rounding issues with xyWH format    
        bboxes = dali.math.clamp(bboxes, lo=0.0, hi=1.0)
        xy = bboxes[:,0:2]
        wh = bboxes[:,2:4]
        wh -= dali.math.max(0.0, (xy+wh) - 1.0)
        bboxes = fn.cat(xy,wh, axis=1)
        
        if train:
            aspect_ratio = [0.5, 2.0]
            thresholds=[0, 0.1, 0.3, 0.5, 0.7, 0.9]
            scaling=[0.3, 1.0]
        else:
            aspect_ratio = [1.0, 1.0]            
            thresholds= [0.9]
            scaling = [1.0, 1.0]

        #input_shape = fn.slice(fn.cast(fn.peek_image_shape(images), dtype=types.INT32), 0, 2, axes=[0])
        crop_begin, crop_size, bboxes, labels = fn.random_bbox_crop(bboxes, labels,
                                                                    device='cpu',
                                                                    aspect_ratio=aspect_ratio,
                                                                    thresholds=thresholds,
                                                                    scaling=scaling,
                                                                    bbox_layout='xyWH',
                                                                    allow_no_crop=True,
                                                                    num_attempts=50)

        #images = fn.decoders.image(images, device='mixed', output_type=types.RGB)
        images = fn.decoders.image_slice(images, crop_begin, crop_size, 
                                         device='mixed' if device == 'gpu' else 'cpu',
                                         output_type=types.RGB)

        if train:
            flip_coin = fn.random.coin_flip(probability=0.5)
        else:
            flip_coin = fn.random.coin_flip(probability=0.0)

        images = fn.resize(images, resize_x=416, resize_y=416,
                           min_filter=types.DALIInterpType.INTERP_TRIANGULAR)

        if train:
            saturation = fn.random.uniform(range=[0.5, 1.5])
            contrast = fn.random.uniform(range=[0.5, 1.5])
            brightness = fn.random.uniform(range=[0.875, 1.125])
            hue = fn.random.uniform(range=[-0.5, 0.5])            

            images = fn.hsv(images, dtype=types.FLOAT, hue=hue, saturation=saturation)  # use float to avoid clipping and
                                                                 # quantizing the intermediate result
            images = fn.brightness_contrast(images,
                                            contrast_center = 128,  # input is in float, but in 0..255 range
                                            dtype = types.UINT8,
                                            brightness = brightness,
                                            contrast = contrast)

        dtype = types.FLOAT16 if fp16 else types.FLOAT

        bboxes = fn.bb_flip(bboxes, ltrb=False, horizontal=flip_coin)

        images = fn.crop_mirror_normalize(images,
                                          crop=(416, 416),
                                          mean=self.mean,
                                          std=self.std,
                                          mirror=flip_coin,
                                          dtype=dtype,
                                          output_layout='CHW',
                                          pad_output=False)
        # Un-normalize
        bboxes *= 416        
        
        # Pad
        bboxes = fn.pad(bboxes, fill_value=0.0, axes=(0,), shape=(64,))
        labels = fn.pad(labels, fill_value=0.0, axes=(0,), shape=(64,))

        if device == 'gpu':
            labels = labels.gpu()
            bboxes = bboxes.gpu()
            
        # Cast to int
        bboxes = fn.cast(bboxes, dtype=types.INT64)
        labels = fn.cast(labels, dtype=types.INT64)  

        return images, bboxes, labels

    @pipeline_def
    def _wds_pipeline(self, 
                      urls,
                      device,
                      shard_id=0,
                      num_shards=1,
                      random_shuffle=True,
                      train=True):
        images, bboxes, labels = fn.readers.webdataset(
            paths=urls,
            shard_id=shard_id, 
            num_shards=num_shards, 
            random_shuffle=random_shuffle,
            #device='mixed' if device == 'gpu' else 'cpu',
            ext=['jpg', 'bboxes', 'labels'],
            missing_component_behavior='error',
            dtypes=[types.UINT8, types.FLOAT, types.INT32],
            seed=self.seed,
            name='Reader')

        return self._decode_augment(images, bboxes=bboxes, labels=labels, device=device, seed=self.seed, train=train)


# intantiate the datamodule
datamodule = DataModuleClass(
    idx2label, 
    train_urls=train_dali_urls,
    val_urls=val_dali_urls, 
    batch_size=16,
)

# If you need information from the dataset to build your model, then run prepare_data() and setup() manually (Lightning ensures the method runs on the correct devices).
datamodule.prepare_data()
datamodule.setup(stage='fit')

@austinmw austinmw changed the title [Documentation/Feature request] Iterable dataset example [Feature request] Compatibility with iterable-style datasets Mar 17, 2022
@ethanwharris
Copy link
Collaborator

Hi @austinmw Thanks for your request! This is a current limitation of certain tasks in Flash where they cannot be directly used with your own datamodule because the model needs to provide the collate function for the data. IceVision models are slightly more complex again in that they need to provide the dataloader in full. I think it should be possible for us to find a workaround there as this would be a great use-case to support 馃槂

@ethanwharris ethanwharris added this to the 0.8.0 milestone Mar 21, 2022
@stale
Copy link

stale bot commented Jun 5, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 5, 2022
@austinmw
Copy link
Author

austinmw commented Jun 6, 2022

Not stale

@stale stale bot removed the won't fix This will not be worked on label Jun 6, 2022
@krshrimali krshrimali self-assigned this Jun 29, 2022
@ethanwharris
Copy link
Collaborator

ethanwharris commented Jun 29, 2022

Hey @austinmw just to give you an update. We have resolved in the framework most of the issues that are needed to support your use-case and now just need to document it properly and ship it in our upcoming 0.8 release. Can't give an exact timeline, but aiming for weeks rather than months. I'll come back here when I can give an updated code snippet to make this work 馃槂

@austinmw
Copy link
Author

Awesome news, can't wait to see, thanks!

@ethanwharris ethanwharris modified the milestones: 0.8.0, 0.9.0 Sep 1, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants