0%

transformer在cv中应用

基础transformer讲解

Position Encoding

前面是没有考虑位置信息的,举个例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn


m = nn.MultiheadAttention(embed_dim=2, num_heads=1)
# embed_dim=2:指定了输入的嵌入维度(embedding dimension),即每个输入样本的特征维度为2,说白了就是图像的列数
# num_heads=1:指定了注意力头的数量,即将输入特征分成几份进行注意力计算。在这个例子中,只使用了1个注意力头。
t1 = [[[1., 2.], # q1, k1, v1
[2., 3.], # q2, k2, v2
[3., 4.]]] # q3, k3, v3

t2 = [[[1., 2.], # q1, k1, v1
[3., 4.], # q3, k3, v3
[2., 3.]]] # q2, k2, v2

q, k, v = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result1: \n", m(q, k, v))

q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2: \n", m(q, k, v))

输出结果

发现对b1没有影响

所以

Vision Transformer模型详解

参考文章一

参考文章二