1. Model Architecture setup and evaluation data flow(for ade150k)
CATSeg
setup:
-
backbone
: D2SwinTransformer -> Swintransformer -> BasicLayer(2) -> SwinTransformerBlock -> WindowAttention -
sem_seg_head
:CATSegHead.from_config
->CATSegPredictor
->-
Load CLIP model -> Load text templates ->
class_embeddings(self.class_texts, prompt_templates, clip_model)
-> for each class:- bpe encode classname in different templates and save results in variable
texts
(80(number of templates), 77(number of sentence length)). - CLIP encode
texts
:texts
go throughtoken_embedding
(nn.Embedding
) (80,77,768(hidden_dim))texts
go through a 12 layers of ResidualAttentionBlock (80,77,768)- take features of
texts
from theeot_token
(80,768)
- bpe encode classname in different templates and save results in variable
-
do the above for all classes (150(number of test classes),80,768)
-
Aggregator
-> 2 layers ofAggregatorLayer
:-
swin_block
:-
SwinTransformerBlockWrapper
:class SwinTransformerBlockWrapper(nn.Module): def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5): super().__init__() self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0) self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2) self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None
-
-
attention
:-
ClassTransformerLayer
:class ClassTransformerLayer(nn.Module): def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None: super().__init__() self.pool = nn.AvgPool2d(pooling_size) self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type) self.MLP = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(), nn.Linear(hidden_dim * 4, hidden_dim) ) self.norm1 = nn.LayerNorm(hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim)
class LinearAttention(nn.Module): def __init__(self, eps=1e-6): super().__init__() self.feature_map = elu_feature_map self.eps = eps def forward(self, queries, keys, values): """ Multi-Head linear attention proposed in "Transformers are RNNs" Args: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ Q = self.feature_map(queries) K = self.feature_map(keys) v_length = values.size(1) values = values / v_length # prevent fp16 overflow KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length return queried_values.contiguous() class FullAttention(nn.Module): def __init__(self, use_dropout=False, attention_dropout=0.1): super().__init__() self.use_dropout = use_dropout self.dropout = nn.Dropout(attention_dropout) def forward(self, queries, keys, values, q_mask=None, kv_mask=None): """ Multi-head scaled dot-product attention, a.k.a full attention. Args: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ # Compute the unnormalized attention and apply the masks QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) if kv_mask is not None: QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) # Compute the attention and the weighted average softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=2) if self.use_dropout: A = self.dropout(A) queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) return queried_values.contiguous() class AttentionLayer(nn.Module): def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'): super().__init__() self.nheads = nheads self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim) self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim) self.v = nn.Linear(hidden_dim, hidden_dim) if attention_type == 'linear': self.attention = LinearAttention() elif attention_type == 'full': self.attention = FullAttention() else: raise NotImplementedError
-
-
Remaining of
Aggregator
self.guidance_projection = nn.Sequential( nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) if appearance_guidance_dim > 0 else None self.text_guidance_projection = nn.Sequential( nn.Linear(text_guidance_dim, text_guidance_proj_dim), nn.ReLU(), ) if text_guidance_dim > 0 else None self.decoder_guidance_projection = nn.ModuleList([ nn.Sequential( nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims) ]) if decoder_guidance_dims[0] > 0 else None self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0]) self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1]) self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
-
Up
class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, guidance_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x, guidance=None): x = self.up(x) if guidance is not None: T = x.size(0) // guidance.size(0) guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T) x = torch.cat([x, guidance], dim=1) return self.conv(x)
-
-
-
CATSeg
forward for each image:
-
image
(3,640,854) ->self.inference_sliding_window(batched_inputs)
-
image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
-> (3,640,640)
-
image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
-> (442368(3x384x384(kernel size)),4(number of such patch)) -> (4,3,384,384)
-
global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False) image = torch.cat((image, global_image), dim=0)
-> (5,3,384,384) 与下面呼应!
-
features = self.backbone(images) # features: a dictionary with length of 3 clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) # clip_images: (5, 3, 336, 336) # outputs: (5,150,96,96) outputs = self.sem_seg_head(clip_features, features)
features
: ![image-20240401003703609](/Users/biboyqg/Library/Application Support/typora-user-images/image-20240401003703609.png)clip_features
: (5,577(24x24+1),768)outputs
: (5,150(number of classes),96,96)
-
After the three steps:
outputs
-> (5,150,96,96)# outputs: (5,150,96,96) -> (5,150,384,384) outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False) # -> (5,150,384,384) 与上面呼应! outputs = outputs.sigmoid() global_output = outputs[-1:] global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,) outputs = outputs[:-1] # -> (4,150,384,384) outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device))) # fenzi: (4,22118400) -> (22118400,4) -> (150,640,640) # fenmu: (1,640,640) # This steps normalize the effects brought by the fold operation. outputs = (outputs + global_output) / 2. # -> (1,150,640,640) height = batched_inputs[0].get("height", out_res[0]) width = batched_inputs[0].get("width", out_res[1]) output = sem_seg_postprocess(outputs[0], out_res, height, width) # -> (150,512,683)
The workflow within three main steps
-
-
The workflow within
features = self.backbone(images) # features: a dictionary with length of 3
:-
class D2SwinTransformer(SwinTransformer, Backbone): def forward(self, x): # x -> (5,3,384,384) """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert ( x.dim() == 4 ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} y = super().forward(x) # y -> a dict of three tensors: {(5,128,96,96), (5,256,48,48), (5,512,24,24)}. Same for shape of the outputs for k in y.keys(): if k in self._out_features: outputs[k] = y[k] return outputs
-
class SwinTransformer(nn.Module): def forward(self, x): """Forward function.""" x = self.patch_embed(x) # (5,3,384,384) -> (5,128,96,96) 解释在下面 Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate( self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" ) x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) # -> (5,9216(96x96),128) x = self.pos_drop(x) # no change (5,9216,128) outs = {} for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) # x_out -> (5,9216,128)/(5,2304,256)/(5,576,256) if i in self.out_indices: norm_layer = getattr(self, f"norm{i}") x_out = norm_layer(x_out) # no change (5,9216,128)/(5,2304,256)/(5,576,512) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() # out: (5,128,96,96)/(5,256,48,48)/(5,512,24,24) outs["res{}".format(i + 2)] = out return outs
-
class PatchEmbed(nn.Module): def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww (5,3,384,384) -> (5,128,96,96) if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x # (5,128,96,96) 传回上面👆🏻
-
self.layers: ModuleList( (0): BasicLayer( (blocks): ModuleList( (0): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=128, out_features=384, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): Identity() (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=512, out_features=128, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (1): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=128, out_features=384, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.014) (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=512, out_features=128, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) ) (downsample): PatchMerging( (reduction): Linear(in_features=512, out_features=256, bias=False) (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True) ) ) (1): BasicLayer( (blocks): ModuleList( (0): SwinTransformerBlock( (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=256, out_features=768, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=256, out_features=256, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.029) (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=256, out_features=1024, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=1024, out_features=256, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (1): SwinTransformerBlock( (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=256, out_features=768, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=256, out_features=256, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.043) (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=256, out_features=1024, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=1024, out_features=256, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) ) (downsample): PatchMerging( (reduction): Linear(in_features=1024, out_features=512, bias=False) (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (2): BasicLayer( (blocks): ModuleList( (0): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.057) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (1): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.071) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (2): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.086) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (3): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.100) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (4): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.114) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (5): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.129) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (6): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.143) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (7): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.157) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (8): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.171) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (9): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.186) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (10): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.200) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (11): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.214) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (12): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.229) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (13): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.243) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (14): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.257) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (15): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.271) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (16): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.286) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) (17): SwinTransformerBlock( (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (qkv): Linear(in_features=512, out_features=1536, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=512, out_features=512, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): DropPath(drop_prob=0.300) (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=512, out_features=2048, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=2048, out_features=512, bias=True) (drop): Dropout(p=0.0, inplace=False) ) ) ) ) )
-
class BasicLayer(nn.Module): def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) w_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition( img_mask, self.window_size ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0) ) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) # (5,9216,128) -> (5,9216,128) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W
-
-
The workflow within
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) # clip_images: (5, 3, 336, 336), clip_features: (5, 577, 768)
:-
class VisualTransformer(nn.Module): def forward(self, x: torch.Tensor, dense=False): # (5,3,336,336) x = self.conv1(x) # shape = [5, 1024, 24, 24] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [5, 1024, 576(24x24)] x = x.permute(0, 2, 1) # shape = [5, 576, 1024] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [5, 576+1, 1024] if dense and (x.shape[1] != self.positional_embedding.shape[0]): x = x + self.resized_pos_embed(self.input_resolution, x.shape[1]).to(x.dtype) else: x = x + self.positional_embedding.to(x.dtype) # shape = [5, 577, 1024] x = self.ln_pre(x) # shape = [5, 577, 1024] x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, dense) x = x.permute(1, 0, 2) # LND -> NLD shape = [5, 577, 1024] if dense: x = self.ln_post(x[:, :, :]) else: x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj # shape -> [5, 577, 768] return x # shape = [5, 577, 768]
-
-
The workflow within
outputs = self.sem_seg_head(clip_features, features)
:-
class CATSegHead(nn.Module): def forward(self, features, guidance_features): """ Arguments: img_feats: (B, C, HW) affinity_features: (B, C, ) """ # features: (5,577,768) -> (5,768,24,24) img_feat = rearrange(features[:, 1:, :], "b (h w) c->b c h w", h=self.feature_resolution[0], w=self.feature_resolution[1]) return self.predictor(img_feat, guidance_features)
-
class CATSegPredictor(nn.Module): # self.transformer -> Aggregator! def forward(self, x, vis_guidance): vis = [vis_guidance[k] for k in vis_guidance.keys()][::-1] # text: (150, 80, 768) text = self.text_features if self.training else self.text_features_test # text -> (5, 150, 80, 768) text = text.repeat(x.shape[0], 1, 1, 1) out = self.transformer(x, text, vis) # This Aggregator part: below👇🏻 return out
-
text_feats
: (5,150,80,768),img_feats
: (5,768,24,24)class Aggregator(nn.Module): def __init__(self, text_guidance_dim=512, text_guidance_proj_dim=128, appearance_guidance_dim=512, appearance_guidance_proj_dim=128, decoder_dims = (64, 32), decoder_guidance_dims=(256, 128), decoder_guidance_proj_dims=(32, 16), num_layers=4, nheads=4, hidden_dim=128, pooling_size=(6, 6), feature_resolution=(24, 24), window_size=12, attention_type='linear', prompt_channel=80, ) -> None: super().__init__() self.num_layers = num_layers self.hidden_dim = hidden_dim self.layers = nn.ModuleList([ AggregatorLayer( hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim, nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type ) for _ in range(num_layers) ]) self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3) self.guidance_projection = nn.Sequential( nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) if appearance_guidance_dim > 0 else None self.text_guidance_projection = nn.Sequential( nn.Linear(text_guidance_dim, text_guidance_proj_dim), nn.ReLU(), ) if text_guidance_dim > 0 else None self.decoder_guidance_projection = nn.ModuleList([ nn.Sequential( nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims) ]) if decoder_guidance_dims[0] > 0 else None self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0]) self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1]) self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1) #---------------------------------------------------------------------------------# def feature_map(self, img_feats, text_feats): img_feats = F.normalize(img_feats, dim=1) # B C H W img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1]) text_feats = F.normalize(text_feats, dim=-1) # B T P C text_feats = text_feats.mean(dim=-2) text_feats = F.normalize(text_feats, dim=-1) # B T C text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1]) return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W def correlation(self, img_feats, text_feats): img_feats = F.normalize(img_feats, dim=1) # (5,768,24,24) text_feats = F.normalize(text_feats, dim=-1) # (5,150,80,768) corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats) return corr # corr: (5,80,150,24,24) def corr_embed(self, x): B = x.shape[0] # x: (5,80,150,24,24) -> (750, 80, 24, 24) corr_embed = rearrange(x, 'B P T H W -> (B T) P H W') # x: (750, 80, 24, 24) -> (750, 128, 24, 24) corr_embed = self.conv1(corr_embed) # x: (750, 128, 24, 24) -> (5, 128, 150, 24, 24) corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) return corr_embed def corr_projection(self, x, proj): corr_embed = rearrange(x, 'B C T H W -> B T H W C') corr_embed = proj(corr_embed) corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W') return corr_embed def upsample(self, x): B = x.shape[0] corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True) corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B) return corr_embed def conv_decoder(self, x, guidance): B = x.shape[0] corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') corr_embed = self.decoder1(corr_embed, guidance[0]) corr_embed = self.decoder2(corr_embed, guidance[1]) corr_embed = self.head(corr_embed) corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B) return corr_embed def forward(self, img_feats, text_feats, appearance_guidance): """ Arguments: img_feats: (B, C, H, W) text_feats: (B, T, P, C) apperance_guidance: tuple of (B, C, H, W) """ # text_feats: (5,150,80,768), img_feats: (5,768,24,24) corr = self.correlation(img_feats, text_feats) # corr: (5,80,150,24,24) corr_embed = self.corr_embed(corr) projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None] if self.guidance_projection is not None: # projected_guidance: (5,128,24,24) projected_guidance = self.guidance_projection(appearance_guidance[0]) if self.decoder_guidance_projection is not None: # 见下图👇🏻 projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])] if self.text_guidance_projection is not None: # (5,150,80,768) -> (5,150,768) text_feats = text_feats.mean(dim=-2) text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) # (5,150,768) -> (5,150,128) projected_text_guidance = self.text_guidance_projection(text_feats) # corr_embed: (5,80,150,24,24) -> (5,80,150,24,24) -> (5,80,150,24,24) for layer in self.layers: corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance) # corr_embed: (5,80,150,24,24), 见最下面👇🏻 # logit: (5,150,96,96) logit = self.conv_decoder(corr_embed, projected_decoder_guidance) return logit
![image-20240401182112928](/Users/biboyqg/Library/Application Support/typora-user-images/image-20240401182112928.png)
for layer in self.layers: corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance):
with
self.layers
’s structure as follow:ModuleList( (0): AggregatorLayer( (swin_block): SwinTransformerBlockWrapper( (block_1): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): Identity() (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (drop1): Dropout(p=0.0, inplace=False) (fc2): Linear(in_features=512, out_features=128, bias=True) (drop2): Dropout(p=0.0, inplace=False) ) ) (block_2): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): Identity() (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (drop1): Dropout(p=0.0, inplace=False) (fc2): Linear(in_features=512, out_features=128, bias=True) (drop2): Dropout(p=0.0, inplace=False) ) ) (guidance_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) (attention): ClassTransformerLayer( (pool): AvgPool2d(kernel_size=[1, 1], stride=[1, 1], padding=0) (attention): AttentionLayer( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attention): LinearAttention() ) (MLP): Sequential( (0): Linear(in_features=128, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=128, bias=True) ) (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) (1): AggregatorLayer( (swin_block): SwinTransformerBlockWrapper( (block_1): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): Identity() (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (drop1): Dropout(p=0.0, inplace=False) (fc2): Linear(in_features=512, out_features=128, bias=True) (drop2): Dropout(p=0.0, inplace=False) ) ) (block_2): SwinTransformerBlock( (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (attn): WindowAttention( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=128, out_features=128, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) (softmax): Softmax(dim=-1) ) (drop_path): Identity() (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=128, out_features=512, bias=True) (act): GELU(approximate='none') (drop1): Dropout(p=0.0, inplace=False) (fc2): Linear(in_features=512, out_features=128, bias=True) (drop2): Dropout(p=0.0, inplace=False) ) ) (guidance_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) (attention): ClassTransformerLayer( (pool): AvgPool2d(kernel_size=[1, 1], stride=[1, 1], padding=0) (attention): AttentionLayer( (q): Linear(in_features=256, out_features=128, bias=True) (k): Linear(in_features=256, out_features=128, bias=True) (v): Linear(in_features=128, out_features=128, bias=True) (attention): LinearAttention() ) (MLP): Sequential( (0): Linear(in_features=128, out_features=512, bias=True) (1): ReLU() (2): Linear(in_features=512, out_features=128, bias=True) ) (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) )
and data flow in each
AggregatorLayer
is as follow (the change of shape is the same for two layers):for layer in self.layers: # corr_embed: (5,128,150,24,24) # projected_guidance: (5,128,24,24) # projected_text_guidance: (5,150,128) corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance)
class AggregatorLayer(nn.Module): def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None: super().__init__() self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size) self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size) def forward(self, x, appearance_guidance, text_guidance): """ Arguments: x: B C T H W """ # x: (5,128,150,24,24) # appearance_guidance: (5,128,24,24) # text_guidance: (5,150,128) # x: (5,128,150,24,24) -> (5,128,150,24,24) x = self.swin_block(x, appearance_guidance) # x: (5,128,150,24,24) -> (5,128,150,24,24) x = self.attention(x, text_guidance) return x
For
SwinTransformerBlockWrapper
:class SwinTransformerBlockWrapper(nn.Module): def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5): super().__init__() self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0) self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2) self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None def forward(self, x, appearance_guidance): """ Arguments: x: B C T H W appearance_guidance: B C H W """ B, C, T, H, W = x.shape # x: (5,128,150,24,24) -> (750,576,128) x = rearrange(x, 'B C T H W -> (B T) (H W) C') if appearance_guidance is not None: # appearance_guidance: (5,128,24,24) -> (750,576,128) -> (750,576,128) appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T)) # x: (750,576,128) -> (750,576,128) x = self.block_1(x, appearance_guidance) # x: (750,576,128) -> (750,576,128) x = self.block_2(x, appearance_guidance) # x: (750,576,128) -> (5,128,150,24,24) x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W) return x
In
SwinTransformerBlock
:class SwinTransformerBlock(nn.Module): def forward(self, x, appearance_guidance): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) # x: (750, 576, 128) -> (750, 24, 24, 128) x = x.view(B, H, W, C) if appearance_guidance is not None: # appearance_guidance: (750, 576, 128) -> (750, 24, 24, 128) appearance_guidance = appearance_guidance.view(B, H, W, -1) # x: (750, 24, 24, 128) -> (750, 24, 24, 256) x = torch.cat([x, appearance_guidance], dim=-1) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) # x: (750,576,128) return x
For
self.attention = ClassTransformerLayer
:class ClassTransformerLayer(nn.Module): def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None: super().__init__() self.pool = nn.AvgPool2d(pooling_size) self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type) self.MLP = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(), nn.Linear(hidden_dim * 4, hidden_dim) ) self.norm1 = nn.LayerNorm(hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim) def pool_features(self, x): """ Intermediate pooling layer for computational efficiency. Arguments: x: B, C, T, H, W """ B = x.size(0) # x: (5,128,150,24,24) x = rearrange(x, 'B C T H W -> (B T) C H W') x = self.pool(x) x = rearrange(x, '(B T) C H W -> B C T H W', B=B) # x: (5,128,150,24,24) return x def forward(self, x, guidance): """ Arguments: x: B, C, T, H, W guidance: B, T, C """ B, _, _, H, W = x.size() # x: (5,128,150,24,24) # x_pool: (5,128,150,24,24) x_pool = self.pool_features(x) *_, H_pool, W_pool = x_pool.size() # x_pool: (5,128,150,24,24) -> (2880,150,128) x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C') if guidance is not None: # guidance: (5,150,128) -> (2880,150,128) guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool) # x_pool: (2880,150,128) x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # 见下面👇🏻 x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP # x_pool: (750,128,24,24) x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool) # x_pool: (750,128,24,24) x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True) x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B) # x: (5,128,150,24,24) x = x + x_pool # Residual return x
For
self.attention(self.norm1(x_pool), guidance)
:class AttentionLayer(nn.Module): def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'): super().__init__() self.nheads = nheads self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim) self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim) self.v = nn.Linear(hidden_dim, hidden_dim) if attention_type == 'linear': self.attention = LinearAttention() elif attention_type == 'full': self.attention = FullAttention() else: raise NotImplementedError def forward(self, x, guidance): """ Arguments: x: B, L, C guidance: B, L, C """ # q,k,v: (2880,150,128) q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x) k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x) v = self.v(x) # q,k,v: (2880,150,4,32) q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads) k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads) v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads) # out: (2880,150,4,32) out = self.attention(q, k, v) # out: (2880,150,4,32) -> (2880,150,128) out = rearrange(out, 'B L H D -> B L (H D)') return out
For
self.conv_decoder(corr_embed, projected_decoder_guidance)
:def conv_decoder(self, x, guidance): B = x.shape[0] # corr_embed: (750,128,24,24) corr_embed = rearrange(x, 'B C T H W -> (B T) C H W') # corr_embed: (750,64,48,48) corr_embed = self.decoder1(corr_embed, guidance[0]) # corr_embed: (750,32,96,96) corr_embed = self.decoder2(corr_embed, guidance[1]) # corr_embed: (750,1,96,96) corr_embed = self.head(corr_embed) # corr_embed: (5,150,96,96) corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B) return corr_embed class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, guidance_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x, guidance=None): # x: (750,128,24,24) -> (750,96,48,48) x = self.up(x) if guidance is not None: T = x.size(0) // guidance.size(0) # guidance: (5,32,48,48) -> (750,32,48,48) guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T) # x: (750,96,48,48) -> (750,128,48,48) x = torch.cat([x, guidance], dim=1) # x: (750,128,48,48) -> (750,64,48,48) return self.conv(x)
-
2. Unknown stuff
- Loss computation
- GroupNorm