我的编程空间,编程开发者的网络收藏夹
学习永远不晚

改进YOLOv5:添加EMA注意力机制

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

改进YOLOv5:添加EMA注意力机制

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

本文主要介绍一种在YOLOv5-7.0中添加EMA注意力机制的方法。EMA注意力机制原论文地址,有关EMA注意力机制的解读可参考文章

新建EMA.py文件

在yolov5的models文件中新建一个名为EMA.py文件,将下述代码复制到EMA.py文件中并保存。

import torchfrom torch import nnclass EMA(nn.Module):    def __init__(self, channels, factor=8):        super(EMA, self).__init__()        self.groups = factor        assert channels // self.groups > 0        self.softmax = nn.Softmax(-1)        self.agp = nn.AdaptiveAvgPool2d((1, 1))        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))        self.pool_w = nn.AdaptiveAvgPool2d((1, None))        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)    def forward(self, x):        b, c, h, w = x.size()        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w        x_h = self.pool_h(group_x)        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))        x_h, x_w = torch.split(hw, [h, w], dim=2)        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())        x2 = self.conv3x3(group_x)        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

修改yolo.py文件

1.导入EMA.py

在yolo.py文件开头导入EMA.py,代码如下:

from models.EMA import EMA

代码放在yolo.py位置如下图所示:
在这里插入图片描述

2.修改parse_model

这里主要是添加通道参数,再添加一个elif,把EMA添加进去,代码如下:

 elif m is EMA:               args = [ch[f], *args]

添加上述代码的位置可参考下图:
在这里插入图片描述


修改yaml文件(yolov5s为例)

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parametersnc: 80  # number of classesdepth_multiple: 0.33  # model depth multiplewidth_multiple: 0.50  # layer channel multipleanchors:  - [10,13, 16,30, 33,23]  # P3/8  - [30,61, 62,45, 59,119]  # P4/16  - [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbonebackbone:  # [from, number, module, args]  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4   [-1, 3, C3, [128]],   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8   [-1, 6, C3, [256]],   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16   [-1, 9, C3, [512]],   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32   [-1, 3, C3, [1024]],   [-1, 1, EMA, [8]],   [-1, 1, SPPF, [1024, 5]],  # 9  ]# YOLOv5 v6.0 headhead:  [[-1, 1, Conv, [512, 1, 1]],   [-1, 1, nn.Upsample, [None, 2, 'nearest']],   [[-1, 6], 1, Concat, [1]],  # cat backbone P4   [-1, 3, C3, [512, False]],  # 13   [-1, 1, Conv, [256, 1, 1]],   [-1, 1, nn.Upsample, [None, 2, 'nearest']],   [[-1, 4], 1, Concat, [1]],  # cat backbone P3   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)   [-1, 1, Conv, [256, 3, 2]],   [[-1, 15], 1, Concat, [1]],  # cat head P4   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)   [-1, 1, Conv, [512, 3, 2]],   [[-1, 11], 1, Concat, [1]],  # cat head P5   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)  ]

上述代码将EMA注意力机制模块加在backbone层中最后C3模块后面,SPPF模块前面,仅供参考,具体添加位置要根据个人数据集的不同合理的添加。

[-1, 1, EMA, [8]], #-1代表连接上一层通道数,1是个数,8是EMA所需的参数(factor=8)

说明:因为在yolo.py文件parse_model函数中修改了通道参数,因此在yaml文件中无需添加通道参数,只需添加EMA函数所需的其他参数。在backbone中添加一层注意力机制模块,因此后续的层数都要加一,在head层中做如下改动。

[[-1, 15], 1, Concat, [1]],  #未改动前的第14层,在经过上述改动后改为15[[-1, 11], 1, Concat, [1]],  #未改动前的第10层,在记过上述改动后改为11[[18, 21, 24], 1, Detect, [nc, anchors]],  #17,20,23层改为18,21,24

运行train.py文件可以在输出终端窗口看到上图网络结构,可以看到在第9层已经成功添加EMA注意力机制模块。

                from  n    params  module      arguments                       0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2]                1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2]                  2                -1  1     18816  models.common.C3                        [64, 64, 1]                     3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]                 4                -1  2    115712  models.common.C3                        [128, 128, 2]                   5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2]                6                -1  3    625152  models.common.C3                        [256, 256, 3]                   7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2]                8                -1  1   1182720  models.common.C3                        [512, 512, 1]                   9                -1  1     41216  models.EMA.EMA                          [512, 8]                       10                -1  1    656896  models.common.SPPF                      [512, 512, 5]                  11                -1  1    131584  models.common.Conv                      [512, 256, 1, 1]               12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']           13           [-1, 6]  1         0  models.common.Concat                    [1]14                -1  1    361984  models.common.C3                        [512, 256, 1, False]           15                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]               16                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']           17           [-1, 4]  1         0  models.common.Concat                    [1]18                -1  1     90880  models.common.C3                        [256, 128, 1, False]           19                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]               20          [-1, 15]  1         0  models.common.Concat                    [1]21                -1  1    296448  models.common.C3                        [256, 256, 1, False]           22                -1  1    590336  models.common.Conv                      [256, 256, 3, 2]               23          [-1, 11]  1         0  models.common.Concat                    [1]24                -1  1   1182720  models.common.C3                        [512, 512, 1, False]           25      [18, 21, 24]  1     16182  models.yolo.Detect                      [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]YOLOv5sEMA summary: 222 layers, 7063542 parameters, 7063542 gradients, 16.2 GFLOPs

参考

https://www.bilibili.com/video/BV1s84y1775U/?spm_id_from=333.788&vd_source=f83457e2adc10b543ae4c742fba1e3b2
https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/131347981

来源地址:https://blog.csdn.net/qq_43615485/article/details/131470922

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

改进YOLOv5:添加EMA注意力机制

下载Word文档到电脑,方便收藏和打印~

下载Word文档

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录