Pytorch训练网络过程中loss突然变为0的解决方案
短信预约 -IT技能 免费直播动态提醒
问题
// loss 突然变成0
python train.py -b=8
INFO: Using device cpu
INFO: Network:
1 input channels
7 output channels (classes)
Bilinear upscaling
INFO: Creating dataset with 868 examples
INFO: Starting training:
Epochs: 5
Batch size: 8
Learning rate: 0.001
Training size: 782
Validation size: 86
Checkpoints: True
Device: cpu
Images scaling: 1
Epoch 1/5: 10%|██████████████▏ | 80/782 [01:33<13:21, 1.14s/img, loss (batch)=0.886I
NFO: Validation cross entropy: 1.86862473487854
Epoch 1/5: 20%|███████████████████████████▊ | 160/782 [03:34<11:51, 1.14s/img, loss (batch)=2.35e-7I
NFO: Validation cross entropy: 5.887489884504049e-10
Epoch 1/5: 31%|███████████████████████████████████████████▌ | 240/782 [05:41<11:29, 1.27s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 41%|██████████████████████████████████████████████████████████ | 320/782 [07:49<09:16, 1.20s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 51%|████████████████████████████████████████████████████████████████████████▋ | 400/782 [09:55<07:31, 1.18s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 61%|███████████████████████████████████████████████████████████████████████████████████████▏ | 480/782 [12:02<05:58, 1.19s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 72%|█████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 560/782 [14:04<04:16, 1.15s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 82%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 640/782 [16:11<02:49, 1.20s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 720/782 [18:21<01:18, 1.26s/img, loss (batch)=0I
NFO: Validation cross entropy: 0.0
Epoch 1/5: 94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 736/782 [19:17<01:12, 1.57s/img, loss (batch)=0]
Traceback (most recent call last):
File "train.py", line 182, in <module>
val_percent=args.val / 100)
File "train.py", line 66, in train_net
for batch in train_loader:
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 819, in __next__
return self._process_data(data)
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
data.reraise()
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/_utils.py", line 385, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 4.
Original Traceback (most recent call last):
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/public/home/lidd/.conda/envs/lgg2/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: Expected object of scalar type Double but got scalar type Byte for sequence element 4 in sequence argument at position #1 'tensors'
交叉熵损失函数是衡量输出与标签之间的损失,通过求导确定梯度下降的方向。
loss突然变为0,有两种可能性。
一是因为预测输出为0,二是因为标签为0。
如果是因为标签为0,那么一开始loss就可能为0.
检查参数初始化
检查前向传播的网络
检查loss的计算格式
检查梯度下降
是否出现梯度消失。
实际上是标签出了错误
补充:pytorch训练出现loss=na
遇到一个很坑的情况,在pytorch训练过程中出现loss=nan的情况
有以下几种可能:
1.学习率太高。
2.loss函数有问题
3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决
4.数据本身,是否存在Nan、inf,可以用np.isnan(),np.isinf()检查一下input和target
5.target本身应该是能够被loss函数计算的,比如sigmoid激活函数的target应该大于0,同样的需要检查数据集
以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。如有错误或未考虑完全的地方,望不吝赐教。
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341