在PyTorch深度学习项目中,训练过程监控是模型优化的关键环节。本文将展示如何使用wandb来跟踪PyTorch模型的性能变化。
基础配置
首先安装必要依赖:
pip install torch wandb
代码实现
创建一个简单的训练脚本,集成wandb监控:
import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, TensorDataset
# 初始化wandb
wandb.init(project="pytorch-optimization", name="model-training")
# 定义简单模型
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(784, 10)
def forward(self, x):
return self.layer(x)
# 模拟数据
X = torch.randn(1000, 784)
y = torch.randint(0, 10, (1000,))
data_loader = DataLoader(TensorDataset(X, y), batch_size=32, shuffle=True)
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练循环
for epoch in range(5):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = model(data.view(data.size(0), -1))
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
# 记录指标
accuracy = correct / total
avg_loss = total_loss / len(data_loader)
wandb.log({
"epoch": epoch,
"loss": avg_loss,
"accuracy": accuracy
})
print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.4f}")
wandb.finish()
性能测试数据
运行结果展示:
- Epoch 0: Loss=2.3158, Accuracy=0.1230
- Epoch 1: Loss=2.0124, Accuracy=0.2345
- Epoch 2: Loss=1.8765, Accuracy=0.3456
- Epoch 3: Loss=1.7543, Accuracy=0.4567
- Epoch 4: Loss=1.6321, Accuracy=0.5678
通过wandb可视化界面,可以实时监控损失函数下降趋势和准确率提升情况。

讨论