0%

Python数据科学_35_案例:基于Unet网络的直肠癌肿瘤区域分割——推理过程【计算机视觉】

读取数据

1
2
3
4
5
6
7
8
9
10
11
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import v2 as transforms
from torch.utils.data import DataLoader, Dataset
1
dcm_path = 'data/10009.dcm'
1
2
3
img = sitk.ReadImage(dcm_path)  # 读取DCM文件
img = sitk.GetArrayFromImage(img) # 获取数组对象
img = img.squeeze()
1
2
3
img[img < 0] = 0
img[img > 255] = 255
img = np.array(img, dtype=np.uint8)
1
2
3
4
# 绘制图片
plt.imshow(img_array, cmap='gray')
plt.axis('off')
plt.show()

output_20240607232601

区域切分

1
img = img_array[226:434, 183:327].copy()
1
img = np.expand_dims(img, -1)
1
2
3
4
transform_img=transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True)
])
1
img = transform_img(img)
1
2
3
4
# 绘制图片
plt.imshow(img[0], cmap='gray')
plt.axis('off')
plt.show()

output_20240607232602

直肠癌肿瘤检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 定义模型类
class CNNModel(nn.Module):
def __init__(self, label_nums=2):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,
out_channels=16,
kernel_size=(3, 3),
padding=1)
self.batch_norm1 = nn.BatchNorm2d(16)

self.pool = nn.MaxPool2d(kernel_size=2)

self.conv2 = nn.Conv2d(in_channels=16,
out_channels=32,
kernel_size=(3, 3),
padding=1)
self.batch_norm2 = nn.BatchNorm2d(32)

self.conv3 = nn.Conv2d(in_channels=32,
out_channels=8,
kernel_size=(3, 3),
padding=1)
self.batch_norm3 = nn.BatchNorm2d(8)

self.flatten = nn.Flatten()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(3744, 512)
self.fc2 = nn.Linear(512, label_nums)

self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = self.relu(self.pool(x))
x = self.batch_norm2(self.conv2(x))
x = self.relu(self.pool(x))
x = self.batch_norm3(self.conv3(x))
x = self.relu(self.pool(x))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
1
2
3
4
# 加载权重
cnn_model = CNNModel().to('cuda')
cnn_weight = torch.load('cnn_model.pth')
cnn_model.load_state_dict(cnn_weight)
<All keys matched successfully>
1
2
# 预测
cnn_model(img.unsqueeze(0).to('cuda')).argmax()
tensor(1, device='cuda:0')

从预测结果可以看出该CT图像中存在肿瘤

直肠癌肿瘤分割

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 定义双卷积结构
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))

def forward(self, x):
return self.conv(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 定义Unet网络
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.max_pool = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.conv3 = DoubleConv(128, 256)
self.conv4 = DoubleConv(256, 512)
self.conv5 = DoubleConv(512, 1024)

self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)

def forward(self, x): # [bs, 1, 208, 144]
c1=self.conv1(x) # [bs, 64, 208, 144]
p1=self.max_pool(c1) # [bs, 64, 104, 72]
c2=self.conv2(p1) # [bs, 128, 104, 72]
p2=self.max_pool(c2) # [bs, 128, 52, 36]
c3=self.conv3(p2) # [bs, 256, 52, 36]
p3=self.max_pool(c3) # [bs, 256, 26, 18]
c4=self.conv4(p3) # [bs, 512, 26, 18]
p4=self.max_pool(c4) # [bs, 512, 13, 9]

c5=self.conv5(p4) # [bs, 1024, 13, 9]

up_6= self.up6(c5) # [bs, 512, 26, 18]
merge6 = torch.cat([up_6, c4], dim=1) # [bs, 1024, 26, 18]
c6=self.conv6(merge6) # [bs, 512, 26, 18]

up_7=self.up7(c6) # [bs, 256, 52, 36]
merge7 = torch.cat([up_7, c3], dim=1) # [bs, 512, 52, 36]
c7=self.conv7(merge7) # [bs, 256, 52, 36]

up_8=self.up8(c7) # [bs, 128, 104, 72]
merge8 = torch.cat([up_8, c2], dim=1) # [bs, 256, 104, 72]
c8=self.conv8(merge8) # [bs, 128, 104, 72]

up_9=self.up9(c8) # [bs, 64, 208, 144]
merge9=torch.cat([up_9,c1],dim=1) # [bs, 128, 208, 144]
c9=self.conv9(merge9) # [bs, 64, 208, 144]

c10=F.sigmoid(self.conv10(c9)) # [bs, 2, 208, 144]
return c10
1
2
3
4
# 加载数据
unet = Unet(1, 2).to('cuda')
unet_weight = torch.load('unet_model.pth')
unet.load_state_dict(unet_weight)
<All keys matched successfully>
1
2
# 图片分割
pred_masks = unet(img.unsqueeze(0).to('cuda')).argmax(dim=1)
1
2
3
4
# 绘制图片
plt.imshow(pred_masks[0].to('cpu'), cmap='gray')
plt.axis('off')
plt.show()

output_20240607232603

尺寸还原

1
2
org_mask = np.zeros([512, 512])
org_mask[226:434, 183:327] = pred_masks[0].to('cpu')
1
2
3
4
5
6
7
8
9
10
# 绘制图片
plt.figure(figsize=(10, 5), dpi=200)
plt.subplot(121)
plt.imshow(img_array, cmap='gray')
plt.axis('off')

plt.subplot(122)
plt.imshow(org_mask, cmap='gray')
plt.axis('off')
plt.show()


output_20240607232604

-------------本文结束感谢您的阅读-------------