0%

Python数据科学_40_基于Seq2Seq的数字生成任务

需求:完成一个Seq2Seq模型,实现往模型中输入一串数字,输出这串数字+0

例如:

  • 输入:15925858456,输出:159258584560
1
2
3
4
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

定义词袋

1
2
3
4
UNK_TAG = "UNK" #未知词
PAD_TAG = "PAD" #填充词
SOS_TAG = "SOS" #句子的开始标识符【start of sequence】
EOS_TAG = "EOS" #句子的结束标识符【end of sequence】
1
2
# 定义词袋
word_list = list('0123456789') + [UNK_TAG, PAD_TAG, SOS_TAG, EOS_TAG]
1
2
3
# 定义词典
word2index_dict = {j: i for i, j in enumerate(word_list)} # 单词转化为数值
index2word_dict = {i: j for i, j in enumerate(word_list)} # 数值还原为单词
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 定义将词转化为数值的函数
def word2index(sequence, word2index_dict=word2index_dict, max_len=None, add_sos=False, add_eos=False):
sequence = [word2index_dict[i] if i in word2index_dict else word2index_dict[UNK_TAG] for i in sequence]
# 添加起始字符
if add_sos:
sequence = [word2index_dict[SOS_TAG]] + sequence
# 添加结束字符
if add_eos:
sequence += [word2index_dict[EOS_TAG]]
# 当指定了最大字符参数后,判断当前字符长度与最大字符长度大小关系
sequence_length = len(sequence)
if max_len and sequence_length < max_len:
# 当序列长度小于最大长度时,使用PAD填充
sequence_index = [word2index_dict[PAD_TAG]]*max_len
sequence_index[:sequence_length] = sequence
sequence = sequence_index
elif max_len and sequence_length > max_len:
# 当序列长度大于最大长度时,截断到最大长度
sequence = sequence[:max_len]
# 当大于最大字符长度截断时,可能会导致结束字符也被截断
if add_eos:
sequence[-1] = word2index_dict[EOS_TAG]
return sequence
1
word2index('256246', word2index_dict, max_len=10, add_sos=True, add_eos=True)
[12, 2, 5, 6, 2, 4, 6, 13, 11, 11]
1
2
3
# 定义数值列表还原为词的函数
def index2word(sequence, index2word_dict=index2word_dict):
return ''.join([index2word_dict[i] for i in sequence])
1
index2word([2, 5, 6, 2, 6, 5], index2word_dict)
'256265'

数据集定义

自定义DataSet

1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F
1
from torch.utils.data import Dataset, DataLoader
1
2
3
4
5
6
7
8
9
10
11
12
13
class RandomNumDataset(Dataset):
def __init__(self, total_data_size):
super(RandomNumDataset, self).__init__()
self.total_data_size = total_data_size
self.total_data = torch.randint(10000, 1000000, size=[total_data_size])

def __len__(self):
return self.total_data_size

def __getitem__(self, idx):
x = str(self.total_data[idx].item())
y = x + '0'
return x, y
1
random_num_dataset = RandomNumDataset(100000)
1
random_num_dataset[10660]
('157461', '1574610')

自定义DataLoader

1
2
3
4
5
6
7
8
9
10
11
def collate_fn(batch):
""" 对DataLoader所生成的mini-batch进行后处理 """
# batch维度:[batch_size, squence_length]
X = []
Y = []
target = []
for i, (x, y) in enumerate(batch):
X.append(torch.LongTensor(word2index(x, max_len=6)))
Y.append(torch.LongTensor(word2index(y, max_len=8, add_sos=True)))
target.append(torch.LongTensor(word2index(y, max_len=8, add_eos=True)))
return torch.stack(X), torch.stack(Y), torch.stack(target)
1
batch_size = 256
1
train_dataloader = DataLoader(dataset=RandomNumDataset(100000), batch_size=batch_size, collate_fn=collate_fn, drop_last=True)
1
test_dataloader = DataLoader(dataset=RandomNumDataset(50000), batch_size=batch_size, collate_fn=collate_fn, drop_last=True)

模型搭建

编码器搭建

1
2
3
4
5
6
7
8
9
10
11
12
13
class Seq2SeqEncoder(nn.Module):
""" 实现基于LSTM的编码器,也可以是其他类型的,如CNN、TransformerEncoder"""
def __init__(self, embedding_dim, hidden_size, source_vocab_size):
super(Seq2SeqEncoder, self).__init__()
self.embedding_table = nn.Embedding(source_vocab_size, embedding_dim)
self.lstm_layer = nn.LSTM(input_size=embedding_dim,
hidden_size=hidden_size,
batch_first=True)

def forward(self, input_ids):
input_sequence = self.embedding_table(input_ids)
output_states, final_state = self.lstm_layer(input_sequence)
return output_states, final_state

解码器搭建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Seq2SeqDecoder(nn.Module):
def __init__(self, embedding_dim, hidden_size, target_vocab_size, start_id, end_id):
super(Seq2SeqDecoder, self).__init__()
self.embedding_table = nn.Embedding(target_vocab_size, embedding_dim)
self.lstm_cell = nn.LSTMCell(embedding_dim, hidden_size)
self.proj_layer = nn.Linear(hidden_size, target_vocab_size)
self.target_vocab_size = target_vocab_size

self.start_id = start_id # 开始标识符
self.end_id = end_id # 结束标识符

def forward(self, shifted_target_ids, final_encoder):
h_t, c_t = final_encoder
# encoder中的LSTM是层维度为[num_layers, bs, hidden_size]
# 而这里使用的LSTM是Cell,执行的是单次LSTM计算,输入和返回的单元状态都为[bs, hidden_size]
# 所以需要将encoder中的单元状态num_layers维度删除
h_t = h_t.squeeze(0)
c_t = c_t.squeeze(0)
# 训练阶段调用
shifted_target = self.embedding_table(shifted_target_ids)

bs, target_length, embedding_dim = shifted_target.shape

logits = torch.zeros(bs, target_length, self.target_vocab_size) # 定义存储输出的容器
for t in range(target_length):
if t == 0:
# 下一次输入使用真实值
decoder_input_t = shifted_target[:, t, :] # [bs, embedding]
else:
#使用预测值
target_id = logits[:, t-1, :].argmax(axis=1)
decoder_input_t = self.embedding_table(target_id.to(device))

h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
# h_t: [bs, hidden_size], c_t: [bs, hidden_size]
logits[:, t, :] = self.proj_layer(h_t)
return logits

def inference(self, final_encoder):
h_t, c_t = final_encoder
h_t = h_t.squeeze(0)
c_t = c_t.squeeze(0)
batch_size = h_t.shape[0]
# 推理阶段使用
target_id = self.start_id
target_id = torch.stack([target_id]*batch_size, 0)
result = []

while True:
decoder_input_t = self.embedding_table(target_id)
decoder_input_t = decoder_input_t.squeeze(1)
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
logits = self.proj_layer(h_t)
target_id = logits.argmax(axis=1)

result.append(target_id)
# 一直到某个minibatch中预测出了end_id,那么就结束预测
# 注意:必须要有结束标识符,否则程序无法停止
if torch.all(target_id == self.end_id) or len(result) == 8:
print('stop decoding!')
break
predicted_ids = torch.stack(result, dim=0)
return predicted_ids

Seq2Seq模型搭建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Seq2SeqModel(nn.Module):
def __init__(self, embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id):
super(Seq2SeqModel, self).__init__()
self.encoder = Seq2SeqEncoder(embedding_dim, hidden_size, source_vocab_size)
self.decoder = Seq2SeqDecoder(embedding_dim, hidden_size, target_vocab_size, start_id, end_id)

def forward(self, input_sequence_ids, shifted_target_ids):
# 训练阶段
encoder_states, final_encoder = self.encoder(input_sequence_ids)
logits = self.decoder(shifted_target_ids, final_encoder)
return logits

def inference(self, input_sequence_ids):
# 推理阶段
encoder_states, final_encoder = self.encoder(input_sequence_ids)
predicted_ids = self.decoder.inference(final_encoder)
return predicted_ids
1
2
3
4
5
6
7
8
9
# 超参数定义
embedding_dim = 20 # enbedding维度
hidden_size = 16 # LSTM hidden_size
start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long)
end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) # 开始编码与结束编码
source_vocab_size = len(word_list)
target_vocab_size = len(word_list)

model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
1
model
Seq2SeqModel(
  (encoder): Seq2SeqEncoder(
    (embedding_table): Embedding(14, 20)
    (lstm_layer): LSTM(20, 16, batch_first=True)
  )
  (decoder): Seq2SeqDecoder(
    (embedding_table): Embedding(14, 20)
    (lstm_cell): LSTMCell(20, 16)
    (proj_layer): Linear(in_features=16, out_features=14, bias=True)
  )
)

初始化模型权重参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in model.named_parameters():
# 不对词嵌入的层的参数进行初始化,因为我们使用的是预训练的模型
if exclude not in name:
if 'weight' in name:
if method == 'xavier':
nn.init.xavier_normal_(w)
elif method == 'kaiming':
nn.init.kaiming_normal_(w)
else:
nn.init.normal_(w)
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
pass
1
init_network(model)

定义训练日志类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class PrintLog:
def __init__(self, log_filename):
self.log_filename = log_filename

def print_(self, message, only_file=False, add_time=True, cost_time=None):
if add_time:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
message = f"Time {current_time} : {message}"

if cost_time:
message = f"{message} Cost Time: {cost_time}s"

with open(self.log_filename, 'a') as f:
f.write(message + '\n')
if not only_file:
print(message)

模型训练

1
import time, math
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch, pring_log):
model.train() # 声明训练函数,参数的梯度要更新
total = 0 # 记录已经训练的数据个数
start_time = time.time()
for batch_idx, (encode_data, decode_data, target_data) in enumerate(train_loader):
encode_data, decode_data, target_data = encode_data.to(device), decode_data.to(device), target_data.to(device)
optimizer.zero_grad()
output = model(encode_data, decode_data)

loss = criterion(output.view(batch_size*8, -1).to(device), target_data.view(-1).to(device))
loss.backward()
optimizer.step()

total += len(encode_data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, total, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
end_time = time.time()
epoch_cost_time = end_time - start_time
print() # 增加换行符
pring_log.print_(f"Train epoch {epoch}, Loss {loss}", cost_time=epoch_cost_time)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 定义测试函数
def test(model, device, test_loader, criterion, pring_log):
model.eval() # 声明验证函数,禁止所有梯度进行更新
test_loss = 0
correct = 0
# 强制后面的计算不生成计算图,加快测试效率
with torch.no_grad():
for encode_data, decode_data, target_data in test_loader:
encode_data, decode_data, target_data = encode_data.to(device), decode_data.to(device), target_data.to(device)
output = model(encode_data, decode_data).view(batch_size*8, -1)
test_loss += criterion(output.to(device), target_data.view(-1).to(device)).item()
pred = output.argmax(dim=1, keepdim=True).to(device)
correct += pred.eq(target_data.view_as(pred).to(device)).sum().item()
test_loss /= len(test_loader.dataset)

pring_log.print_('Test: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset)*8,
100. * correct / (len(test_loader.dataset)*8)))
return test_loss # 将测试损失返回
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# 定义训练主函数
def run_train(model, epochs, train_loader, test_loader, device, resume="", model_name=""):
log_filename = f"{model_name}_training.log"
pring_log = PrintLog(log_filename)
# 模型定义并加载至GPU
model.to(device)
# 随机梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

min_loss = torch.inf
start_epoch = 0
delta = 1e-4

# 指定检查点
if resume:
pring_log.print_(f'loading from {resume}')
checkpoint = torch.load(resume, map_location=torch.device("cuda:0"))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
min_loss = checkpoint['loss']

for epoch in range(start_epoch + 1, start_epoch + epochs+1):
train(model, device, train_loader, criterion, optimizer, epoch, pring_log)
loss = test(model, device, test_loader, criterion, pring_log)
if loss < min_loss and not torch.isclose(torch.tensor([min_loss]), torch.tensor([loss]), delta): # 监测loss,当loss下降时保存模型,敏感度设置为1e-4(默认为1e-5)
pring_log.print_(f'Loss Reduce {min_loss} to {loss}')
min_loss = loss
save_file = f'{model_name}_checkpoint_best.pt'
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, save_file)
pring_log.print_(f'Save checkpoint to {save_file}')
print('----------------------------------------')
1
batch_size = 256
1
train_dataloader = DataLoader(dataset=RandomNumDataset(100000), batch_size=batch_size, collate_fn=collate_fn, drop_last=True)
1
test_dataloader = DataLoader(dataset=RandomNumDataset(50000), batch_size=batch_size, collate_fn=collate_fn, drop_last=True)
1
2
3
4
5
epochs = 1000
# 查看GPU是否可用,如果可用就用GPU否则用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train(model, epochs, train_dataloader, test_dataloader, device, model_name='Seq2SeqNormal')
Train epoch 1: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 19-19-53 : Train epoch 1, Loss 2.432358741760254 Cost Time: 4.7405595779418945s
Time 2023-12-18 19-19-54 : Test: average loss: 0.0095, accuracy: 79791/400000 (20%)
Time 2023-12-18 19-19-54 : Loss Reduce inf to 0.009483050050735473
Time 2023-12-18 19-19-54 : Save checkpoint to Seq2SeqNormal_checkpoint_best.pt
----------------------------------------
......
Train epoch 509: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 20-12-07 : Train epoch 509, Loss 0.3511051535606384 Cost Time: 4.616323471069336s
Time 2023-12-18 20-12-08 : Test: average loss: 0.0013, accuracy: 352755/400000 (88%)
Time 2023-12-18 20-12-08 : Loss Reduce 0.0013051700329780578 to 0.0013038752526044846
Time 2023-12-18 20-12-08 : Save checkpoint to Seq2SeqNormal_checkpoint_best.pt
----------------------------------------
Train epoch 510: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 20-12-13 : Train epoch 510, Loss 0.3496689796447754 Cost Time: 4.619180202484131s

模型推理

1
dev_dataloader = DataLoader(dataset=RandomNumDataset(10), batch_size=10, collate_fn=collate_fn, drop_last=True)
1
encode_data, decode_data, target_data = next(iter(dev_dataloader))
1
2
3
4
5
6
7
8
9
# 超参数定义
embedding_dim = 20 # enbedding维度
hidden_size = 16 # LSTM hidden_size
start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long)
end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) # 开始编码与结束编码
source_vocab_size = len(word_list)
target_vocab_size = len(word_list)

model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
1
2
3
resume = 'Seq2SeqNormal_checkpoint_best.pt'
checkpoint = torch.load(resume, map_location=torch.device("cuda:0"))
model.load_state_dict(checkpoint['model_state_dict'])
<All keys matched successfully>
1
2
3
4
5
6
7
8
9
10
11
def tensor2str(tensor):
# 将二维张量转化为字符串显示
str_list = tensor.numpy().tolist()
for i in range(len(str_list)):
# 删除填充词
pad_num = str_list[i].count(11)
if pad_num > 0:
[str_list[i].remove(11) for _ in range(pad_num)]
for j in range(len(str_list[i])):
str_list[i][j] = str(str_list[i][j])
return [''.join(i) for i in str_list]
1
tensor2str(encode_data)  # 验证数据
['466760',
 '826747',
 '324833',
 '507226',
 '435957',
 '766421',
 '611993',
 '190449',
 '334029',
 '301514']
1
tensor2str(model.inference(encode_data).T)  # 测试数据
stop decoding!

['466660013',
 '826747013',
 '325833013',
 '507226013',
 '435957013',
 '766421013',
 '611993013',
 '190449013',
 '344009013',
 '301514013']

其中13为结束标识符。

模型训练优化

Teacher forcing训练机制:在RNN的训练过程中,使用前一个预测的结果作为下一个step的输入,可能会导致一步错,步步错的结果,如果提高模型的收敛速度?

  • 可以考虑在训练的过程中,把真实值作为下一步的输入,这样可以避免步步错的局面
  • 同时在使用真实值的过程中,仍然使用预测值作为下一步的输入,两种输入随机使用

上述这种机制把它称为Teacher forcing,就像是一个指导老师,在每一步都会对我们的行为进行纠偏,从而达到在多次训练之后能够需要其中的规律

70e8d1eeab2f4ca79456c87c1788a28f

Decoder定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Seq2SeqDecoder(nn.Module):
def __init__(self, embedding_dim, hidden_size, target_vocab_size, start_id, end_id):
super(Seq2SeqDecoder, self).__init__()
self.embedding_table = nn.Embedding(target_vocab_size, embedding_dim)
self.lstm_cell = nn.LSTMCell(embedding_dim, hidden_size)
self.proj_layer = nn.Linear(hidden_size, target_vocab_size)
self.target_vocab_size = target_vocab_size

self.start_id = start_id # 开始标识符
self.end_id = end_id # 结束标识符

def forward(self, shifted_target_ids, final_encoder):
h_t, c_t = final_encoder
# encoder中的LSTM是层维度为[num_layers, bs, hidden_size]
# 而这里使用的LSTM是Cell,执行的是单次LSTM计算,输入和返回的单元状态都为[bs, hidden_size]
# 所以需要将encoder中的单元状态num_layers维度删除
h_t = h_t.squeeze(0)
c_t = c_t.squeeze(0)
# 训练阶段调用
shifted_target = self.embedding_table(shifted_target_ids)

bs, target_length, embedding_dim = shifted_target.shape

logits = torch.zeros(bs, target_length, self.target_vocab_size) # 定义存储输出的容器
for t in range(target_length):
# 使用Teacher Forcing机制训练
use_teacher_forcing = torch.rand(1)[0] > 0.5
if t == 0 or use_teacher_forcing:
# 下一次输入使用真实值
decoder_input_t = shifted_target[:, t, :] # [bs, embedding]
else:
#使用预测值
target_id = logits[:, t-1, :].argmax(axis=1)
decoder_input_t = self.embedding_table(target_id.to(device))

h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
# h_t: [bs, hidden_size], c_t: [bs, hidden_size]
logits[:, t, :] = self.proj_layer(h_t)
return logits

def inference(self, final_encoder):
h_t, c_t = final_encoder
h_t = h_t.squeeze(0)
c_t = c_t.squeeze(0)
batch_size = h_t.shape[0]
# 推理阶段使用
target_id = self.start_id
target_id = torch.stack([target_id]*batch_size, 0)
result = []

while True:
decoder_input_t = self.embedding_table(target_id)
decoder_input_t = decoder_input_t.squeeze(1)
h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
logits = self.proj_layer(h_t)
target_id = logits.argmax(axis=1)

result.append(target_id)
# 一直到某个minibatch中预测出了end_id,那么就结束预测
# 注意:必须要有结束标识符,否则程序无法停止
if torch.all(target_id == self.end_id) or len(result) == 8:
print('stop decoding!')
break
predicted_ids = torch.stack(result, dim=0)
return predicted_ids

模型定义

1
2
3
4
5
6
7
8
9
# 超参数定义
embedding_dim = 20 # enbedding维度
hidden_size = 16 # LSTM hidden_size
start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long)
end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) # 开始编码与结束编码
source_vocab_size = len(word_list)
target_vocab_size = len(word_list)

model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)

模型初始化

1
init_network(model)

模型训练

1
2
3
4
5
epochs = 510
# 查看GPU是否可用,如果可用就用GPU否则用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train(model, epochs, train_dataloader, test_dataloader, device, model_name='Seq2SeqTeacherForcing')
Train epoch 1: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 21-56-47 : Train epoch 1, Loss 2.4872562885284424 Cost Time: 4.550509929656982s
Time 2023-12-18 21-56-49 : Test: average loss: 0.0097, accuracy: 67954/400000 (17%)
Time 2023-12-18 21-56-49 : Loss Reduce inf to 0.009680559158325195
Time 2023-12-18 21-56-49 : Save checkpoint to Seq2SeqTeacherForcing_checkpoint_best.pt
......
----------------------------------------
Train epoch 509: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 22-51-40 : Train epoch 509, Loss 0.043971575796604156 Cost Time: 5.30586314201355s
Time 2023-12-18 22-51-41 : Test: average loss: 0.0002, accuracy: 398088/400000 (100%)
----------------------------------------
Train epoch 510: 99840/100000, [-------------------------------------------------->] 100%
Time 2023-12-18 22-51-47 : Train epoch 510, Loss 0.039883047342300415 Cost Time: 5.410919666290283s
Time 2023-12-18 22-51-49 : Test: average loss: 0.0002, accuracy: 398152/400000 (100%)
----------------------------------------

模型推理

1
2
3
4
5
6
7
8
9
# 超参数定义
embedding_dim = 20 # enbedding维度
hidden_size = 16 # LSTM hidden_size
start_id = torch.tensor([word2index_dict[SOS_TAG], ], dtype=torch.long)
end_id = torch.tensor([word2index_dict[EOS_TAG], ], dtype=torch.long) # 开始编码与结束编码
source_vocab_size = len(word_list)
target_vocab_size = len(word_list)

model = Seq2SeqModel(embedding_dim, hidden_size, source_vocab_size, target_vocab_size, start_id, end_id)
1
2
3
resume = 'Seq2SeqTeacherForcing_checkpoint_best.pt'
checkpoint = torch.load(resume, map_location=torch.device("cuda:0"))
model.load_state_dict(checkpoint['model_state_dict'])
<All keys matched successfully>
1
tensor2str(encode_data)  # 验证数据
['466760',
 '826747',
 '324833',
 '507226',
 '435957',
 '766421',
 '611993',
 '190449',
 '334029',
 '301514']
1
tensor2str(model.inference(encode_data).T)  # 测试数据
stop decoding!



['466760013',
 '826747013',
 '324833013',
 '507226013',
 '435957013',
 '766421013',
 '611993013',
 '190449013',
 '334029013',
 '301514013']

其中13为结束标识符。

-------------本文结束感谢您的阅读-------------