Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mue #336

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.idea/*
.DS_Store
.vscode
.vscode
__pycache__
logs/
vqa_logs/
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ We list the parameters and pretrained checkpoints of OFAs below. For finetuned c
<br></br>

# Results
Below we demonstrate the results of OFAs on cross-modal understanding and generation.
Below we demonstrate the results of OFAs on cross-modal understanding and generation. You can find more results of MuE model in [MuE](https://arxiv.org/abs/2211.11152)

<table border="1" width="100%">
<tr align="center">
Expand Down Expand Up @@ -254,6 +254,9 @@ We provide procedures to reproduce our results of image captioning on our paper
cd run_scripts/caption
nohup sh train_caption_stage1.sh > train_stage1.out & # stage 1, train with cross-entropy loss
nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best ckpt of stage1 and train with CIDEr optimization
# If you need to finetune MuE model, please apply the following script
nohup sh train_caption_stage1_base_MuE.sh > train_stage1.out &
# The stage2 uses the same script above
</pre>
</details>
<details>
Expand All @@ -263,6 +266,9 @@ nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best
</p>
<pre>
cd run_scripts/caption ; sh evaluate_caption.sh # inference & evaluate
# If you want to evaluate your MuE Model
sh evaluate_caption_base_MuE.sh
# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
</pre>
</details>

Expand Down Expand Up @@ -429,6 +435,8 @@ We provide steps for you to reproduce our results in visual entailment. See the
<pre>
cd run_scripts/snli_ve
nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve
# If you need to finetune MuE model, please apply the following script
nohup sh train_snli_ve_base_MuE.sh > train_snli_ve_MuE.out &
</pre>
</details>
<details>
Expand All @@ -438,6 +446,9 @@ nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve
</p>
<pre>
cd run_scripts/snli_ve ; sh evaluate_snli_ve.sh dev # specify 'dev' or 'test'
# If you want to evaluate your MuE Model
sh evaluate_snli_ve_base_MuE.sh
# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
</pre>
</details>

Expand Down Expand Up @@ -600,5 +611,20 @@ Please cite our paper if you find it helpful :)
volume = {abs/2202.03052},
year = {2022}
}

@article{tang2022you,
title={You Need Multiple Exiting: Dynamic Early Exiting for Accelerating Unified Vision Language Model},
author={Tang, Shengkun and
Wang, Yaqing and
Kong, Zhenglun and
Zhang, Tianchi and
Li, Yao and
Ding, Caiwen and
Wang, Yanzhi and
Liang, Yi and
Xu, Dongkuan},
journal={arXiv preprint arXiv:2211.11152},
year={2022}
}
```
<br></br>
67 changes: 66 additions & 1 deletion criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ def forward(self, model, sample, update_num=0, reduce=True):
def get_lprobs_and_target(self, model, net_output, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
constraint_masks = None
# some weird bug will occur without this operation and following out-place operation.
# This operation doesn't change logic.
net_output = list(net_output)
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
constraint_masks = sample["constraint_masks"]
net_output[0].masked_fill_(~constraint_masks, -math.inf)
net_output[0] = net_output[0].masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
Expand Down Expand Up @@ -341,3 +344,65 @@ def logging_outputs_can_be_summed() -> bool:
to True will improves distributed training speed.
"""
return True

@register_criterion(
"MuE_Task_Loss", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
)
class MuE_Task_Loss(AdjustLabelSmoothedCrossEntropyCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
ignore_eos=False,
report_accuracy=False,
drop_worst_ratio=0,
drop_worst_after=0,
use_rdrop=False,
reg_alpha=1.0,
sample_patch_num=196,
constraint_range=None
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
ignore_eos,
report_accuracy,
drop_worst_ratio,
drop_worst_after,
use_rdrop,
reg_alpha,
sample_patch_num,
constraint_range)

def compute_loss(self, model, net_output, sample, update_num, reduce=True):
loss_all = 0.0
nll_loss_all = 0.0
ntokens = 0
print("using MuE Task loss")
for state in net_output[1]["inner_out_states"]:
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, [state], sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]
target = target[target != self.padding_idx]
loss, nll_loss, ntokens = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
update_num,
reduce=reduce,
drop_worst_ratio=self.drop_worst_ratio,
drop_worst_after=self.drop_worst_after,
use_rdrop=self.use_rdrop,
reg_alpha=self.reg_alpha,
constraint_masks=constraint_masks,
constraint_start=self.constraint_start,
constraint_end=self.constraint_end
)
loss_all += loss
nll_loss_all += nll_loss
return loss_all, nll_loss_all, ntokens
13 changes: 10 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def main(cfg: DictConfig, **kwargs):
score_sum += sum([s[0] for s in scores])
score_cnt += sum([s[1] for s in scores])
else:
score_sum += sum(scores) if scores is not None else 0
score_cnt += len(scores) if scores is not None else 0
score_sum += sum(scores)
score_cnt += len(scores)

progress.log({"sentences": sample["nsentences"]})

Expand All @@ -173,10 +173,17 @@ def cli_main():
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument("--zero-shot", action='store_true')
parser.add_argument('--img_thres', type=float, metavar='D', default=1.0,
help='image theshold for early exiting model')
parser.add_argument('--txt_thres', type=float, metavar='D', default=1.0,
help='text theshold for early exiting model')
parser.add_argument('--decoder_thres', type=float, metavar='D', default=1.0,
help='decoder theshold for early exiting model')
args = options.parse_args_and_arch(parser)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot,
img_thres=args.img_thres, txt_thres=args.txt_thres, decoder_thres=args.decoder_thres, is_train=False
)


Expand Down
4 changes: 2 additions & 2 deletions fairseq/fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,11 @@ def build_dataset_for_inference(
raise NotImplementedError

def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
self, generator, models, sample, prefix_tokens=None, constraints=None, **kwargs
):
with torch.no_grad():
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
models, sample, prefix_tokens=prefix_tokens, constraints=constraints, **kwargs
)

def begin_epoch(self, epoch, model):
Expand Down
15 changes: 9 additions & 6 deletions models/ofa/unify_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def forward(
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)


saved_state_new = {}
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
Expand Down Expand Up @@ -298,13 +299,15 @@ def forward(
src_len=k.size(1),
static_kv=static_kv,
)

saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# There are reference bugs if change saved_state directly.
# This causes error during inference in early exiting models (MuE)
# However, this has no influence on original OFA models.
saved_state_new["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state_new["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state_new["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
incremental_state = self._set_input_buffer(incremental_state, saved_state_new)
assert k is not None
assert k.size(1) == src_len

Expand Down
Loading