深度学习与PyTorch入门实战 - 32. Train-Val-Test-交叉验证
1.10 MB
13 页
0 评论
语言 | 格式 | 评分 |
---|---|---|
中文(简体) | .pdf | 3 |
摘要 | ||
文档介绍了在深度学习中使用Train-Val-Test数据集划分和K折交叉验证技术的方法,用于减轻Overfitting问题。通过PyTorch框架,展示了如何实现数据加载和模型训练,并提供了相应的代码示例。文档还讨论了训练误差、测试误差以及模型的合适性评估。 | ||
AI总结 | ||
以下是对文档内容的总结:
本文主要介绍了深度学习中_train-val-test_划分和_k-fold cross validation_的概念及其在PyTorch中的实现方法,旨在帮助减轻模型的过拟合问题并提高模型的泛化能力。
### 1. **Train-Val-Test划分**
- **目的**:通过划分训练集(train set)、验证集(val set)和测试集(test set),避免过拟合并评估模型性能。
- **实现**:
- **划分方法**:将数据集随机划分为训练集、验证集和测试集。
- **PyTorch实现**:使用`torch.utils.data.random_split`对数据集进行划分,并通过`DataLoader`创建训练集和验证集的数据加载器。
```python
train_db, val_db = torch.utils.data.random_split(train_db, [train_size, val_size])
train_loader = DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_db, batch_size=batch_size, shuffle=False)
```
- **结果输出**:打印训练集和测试集的长度。
```python
print('train:', len(train_db), 'test:', len(test_db))
```
### 2. **K-Fold Cross Validation**
- **概念**:将训练集分为_k_个子集,其中每次使用1个子集作为验证集,其余作为训练集,可以减少验证结果的偏差。
- **实现**:
- 将训练集和验证集合并,然后随机采样1/k作为验证集。
- 通过多次划分和训练模型,取平均结果,提高模型的泛化能力。
### 3. **训练流程**
- **训练循环**:定义训练误差和测试误差,在训练过程中打印训练误差。
```python
for epoch in range(10):
# 训练
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images.view(-1, 28*28))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 验证
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images.view(-1, 28*28))
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(test_loader.dataset)
print('Epoch {}: Test Loss {:.4f}, Accuracy {:.2f}%'.format(epoch+1, test_loss / len(test_loader.dataset), 100. * accuracy))
```
### 4. **数据预处理**
- 使用`transforms.Compose`对数据进行预处理,例如归一化。
```python
transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((channel_mean), (channel_std))
])
```
### 核心观点
- 合理划分训练集、验证集和测试集可以帮助评估模型性能并防止过拟合。
- K折交叉验证是一种有效的方法,可以减少验证结果的波动并提高模型的稳健性。
- PyTorch提供了灵活的工具和接口,使得数据的划分和加载变得简单高效。 |
P1
P2
P3
P4
P5
P6
P7
下载文档到本地,方便使用
- 可预览页数已用完,剩余
6 页请下载阅读 -
文档评分