Skip to content

Commit

Permalink
refactor: Renamed DataLoader arg "workers" into "num_workers" (#737)
Browse files Browse the repository at this point in the history
* refactor: Renamed DataLoader arg

* refactor: Reflected changes
  • Loading branch information
fg-mindee committed Dec 22, 2021
1 parent 808081f commit dcaf0a3
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
8 changes: 4 additions & 4 deletions doctr/datasets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DataLoader:
shuffle: whether the samples should be shuffled before passing it to the iterator
batch_size: number of elements in each batch
drop_last: if `True`, drops the last batch if it isn't full
workers: number of workers to use for data loading
num_workers: number of workers to use for data loading
"""

def __init__(
Expand All @@ -55,7 +55,7 @@ def __init__(
shuffle: bool = True,
batch_size: int = 1,
drop_last: bool = False,
workers: Optional[int] = None,
num_workers: Optional[int] = None,
collate_fn: Optional[Callable] = None,
) -> None:
self.dataset = dataset
Expand All @@ -67,7 +67,7 @@ def __init__(
self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else default_collate
else:
self.collate_fn = collate_fn
self.workers = workers
self.num_workers = num_workers
self.reset()

def __len__(self) -> int:
Expand All @@ -90,7 +90,7 @@ def __next__(self):
idx = self._num_yielded * self.batch_size
indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)]

samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.workers)
samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers)

batch_data = self.collate_fn(samples)

Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main(args):
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
workers=args.workers,
num_workers=args.workers,
collate_fn=collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
Expand Down Expand Up @@ -153,7 +153,7 @@ def main(args):
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
workers=args.workers,
num_workers=args.workers,
collate_fn=collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
Expand Down
16 changes: 14 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ def main(args):
label_path=os.path.join(args.val_path, 'labels.json'),
sample_transforms=T.Resize((args.input_size, args.input_size)),
)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, workers=args.workers)
val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.workers,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{val_loader.num_batches} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
Expand Down Expand Up @@ -133,7 +139,13 @@ def main(args):
T.RandomBrightness(.3),
]),
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
num_workers=args.workers,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")
with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f:
Expand Down
16 changes: 14 additions & 2 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,13 @@ def main(args):
labels_path=os.path.join(args.val_path, 'labels.json'),
sample_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, workers=args.workers)
val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.workers,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{val_loader.num_batches} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
Expand Down Expand Up @@ -144,7 +150,13 @@ def main(args):
for subfolder in parts[1:]:
train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json')))

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, workers=args.workers)
train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
num_workers=args.workers,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{train_loader.num_batches} batches)")
with open(parts[0].joinpath('labels.json'), 'rb') as f:
Expand Down

0 comments on commit dcaf0a3

Please sign in to comment.