-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,213 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# table-transformer baseline | ||
|
||
The original repository: https://github.com/microsoft/table-transformer | ||
|
||
We use the checkpoints and code from HuggingFace: | ||
https://huggingface.co/microsoft/table-transformer-detection | ||
|
||
As table-transformer is simply a DETR pretrained on [PubTables-1M](https://github.com/microsoft/table-transformer) and pretraining on additional document data is not allowed for DocILE, we start the training from a general [DETR checkpoint](https://huggingface.co/facebook/detr-resnet-50). | ||
|
||
## Running training and evaluation | ||
|
||
We use [pytorch-lightning](https://pytorch-lightning.readthedocs.io/) CLI for running the experiments: | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0 poetry run python main.py fit --config table_detection_config.yaml | ||
``` | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0 poetry run python main.py fit --config table_line_item_detection_config.yaml | ||
``` | ||
|
||
To get table bounding box predictions, run e.g. | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0 poetry run python predict.py --checkpoint_path /app/data/baselines/checkpoints/detr_table_20140.ckpt --prediction_type table_detection --dataset_path /app/data/docile/ --split val --output_json_path /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_detection.json | ||
``` | ||
|
||
and e.g. | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0 poetry run python predict.py --checkpoint_path /app/data/baselines/checkpoints/detr_LI_170100.ckpt --prediction_type table_line_item_detection --dataset_path /app/data/docile/ --split val --output_json_path /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_line_item_detection.json --table_detection_predictions_pickle /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_detection.pickle | ||
``` | ||
|
||
to get line item bounding boxes. These can then be passed to [../NER/docile_inference_NER_multilabel.py](../NER/docile_inference_NER_multilabel.py) as `--crop_bboxes_filename` and `--line_item_bboxes_filename`, respectively. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# pytorch_lightning==1.8.0.post1 | ||
trainer: | ||
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/ | ||
gradient_clip_val: 0.1 | ||
max_epochs: 1000 | ||
log_every_n_steps: 20 | ||
accelerator: gpu | ||
devices: 1 | ||
precision: 16 | ||
callbacks: | ||
- class_path: "pytorch_lightning.callbacks.LearningRateMonitor" | ||
init_args: | ||
logging_interval: step | ||
- class_path: "pytorch_lightning.callbacks.ModelCheckpoint" | ||
init_args: | ||
save_top_k: 1 | ||
save_last: true | ||
monitor: val_loss | ||
every_n_epochs: 1 | ||
model: | ||
lr: 3.0e-05 | ||
lr_backbone: 3.0e-07 | ||
weight_decay: 0.0001 | ||
threshold: 0.5 | ||
batch_size: 32 | ||
train_dataset_name: train | ||
val_dataset_name: val | ||
ckpt_path: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import table_transformer | ||
from pytorch_lightning.cli import LightningCLI | ||
|
||
|
||
def cli_main(): | ||
cli = LightningCLI(table_transformer.TableDetr) # noqa: F841 | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import argparse | ||
import json | ||
import logging | ||
import pickle | ||
from collections import defaultdict | ||
|
||
import table_transformer | ||
import torch | ||
from pytorch_lightning import Trainer | ||
from torch.utils.data import DataLoader | ||
|
||
import docile | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--checkpoint_path", required=True) | ||
parser.add_argument( | ||
"--prediction_type", | ||
choices=["table_detection", "table_line_item_detection"], | ||
required=True, | ||
) | ||
parser.add_argument("--dataset_path", required=True) | ||
parser.add_argument("--split", choices=["train", "val", "test"], required=True) | ||
parser.add_argument( | ||
"--output_json_path", | ||
required=True, | ||
help="A dictionary doc_id->page_n->predictions. With --prediction_type " | ||
"table_detection, a pickle file with a list of bboxes (one for each " | ||
"page) will be saved beside the json file.", | ||
) | ||
parser.add_argument("--batch_size", default=16) | ||
parser.add_argument("--num_workers", default=16) | ||
parser.add_argument( | ||
"--table_detection_predictions_pickle", | ||
default=None, | ||
help="Path to a file created by this script for a table detection model.", | ||
) | ||
args = parser.parse_args() | ||
|
||
if ( | ||
args.prediction_type == "table_line_item_detection" | ||
and args.table_detection_predictions_pickle is None | ||
): | ||
raise ValueError( | ||
"--prediction_type table_line_item_detection requires --table_detection_predictions_pickle" | ||
) | ||
|
||
if torch.cuda.is_available(): | ||
accelerator = "gpu" | ||
else: | ||
logging.warning("CUDA not available, predicting on CPU") | ||
accelerator = "cpu" | ||
|
||
table_detr = table_transformer.TableDetr.load_from_checkpoint(args.checkpoint_path) | ||
|
||
docile_dataset = docile.dataset.Dataset(args.split, args.dataset_path) | ||
|
||
crop_bboxes = None | ||
if args.table_detection_predictions_pickle is not None: | ||
with open(args.table_detection_predictions_pickle, "rb") as fin: | ||
crop_bboxes = pickle.load(fin) | ||
dataset_not_cropped = table_transformer.TableTransformerDataset( | ||
docile_dataset=docile_dataset, extractor=table_detr.extractor | ||
) | ||
|
||
dataset = table_transformer.TableTransformerDataset( | ||
docile_dataset=docile_dataset, extractor=table_detr.extractor, crop_bboxes=crop_bboxes | ||
) | ||
|
||
evaluator = Trainer(accelerator=accelerator, devices=1) | ||
_res = evaluator.predict( | ||
table_detr, | ||
DataLoader( | ||
dataset, | ||
collate_fn=table_detr.collate_fn, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers, | ||
), | ||
) | ||
res_detection = [] | ||
for r in _res: | ||
res_detection.extend(r) | ||
|
||
bbox_dict = defaultdict(lambda: defaultdict(list)) | ||
|
||
if args.prediction_type == "table_detection": | ||
res_detection = [list(a[0]) if len(a) else None for a in res_detection] | ||
with open(args.output_json_path.strip(".json") + ".pickle", "wb") as fout: | ||
pickle.dump(res_detection, fout) | ||
for ((doc_id, page_n), _), bbox in zip(dataset.coco_annots, res_detection): | ||
bbox_dict[doc_id][page_n] = [round(x) for x in bbox] if bbox else None | ||
|
||
elif args.prediction_type == "table_line_item_detection": | ||
assert len(res_detection) == len(dataset.crop_bboxes) | ||
for bboxes, crop_bbox in zip(res_detection, dataset.crop_bboxes): | ||
for bbox in bboxes: | ||
# left offset | ||
bbox[0] += crop_bbox[0] | ||
bbox[2] += crop_bbox[0] | ||
# top offset | ||
bbox[1] += crop_bbox[1] | ||
bbox[3] += crop_bbox[1] | ||
|
||
# output predictions for all pages, not only for those with predictions | ||
for (doc_id, page_n), _ in dataset_not_cropped.coco_annots: | ||
bbox_dict[doc_id][page_n] = [] | ||
|
||
for ((doc_id, page_n), _), bboxes in zip(dataset.coco_annots, res_detection): | ||
bbox_dict[doc_id][page_n] = [[round(x) for x in bbox] for bbox in bboxes] | ||
|
||
with open(args.output_json_path, "w") as fout: | ||
json.dump(bbox_dict, fout) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# pytorch_lightning==1.8.6 | ||
seed_everything: true | ||
trainer: | ||
logger: true | ||
enable_checkpointing: true | ||
callbacks: | ||
- class_path: pytorch_lightning.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
log_momentum: false | ||
- class_path: pytorch_lightning.callbacks.ModelCheckpoint | ||
init_args: | ||
dirpath: null | ||
filename: null | ||
monitor: val_loss | ||
verbose: false | ||
save_last: true | ||
save_top_k: 1 | ||
save_weights_only: false | ||
mode: min | ||
auto_insert_metric_name: true | ||
every_n_train_steps: null | ||
train_time_interval: null | ||
every_n_epochs: 1 | ||
save_on_train_epoch_end: null | ||
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/ | ||
gradient_clip_val: 0.1 | ||
gradient_clip_algorithm: null | ||
num_nodes: 1 | ||
num_processes: null | ||
devices: 1 | ||
gpus: null | ||
auto_select_gpus: false | ||
tpu_cores: null | ||
ipus: null | ||
enable_progress_bar: true | ||
overfit_batches: 0.0 | ||
track_grad_norm: -1 | ||
check_val_every_n_epoch: 1 | ||
fast_dev_run: false | ||
accumulate_grad_batches: null | ||
max_epochs: 1000 | ||
min_epochs: null | ||
max_steps: -1 | ||
min_steps: null | ||
max_time: null | ||
limit_train_batches: null | ||
limit_val_batches: null | ||
limit_test_batches: null | ||
limit_predict_batches: null | ||
val_check_interval: null | ||
log_every_n_steps: 20 | ||
accelerator: gpu | ||
strategy: null | ||
sync_batchnorm: false | ||
precision: 16 | ||
enable_model_summary: true | ||
num_sanity_val_steps: 2 | ||
resume_from_checkpoint: null | ||
profiler: null | ||
benchmark: null | ||
deterministic: null | ||
reload_dataloaders_every_n_epochs: 0 | ||
auto_lr_find: false | ||
replace_sampler_ddp: true | ||
detect_anomaly: false | ||
auto_scale_batch_size: false | ||
plugins: null | ||
amp_backend: native | ||
amp_level: null | ||
move_metrics_to_cpu: false | ||
multiple_trainloader_mode: max_size_cycle | ||
inference_mode: true | ||
model: | ||
description: Fine-tuning for table detection (full table bbox) | ||
train_dataset_name: train | ||
val_dataset_name: val | ||
task: table-detection | ||
initial_checkpoint: facebook/detr-resnet-50 | ||
dataset_path: /app/data/docile/ | ||
load_ground_truth_crop_bboxes: false | ||
predictions_root_dir: /app/data/baselines/line_item_detection/table_transformer/predictions/ | ||
crop_bboxes_filename: null | ||
lr: 3.0e-05 | ||
lr_backbone: 3.0e-07 | ||
weight_decay: 0.0001 | ||
batch_size: 32 | ||
num_workers: 16 | ||
threshold: 0.5 | ||
ckpt_path: null |
90 changes: 90 additions & 0 deletions
90
baselines/table-transformer/table_line_item_detection_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# pytorch_lightning==1.8.6 | ||
seed_everything: true | ||
trainer: | ||
logger: true | ||
enable_checkpointing: true | ||
callbacks: | ||
- class_path: pytorch_lightning.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
log_momentum: false | ||
- class_path: pytorch_lightning.callbacks.ModelCheckpoint | ||
init_args: | ||
dirpath: null | ||
filename: null | ||
monitor: val_loss | ||
verbose: false | ||
save_last: true | ||
save_top_k: 1 | ||
save_weights_only: false | ||
mode: min | ||
auto_insert_metric_name: true | ||
every_n_train_steps: null | ||
train_time_interval: null | ||
every_n_epochs: 1 | ||
save_on_train_epoch_end: null | ||
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/ | ||
gradient_clip_val: 0.1 | ||
gradient_clip_algorithm: null | ||
num_nodes: 1 | ||
num_processes: null | ||
devices: 1 | ||
gpus: null | ||
auto_select_gpus: false | ||
tpu_cores: null | ||
ipus: null | ||
enable_progress_bar: true | ||
overfit_batches: 0.0 | ||
track_grad_norm: -1 | ||
check_val_every_n_epoch: 1 | ||
fast_dev_run: false | ||
accumulate_grad_batches: null | ||
max_epochs: 1000 | ||
min_epochs: null | ||
max_steps: -1 | ||
min_steps: null | ||
max_time: null | ||
limit_train_batches: null | ||
limit_val_batches: null | ||
limit_test_batches: null | ||
limit_predict_batches: null | ||
val_check_interval: null | ||
log_every_n_steps: 20 | ||
accelerator: gpu | ||
strategy: null | ||
sync_batchnorm: false | ||
precision: 16 | ||
enable_model_summary: true | ||
num_sanity_val_steps: 2 | ||
resume_from_checkpoint: null | ||
profiler: null | ||
benchmark: null | ||
deterministic: null | ||
reload_dataloaders_every_n_epochs: 0 | ||
auto_lr_find: false | ||
replace_sampler_ddp: true | ||
detect_anomaly: false | ||
auto_scale_batch_size: false | ||
plugins: null | ||
amp_backend: native | ||
amp_level: null | ||
move_metrics_to_cpu: false | ||
multiple_trainloader_mode: max_size_cycle | ||
inference_mode: true | ||
model: | ||
description: Fine-tuning DETR for line item detection on GT cropped tables (full) | ||
train_dataset_name: train | ||
val_dataset_name: val | ||
task: table-line-item-detection | ||
initial_checkpoint: facebook/detr-resnet-50 | ||
dataset_path: /app/data/docile/ | ||
load_ground_truth_crop_bboxes: false | ||
predictions_root_dir: /app/data/baselines/line_item_detection/table_transformer/predictions/ | ||
crop_bboxes_filename: null | ||
lr: 3.0e-05 | ||
lr_backbone: 3.0e-07 | ||
weight_decay: 0.0001 | ||
batch_size: 32 | ||
num_workers: 16 | ||
threshold: 0.5 | ||
ckpt_path: null |
Oops, something went wrong.