Skip to content

Commit

Permalink
[20230608 v0.5.1] Fix compatibility issues
Browse files Browse the repository at this point in the history
  • Loading branch information
horrible-dong committed Jun 8, 2023
1 parent c2ecb57 commit 3386bd1
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 21 deletions.
2 changes: 1 addition & 1 deletion qtcls/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.nn as nn
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_

__all__ = [
'Cait',
Expand Down
6 changes: 4 additions & 2 deletions qtcls/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# -------------------------------------------------------------------------------

import torch
from timm.models.vision_transformer import VisionTransformer, trunc_normal_
from torch import nn as nn
import torch.nn as nn
from timm.layers import trunc_normal_

from .vision_transformer_timm import VisionTransformer

__all__ = [
'VisionTransformerDistilled',
Expand Down
3 changes: 1 addition & 2 deletions qtcls/models/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import torch
import torch.nn as nn
from timm.models.layers import to_ntuple, get_act_layer
from timm.models.vision_transformer import trunc_normal_
from timm.layers import to_ntuple, get_act_layer, trunc_normal_

__all__ = [
'Levit',
Expand Down
4 changes: 2 additions & 2 deletions qtcls/models/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
import torch.nn as nn
from timm.models.helpers import named_apply
from timm.models.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from timm.models import named_apply

__all__ = [
'MlpMixer',
Expand Down
2 changes: 1 addition & 1 deletion qtcls/models/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_

__all__ = [
'PyramidVisionTransformer',
Expand Down
2 changes: 1 addition & 1 deletion qtcls/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_

__all__ = [
'SwinTransformer',
Expand Down
4 changes: 2 additions & 2 deletions qtcls/models/swin_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.fx_features import register_notrace_function
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from timm.models import register_notrace_function

__all__ = [
'SwinTransformerV2',
Expand Down
4 changes: 1 addition & 3 deletions qtcls/models/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_
from timm.models.layers import _assert
from timm.models.layers.helpers import to_2tuple
from timm.layers import _assert, to_2tuple, DropPath, Mlp, trunc_normal_

__all__ = [
'TNT',
Expand Down
31 changes: 29 additions & 2 deletions qtcls/models/twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, Mlp, to_2tuple, trunc_normal_
from timm.models.vision_transformer import Attention
from timm.layers import DropPath, Mlp, to_2tuple, trunc_normal_

__all__ = [
'Twins',
Expand All @@ -25,6 +24,34 @@
Size_ = Tuple[int, int]


class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group
"""
Expand Down
6 changes: 2 additions & 4 deletions qtcls/models/vision_transformer_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

import torch
import torch.nn as nn
from timm.models.helpers import named_apply
from timm.models.layers import trunc_normal_, lecun_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.layers.trace_utils import _assert
from timm.layers import _assert, to_2tuple, trunc_normal_, lecun_normal_
from timm.models import named_apply

__all__ = [
'VisionTransformer',
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Pillow
scikit-learn
scipy
termcolor
timm
timm==0.9.2

0 comments on commit 3386bd1

Please sign in to comment.