diff --git a/doctr/datasets/loader.py b/doctr/datasets/loader.py index 75799761fe..e55d902890 100644 --- a/doctr/datasets/loader.py +++ b/doctr/datasets/loader.py @@ -46,7 +46,7 @@ class DataLoader: shuffle: whether the samples should be shuffled before passing it to the iterator batch_size: number of elements in each batch drop_last: if `True`, drops the last batch if it isn't full - workers: number of workers to use for data loading + num_workers: number of workers to use for data loading """ def __init__( @@ -55,7 +55,7 @@ def __init__( shuffle: bool = True, batch_size: int = 1, drop_last: bool = False, - workers: Optional[int] = None, + num_workers: Optional[int] = None, collate_fn: Optional[Callable] = None, ) -> None: self.dataset = dataset @@ -67,7 +67,7 @@ def __init__( self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else default_collate else: self.collate_fn = collate_fn - self.workers = workers + self.num_workers = num_workers self.reset() def __len__(self) -> int: @@ -90,7 +90,7 @@ def __next__(self): idx = self._num_yielded * self.batch_size indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)] - samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.workers) + samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers) batch_data = self.collate_fn(samples) diff --git a/references/classification/train_tensorflow.py b/references/classification/train_tensorflow.py index 8849b6aad4..3848ff2e8d 100644 --- a/references/classification/train_tensorflow.py +++ b/references/classification/train_tensorflow.py @@ -102,7 +102,7 @@ def main(args): batch_size=args.batch_size, shuffle=False, drop_last=False, - workers=args.workers, + num_workers=args.workers, collate_fn=collate_fn, ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " @@ -153,7 +153,7 @@ def main(args): batch_size=args.batch_size, shuffle=True, drop_last=True, - workers=args.workers, + num_workers=args.workers, collate_fn=collate_fn, ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 4cec38a809..32d4094fe0 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -87,7 +87,13 @@ def main(args): label_path=os.path.join(args.val_path, 'labels.json'), sample_transforms=T.Resize((args.input_size, args.input_size)), ) - val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, workers=args.workers) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=args.workers, + ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{val_loader.num_batches} batches)") with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: @@ -133,7 +139,13 @@ def main(args): T.RandomBrightness(.3), ]), ) - train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.workers, + ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{train_loader.num_batches} batches)") with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f: diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index b4703ff852..c2355b58d2 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -90,7 +90,13 @@ def main(args): labels_path=os.path.join(args.val_path, 'labels.json'), sample_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) - val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, workers=args.workers) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=args.workers, + ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{val_loader.num_batches} batches)") with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: @@ -144,7 +150,13 @@ def main(args): for subfolder in parts[1:]: train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json'))) - train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.workers, + ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{train_loader.num_batches} batches)") with open(parts[0].joinpath('labels.json'), 'rb') as f: