PyTorch 中的 model.train() 起什么作用?
- 2025-03-18 08:55:00
- admin 原创
- 44
问题描述:
它是否调用forward()
?nn.Module
我认为当我们调用模型时,forward
方法正在被使用。为什么我们需要指定train()?
解决方案 1:
model.train()
告诉您的模型您正在训练模型。这有助于通知 Dropout 和 BatchNorm 等层,这些层在训练和评估期间的行为有所不同。例如,在训练模式下,BatchNorm 会在每个新批次上更新移动平均值;而在评估模式下,这些更新会被冻结。
更多详细信息:model.train()
将模式设置为训练(参见源代码)。您可以调用model.eval()
或model.train(mode=False)
来表明您正在测试。期望train
函数训练模型有点直观,但它并没有这样做。它只是设置模式。
解决方案 2:
代码如下nn.Module.train()
:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
代码如下nn.Module.eval()
:
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
默认情况下,该self.training
标志设置为True
,即模块默认处于训练模式。当 时self.training
,False
模块处于相反状态,即评估模式。
对最常用的层来说,只Dropout
关心BatchNorm
那个标志。
解决方案 3:
model.train() | model.eval() |
---|---|
将模型设置为训练模式,即 •BatchNorm 层使用每批次统计数据•Dropout 层已激活等 | 将模型设置为评估(推理)模式,即•BatchNorm 各层使用运行统计数据•Dropout 各层停用等 |
相当于model.train(False) 。 |
注意:这两个函数调用都不会运行前向/后向传递。它们会告诉模型运行时如何操作。
这很重要,因为某些模块(层)(例如Dropout
)BatchNorm
在训练和推理期间的行为设计不同,因此如果在错误的模式下运行,模型将产生意外的结果。
解决方案 4:
有两种方法可以让模型知道你的意图,即你是想训练模型还是想用模型来评估。如果model.train()
模型知道它必须学习层,当我们使用model.eval()
它时,它表示模型不需要学习任何新东西,并且该模型用于测试。model.eval()
这也是必要的,因为在 pytorch 中如果我们使用 batchnorm,并且在测试期间如果我们只想传递单个图像,如果model.eval()
没有指定,pytorch 会抛出一个错误。
解决方案 5:
考虑以下模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
这里, 的功能dropout
在不同的操作模式下有所不同。如您所见,它仅在 时起作用self.training==True
。因此,当您输入 时model.train()
,模型的 forward 函数将执行 dropout,否则不会执行(例如当model.eval()
或 时model.train(mode=False)
)。
解决方案 6:
目前的官方文档说明如下:
这仅对某些模块有影响。如果受影响,请参阅特定模块的文档以了解其在训练/评估模式下的行为详情,例如 Dropout、BatchNorm 等。
扫码咨询,免费领取项目管理大礼包!