Skip to content

Commit

Permalink
[Feature] Support otter (#1651)
Browse files Browse the repository at this point in the history
* [Feature] Support Otter

* Update docs
  • Loading branch information
mzr1996 committed Jun 17, 2023
1 parent 9d3fc43 commit e69bace
Show file tree
Hide file tree
Showing 13 changed files with 540 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ mim install -e ".[multimodal]"
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>
<td>
Expand Down
78 changes: 78 additions & 0 deletions configs/otter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Otter

> [Otter: A Multi-Modal Model with In-Context Instruction Tuning](https://arxiv.org/abs/2305.03726)
<!-- [ALGORITHM] -->

## Abstract

Large language models (LLMs) have demonstrated significant universal capabilities as few/zero-shot learners in various tasks due to their pre-training on vast amounts of text data, as exemplified by GPT-3, which boosted to InstrctGPT and ChatGPT, effectively following natural language instructions to accomplish real-world tasks. In this paper, we propose to introduce instruction tuning into multi-modal models, motivated by the Flamingo model's upstream interleaved format pretraining dataset. We adopt a similar approach to construct our MultI-Modal In-Context Instruction Tuning (MIMIC-IT) dataset. We then introduce Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following ability and in-context learning. We also optimize OpenFlamingo's implementation for researchers, democratizing the required training resources from 1$\times$ A100 GPU to 4$\times$ RTX-3090 GPUs, and integrate both OpenFlamingo and Otter into Huggingface Transformers for more researchers to incorporate the models into their customized training and inference pipelines.

<div align=center>
<img src="https://camo.githubusercontent.com/70613ab882a7827808148a2c577029d544371e707b0832a0b01151c54ce553c3/68747470733a2f2f692e706f7374696d672e63632f5477315a304243572f6f7474657276302d322d64656d6f2e706e67" width="80%"/>
</div>

## How to use it?

<!-- [TABS-BEGIN] -->

**Use the model**

```python
import torch
from mmpretrain import get_model, inference_model

model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
print(out)
```

**Test Command**

Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).

Test:

```shell
python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
```

<!-- [TABS-END] -->

## Models and results

### Image Caption on COCO

| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download |
| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |

*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*

### Visual Question Answering on VQAv2

| Model | Pretrain | Params (M) | Accuracy | Config | Download |
| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |

*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*

## Citation

```bibtex
@article{li2023otter,
title={Otter: A Multi-Modal Model with In-Context Instruction Tuning},
author={Li, Bo and Zhang, Yuanhan and Chen, Liangyu and Wang, Jinghao and Yang, Jingkang and Liu, Ziwei},
journal={arXiv preprint arXiv:2305.03726},
year={2023}
}
@article{li2023mimicit,
title={MIMIC-IT: Multi-Modal In-Context Instruction Tuning},
author={Bo Li and Yuanhan Zhang and Liangyu Chen and Jinghao Wang and Fanyi Pu and Jingkang Yang and Chunyuan Li and Ziwei Liu},
year={2023},
eprint={2306.05425},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
43 changes: 43 additions & 0 deletions configs/otter/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Collections:
- Name: Otter
Metadata:
Architecture:
- Transformer
- Gated Cross-Attention Dense
Paper:
Title: 'Otter: A Multi-Modal Model with In-Context Instruction Tuning'
URL: https://arxiv.org/abs/2305.03726
README: configs/otter/README.md

Models:
- Name: otter-9b_3rdparty_caption
Metadata:
FLOPs: null
Parameters: 8220452880
In Collection: Otter
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
Config: configs/otter/otter-9b_caption.py
Converted From:
Weights: https://huggingface.co/luodian/otter-9b-hf
Code: https://github.com/Luodian/Otter/tree/main
- Name: otter-9b_3rdparty_vqa
Metadata:
FLOPs: null
Parameters: 8220452880
In Collection: Otter
Results:
- Task: Visual Question Answering
Dataset: VQAv2
Metrics:
Accuracy: null
Weights: https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth
Config: configs/otter/otter-9b_vqa.py
Converted From:
Weights: https://huggingface.co/luodian/otter-9b-hf
Code: https://github.com/Luodian/Otter/tree/main
91 changes: 91 additions & 0 deletions configs/otter/otter-9b_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
_base_ = [
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='Otter',
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
lang_encoder=dict(
base=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
local_files_only=True),
adapter=dict(
type='FlamingoLMAdapter',
vis_hidden_size=1024,
cross_attn_every_n_layers=4,
use_media_placement_augmentation=False,
only_attend_previous=True,
),
),
task='caption',
final_prompt_tmpl='<image>User:Please describe the image. GPT:<answer>',
generation_cfg=dict(
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='CenterCrop', crop_size=(224, 224)),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]

val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOCaption',
data_root='data/coco',
ann_file='annotations/captions_train2014.json',
data_prefix=dict(img_path='train2014'),
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

val_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/captions_train2014.json')

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

# schedule settings
val_cfg = dict()
test_cfg = dict()
104 changes: 104 additions & 0 deletions configs/otter/otter-9b_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
_base_ = [
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='Otter',
tokenizer=dict(type='LlamaTokenizer', name_or_path='huggyllama/llama-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
lang_encoder=dict(
base=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
local_files_only=True),
adapter=dict(
type='FlamingoLMAdapter',
vis_hidden_size=1024,
cross_attn_every_n_layers=4,
use_media_placement_augmentation=False,
only_attend_previous=True,
),
),
task='vqa',
final_prompt_tmpl='<image>User:{question} GPT:<answer>',
generation_cfg=dict(
num_beams=3, max_new_tokens=24, no_repeat_ngram_size=3),
)

# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='CenterCrop', crop_size=(224, 224)),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight', 'shots'],
meta_keys=['image_id'],
),
]

val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOVQA',
data_root='data/coco',
data_prefix='val2014',
question_file='annotations/v2_OpenEnded_mscoco_val2014_questions.json',
ann_file='annotations/v2_mscoco_val2014_annotations.json',
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VQAAcc')

test_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOVQA',
data_root='data/coco',
data_prefix='test2015',
question_file=
'annotations/v2_OpenEnded_mscoco_test-dev2015_questions.json',
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(type='ReportVQA', file_path='vqa_test-dev.json')

# schedule settings
val_cfg = dict()
test_cfg = dict()
1 change: 1 addition & 0 deletions docs/en/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Multi-Modality Algorithms
Flamingo
OFA
MiniGPT4
Otter

.. module:: mmpretrain.models.backbones

Expand Down
3 changes: 2 additions & 1 deletion mmpretrain/models/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from .flamingo import * # noqa: F401, F403
from .minigpt4 import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
else:
from mmpretrain.registry import MODELS
from mmpretrain.utils.dependency import register_multimodal_placeholder

register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
'OFA', 'ChineseCLIP', 'MiniGPT4'
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Otter'
], MODELS)
10 changes: 8 additions & 2 deletions mmpretrain/models/multimodal/flamingo/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def extend_init(
vis_hidden_size: int,
cross_attn_every_n_layers: int,
use_media_placement_augmentation: bool,
only_attend_previous: bool = False,
):
"""Initialize Flamingo by adding a new gated cross attn to the decoder.
Expand Down Expand Up @@ -48,6 +49,7 @@ def extend_init(
]))
base.use_media_placement_augmentation = use_media_placement_augmentation # noqa
base.initialized_flamingo = True
base.only_attend_previous = only_attend_previous
return base

def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
Expand All @@ -67,8 +69,12 @@ def forward(self, *input, **kwargs):
function."""
input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = ((random.random() < 0.5)
if self.use_media_placement_augmentation else False)
if self.only_attend_previous:
attend_previous = True
elif self.use_media_placement_augmentation:
attend_previous = (random.random() < 0.5)
else:
attend_previous = False

for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
Expand Down
4 changes: 4 additions & 0 deletions mmpretrain/models/multimodal/otter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .otter import Otter

__all__ = ['Otter']
Loading

0 comments on commit e69bace

Please sign in to comment.