MMDetection中对Resnet增加注意力机制Attention的简单方法
短信预约 -IT技能 免费直播动态提醒
# -*- encoding: utf-8 -*-'''@File : resnet_with_attention.py@Time : 2023/03/25 08:55:30@Author : RainfyLee @Version : 1.0@Contact : 379814385@qq.com''' # here put the import lib import torchfrom mmdet.models.backbones import ResNetfrom fightingcv_attention.attention.CoordAttention import CoordAttfrom fightingcv_attention.attention.SEAttention import SEAttentionfrom mmdet.models.builder import BACKBONES # 定义带attention的resnet18基类class ResNetWithAttention(ResNet): def __init__(self , **kwargs): super(ResNetWithAttention, self).__init__(**kwargs) # 目前将注意力模块加在最后的三个输出特征层 # resnet输出四个特征层 if self.depth in (18, 34): self.dims = (64, 128, 256, 512) elif self.depth in (50, 101, 152): self.dims = (256, 512, 1024, 2048) else: raise Exception() self.attention1 = self.get_attention_module(self.dims[1]) self.attention2 = self.get_attention_module(self.dims[2]) self.attention3 = self.get_attention_module(self.dims[3]) # 子类只需要实现该attention即可 def get_attention_module(self, dim): raise NotImplementedError() def forward(self, x): outs = super().forward(x) outs = list(outs) outs[1] = self.attention1(outs[1]) outs[2] = self.attention2(outs[2]) outs[3] = self.attention3(outs[3]) outs = tuple(outs) return outs @BACKBONES.register_module()class ResNetWithCoordAttention(ResNetWithAttention): def __init__(self , **kwargs): super(ResNetWithCoordAttention, self).__init__(**kwargs) # 子类只需要实现该attention即可 def get_attention_module(self, dim): return CoordAtt(inp=dim, oup=dim, reduction=32) @BACKBONES.register_module()class ResNetWithSEAttention(ResNetWithAttention): def __init__(self , **kwargs): super(ResNetWithSEAttention, self).__init__(**kwargs) # 子类只需要实现该attention即可 def get_attention_module(self, dim): return SEAttention(channel=dim, reduction=16) if __name__ == "__main__": # model = ResNet(depth=18) # model = ResNet(depth=34) # model = ResNet(depth=50) # model = ResNet(depth=101) # model = ResNet(depth=152) # model = ResNetWithCoordAttention(depth=18) model = ResNetWithSEAttention(depth=18) x = torch.rand(1, 3, 224, 224) outs = model(x) # print(outs.shape) for i, out in enumerate(outs): print(i, out.shape)
以resnet为例子,我在多个尺度的特征层输出增加注意力机制,以此编写一个基类,子类只需要实现这个attention即可。
参考开源仓库实现attention:
当然也可以直接pip调用:
pip install fightingcv-attention
测试完模型输出后可以利用注册到mmdetection:
简单的方法是,添加backbone注册修饰器,并在train.py和test.py中,import 该文件。
在配置上将model的type从Resnet更改为ResNetWithSEAttention或者ResNetWithSEAttention即可。
来源地址:https://blog.csdn.net/qq_21904447/article/details/129762735
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341