logo
0
0
WeChat Login
代码优化以及readme增强

DRIVE-Unet

这是一个基于PyTorch实现的U-Net医学图像分割项目,专门用于DRIVE(Digital Retinal Images for Vessel Extraction)数据集的血管分割任务。

项目简介

本项目实现了U-Net架构,用于对眼底图像进行血管分割。U-Net是一种经典的卷积神经网络架构,最初设计用于生物医学图像分割任务。本项目包含数据加载、数据增强、模型定义、损失函数、训练和测试等完整流程。

数据增强流程

数据增强是提高模型泛化能力和防止过拟合的重要步骤。本项目的数据增强流程如下:

1. 数据加载

  • 从DRIVE数据集中加载训练图像和对应的血管标注掩码
  • 训练集:20张图像及其对应的手工标注
  • 测试集:20张图像及其对应的手工标注

2. 增强操作

对训练数据应用以下增强技术:

  • 水平翻转:以100%概率进行水平镜像翻转
  • 垂直翻转:以100%概率进行垂直镜像翻转
  • 旋转:在±45°范围内随机旋转图像和掩码

3. 数据处理流程

原始图像 → 水平翻转 → 垂直翻转 → 旋转增强 ↓ ↓ ↓ ↓ 4张图像 4张图像 4张图像 4张图像

4. 输出结果

  • 每张原始图像生成4个增强样本(包括原始图像本身)
  • 所有图像统一调整为512×512像素大小
  • 增强后的数据保存到new_data/目录下
  • 训练数据和测试数据分别存储在对应子目录中

5. 目录结构

new_data/ ├── train/ │ ├── image/ # 增强后的训练图像 │ └── mask/ # 对应的血管掩码 └── test/ ├── image/ # 测试图像(未增强) └── mask/ # 测试掩码

训练完整流程

1. 环境准备

  • 确保安装所有依赖项(参考requirements.txt)
  • 准备DRIVE数据集并完成数据增强预处理

2. 模型配置

超参数设置: - 图像尺寸:512×512 - 批次大小:2 - 学习率:1e-4 - 训练轮数:50 epochs - 损失函数:DiceBCELoss - 优化器:Adam - 学习率调度器:ReduceLROnPlateau(耐心期5)

3. 训练执行

运行命令:

python train.py

4. 训练过程详解

  • 数据加载:使用DriveDataset加载增强后的训练和验证数据
  • 模型初始化:构建U-Net模型并移至GPU(如果可用)
  • 训练循环
    • 每个epoch遍历完整训练集
    • 前向传播计算损失
    • 反向传播更新模型参数
    • 验证集评估模型性能
  • 模型保存:自动保存验证损失最低的模型权重
  • 学习率调整:根据验证损失动态调整学习率

5. 训练监控

  • 实时显示每个epoch的训练损失和验证损失
  • 显示当前学习率
  • 记录每个epoch的训练时间
  • 保存最佳模型到files/checkpoint.pth

测试完整流程

1. 模型准备

  • 确保已训练好的模型权重文件存在(files/checkpoint.pth
  • 准备测试数据集

2. 测试执行

运行命令:

python test.py

3. 测试过程详解

  • 模型加载:从检查点文件加载训练好的模型权重
  • 数据处理
    • 遍历测试数据集中的每张图像
    • 图像预处理:归一化、维度调整
    • 掩码加载和预处理
  • 模型推理
    • 对每张图像进行前向传播预测
    • 应用Sigmoid激活函数
    • 二值化处理(阈值0.5)
  • 结果生成
    • 将原始图像、真实掩码、预测结果横向拼接
    • 保存对比图像到results/目录
    • 计算各项评估指标

4. 输出结果

  • 可视化结果:在results/目录下生成对比图像
  • 性能指标
    • Jaccard Score (IoU):交并比
    • F1 Score:F1得分
    • Recall:召回率
    • Precision:精确率
    • Accuracy:准确率
  • 性能统计:平均处理速度(FPS)

5. 结果文件命名

测试结果按以下格式保存:

results/ ├── 01_test_0.png # 第1张测试图像的结果 ├── 02_test_0.png # 第2张测试图像的结果 └── ... # 其他测试图像结果

文件结构

  • data.py: 定义了DRIVE数据集的PyTorch数据加载器
  • data_aug.py: 包含数据加载和增强功能
  • loss.py: 定义了Dice损失和Dice BCE损失函数
  • model.py: 实现了U-Net模型架构
  • train.py: 训练脚本
  • test.py: 测试脚本,包含评估指标
  • utils.py: 包含工具函数,如随机种子设置、目录创建等

模型架构

U-Net包含以下组件:

  • 编码器块(Encoder blocks)
  • 解码器块(Decoder blocks)
  • 跳跃连接(Skip connections)
  • 瓶颈层(Bottleneck layer)

损失函数

项目使用了两种损失函数:

  • Dice Loss
  • Dice BCE Loss(Dice损失和二元交叉熵损失的组合)

使用方法

  1. 准备DRIVE数据集
  2. 运行data_aug.py进行数据预处理和增强
  3. 运行train.py开始训练模型
  4. 训练完成后,运行test.py评估模型性能

评估指标

测试脚本计算以下评估指标:

  • Jaccard Score (IoU)
  • F1 Score
  • Recall
  • Precision
  • Accuracy

依赖项

项目依赖项详见requirements.txt文件。

许可证

请根据您的需要添加许可证信息。