logo
0
0
WeChat Login

TensorFlow 神经网络项目

这是一个基于TensorFlow的神经网络模型,用于银行营销数据预测,判断客户是否会订阅定期存款。

项目结构

/workspace/ ├── neural_network_model.py # 主要的神经网络模型类 ├── test_model.py # 模型测试脚本 ├── demo.py # 模型演示脚本 ├── requirements.txt # 项目依赖 ├── data/ # 数据文件夹 │ ├── train.csv # 训练数据 │ ├── test.csv # 测试数据 │ └── val_unlabeled.csv # 验证数据(无标签) ├── training_history.png # 训练历史图表 └── confusion_matrix.png # 混淆矩阵图表

安装依赖

pip install -r requirements.txt

数据集

该项目使用银行营销数据集,包含以下特征:

  • age: 年龄
  • job: 职业类型
  • marital: 婚姻状况
  • education: 教育水平
  • default: 是否有信用卡违约
  • balance: 账户余额
  • housing: 是否有住房贷款
  • loan: 是否有个人贷款
  • contact: 联系方式
  • day: 最后一次联系日期
  • month: 最后一次联系月份
  • duration: 最后一次联系时长(秒)
  • campaign: 本次活动联系次数
  • pdays: 距离上次联系天数
  • previous: 之前联系次数
  • poutcome: 之前营销结果
  • y: 目标变量(是否订阅定期存款)

模型架构

神经网络模型包含以下层:

  1. 输入层: 128个神经元,ReLU激活函数
  2. Dropout层: 30%的神经元随机失活
  3. 隐藏层1: 64个神经元,ReLU激活函数
  4. Dropout层: 30%的神经元随机失活
  5. 隐藏层2: 32个神经元,ReLU激活函数
  6. Dropout层: 20%的神经元随机失活
  7. 输出层: 1个神经元,Sigmoid激活函数(二分类)

使用方法

1. 运行完整模型训练和测试

python neural_network_model.py

这将执行以下步骤:

  • 加载和预处理数据
  • 构建神经网络模型
  • 训练模型
  • 评估模型性能
  • 生成训练历史和混淆矩阵图表

2. 运行模型测试

python test_model.py

这将测试模型的基本功能和新数据预测能力。

3. 运行演示

python demo.py

这将展示如何使用训练好的模型对新客户数据进行预测。

模型性能

在测试数据上的性能指标:

  • 准确率: ~94%
  • 精确率: ~92%
  • 召回率: ~97%
  • F1分数: ~94%

特性

  • ✅ 完整的数据预处理流程
  • ✅ 自动处理分类变量编码
  • ✅ 特征标准化
  • ✅ 早停机制防止过拟合
  • ✅ 模型性能评估
  • ✅ 可视化训练过程
  • ✅ 混淆矩阵可视化
  • ✅ 新数据预测功能

技术栈

  • TensorFlow 2.13.0: 深度学习框架
  • Pandas: 数据处理
  • NumPy: 数值计算
  • Scikit-learn: 机器学习工具
  • Matplotlib/Seaborn: 数据可视化

自定义使用

from neural_network_model import BankMarketingNN # 创建模型实例 model = BankMarketingNN() # 加载数据 model.load_data('path/to/train.csv', 'path/to/test.csv') # 数据预处理 model.preprocess_data() # 构建模型 model.build_model(input_dim=len(model.feature_columns)) # 训练模型 model.train_model(epochs=50, batch_size=32) # 评估模型 accuracy, y_pred, y_pred_proba = model.evaluate_model()

注意事项

  1. 确保数据文件路径正确
  2. 新数据需要经过相同的预处理流程
  3. 模型使用CPU进行训练,如需GPU加速请配置CUDA环境
  4. 训练时间取决于数据大小和硬件配置

扩展建议

  1. 尝试不同的网络架构
  2. 调整超参数(学习率、批次大小等)
  3. 添加更多特征工程技术
  4. 实现交叉验证
  5. 添加模型保存和加载功能