1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| import torch import torch.nn as nn
torch.manual_seed(2021)
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position): super().__init__() self.num_units = num_units self.max_relative_position = max_relative_position self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k): range_vec_q = torch.arange(length_q) range_vec_k = torch.arange(length_k) distance_mat = range_vec_k[None, :] - range_vec_q[:, None] distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) final_mat = distance_mat_clipped + self.max_relative_position final_mat = torch.LongTensor(final_mat) embeddings = self.embeddings_table[final_mat]
return embeddings
class MultiHeadAttentionLayer(nn.Module): def __init__(self, hid_dim, n_heads, dropout, device): super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim self.n_heads = n_heads self.head_dim = hid_dim // n_heads self.max_relative_position = 2
self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position) self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)
self.fc_q = nn.Linear(hid_dim, hid_dim) self.fc_k = nn.Linear(hid_dim, hid_dim) self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask=None): # query = [batch size, query len, hid dim] # key = [batch size, key len, hid dim] # value = [batch size, value len, hid dim] batch_size = query.shape[0] len_k = key.shape[1] len_q = query.shape[1] len_v = value.shape[1]
query = self.fc_q(query) key = self.fc_k(key) value = self.fc_v(value)
r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) # q对k元素的attention
r_k2 = self.relative_position_k(len_q, len_k) attn2 = torch.einsum('bhqe,qke->bhqk', r_q1, r_k2) # q对k位置的attention attn = (attn1 + attn2) / self.scale if mask is not None: attn = attn.masked_fill(mask == 0, -1e10) attn = self.dropout(torch.softmax(attn, dim=-1)) r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) weight1 = torch.matmul(attn, r_v1) # qk对v元素的attention r_v2 = self.relative_position_v(len_q, len_v) weight2 = torch.einsum('bhav,ave->bhae', attn, r_v2) # qk对v位置的attention x = weight1 + weight2 x = x.permute(0, 2, 1, 3).contiguous() x = x.view(batch_size, -1, self.hid_dim) x = self.fc_o(x) return x
if __name__ == '__main__': multiHeadAttentionLayer = MultiHeadAttentionLayer(128, 8, 0.5, 'cpu') x = torch.randn(4, 43, 128) result = multiHeadAttentionLayer(x, x, x) print(result) # x = torch.randn(64, 8, 43, 16) # y = torch.randn(43, 43, 16) # print(torch.einsum('bhqe,qke->bhqk', [x, y]))
|