我自己用来做 2D image 之间的 Attention 的玩具,把 Linear 换成了 Conv2d,免得还得纬度转来转去,其他没变。上一个版本有过任务验证,起码能收敛。
https://github.com/qinjian623/pytorch_toys/blob/master/utils/modules.py
class Conv2dCrossAttention(nn.Module):
r"""
Attention from 2D vk space to 2D query space. If they are the same, this module will be a self-attention.
VK space: (bs, D_vk_space, vk_h, vk_w)
Q space: (bs, D_query_space, query_h, query_w)
Output: (bs, D_output, query_h, query_w)
"""
def __init__(self,
D_query_space: int,
D_vk_space: int,
D_emb: int,
D_output: int):
super(Conv2dCrossAttention, self).__init__()
self.query = nn.Conv2d(D_query_space, D_emb, 1)
self.key = nn.Conv2d(D_vk_space, D_emb, 1)
self.value = nn.Conv2d(D_vk_space, D_output, 1)
self.scale = sqrt(D_emb)
self.init() # Special init.
def init(self):
# Without ReLU, xavier is better than kaiming, maybe.
xavier_uniform_(self.query.weight)
xavier_uniform_(self.key.weight)
xavier_uniform_(self.value.weight)
def forward(self, image: Tensor, query_space: Tensor):
# Boring QKV
query = self.query(query_space)
key = self.key(image)
value = self.value(image)
bs, bc, bh, bw = query_space.shape
_, ic, ih, iw = image.shape
# Reshape to MLP style
# Eliminate spatial dim
# Permute to (batch size, seq length, emb) x (bs, emb, seq)
query = query.reshape(bs, -1, bh * bw)
key = key.reshape(bs, -1, ih * iw).permute(0, 2, 1)
value = value.reshape(bs, -1, ih * iw)
# Scores and outputs
scores = torch.bmm(key, query) / self.scale
weights = F.softmax(scores, dim=-2)
# Weighted sum of value.
outputs = torch.bmm(value, weights)
# Back to CNN style
outputs = outputs.reshape(bs, -1, bh, bw)
return outputs