Skip to content

Commit

Permalink
Rename the TTA model to AverageClsScoreTTA.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Dec 30, 2022
1 parent 2390c9f commit bd31382
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
4 changes: 2 additions & 2 deletions mmcls/models/tta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .score_tta import AverageScoreTTAModel
from .score_tta import AverageClsScoreTTA

__all__ = ['AverageScoreTTAModel']
__all__ = ['AverageClsScoreTTA']
18 changes: 13 additions & 5 deletions mmcls/models/tta/score_tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,33 @@
from typing import List

from mmengine.model import BaseTTAModel
from mmengine.structures import BaseDataElement

from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample


@MODELS.register_module()
class AverageScoreTTAModel(BaseTTAModel):
class AverageClsScoreTTA(BaseTTAModel):

def merge_preds(
self,
data_samples_list: List[List[ClsDataSample]],
) -> List[BaseDataElement]:
) -> List[ClsDataSample]:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[ClsDataSample]]): List of predictions
of all enhanced data.
Returns:
List[ClsDataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self.merge_single_sample(data_samples))
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples

def merge_single_sample(self, data_samples):
def _merge_single_sample(self, data_samples):
merged_data_sample: ClsDataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_models/test_tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
import torch
from mmengine import ConfigDict

from mmcls.models import AverageScoreTTAModel, ImageClassifier
from mmcls.models import AverageClsScoreTTA, ImageClassifier
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules

register_all_modules()


class TestAverageScoreTTAModel(TestCase):
class TestAverageClsScoreTTA(TestCase):
DEFAULT_ARGS = dict(
type='AverageScoreTTAModel',
type='AverageClsScoreTTA',
module=dict(
type='ImageClassifier',
backbone=dict(type='ResNet', depth=18),
Expand All @@ -27,12 +27,12 @@ class TestAverageScoreTTAModel(TestCase):
loss=dict(type='CrossEntropyLoss'))))

def test_initialize(self):
model: AverageScoreTTAModel = MODELS.build(self.DEFAULT_ARGS)
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
self.assertIsInstance(model.module, ImageClassifier)

def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
model: AverageScoreTTAModel = MODELS.build(self.DEFAULT_ARGS)
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)

# The forward of TTA model should not be called.
with self.assertRaisesRegex(NotImplementedError, 'will not be called'):
Expand All @@ -42,7 +42,7 @@ def test_test_step(self):
cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS))
cfg.module.data_preprocessor = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
model: AverageScoreTTAModel = MODELS.build(cfg)
model: AverageClsScoreTTA = MODELS.build(cfg)

img1 = torch.randint(0, 256, (1, 3, 224, 224))
img2 = torch.randint(0, 256, (1, 3, 224, 224))
Expand Down
2 changes: 1 addition & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def merge_args(cfg, args):
# -------------------- TTA related args --------------------
if args.tta:
if 'tta_model' not in cfg:
cfg.tta_model = dict(type='mmcls.AverageScoreTTAModel')
cfg.tta_model = dict(type='mmcls.AverageClsScoreTTA')
if 'tta_pipeline' not in cfg:
test_pipeline = cfg.test_dataloader.dataset.pipeline
cfg.tta_pipeline = deepcopy(test_pipeline)
Expand Down

0 comments on commit bd31382

Please sign in to comment.