transformer中共使用三个mask: src_mask: 用于encoder中当句子长度不一时,需要将所有的句子填充至相同的长度。因此在求Q和K的相关性时,由于Q和K在encoder中相等,所以src_mask最后表现为右边和下边遮挡(填充1)的矩阵; target_mask: 用于使decoder中前面词无法使用到后面词的信息,并且也需要考虑padding。在训练decoder...
那么我们可以这样设置src_mask: 该src_mask实现对输入序列中第2个位置的完全遮挡 按照咱们目前的理解,现在TransformerEncoderLayer输出序列的所有位置都看不到输入序列中第2个位置的信息了。如果咱们用TransformerEncoderLayer输出序列的第3个位置去预测输入序列中被遮挡掉的第2个位置,预测的正确率应该接近1/3(相当于在...
def init_weights(self):# 初始化权重initrange =0.1self.encoder.weight.data.uniform_(-initrange, initrange)self.decoder.bias.data.zero_()self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, src, src_mask):# 前向传播src = self.enc...
self.src_mask = (src != pad).unsqueeze(-2) if trg is not None: self.trg = trg[:, :-1] self.trg_y = trg[:, 1:] self.trg_mask = \ self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum() @staticmethod def make_std_mask(tgt, pad): "Creat...
x = layer(x, memory, src_mask, tgt_mask) returnself.norm(x) 4.2 解码器层 每个解码器层由三个子层连接结构组成,第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接,第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接,第三...
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) seq_length = tgt.size(1) nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool() tgt_mask = tgt_mask & nopeak_mask return src_mask, tgt_mask ...
| (batch_size, num_heads, src_seq_len, d_k) * (batch_size, num_heads, src_seq_len, d_k).transpose(-2, -1) --> (batch_size, num_heads, src_seq_len, src_seq_len) | - 应用掩码 `src_mask` | (batch_size, num_heads, src_seq_len, src_seq_len) * (batch_size, 1, ...
(x, x, x, tgt_mask))# 这里使用的是 Self-Attention 机制,其实 m 是encoder的输出,x是decoder第一部分的输出,# 因为上面一部分的输出中, 未被预测的单词的 query 其实是 0(padding), 那么在这里可以直接使用 src_maskx = self.sublayer[1](x,lambdax: self.src_attn(x, m, m, src_mask))# ...
classEncoderDecoder(nn.Module):defencode(self, src, src_mask):# 先对输入进行embedding,然后再经过encoderreturnself.encoder(self.src_embed(src), src_mask) 2.2 输入层embedding 原始文本经过embedding层进行向量化,它包括token embedding和position embedding两层。
defforward(self,src,tgt,src_mask,tgt_mask):"Take in and process masked src and target sequences."returnself.decode(self.encode(src,src_mask),src_mask,tgt,tgt_mask)defencode(self,src,src_mask):returnself.encoder(self.src_embed(src),src_mask)defdecode(self,memory,src_mask,tgt,tgt_mask...