Skip to content

Commit

Permalink
Add table-transformer baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
Matěj Kocián authored and simsa-st committed Mar 15, 2023
1 parent fdd647a commit 9da2379
Show file tree
Hide file tree
Showing 10 changed files with 1,213 additions and 22 deletions.
2 changes: 1 addition & 1 deletion baselines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ We provide code to reproduce all results that can be found in the [paper](../REA
The code is structured into three subfolders:
* [NER](NER/) contains most of the baselines code, including training code for RoBERTa, LayoutLMv3 and RoBERTa pretraining, and the inference code.
* [layoutlmv3_pretraing](layoutlmv3_pretrain/) contains code for LayoutLMv3 pretraining.
* [table-transformer](table-transformer/) [coming soon] contains code for DETR used for table and Line Item detection.
* [table-transformer](table-transformer/) contains code for DETR used for table and Line Item detection.

## Results of the provided baselines

Expand Down
35 changes: 35 additions & 0 deletions baselines/table-transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# table-transformer baseline

The original repository: https://github.com/microsoft/table-transformer

We use the checkpoints and code from HuggingFace:
https://huggingface.co/microsoft/table-transformer-detection

As table-transformer is simply a DETR pretrained on [PubTables-1M](https://github.com/microsoft/table-transformer) and pretraining on additional document data is not allowed for DocILE, we start the training from a general [DETR checkpoint](https://huggingface.co/facebook/detr-resnet-50).

## Running training and evaluation

We use [pytorch-lightning](https://pytorch-lightning.readthedocs.io/) CLI for running the experiments:

```
CUDA_VISIBLE_DEVICES=0 poetry run python main.py fit --config table_detection_config.yaml
```

```
CUDA_VISIBLE_DEVICES=0 poetry run python main.py fit --config table_line_item_detection_config.yaml
```

To get table bounding box predictions, run e.g.

```
CUDA_VISIBLE_DEVICES=0 poetry run python predict.py --checkpoint_path /app/data/baselines/checkpoints/detr_table_20140.ckpt --prediction_type table_detection --dataset_path /app/data/docile/ --split val --output_json_path /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_detection.json
```

and e.g.

```
CUDA_VISIBLE_DEVICES=0 poetry run python predict.py --checkpoint_path /app/data/baselines/checkpoints/detr_LI_170100.ckpt --prediction_type table_line_item_detection --dataset_path /app/data/docile/ --split val --output_json_path /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_line_item_detection.json --table_detection_predictions_pickle /app/data/baselines/line_item_detection/table_transformer/predictions/val/detr_table_detection.pickle
```

to get line item bounding boxes. These can then be passed to [../NER/docile_inference_NER_multilabel.py](../NER/docile_inference_NER_multilabel.py) as `--crop_bboxes_filename` and `--line_item_bboxes_filename`, respectively.
28 changes: 28 additions & 0 deletions baselines/table-transformer/base_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# pytorch_lightning==1.8.0.post1
trainer:
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/
gradient_clip_val: 0.1
max_epochs: 1000
log_every_n_steps: 20
accelerator: gpu
devices: 1
precision: 16
callbacks:
- class_path: "pytorch_lightning.callbacks.LearningRateMonitor"
init_args:
logging_interval: step
- class_path: "pytorch_lightning.callbacks.ModelCheckpoint"
init_args:
save_top_k: 1
save_last: true
monitor: val_loss
every_n_epochs: 1
model:
lr: 3.0e-05
lr_backbone: 3.0e-07
weight_decay: 0.0001
threshold: 0.5
batch_size: 32
train_dataset_name: train
val_dataset_name: val
ckpt_path: null
10 changes: 10 additions & 0 deletions baselines/table-transformer/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import table_transformer
from pytorch_lightning.cli import LightningCLI


def cli_main():
cli = LightningCLI(table_transformer.TableDetr) # noqa: F841


if __name__ == "__main__":
cli_main()
117 changes: 117 additions & 0 deletions baselines/table-transformer/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import argparse
import json
import logging
import pickle
from collections import defaultdict

import table_transformer
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

import docile


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", required=True)
parser.add_argument(
"--prediction_type",
choices=["table_detection", "table_line_item_detection"],
required=True,
)
parser.add_argument("--dataset_path", required=True)
parser.add_argument("--split", choices=["train", "val", "test"], required=True)
parser.add_argument(
"--output_json_path",
required=True,
help="A dictionary doc_id->page_n->predictions. With --prediction_type "
"table_detection, a pickle file with a list of bboxes (one for each "
"page) will be saved beside the json file.",
)
parser.add_argument("--batch_size", default=16)
parser.add_argument("--num_workers", default=16)
parser.add_argument(
"--table_detection_predictions_pickle",
default=None,
help="Path to a file created by this script for a table detection model.",
)
args = parser.parse_args()

if (
args.prediction_type == "table_line_item_detection"
and args.table_detection_predictions_pickle is None
):
raise ValueError(
"--prediction_type table_line_item_detection requires --table_detection_predictions_pickle"
)

if torch.cuda.is_available():
accelerator = "gpu"
else:
logging.warning("CUDA not available, predicting on CPU")
accelerator = "cpu"

table_detr = table_transformer.TableDetr.load_from_checkpoint(args.checkpoint_path)

docile_dataset = docile.dataset.Dataset(args.split, args.dataset_path)

crop_bboxes = None
if args.table_detection_predictions_pickle is not None:
with open(args.table_detection_predictions_pickle, "rb") as fin:
crop_bboxes = pickle.load(fin)
dataset_not_cropped = table_transformer.TableTransformerDataset(
docile_dataset=docile_dataset, extractor=table_detr.extractor
)

dataset = table_transformer.TableTransformerDataset(
docile_dataset=docile_dataset, extractor=table_detr.extractor, crop_bboxes=crop_bboxes
)

evaluator = Trainer(accelerator=accelerator, devices=1)
_res = evaluator.predict(
table_detr,
DataLoader(
dataset,
collate_fn=table_detr.collate_fn,
batch_size=args.batch_size,
num_workers=args.num_workers,
),
)
res_detection = []
for r in _res:
res_detection.extend(r)

bbox_dict = defaultdict(lambda: defaultdict(list))

if args.prediction_type == "table_detection":
res_detection = [list(a[0]) if len(a) else None for a in res_detection]
with open(args.output_json_path.strip(".json") + ".pickle", "wb") as fout:
pickle.dump(res_detection, fout)
for ((doc_id, page_n), _), bbox in zip(dataset.coco_annots, res_detection):
bbox_dict[doc_id][page_n] = [round(x) for x in bbox] if bbox else None

elif args.prediction_type == "table_line_item_detection":
assert len(res_detection) == len(dataset.crop_bboxes)
for bboxes, crop_bbox in zip(res_detection, dataset.crop_bboxes):
for bbox in bboxes:
# left offset
bbox[0] += crop_bbox[0]
bbox[2] += crop_bbox[0]
# top offset
bbox[1] += crop_bbox[1]
bbox[3] += crop_bbox[1]

# output predictions for all pages, not only for those with predictions
for (doc_id, page_n), _ in dataset_not_cropped.coco_annots:
bbox_dict[doc_id][page_n] = []

for ((doc_id, page_n), _), bboxes in zip(dataset.coco_annots, res_detection):
bbox_dict[doc_id][page_n] = [[round(x) for x in bbox] for bbox in bboxes]

with open(args.output_json_path, "w") as fout:
json.dump(bbox_dict, fout)


if __name__ == "__main__":
main()
90 changes: 90 additions & 0 deletions baselines/table-transformer/table_detection_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# pytorch_lightning==1.8.6
seed_everything: true
trainer:
logger: true
enable_checkpointing: true
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: step
log_momentum: false
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: val_loss
verbose: false
save_last: true
save_top_k: 1
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: null
train_time_interval: null
every_n_epochs: 1
save_on_train_epoch_end: null
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/
gradient_clip_val: 0.1
gradient_clip_algorithm: null
num_nodes: 1
num_processes: null
devices: 1
gpus: null
auto_select_gpus: false
tpu_cores: null
ipus: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: null
max_epochs: 1000
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
val_check_interval: null
log_every_n_steps: 20
accelerator: gpu
strategy: null
sync_batchnorm: false
precision: 16
enable_model_summary: true
num_sanity_val_steps: 2
resume_from_checkpoint: null
profiler: null
benchmark: null
deterministic: null
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
amp_backend: native
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
inference_mode: true
model:
description: Fine-tuning for table detection (full table bbox)
train_dataset_name: train
val_dataset_name: val
task: table-detection
initial_checkpoint: facebook/detr-resnet-50
dataset_path: /app/data/docile/
load_ground_truth_crop_bboxes: false
predictions_root_dir: /app/data/baselines/line_item_detection/table_transformer/predictions/
crop_bboxes_filename: null
lr: 3.0e-05
lr_backbone: 3.0e-07
weight_decay: 0.0001
batch_size: 32
num_workers: 16
threshold: 0.5
ckpt_path: null
90 changes: 90 additions & 0 deletions baselines/table-transformer/table_line_item_detection_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# pytorch_lightning==1.8.6
seed_everything: true
trainer:
logger: true
enable_checkpointing: true
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: step
log_momentum: false
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: val_loss
verbose: false
save_last: true
save_top_k: 1
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: null
train_time_interval: null
every_n_epochs: 1
save_on_train_epoch_end: null
default_root_dir: /app/data/baselines/line_item_detection/table_transformer/
gradient_clip_val: 0.1
gradient_clip_algorithm: null
num_nodes: 1
num_processes: null
devices: 1
gpus: null
auto_select_gpus: false
tpu_cores: null
ipus: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: null
max_epochs: 1000
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
val_check_interval: null
log_every_n_steps: 20
accelerator: gpu
strategy: null
sync_batchnorm: false
precision: 16
enable_model_summary: true
num_sanity_val_steps: 2
resume_from_checkpoint: null
profiler: null
benchmark: null
deterministic: null
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
amp_backend: native
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
inference_mode: true
model:
description: Fine-tuning DETR for line item detection on GT cropped tables (full)
train_dataset_name: train
val_dataset_name: val
task: table-line-item-detection
initial_checkpoint: facebook/detr-resnet-50
dataset_path: /app/data/docile/
load_ground_truth_crop_bboxes: false
predictions_root_dir: /app/data/baselines/line_item_detection/table_transformer/predictions/
crop_bboxes_filename: null
lr: 3.0e-05
lr_backbone: 3.0e-07
weight_decay: 0.0001
batch_size: 32
num_workers: 16
threshold: 0.5
ckpt_path: null
Loading

0 comments on commit 9da2379

Please sign in to comment.