Skip to content

Commit

Permalink
[Update] update the args
Browse files Browse the repository at this point in the history
  • Loading branch information
tonysy committed Oct 18, 2022
1 parent a4c0c73 commit a562ad4
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 40 deletions.
5 changes: 3 additions & 2 deletions configs/_base_/models/tinyvit/tinyvit-11m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='TinyViT',
arch='tinyvit_11m_224',
resolution=(224, 224),
arch='11m',
img_size=(224, 224),
window_size=[7, 7, 14, 7],
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
5 changes: 3 additions & 2 deletions configs/_base_/models/tinyvit/tinyvit-21m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_224',
resolution=(224, 224),
arch='21m',
img_size=(224, 224),
window_size=[7, 7, 14, 7],
out_indices=(3, ),
drop_path_rate=0.2,
gap_before_final_norm=True,
Expand Down
5 changes: 3 additions & 2 deletions configs/_base_/models/tinyvit/tinyvit-21m-384.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_384',
resolution=(384, 384),
arch='21m',
img_size=(384, 384),
window_size=[12, 12, 24, 12],
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
5 changes: 3 additions & 2 deletions configs/_base_/models/tinyvit/tinyvit-21m-512.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='TinyViT',
arch='tinyvit_21m_512',
resolution=(512, 512),
arch='21m',
img_size=(512, 512),
window_size=[16, 16, 32, 16],
out_indices=(3, ),
drop_path_rate=0.1,
gap_before_final_norm=True,
Expand Down
5 changes: 3 additions & 2 deletions configs/_base_/models/tinyvit/tinyvit-5m-224.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
type='ImageClassifier',
backbone=dict(
type='TinyViT',
arch='tinyvit_5m_224',
resolution=(224, 224),
arch='5m',
img_size=(224, 224),
window_size=[7, 7, 14, 7],
out_indices=(3, ),
drop_path_rate=0.0,
gap_before_final_norm=True,
Expand Down
2 changes: 1 addition & 1 deletion configs/tinyvit/tinyvit-21m-512-distill_8xb256_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
dict(type='PackClsInputs'),
]

val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val_dataloader = dict(batch_size=16, dataset=dict(pipeline=test_pipeline))

test_dataloader = val_dataloader
44 changes: 15 additions & 29 deletions mmcls/models/backbones/tinyvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self,
act_cfg=dict(type='GELU')):
super().__init__()

self.resolution = resolution
self.img_size = resolution

self.act = build_activation_layer(act_cfg)
self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1)
Expand All @@ -172,7 +172,7 @@ def __init__(self,

def forward(self, x):
if len(x.shape) == 3:
H, W = self.resolution
H, W = self.img_size
B = x.shape[0]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(self,
act_cfg=dict(type='GELU')):
super().__init__()
self.in_channels = in_channels
self.resolution = resolution
self.img_size = resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
Expand Down Expand Up @@ -428,7 +428,7 @@ def __init__(self,
groups=in_channels)

def forward(self, x):
H, W = self.resolution
H, W = self.img_size
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
Expand Down Expand Up @@ -596,41 +596,27 @@ class TinyViT(BaseBackbone):
Default: None.
"""
arch_settings = {
'tinyvit_5m_224': {
'5m': {
'channels': [64, 128, 160, 320],
'num_heads': [2, 4, 5, 10],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_11m_224': {
'11m': {
'channels': [64, 128, 256, 448],
'num_heads': [2, 4, 8, 14],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_224': {
'21m': {
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [7, 7, 14, 7],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_384': {
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [12, 12, 24, 12],
'depths': [2, 2, 6, 2],
},
'tinyvit_21m_512': {
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'window_sizes': [16, 16, 32, 16],
'depths': [2, 2, 6, 2],
}
}

def __init__(self,
arch='tinyvit_5m_224',
resolution=(224, 224),
img_size=(224, 224),
window_size=[7, 7, 14, 7],
in_channels=3,
mlp_ratio=4.,
drop_rate=0.,
Expand All @@ -654,14 +640,14 @@ def __init__(self,
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'channels' in arch and 'num_heads' in arch and \
'window_sizes' in arch and 'depths' in arch, \
f'Th arch dict must have "channels", "num_heads", ' \
f'"window_sizes" keys, but got {arch.keys()}'
'depths' in arch, 'The arch dict must have' \
f'"channels", "num_heads", "window_sizes" ' \
f'keys, but got {arch.keys()}'

self.channels = arch['channels']
self.num_heads = arch['num_heads']
self.widow_sizes = arch['window_sizes']
self.resolution = resolution
self.widow_sizes = window_size
self.img_size = img_size
self.depths = arch['depths']

self.num_stages = len(self.channels)
Expand All @@ -684,7 +670,7 @@ def __init__(self,
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dim=self.channels[0],
resolution=self.resolution,
resolution=self.img_size,
act_cfg=dict(type='GELU'))
patches_resolution = self.patch_embed.patches_resolution

Expand Down

0 comments on commit a562ad4

Please sign in to comment.