1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): super(EncoderDecoder, self).__init__() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator def forward(self, src, tgt, src_mask, tgt_mask): memory = self.encode(src, src_mask) res = self.decode(memory, src_mask, tgt, tgt_mask) return res def encode(self, src, src_mask): src_embedds = self.src_embed(src) return self.encoder(src_embedds, src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): target_embedds = self.tgt_embed(tgt) return self.decoder(target_embedds, memory, src_mask, tgt_mask)
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): c = copy.deepcopy attn = MultiHeadedAttention(h, d_model) ff = PositionwiseFeedForward(d_model, d_ff, dropout) position = PositionalEncoding(d_model, dropout) model = EncoderDecoder( Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), nn.Sequential(Embeddings(d_model, src_vocab), c(position)), nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), Generator(d_model, tgt_vocab))
for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) return model
|