0%

Python数据科学_34_案例:基于Unet网络的直肠癌肿瘤区域分割【计算机视觉】

读取数据

1
import SimpleITK as sitk
1
2
3
import matplotlib.pyplot as plt
import numpy as np
import cv2

读取CT影像照片DCM文件

1
dcm_path = 'data/10009.dcm'
1
2
3
img = sitk.ReadImage(dcm_path)  # 读取DCM文件
img = sitk.GetArrayFromImage(img) # 获取数组对象
img = img.squeeze() # 删除维度为1的维度
1
2
3
img[img < 0] = 0
img[img > 255] = 255
img = np.array(img, dtype=np.uint8)
1
2
3
4
# 绘制图片
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()


output_20240607232201

读取CT影像照片对应的掩膜图片

1
mask_path = 'data/10009_mask.png'
1
mask_img = cv2.imread(mask_path)[:, :, 0]  # 因为该掩膜图片3层通道都是相等的,所以取一个通道进行研究即可
1
2
3
4
# 绘制图片
plt.imshow(mask_img, cmap='gray')
plt.axis('off')
plt.show()


output_20240607232202

在该任务中,我们将图像分割任务分为两个子任务。

  1. 首先去建立一个2分类模型,去判断每个CT影像图片中是否有肿瘤。
  2. 对有肿瘤的图片搭建图像分割模型,分割出肿瘤区域。

要实现这些任务,由于原始的数据量太大,并且图片像素很高,所以在进行建模前需要做大量的数据预处理工作。

  1. 要实现2分类任务,首先去要去构建标签,在构建标签前,还需要对原始的数据进行采样,保证正负样本的类别均衡。
  2. 正样本,在掩膜图片中存在肿瘤区域,也就是将掩膜图片中的像素进行求和后不为0;反之,则为负样本。
  3. 现有图片像素过大,但是经过观察发现,图片中的肿瘤区域都是集中在图像中部。

批量读取图片

读取掩膜图片

1
2
from glob import glob
from tqdm import tqdm
1
2
3
# 获取所有图片路径
mask_imgs_path = glob('data/dcmData1/*/*/*.png')
dcm_imgs_path = glob('data/dcmData1/*/*/*.dcm')
1
2
3
4
5
mask_imgs = []
for mask_path in tqdm(mask_imgs_path):
mask_img = cv2.imread(mask_path)[:, :, 0]
mask_imgs.append(mask_img)
mask_imgs = np.stack(mask_imgs, axis=0)
100%|█████████████████████████████████████████████████████████████████████████████| 6248/6248 [00:06<00:00, 908.96it/s]

切分出肿瘤区域

1
2
3
4
5
# 合并所有掩膜图片中的肿瘤区域
sum_img = np.zeros((512, 512))
for mask_img in mask_imgs:
sum_img += mask_img
sum_img[sum_img != 0] = 255
1
2
3
4
# 绘制图片
plt.imshow(sum_img, cmap='gray')
plt.axis('off')
plt.show()


output_20240607232203

1
2
3
4
# 获取肿瘤区域位置下标
row_min, row_max = np.arange(512)[sum_img.sum(axis=1) != 0][[0, -1]]
col_min, col_max = np.arange(512)[sum_img.sum(axis=0) != 0][[0, -1]]
print(f"肿瘤区域位置为:[{row_min}:{row_max}, {col_min}:{col_max}]")
肿瘤区域位置为:[230:430, 189:321]

为了后续模型搭建方便,在此将数据肿瘤区域切分位置定位为208x144大小,必须为16的倍数

区域固定为:[226:434, 183:327]

1
2
# 利用计算出来的肿瘤区域,将所有的掩膜图片中的肿瘤区域切分出来
mask_imgs = np.stack([mask_img[226:434, 183:327] for mask_img in mask_imgs], axis=0)
1
mask_imgs.shape
(6248, 208, 144)

构建正负样本

1
2
3
4
5
labels = []
for mask_img in mask_imgs:
label = 0 if mask_img.sum() == 0 else 1
labels.append(label)
labels = np.array(labels)
1
2
3
# 统计正负样本数量
from collections import Counter
Counter(labels)
Counter({0: 4453, 1: 1795})

可以看到,负样本占了绝大部分,正负样本分布不均衡。

1
2
3
4
5
# 正样本下标
index_1 = np.arange(len(labels))[labels == 1]
# 负样本下标
index_0 = np.random.choice(np.arange(len(labels))[labels == 0], size=len(index_1))
index = np.concatenate([index_0, index_1])

读取所有图片和标签

1
2
3
4
5
6
7
8
9
10
11
labels_sample = labels[index]  # 获取标签
masks_sample = mask_imgs[index]
imgs_sample = [] # 定义存储所有的CT图片
for i in tqdm(index):
dcm_path = dcm_imgs_path[i]
img = sitk.ReadImage(dcm_path) # 读取DCM文件
img_array = sitk.GetArrayFromImage(img) # 获取数组对象
img_array = img_array.squeeze() # 删除通道1维度
img_array = img_array[226:434, 183:327] # 切分出肿瘤区域
imgs_sample.append(img_array)
imgs_sample = np.stack(imgs_sample, axis=0)
100%|█████████████████████████████████████████████████████████████████████████████| 3590/3590 [00:25<00:00, 140.12it/s]
1
2
3
4
# 绘制图片
plt.imshow(imgs_sample[2000], cmap='gray')
plt.axis('off')
plt.show()


output_20240607232204

1
2
3
4
# 绘制图片
plt.imshow(masks_sample[2000], cmap='gray')
plt.axis('off')
plt.show()


output_20240607232205

1
labels_sample[2000]
1
1
np.savez('data/dcmData.npz', imgs_sample=imgs_sample, masks_sample=masks_sample, labels_sample=labels_sample)

切分数据集

1
import numpy as np
1
2
3
4
data = np.load('data/dcmData.npz')
imgs_sample = data['imgs_sample']
masks_sample = data['masks_sample']
labels_sample = data['labels_sample']
1
from sklearn.model_selection import train_test_split
1
train_imgs, test_imgs, train_masks, test_masks, train_labels, test_labels = train_test_split(imgs_sample, masks_sample, labels_sample, test_size=0.2, random_state=2024, shuffle=True)
1
2
print('Train:', train_imgs.shape)
print('Test:', test_imgs.shape)
Train: (2872, 208, 144)
Test: (718, 208, 144)

判断CT图片中是否有肿瘤

1
2
3
4
5
6
7
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import v2 as transforms
from torch.utils.data import DataLoader, Dataset

自定义Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class CTDataset1(Dataset):
def __init__(self, imgs, labels, transforms=None):
super(CTDataset1, self).__init__()
self.imgs = np.expand_dims(imgs, -1) # 增加一个通道维度
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.imgs)

def __getitem__(self, idx):
img = self.imgs[idx].copy()
img[img < 0] = 0
img[img > 255] = 255
img = np.array(img, dtype=np.uint8)
label = int(self.labels[idx])
if self.transforms:
img = self.transforms(img)
return img, label
1
2
3
4
5
6
7
8
9
10
11
train_dataset = CTDataset1(train_imgs, train_labels,
transforms=transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True)
]))

test_dataset = CTDataset1(test_imgs, test_labels,
transforms=transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True)
]))

定义模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 定义模型类
class CNNModel(nn.Module):
def __init__(self, label_nums=2):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,
out_channels=16,
kernel_size=(3, 3),
padding=1)
self.batch_norm1 = nn.BatchNorm2d(16)

self.pool = nn.MaxPool2d(kernel_size=2)

self.conv2 = nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=(3, 3),
padding=1)
self.batch_norm2 = nn.BatchNorm2d(32)

self.conv3 = nn.Conv2d(in_channels=32,
out_channels=8,
kernel_size=(3, 3),
padding=1)
self.batch_norm3 = nn.BatchNorm2d(8)

self.flatten = nn.Flatten()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(3744, 512)
self.fc2 = nn.Linear(512, label_nums)

self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = self.relu(self.pool(x))
x = self.batch_norm2(self.conv2(x))
x = self.relu(self.pool(x))
x = self.batch_norm3(self.conv3(x))
x = self.relu(self.pool(x))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
1
2
3
4
from torchsummary.torchsummary import summary
cnn_model = CNNModel().to('cuda')

summary(cnn_model, input_size=(1, 208, 144))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 208, 144]             160
       BatchNorm2d-2         [-1, 16, 208, 144]              32
         MaxPool2d-3          [-1, 16, 104, 72]               0
              ReLU-4          [-1, 16, 104, 72]               0
            Conv2d-5          [-1, 32, 104, 72]           4,640
       BatchNorm2d-6          [-1, 32, 104, 72]              64
         MaxPool2d-7           [-1, 32, 52, 36]               0
              ReLU-8           [-1, 32, 52, 36]               0
            Conv2d-9            [-1, 8, 52, 36]           2,312
      BatchNorm2d-10            [-1, 8, 52, 36]              16
        MaxPool2d-11            [-1, 8, 26, 18]               0
             ReLU-12            [-1, 8, 26, 18]               0
          Flatten-13                 [-1, 3744]               0
           Linear-14                  [-1, 512]       1,917,440
             ReLU-15                  [-1, 512]               0
          Dropout-16                  [-1, 512]               0
           Linear-17                    [-1, 2]           1,026
================================================================
Total params: 1,925,690
Trainable params: 1,925,690
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 14.04
Params size (MB): 7.35
Estimated Total Size (MB): 21.50
----------------------------------------------------------------

定义损失函数

1
criterion = nn.CrossEntropyLoss()

定义优化器

1
optimizer = torch.optim.Adam(cnn_model.parameters())

模型训练

编写训练函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# model:模型    device:模型训练场所     optimizer:优化器    epoch:模型训练轮次
def train(model, device, train_loader, criterion, optimizer, epoch):
model.train() # 声明训练函数,参数的梯度要更新
total = 0 # 记录已经训练的数据个数
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

total += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, total, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')

编写测试函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import math
def test(model, device, test_loader, criterion):
model.eval() # 声明验证函数,禁止所有梯度进行更新
test_loss = 0
correct = 0
# 强制后面的计算不生成计算图,加快测试效率
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # 对每个batch的loss进行求和
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))

训练主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
epochs = 10  # 迭代次数
batch_size = 256
torch.manual_seed(2024)

# 查看GPU是否可用,如果可用就用GPU否则用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 训练集的定义
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 测试集的定义
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(1, epochs+1):
train(cnn_model, device, train_loader, criterion, optimizer, epoch)
test(cnn_model, device, test_loader, criterion)
print('--------------------------')
Train epoch 1: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0027, accuracy: 411/718 (57%)
--------------------------
Train epoch 2: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0025, accuracy: 437/718 (61%)
--------------------------
Train epoch 3: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0021, accuracy: 516/718 (72%)
--------------------------
Train epoch 4: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0016, accuracy: 574/718 (80%)
--------------------------
Train epoch 5: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0013, accuracy: 622/718 (87%)
--------------------------
Train epoch 6: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0011, accuracy: 647/718 (90%)
--------------------------
Train epoch 7: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0009, accuracy: 653/718 (91%)
--------------------------
Train epoch 8: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0009, accuracy: 652/718 (91%)
--------------------------
Train epoch 9: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0009, accuracy: 658/718 (92%)
--------------------------
Train epoch 10: 2872/2872, [---------------------------------------------->    ] 92%
Test: average loss: 0.0008, accuracy: 661/718 (92%)
--------------------------

模型验证

1
2
3
4
5
6
pred_index = 37
pred = cnn_model(test_dataset[pred_index][0].unsqueeze(0).to('cuda')).argmax()
plt.imshow(test_masks[pred_index])
plt.title(f'Pred: {pred.item()}')
plt.axis('off')
plt.show()


output_20240607232206

1
torch.save(cnn_model.state_dict(), 'cnn_model.pth')

在判断有肿瘤区域的前提下,定位肿瘤区域的具体位置

1
2
3
4
5
import numpy as np
data = np.load('data/dcmData.npz')
imgs_sample = data['imgs_sample'] # CT图片数据
masks_sample = data['masks_sample'] # CT图片对应的掩膜图片数据
labels_sample = data['labels_sample'] # CT图片中是否有肿瘤的标签数据
1
2
3
# 筛选出所有包含肿瘤的数据
imgs_sample = imgs_sample[labels_sample == 1]
masks_sample = masks_sample[labels_sample == 1]
1
2
3
4
# 切分数据集
from sklearn.model_selection import train_test_split

train_imgs, test_imgs, train_masks, test_masks = train_test_split(imgs_sample, masks_sample, test_size=0.2, random_state=2024, shuffle=True)
1
2
print('Train:', train_imgs.shape)
print('Test:', test_imgs.shape)
Train: (1436, 208, 144)
Test: (359, 208, 144)

自定义Dataset

1
2
3
4
5
6
7
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import v2 as transforms
from torch.utils.data import DataLoader, Dataset
D:\ProgramSoftware\Python\Miniconda3\envs\pytorch\lib\site-packages\transformers\utils\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class CTDataset2(Dataset):
def __init__(self, imgs, masks, transforms=None):
super(CTDataset2, self).__init__()
self.imgs = np.expand_dims(imgs, -1) # 增加一个通道维度
self.masks = masks
self.transforms = transforms

def __len__(self):
return len(self.imgs)

def __getitem__(self, idx):
img = self.imgs[idx].copy()
img[img < 0] = 0
img[img > 255] = 255
img = np.array(img, dtype=np.uint8)
if self.transforms:
img = self.transforms(img)

mask = self.masks[idx].copy()
mask[mask == 255] = 1
mask = torch.LongTensor(mask)
return img, mask
1
2
3
4
5
6
7
8
9
10
11
train_dataset2 = CTDataset2(train_imgs, train_masks,
transforms=transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True)
]))

test_dataset2 = CTDataset2(test_imgs, test_masks,
transforms=transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True)
]))

定义模型

Unet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 定义双卷积结构
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))

def forward(self, x):
return self.conv(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 定义Unet网络
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.max_pool = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.conv3 = DoubleConv(128, 256)
self.conv4 = DoubleConv(256, 512)
self.conv5 = DoubleConv(512, 1024)

self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)

def forward(self, x): # [bs, 1, 208, 144]
c1=self.conv1(x) # [bs, 64, 208, 144]
p1=self.max_pool(c1) # [bs, 64, 104, 72]
c2=self.conv2(p1) # [bs, 128, 104, 72]
p2=self.max_pool(c2) # [bs, 128, 52, 36]
c3=self.conv3(p2) # [bs, 256, 52, 36]
p3=self.max_pool(c3) # [bs, 256, 26, 18]
c4=self.conv4(p3) # [bs, 512, 26, 18]
p4=self.max_pool(c4) # [bs, 512, 13, 9]

c5=self.conv5(p4) # [bs, 1024, 13, 9]

up_6= self.up6(c5) # [bs, 512, 26, 18]
merge6 = torch.cat([up_6, c4], dim=1) # [bs, 1024, 26, 18]
c6=self.conv6(merge6) # [bs, 512, 26, 18]

up_7=self.up7(c6) # [bs, 256, 52, 36]
merge7 = torch.cat([up_7, c3], dim=1) # [bs, 512, 52, 36]
c7=self.conv7(merge7) # [bs, 256, 52, 36]

up_8=self.up8(c7) # [bs, 128, 104, 72]
merge8 = torch.cat([up_8, c2], dim=1) # [bs, 256, 104, 72]
c8=self.conv8(merge8) # [bs, 128, 104, 72]

up_9=self.up9(c8) # [bs, 64, 208, 144]
merge9=torch.cat([up_9,c1],dim=1) # [bs, 128, 208, 144]
c9=self.conv9(merge9) # [bs, 64, 208, 144]

c10=F.sigmoid(self.conv10(c9)) # [bs, 2, 208, 144]
return c10
1
2
3
from torchsummary.torchsummary import summary
unet = Unet(1, 2).to('cuda')
summary(unet, input_size=(1, 208, 144))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 208, 144]             640
       BatchNorm2d-2         [-1, 64, 208, 144]             128
              ReLU-3         [-1, 64, 208, 144]               0
            Conv2d-4         [-1, 64, 208, 144]          36,928
       BatchNorm2d-5         [-1, 64, 208, 144]             128
              ReLU-6         [-1, 64, 208, 144]               0
        DoubleConv-7         [-1, 64, 208, 144]               0
         MaxPool2d-8          [-1, 64, 104, 72]               0
            Conv2d-9         [-1, 128, 104, 72]          73,856
      BatchNorm2d-10         [-1, 128, 104, 72]             256
             ReLU-11         [-1, 128, 104, 72]               0
           Conv2d-12         [-1, 128, 104, 72]         147,584
      BatchNorm2d-13         [-1, 128, 104, 72]             256
             ReLU-14         [-1, 128, 104, 72]               0
       DoubleConv-15         [-1, 128, 104, 72]               0
        MaxPool2d-16          [-1, 128, 52, 36]               0
           Conv2d-17          [-1, 256, 52, 36]         295,168
      BatchNorm2d-18          [-1, 256, 52, 36]             512
             ReLU-19          [-1, 256, 52, 36]               0
           Conv2d-20          [-1, 256, 52, 36]         590,080
      BatchNorm2d-21          [-1, 256, 52, 36]             512
             ReLU-22          [-1, 256, 52, 36]               0
       DoubleConv-23          [-1, 256, 52, 36]               0
        MaxPool2d-24          [-1, 256, 26, 18]               0
           Conv2d-25          [-1, 512, 26, 18]       1,180,160
      BatchNorm2d-26          [-1, 512, 26, 18]           1,024
             ReLU-27          [-1, 512, 26, 18]               0
           Conv2d-28          [-1, 512, 26, 18]       2,359,808
      BatchNorm2d-29          [-1, 512, 26, 18]           1,024
             ReLU-30          [-1, 512, 26, 18]               0
       DoubleConv-31          [-1, 512, 26, 18]               0
        MaxPool2d-32           [-1, 512, 13, 9]               0
           Conv2d-33          [-1, 1024, 13, 9]       4,719,616
      BatchNorm2d-34          [-1, 1024, 13, 9]           2,048
             ReLU-35          [-1, 1024, 13, 9]               0
           Conv2d-36          [-1, 1024, 13, 9]       9,438,208
      BatchNorm2d-37          [-1, 1024, 13, 9]           2,048
             ReLU-38          [-1, 1024, 13, 9]               0
       DoubleConv-39          [-1, 1024, 13, 9]               0
  ConvTranspose2d-40          [-1, 512, 26, 18]       2,097,664
           Conv2d-41          [-1, 512, 26, 18]       4,719,104
      BatchNorm2d-42          [-1, 512, 26, 18]           1,024
             ReLU-43          [-1, 512, 26, 18]               0
           Conv2d-44          [-1, 512, 26, 18]       2,359,808
      BatchNorm2d-45          [-1, 512, 26, 18]           1,024
             ReLU-46          [-1, 512, 26, 18]               0
       DoubleConv-47          [-1, 512, 26, 18]               0
  ConvTranspose2d-48          [-1, 256, 52, 36]         524,544
           Conv2d-49          [-1, 256, 52, 36]       1,179,904
      BatchNorm2d-50          [-1, 256, 52, 36]             512
             ReLU-51          [-1, 256, 52, 36]               0
           Conv2d-52          [-1, 256, 52, 36]         590,080
      BatchNorm2d-53          [-1, 256, 52, 36]             512
             ReLU-54          [-1, 256, 52, 36]               0
       DoubleConv-55          [-1, 256, 52, 36]               0
  ConvTranspose2d-56         [-1, 128, 104, 72]         131,200
           Conv2d-57         [-1, 128, 104, 72]         295,040
      BatchNorm2d-58         [-1, 128, 104, 72]             256
             ReLU-59         [-1, 128, 104, 72]               0
           Conv2d-60         [-1, 128, 104, 72]         147,584
      BatchNorm2d-61         [-1, 128, 104, 72]             256
             ReLU-62         [-1, 128, 104, 72]               0
       DoubleConv-63         [-1, 128, 104, 72]               0
  ConvTranspose2d-64         [-1, 64, 208, 144]          32,832
           Conv2d-65         [-1, 64, 208, 144]          73,792
      BatchNorm2d-66         [-1, 64, 208, 144]             128
             ReLU-67         [-1, 64, 208, 144]               0
           Conv2d-68         [-1, 64, 208, 144]          36,928
      BatchNorm2d-69         [-1, 64, 208, 144]             128
             ReLU-70         [-1, 64, 208, 144]               0
       DoubleConv-71         [-1, 64, 208, 144]               0
           Conv2d-72          [-1, 2, 208, 144]             130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 425.04
Params size (MB): 118.42
Estimated Total Size (MB): 543.57
----------------------------------------------------------------

定义评价函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def dice_value(pred, target, smooth=1.0):  
"""
计算Dice。
Args:
pred (torch.Tensor): 预测结果,形状为[B, C, H, W],其中C是类别数。
target (torch.Tensor): 真实标签,形状与pred相同。
smooth (float, optional): 平滑因子,用于防止除以零。默认为1.0。
Returns:
torch.Tensor: Dice值。
"""
# 如果存在空标签,则添加平滑项
iflat = (pred * target).sum(dim=[2, 3]) + smooth
tflat = target.sum(dim=[2, 3]) + smooth
pflat = pred.sum(dim=[2, 3]) + smooth

dice = 2. * iflat / (tflat + pflat)
return dice

定义损失函数

1
criterion = nn.CrossEntropyLoss()

定义优化器

1
optimizer = torch.optim.Adam(unet.parameters())

模型训练

编写训练函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# model:模型    device:模型训练场所     optimizer:优化器    epoch:模型训练轮次
def train(model, device, train_loader, criterion, optimizer, epoch):
model.train() # 声明训练函数,参数的梯度要更新
total = 0 # 记录已经训练的数据个数
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

total += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, total, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')

编写测试函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import math
def test(model, device, test_loader, criterion):
model.eval() # 声明验证函数,禁止所有梯度进行更新
test_loss = 0
dice = []
# 强制后面的计算不生成计算图,加快测试效率
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # 对每个batch的loss进行求和
pred = output.argmax(dim=1, keepdim=True)
dice.append(dice_value(pred, target.unsqueeze(dim=1)).mean().to('cpu'))
dice = np.mean(dice)
print('\nTest: average loss: {:.4f}, dice: {}'.format(test_loss, dice))
return pred, target

训练主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
epochs = 50  # 迭代次数
batch_size = 16
torch.manual_seed(2024)

# 查看GPU是否可用,如果可用就用GPU否则用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 训练集的定义
train_loader = DataLoader(train_dataset2, batch_size=batch_size, shuffle=True)
# 测试集的定义
test_loader = DataLoader(test_dataset2, batch_size=batch_size, shuffle=True)

for epoch in range(1, epochs+1):
train(unet, device, train_loader, criterion, optimizer, epoch)
test(unet, device, test_loader, criterion)
print('--------------------------')
Train epoch 1: 1436/1436, [-------------------------------------------------->] 100%

D:\ProgramSoftware\Python\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cudnn\Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,


​ Test: average loss: 8.5509, dice: 0.7674808502197266
​ —————————————
​ Train epoch 2: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.9442, dice: 0.7710758447647095
​ —————————————
​ Train epoch 3: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.6687, dice: 0.8228566646575928
​ —————————————
​ Train epoch 4: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5969, dice: 0.8312968611717224
​ —————————————
​ Train epoch 5: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5754, dice: 0.8371047377586365
​ —————————————
​ Train epoch 6: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5523, dice: 0.7972604632377625
​ —————————————
​ Train epoch 7: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5325, dice: 0.8414972424507141
​ —————————————
​ Train epoch 8: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5442, dice: 0.809937596321106
​ —————————————
​ Train epoch 9: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5278, dice: 0.8313433527946472
​ —————————————
​ Train epoch 10: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.5080, dice: 0.8392102718353271
​ —————————————
​ Train epoch 11: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.6097, dice: 0.8199900984764099
​ —————————————
​ Train epoch 12: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4869, dice: 0.8532517552375793
​ —————————————
​ Train epoch 13: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4850, dice: 0.8600601553916931
​ —————————————
​ Train epoch 14: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4724, dice: 0.8592535257339478
​ —————————————
​ Train epoch 15: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4830, dice: 0.8472432494163513
​ —————————————
​ Train epoch 16: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4782, dice: 0.8639267683029175
​ —————————————
​ Train epoch 17: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4710, dice: 0.8511594533920288
​ —————————————
​ Train epoch 18: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4698, dice: 0.866731584072113
​ —————————————
​ Train epoch 19: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4685, dice: 0.8629027605056763
​ —————————————
​ Train epoch 20: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4636, dice: 0.8664995431900024
​ —————————————
​ Train epoch 21: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4510, dice: 0.8632227182388306
​ —————————————
​ Train epoch 22: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4550, dice: 0.8614387512207031
​ —————————————
​ Train epoch 23: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4517, dice: 0.8725950121879578
​ —————————————
​ Train epoch 24: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4581, dice: 0.8588936924934387
​ —————————————
​ Train epoch 25: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4520, dice: 0.8660483360290527
​ —————————————
​ Train epoch 26: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4538, dice: 0.8603050112724304
​ —————————————
​ Train epoch 27: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4770, dice: 0.8656834959983826
​ —————————————
​ Train epoch 28: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4509, dice: 0.8729587197303772
​ —————————————
​ Train epoch 29: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4684, dice: 0.8699115514755249
​ —————————————
​ Train epoch 30: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4480, dice: 0.869256317615509
​ —————————————
​ Train epoch 31: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4362, dice: 0.8815281987190247
​ —————————————
​ Train epoch 32: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4504, dice: 0.8704572916030884
​ —————————————
​ Train epoch 33: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4587, dice: 0.862378716468811
​ —————————————
​ Train epoch 34: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4359, dice: 0.8830628395080566
​ —————————————
​ Train epoch 35: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4416, dice: 0.8735705614089966
​ —————————————
​ Train epoch 36: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4302, dice: 0.8837142586708069
​ —————————————
​ Train epoch 37: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4434, dice: 0.8692076206207275
​ —————————————
​ Train epoch 38: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4257, dice: 0.8872324228286743
​ —————————————
​ Train epoch 39: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4548, dice: 0.8668719530105591
​ —————————————
​ Train epoch 40: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4417, dice: 0.8776969909667969
​ —————————————
​ Train epoch 41: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4431, dice: 0.8810428380966187
​ —————————————
​ Train epoch 42: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4483, dice: 0.8780376315116882
​ —————————————
​ Train epoch 43: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4395, dice: 0.8790196776390076
​ —————————————
​ Train epoch 44: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4278, dice: 0.8867204785346985
​ —————————————
​ Train epoch 45: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4349, dice: 0.880047082901001
​ —————————————
​ Train epoch 46: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4268, dice: 0.8837651610374451
​ —————————————
​ Train epoch 47: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4484, dice: 0.8826776742935181
​ —————————————
​ Train epoch 48: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4324, dice: 0.8834880590438843
​ —————————————
​ Train epoch 49: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4481, dice: 0.8696610927581787
​ —————————————
​ Train epoch 50: 1436/1436, [—————————————————————————>] 100%
​ Test: average loss: 7.4274, dice: 0.8883213400840759
​ —————————————

1
torch.save(unet.state_dict(), 'unet_model.pth')

模型测试

1
2
3
4
index = 31
ct_img = test_dataset2[index][0].unsqueeze(0).to('cuda')
ct_mask = test_dataset2[index][1]
pred = unet(ct_img).argmax(dim=1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt

plt.subplot(131)
plt.imshow(ct_img[0][0].to('cpu'), cmap='gray')
plt.title('original')
plt.axis('off')

plt.subplot(132)
plt.imshow(pred[0].to('cpu'), cmap='gray')
plt.title('Pred')
plt.axis('off')

plt.subplot(133)
plt.imshow(ct_mask.to('cpu'), cmap='gray')
plt.title('True')
plt.axis('off')

plt.show()

output_20240607232207

-------------本文结束感谢您的阅读-------------