需求:完成一个Seq2Seq模型,实现往模型中输入一串数字,输出这串数字+0 例如:
输入:15925858456,输出:159258584560
1 2 3 4 import torchimport numpy as npimport torch.nn as nnimport torch.nn.functional as F
定义词袋 1 2 3 4 UNK_TAG = "UNK" PAD_TAG = "PAD" SOS_TAG = "SOS" EOS_TAG = "EOS"
1 2 word_list = list ('0123456789' ) + [UNK_TAG, PAD_TAG, SOS_TAG, EOS_TAG]
1 2 3 word2index_dict = {j: i for i, j in enumerate (word_list)} index2word_dict = {i: j for i, j in enumerate (word_list)}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 def word2index (sequence, word2index_dict=word2index_dict, max_len=None , add_sos=False , add_eos=False ): sequence = [word2index_dict[i] if i in word2index_dict else word2index_dict[UNK_TAG] for i in sequence] if add_sos: sequence = [word2index_dict[SOS_TAG]] + sequence if add_eos: sequence += [word2index_dict[EOS_TAG]] sequence_length = len (sequence) if max_len and sequence_length < max_len: sequence_index = [word2index_dict[PAD_TAG]]*max_len sequence_index[:sequence_length] = sequence sequence = sequence_index elif max_len and sequence_length > max_len: sequence = sequence[:max_len] if add_eos: sequence[-1 ] = word2index_dict[EOS_TAG] return sequence
1 word2index('256246' , word2index_dict, max_len=10 , add_sos=True , add_eos=True )
[12, 2, 5, 6, 2, 4, 6, 13, 11, 11]
1 2 3 def index2word (sequence, index2word_dict=index2word_dict ): return '' .join([index2word_dict[i] for i in sequence])
1 index2word([2 , 5 , 6 , 2 , 6 , 5 ], index2word_dict)
'256265'
数据集定义 自定义DataSet 1 2 3 import torchimport torch.nn as nnimport torch.nn.functional as F
1 from torch.utils.data import Dataset, DataLoader
1 2 3 4 5 6 7 8 9 10 11 12 13 class RandomNumDataset (Dataset ): def __init__ (self, total_data_size ): super (RandomNumDataset, self).__init__() self.total_data_size = total_data_size self.total_data = torch.randint(10000 , 1000000 , size=[total_data_size]) def __len__ (self ): return self.total_data_size def __getitem__ (self, idx ): x = str (self.total_data[idx].item()) y = x + '0' return x, y
1 random_num_dataset = RandomNumDataset(100000 )
1 random_num_dataset[10660 ]
('157461', '1574610')
自定义DataLoader 1 2 3 4 5 6 7 8 9 10 11 def collate_fn (batch ): """ 对DataLoader所生成的mini-batch进行后处理 """ X = [] Y = [] target = [] for i, (x, y) in enumerate (batch): X.append(torch.LongTensor(word2index(x, max_len=6 ))) Y.append(torch.LongTensor(word2index(y, max_len=8 , add_sos=True ))) target.append(torch.LongTensor(word2index(y, max_len=8 , add_eos=True ))) return torch.stack(X), torch.stack(Y), torch.stack(target)
1 train_dataloader = DataLoader(dataset=RandomNumDataset(100000 ), batch_size=batch_size, collate_fn=collate_fn, drop_last=True )
1 test_dataloader = DataLoader(dataset=RandomNumDataset(50000 ), batch_size=batch_size, collate_fn=collate_fn, drop_last=True )
模型搭建 编码器搭建 1 2 3 4 5 6 7 8 9 10 11 12 13 class Seq2SeqEncoder (nn.Module): """ 实现基于LSTM的编码器,也可以是其他类型的,如CNN、TransformerEncoder""" def __init__ (self, embedding_dim, hidden_size, source_vocab_size ): super (Seq2SeqEncoder, self).__init__() self.embedding_table = nn.Embedding(source_vocab_size, embedding_dim) self.lstm_layer = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=True ) def forward (self, input_ids ): input_sequence = self.embedding_table(input_ids) output_states, final_state = self.lstm_layer(input_sequence) return output_states, final_state
解码器搭建 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class Seq2SeqDecoder (nn.Module): def __init__ (self, embedding_dim, hidden_size, target_vocab_size, start_id, end_id ): super (Seq2SeqDecoder, self).__init__() self.embedding_table = nn.Embedding(target_vocab_size, embedding_dim) self.lstm_cell = nn.LSTMCell(embedding_dim, hidden_size) self.proj_layer = nn.Linear(hidden_size, target_vocab_size) self.target_vocab_size = target_vocab_size self.start_id = start_id self.end_id = end_id def forward (self, shifted_target_ids, final_encoder ): h_t, c_t = final_encoder h_t = h_t.squeeze(0 ) c_t = c_t.squeeze(0 ) shifted_target = self.embedding_table(shifted_target_ids) bs, target_length, embedding_dim = shifted_target.shape logits = torch.zeros(bs, target_length, self.target_vocab_size) for t in range (target_length): if t == 0 : decoder_input_t = shifted_target[:, t, :] else : target_id = logits[:, t-1 , :].argmax(axis=1 ) decoder_input_t = self.embedding_table(target_id.to(device)) h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t)) logits[:, t, :] = self.proj_layer(h_t) return logits def inference (self, final_encoder ): h_t, c_t = final_encoder h_t = h_t.squeeze(0 ) c_t = c_t.squeeze(0 ) batch_size = h_t.shape[0 ] target_id = self.start_id target_id = torch.stack([target_id]*batch_size, 0 ) result = [] while True : decoder_input_t = self.embedding_table(target_id) decoder_input_t = decoder_input_t.squeeze(1 ) h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t)) logits = self.proj_layer(h_t) target_id = logits.argmax(axis=1 ) result.append(target_id) if torch.all (target_id == self.end_id) or len (result) == 8 : print ('stop decoding!' ) break predicted_ids = torch.stack(result, dim=0 ) return predicted_ids
Seq2Seq模型搭建 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class Seq2SeqModel (nn.Module): def __init__ (self, embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id ): super (Seq2SeqModel, self).__init__() self.encoder = Seq2SeqEncoder(embedding_dim, hidden_size, source_vocab_size) self.decoder = Seq2SeqDecoder(embedding_dim, hidden_size, target_vocab_size, start_id, end_id) def forward (self, input_sequence_ids, shifted_target_ids ): encoder_states, final_encoder = self.encoder(input_sequence_ids) logits = self.decoder(shifted_target_ids, final_encoder) return logits def inference (self, input_sequence_ids ): encoder_states, final_encoder = self.encoder(input_sequence_ids) predicted_ids = self.decoder.inference(final_encoder) return predicted_ids
1 2 3 4 5 6 7 8 9 embedding_dim = 20 hidden_size = 16 start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long) end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) source_vocab_size = len (word_list) target_vocab_size = len (word_list) model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
Seq2SeqModel(
(encoder): Seq2SeqEncoder(
(embedding_table): Embedding(14, 20)
(lstm_layer): LSTM(20, 16, batch_first=True)
)
(decoder): Seq2SeqDecoder(
(embedding_table): Embedding(14, 20)
(lstm_cell): LSTMCell(20, 16)
(proj_layer): Linear(in_features=16, out_features=14, bias=True)
)
)
初始化模型权重参数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def init_network (model, method='xavier' , exclude='embedding' , seed=123 ): for name, w in model.named_parameters(): if exclude not in name: if 'weight' in name: if method == 'xavier' : nn.init.xavier_normal_(w) elif method == 'kaiming' : nn.init.kaiming_normal_(w) else : nn.init.normal_(w) elif 'bias' in name: nn.init.constant_(w, 0 ) else : pass
定义训练日志类 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class PrintLog : def __init__ (self, log_filename ): self.log_filename = log_filename def print_ (self, message, only_file=False , add_time=True , cost_time=None ): if add_time: current_time = time.strftime("%Y-%m-%d %H:%M:%S" , time.localtime()) message = f"Time {current_time} : {message} " if cost_time: message = f"{message} Cost Time: {cost_time} s" with open (self.log_filename, 'a' ) as f: f.write(message + '\n' ) if not only_file: print (message)
模型训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 def train (model, device, train_loader, criterion, optimizer, epoch, pring_log ): model.train() total = 0 start_time = time.time() for batch_idx, (encode_data, decode_data, target_data) in enumerate (train_loader): encode_data, decode_data, target_data = encode_data.to(device), decode_data.to(device), target_data.to(device) optimizer.zero_grad() output = model(encode_data, decode_data) loss = criterion(output.view(batch_size*8 , -1 ).to(device), target_data.view(-1 ).to(device)) loss.backward() optimizer.step() total += len (encode_data) progress = math.ceil(batch_idx / len (train_loader) * 50 ) print ("\rTrain epoch %d: %d/%d, [%-51s] %d%%" % (epoch, total, len (train_loader.dataset), '-' * progress + '>' , progress * 2 ), end='' ) end_time = time.time() epoch_cost_time = end_time - start_time print () pring_log.print_(f"Train epoch {epoch} , Loss {loss} " , cost_time=epoch_cost_time)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def test (model, device, test_loader, criterion, pring_log ): model.eval () test_loss = 0 correct = 0 with torch.no_grad(): for encode_data, decode_data, target_data in test_loader: encode_data, decode_data, target_data = encode_data.to(device), decode_data.to(device), target_data.to(device) output = model(encode_data, decode_data).view(batch_size*8 , -1 ) test_loss += criterion(output.to(device), target_data.view(-1 ).to(device)).item() pred = output.argmax(dim=1 , keepdim=True ).to(device) correct += pred.eq(target_data.view_as(pred).to(device)).sum ().item() test_loss /= len (test_loader.dataset) pring_log.print_('Test: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)' .format ( test_loss, correct, len (test_loader.dataset)*8 , 100. * correct / (len (test_loader.dataset)*8 ))) return test_loss
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 def run_train (model, epochs, train_loader, test_loader, device, resume="" , model_name="" ): log_filename = f"{model_name} _training.log" pring_log = PrintLog(log_filename) model.to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.001 , momentum=0.9 ) criterion = nn.CrossEntropyLoss() min_loss = torch.inf start_epoch = 0 delta = 1e-4 if resume: pring_log.print_(f'loading from {resume} ' ) checkpoint = torch.load(resume, map_location=torch.device("cuda:0" )) model.load_state_dict(checkpoint['model_state_dict' ]) optimizer.load_state_dict(checkpoint['optimizer_state_dict' ]) start_epoch = checkpoint['epoch' ] min_loss = checkpoint['loss' ] for epoch in range (start_epoch + 1 , start_epoch + epochs+1 ): train(model, device, train_loader, criterion, optimizer, epoch, pring_log) loss = test(model, device, test_loader, criterion, pring_log) if loss < min_loss and not torch.isclose(torch.tensor([min_loss]), torch.tensor([loss]), delta): pring_log.print_(f'Loss Reduce {min_loss} to {loss} ' ) min_loss = loss save_file = f'{model_name} _checkpoint_best.pt' torch.save({ 'epoch' : epoch+1 , 'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'loss' : loss }, save_file) pring_log.print_(f'Save checkpoint to {save_file} ' ) print ('----------------------------------------' )
1 train_dataloader = DataLoader(dataset=RandomNumDataset(100000 ), batch_size=batch_size, collate_fn=collate_fn, drop_last=True )
1 test_dataloader = DataLoader(dataset=RandomNumDataset(50000 ), batch_size=batch_size, collate_fn=collate_fn, drop_last=True )
1 2 3 4 5 epochs = 1000 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) run_train(model, epochs, train_dataloader, test_dataloader, device, model_name='Seq2SeqNormal' )
Train epoch 1: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 19-19-53 : Train epoch 1, Loss 2.432358741760254 Cost Time: 4.7405595779418945s
Time 2023-12-18 19-19-54 : Test: average loss: 0.0095, accuracy: 79791/400000 (20%)
Time 2023-12-18 19-19-54 : Loss Reduce inf to 0.009483050050735473
Time 2023-12-18 19-19-54 : Save checkpoint to Seq2SeqNormal_checkpoint_best.pt
----------------------------------------
......
Train epoch 509: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 20-12-07 : Train epoch 509, Loss 0.3511051535606384 Cost Time: 4.616323471069336s
Time 2023-12-18 20-12-08 : Test: average loss: 0.0013, accuracy: 352755/400000 (88%)
Time 2023-12-18 20-12-08 : Loss Reduce 0.0013051700329780578 to 0.0013038752526044846
Time 2023-12-18 20-12-08 : Save checkpoint to Seq2SeqNormal_checkpoint_best.pt
----------------------------------------
Train epoch 510: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 20-12-13 : Train epoch 510, Loss 0.3496689796447754 Cost Time: 4.619180202484131s
模型推理 1 dev_dataloader = DataLoader(dataset=RandomNumDataset(10 ), batch_size=10 , collate_fn=collate_fn, drop_last=True )
1 encode_data, decode_data, target_data = next (iter (dev_dataloader))
1 2 3 4 5 6 7 8 9 embedding_dim = 20 hidden_size = 16 start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long) end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) source_vocab_size = len (word_list) target_vocab_size = len (word_list) model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
1 2 3 resume = 'Seq2SeqNormal_checkpoint_best.pt' checkpoint = torch.load(resume, map_location=torch.device("cuda:0" )) model.load_state_dict(checkpoint['model_state_dict' ])
<All keys matched successfully>
1 2 3 4 5 6 7 8 9 10 11 def tensor2str (tensor ): str_list = tensor.numpy().tolist() for i in range (len (str_list)): pad_num = str_list[i].count(11 ) if pad_num > 0 : [str_list[i].remove(11 ) for _ in range (pad_num)] for j in range (len (str_list[i])): str_list[i][j] = str (str_list[i][j]) return ['' .join(i) for i in str_list]
['466760',
'826747',
'324833',
'507226',
'435957',
'766421',
'611993',
'190449',
'334029',
'301514']
1 tensor2str(model.inference(encode_data).T)
stop decoding!
['466660013',
'826747013',
'325833013',
'507226013',
'435957013',
'766421013',
'611993013',
'190449013',
'344009013',
'301514013']
其中13为结束标识符。
模型训练优化 Teacher forcing训练机制:在RNN的训练过程中,使用前一个预测的结果作为下一个step的输入,可能会导致一步错,步步错的结果,如果提高模型的收敛速度?
可以考虑在训练的过程中,把真实值作为下一步的输入,这样可以避免步步错的局面
同时在使用真实值的过程中,仍然使用预测值作为下一步的输入,两种输入随机使用
上述这种机制把它称为Teacher forcing,就像是一个指导老师,在每一步都会对我们的行为进行纠偏,从而达到在多次训练之后能够需要其中的规律
Decoder定义 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 class Seq2SeqDecoder (nn.Module): def __init__ (self, embedding_dim, hidden_size, target_vocab_size, start_id, end_id ): super (Seq2SeqDecoder, self).__init__() self.embedding_table = nn.Embedding(target_vocab_size, embedding_dim) self.lstm_cell = nn.LSTMCell(embedding_dim, hidden_size) self.proj_layer = nn.Linear(hidden_size, target_vocab_size) self.target_vocab_size = target_vocab_size self.start_id = start_id self.end_id = end_id def forward (self, shifted_target_ids, final_encoder ): h_t, c_t = final_encoder h_t = h_t.squeeze(0 ) c_t = c_t.squeeze(0 ) shifted_target = self.embedding_table(shifted_target_ids) bs, target_length, embedding_dim = shifted_target.shape logits = torch.zeros(bs, target_length, self.target_vocab_size) for t in range (target_length): use_teacher_forcing = torch.rand(1 )[0 ] > 0.5 if t == 0 or use_teacher_forcing: decoder_input_t = shifted_target[:, t, :] else : target_id = logits[:, t-1 , :].argmax(axis=1 ) decoder_input_t = self.embedding_table(target_id.to(device)) h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t)) logits[:, t, :] = self.proj_layer(h_t) return logits def inference (self, final_encoder ): h_t, c_t = final_encoder h_t = h_t.squeeze(0 ) c_t = c_t.squeeze(0 ) batch_size = h_t.shape[0 ] target_id = self.start_id target_id = torch.stack([target_id]*batch_size, 0 ) result = [] while True : decoder_input_t = self.embedding_table(target_id) decoder_input_t = decoder_input_t.squeeze(1 ) h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t)) logits = self.proj_layer(h_t) target_id = logits.argmax(axis=1 ) result.append(target_id) if torch.all (target_id == self.end_id) or len (result) == 8 : print ('stop decoding!' ) break predicted_ids = torch.stack(result, dim=0 ) return predicted_ids
模型定义 1 2 3 4 5 6 7 8 9 embedding_dim = 20 hidden_size = 16 start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long) end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) source_vocab_size = len (word_list) target_vocab_size = len (word_list) model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
模型初始化
模型训练 1 2 3 4 5 epochs = 510 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) run_train(model, epochs, train_dataloader, test_dataloader, device, model_name='Seq2SeqTeacherForcing' )
Train epoch 1: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 21-56-47 : Train epoch 1, Loss 2.4872562885284424 Cost Time: 4.550509929656982s
Time 2023-12-18 21-56-49 : Test: average loss: 0.0097, accuracy: 67954/400000 (17%)
Time 2023-12-18 21-56-49 : Loss Reduce inf to 0.009680559158325195
Time 2023-12-18 21-56-49 : Save checkpoint to Seq2SeqTeacherForcing_checkpoint_best.pt
......
----------------------------------------
Train epoch 509: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 22-51-40 : Train epoch 509, Loss 0.043971575796604156 Cost Time: 5.30586314201355s
Time 2023-12-18 22-51-41 : Test: average loss: 0.0002, accuracy: 398088/400000 (100%)
----------------------------------------
Train epoch 510: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 22-51-47 : Train epoch 510, Loss 0.039883047342300415 Cost Time: 5.410919666290283s
Time 2023-12-18 22-51-49 : Test: average loss: 0.0002, accuracy: 398152/400000 (100%)
----------------------------------------
模型推理 1 2 3 4 5 6 7 8 9 embedding_dim = 20 hidden_size = 16 start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long) end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) source_vocab_size = len (word_list) target_vocab_size = len (word_list) model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
1 2 3 resume = 'Seq2SeqTeacherForcing_checkpoint_best.pt' checkpoint = torch.load(resume, map_location=torch.device("cuda:0" )) model.load_state_dict(checkpoint['model_state_dict' ])
<All keys matched successfully>
['466760',
'826747',
'324833',
'507226',
'435957',
'766421',
'611993',
'190449',
'334029',
'301514']
1 tensor2str(model.inference(encode_data).T)
stop decoding!
['466760013',
'826747013',
'324833013',
'507226013',
'435957013',
'766421013',
'611993013',
'190449013',
'334029013',
'301514013']
其中13为结束标识符。