새소식

딥러닝/자연어 처리

LSTM sequence-to-sequence with attention

  • -

 

Encoder

source 문장을 압축한 context vecotor를 decoder에 넘겨준다.

Encoder 자체만 놓고 보면 non-auto-regressive task이므로 Bi-directional RNN을 사용 가능하다.

 

import torch import torch.nn as nn from torch.nn.utils.rnn import pack_padded_sequecne as pack from torch.nn.utils.rnn import pad_packed_sequence as unpack class Encoder(nn.Module): """ :input: Embedding tensor :return: y, h """ def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2): super(Encoder, self).__init__() self.rnn = nn.LSTM( word_vec_size, int(hidden_size / 2), # bidirectional num_layers=n_layers, dropout = dropout_p, bidirectional = True, batch_first = True, ) def forward(self, emb): # |emb| = (batch_size, length, word_vec_size) if isinstance(emb, tuple): x, lengths = emb x = pack(x, lengths.tolist(), batch_first = True) else: x = emb y, h = self.rnn(x) # |y| = (batch_size, length, hidden_size) # |h[0]| = (num_layers*2, batch_size, hidden_size/2) #LSTM이므로 (hidden state, cell state) 튜플이 나오기때문 if isinstance(emb, tuple): y, _ = unpack(y, batch_first=True) return y, h

✅ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

더보기

https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html

import torch import torch.nn as nn from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence a = [torch.tensor([1,2,3]), torch.tensor([3,4])] b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True) b >>> tensor([[1, 2, 3], [3, 4, 0]]) torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,2]) >>> PackedSequence(data=tensor([1, 3, 2, 4, 3]), batch_sizes=tensor([2, 2, 1]), sorted_indices=None, unsorted_indices=None)

 


Decoder

Encoder로부터 문장을 압축한 context vector를 바탕으로 문장을 생성한다.

Auto-regressive task이므로 Bi-directional RNN을 사용할 수 없다.

class Decoder(nn.Module): def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2): super(Decoder, self).__init__() self.lstm = nn.LSTM( word_vec_size + hidden_size, hidden_size, num_layers = n_layers, dropout = dropout_p, bidirectional = False, batch_first = True, ) def forward(self, emb_t, h_t_1_tilde, h_t_1): # h_t_1 => h_(t-1) # |emb_t| = (bs, 1, word_vec_size) # |h_t_1_tilde| = (bs, 1, hidden_size) # |h_t_1[0]| = (n_layers, bs, hidden_size) batch_size = emb_t.size(0) hidden_size = h_t_1[0].size(-1) if h_t_1_tilde is None: # if it is first time step h_t_1_tilde = emb_t.new(batch_size, 1, hidden_size).zero_() # input feeding x = torch.cat([emb_t, h_t_1_tilde], dim=-1) y, h = self.lstm(x, h_t_1) return y, h

 

Generator

Decoder의 hidden state를 받아 현재 time-step의 출력 token에 대한 확률 분포를 반환한다.

단어를 선택하는 문제이므로 cross entropy loss를 통해 최적화한다.

class Generator(nn.Module): def __init__(self, hidden_size, output_size): super(Generator, self).__init__() self.output = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x): # |x| = (bs, len, hidden_size) y = self.softmax(self.output(x)) # |y| = (bs, len, output_size) return y # log probability

 

 

 


Attention

seq2seq 모델은 하나의 고정된 크기의 벡터(context vector)에 모든 정보를 압축하다보니 정보 손실이 발생한다.

이런 문제를 해결하기 위해 Attention이 등장하였다.

어텐션의 기본 아이디어는 디코더에서 출력 단어를 예측하는 매 시점(time step)마다, 인코더에서의 전체 입력 문장을 다시 한 번 참고한다. 단, 전체 입력 문장을 전부 다 동일한 비율로 참고하는 것이 아니라, 해당 시점에서 예측해야할 단어와 연관이 있는 입력 단어 부분을 좀 더 집중(attention)해서 보게 된다.

 

미분 가능한 Key-Value function이다.

Query와 Key의 유사도에 따라 Value를 반환한다.

현재 시점 decoder의 output으로 다음 output을 예측할 때 현재 시점 output과 인코더들의 output들간의 유사도를 내적을 통해 구한 뒤 해당 계산값을 반영하여 다음 decoder output의 예측의 성능을 높인다.

즉, Decoder의 hidden state의 한계로 인한 부족한 정보를 직접 encoder에 조회하여 예측에 필요한 정보를 얻어오는 과정이다.

 

  • Query: 현재 time-step의 decoder의 output
  • Keys: 모든 time-step 별 encoder output
  • Values: 모든 time-step 별 encoder output

⇒ Attention을 통해 RNN의 hidden state의 한계(context vector에 모든 정보를 담기 어려움)를 극복하고 더 긴 길이의 입출력에 대처할 수 있다.

 

 

class Attention(nn.Module): # Query: 현재 time-step의 decoder의 output # Keys : 모든 time-step별 enocder output # Values : 모든 time-step 별 encoder output def __init__(self, hidden_size): super(Attention, self).__init__() self.linear = nn.Linear(hidden_size, hidden_size, bias=False) self.softmax = nn.Softmax(dim = -1) def forward(self, h_src, h_tgt, mask=None): # |h_src| = (bs, len, hidden_size) # |h_tgt| = (bs, 1, hidden_size) # |mask| = (bs, len) query = self.linear(h_tgt) # |query| = (bs, 1, hidden_size) weight = torch.bmm(query, h_src.transpose(1,2)) # |weight| = (bs, 1, hs) x (bs, hs, len) = (bs, 1, len) if mask is not None: weight.masked_fill_(mask.unsqueeze(1), -float('inf')) weight = self.softmax(weight) context_vector = torch.bmm(weight, h_src) # |context_vector| = (bs, 1, len) X (bs, len, hs) = (bs, 1, hs) return context_vector

 

Masking

Mini-batch내 문장 구성에 따라 <pad>가 동적으로 생성된다.

이 때, <pad>의 hidden state에는 attention weight가 할당되면 안된다.

따라서 Key와 Query의 dot product 이후에 masking을 통해 <pad> 위치의 값을 -∞로 설정

→ softmax 시 <pad>에는 0으로 할당

 

Input Feeding

 

  • Softmax의 결과값은 continuous한 벡터의인데 argmax로 샘플링(원핫벡터)을 취하는 과정에서 정보손실이 발생한다.
    → softmax 이전 값과 word embedding 벡터를 concat해서 정보손실을 최소화
  • Teacher Forcing으로 인한 학습/인퍼런스 사이의 괴리를 최소화
    : Auto regressive task를 feed forward할 때는 보통 이전 time-step의 output이 현재 time-step의 입력이 되나, train 할 때는 teacher forcing을 통해 원래 정답을 넣어주어 학습하기 때문에 학습과 인퍼런스 사이에 괴리가 생기게 된다.

 

 

class Seq2Seq(nn.Module): def __init__( self, input_size, word_vec_size, hidden_size, output_size, n_layers=4, dropout_p=0.2, ): self.input_size = input_size self.word_vec_size = word_vec_size self.hidden_size = hidden_size self.output_size = output_size self.n_layers = n_layers self.dropout_p = dropout_p super(Seq2Seq, self).__init__() self.emb_src = nn.Embedding(input_size, word_vec_size) self.emb_dec = nn.Embedding(output_size, word_vec_size) self.encoder = Encoder( word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p ) self.decoder = Decoder( word_vec_size, hidden_size, n_layers=n_layers, dropout_p=dropout_p ) self.attention = Attention(hidden_size) self.concat = nn.Linear(hidden_size * 2, hidden_size) self.tanh = nn.Tanh() self.generator = Generator(hidden_size, output_size) def generate_mask(self, x, length): mask = [] max_length = max(length) for l in length: if max_length - l > 0: mask += [ torch.cat( [x.new_ones(1, l).zero_(), x.new_ones(1, (max_length - l))], dim=-1, ) ] else: mask += [x.new_ones(1, l).zero_()] mask = torch.cat(mask, dim=0).bool() return mask def merge_encoder_hiddens(self, encoder_hiddens): # Merge bidirectional to uni-directional # (n_layers * 2, bs, hidden_size /2) => (n_layers, bs, hidden_size) h_0_tgt, c_0_tgt = encoder_hiddens batch_size = h_0_tgt.size(1) h_0_tgt = ( h_0_tgt.transpose(0, 1) .contiguous() .view(batch_size, -1, self.hidden_size) .transpose(0, 1) .contiguous() ) c_0_tgt = ( c_0_tgt.transpose(0, 1) .contiguous() .view(batch_size, -1, self.hidden_size) .transpose(0, 1) .contiguous() ) return h_0_tgt, c_0_tgt def forward(self, src, tgt): # |src| = (bs, n) # |tgt| = (bs, m) batch_size = tgt.size(0) mask = None x_length = None if isinstance(src, tuple): x, x_length = src mask = self.generate_mask(x, x_length) # |mask| = (bs, length) else: x = src emb_src = self.emb_src(x) # |emb_src| = (bs, length, word_vec_size) h_src, h_0_tgt = self.encoder((emb_src, x_length)) # |h_src| = (bs, length, hidden_size) # |h_0_tgt| = (n_layers * 2, batch_size, hidden_size/2) # last hidden state of the encoder would be an initial hidden state of decoder h_0_tgt = self.merge_enocder_hiddens(h_0_tgt) emb_tgt = self.emb_dec(tgt) # |emb_tgt| = (bs, length(m), word_vec_size) h_tilde = [] h_t_tilde = None decoder_hidden = h_0_tgt # Run decoder time by time step for t in range(tgt.size(1)): # Teacher forcing => take each input from training set, not from the last time step's output emb_t = emb_tgt[:, t, :].unsqueeze(1) # |emb_t| = (bs, 1, word_vec_size) decoder_output, decoder_hidden = self.decoder( emb_t, h_t_tilde, decoder_hidden ) # |decoder_output| = (bs, 1, hs) # |decoder_hidden| = (n_layers, 1, hs) context_vector = self.attention(h_src, decoder_output, mask) # |context_vector| = (bs, 1, hs) h_t_tilde = self.tanh( self.concat(torch.cat([decoder_output, context_vector], dim=-1)) ) # |h_t_tilde| = (bs, 1, hs) h_tilde += [h_t_tilde] h_tilde = torch.cat(h_tilde, dim=1) # |h_tilde| = (bs, length, hs) y_hat = self.generator(h_tilde) # |y_hat| = (bs, length, output_size) return y_hat

 

728x90
Contents