我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Bert的pooler_output是什么?

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Bert的pooler_output是什么?

BERT的两个输出

在学习bert的时候,我们知道bert是输出每个token的embeding。但在使用hugging face的bert模型时,发现除了last_hidden_state还多了一个pooler_output输出。

例如:

from transformers import AutoTokenizer, AutoModeltokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")model = AutoModel.from_pretrained("bert-base-uncased")inputs = tokenizer("I'm caixunkun. I like singing, dancing, rap and basketball.", return_tensors="pt")outputs = model(**inputs)print("last_hidden_state shape:", outputs.last_hidden_state.size())print("pooler_output shape:", outputs.pooler_output.size())
last_hidden_state shape: torch.Size([1, 20, 768])pooler_output shape: torch.Size([1, 768])

许多人可能以为pooler_output[CLS]token的embedding,但使用last_hidden_state shape[:, 0]比较后,发现又不是,然后就很奇怪。

Bert的Pooler_output

先说一下结论: pooler_output可以理解成该句子语义的特征向量表示

那它是怎么来的?和[CLS]token的embedding区别在哪?

我们将Bert模型打印一下,会发现最后还有一个BertPooler层,pooler_output就是从这来的。如下所示:

BertModel((embedding): BertEmbeddings(....)(encoder): BertEncoder(... # 12层TransformerEncoder)(pooler): BertPooler(    (dense): Linear(in_features=768, out_features=768, bias=True)    (activation): Tanh()))

其中encoder就是将BERT的所有token经过12个TransformerEncoder进行embedding。pooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

我们可以从Bert源码中验证以上结论。在transformers.models.bert.modeling_bert.BertModel.forward方法中这么一行代码:

# sequence_output就是last_hidden_state# self.pooler就是上面的BertPoolerpooled_output = self.pooler(sequence_output) if self.pooler is not None else None

我们再来看看transformers.models.bert.modeling_bert.BertPooler的源码:

class BertPooler(nn.Module):    def __init__(self, config):        super().__init__()        self.dense = nn.Linear(config.hidden_size, config.hidden_size)        self.activation = nn.Tanh()    def forward(self, hidden_states):# hidden_states的第一个维度是batch_size。所以用[:, 0]取所有句子的[CLS]的embedding        first_token_tensor = hidden_states[:, 0]        pooled_output = self.dense(first_token_tensor)        pooled_output = self.activation(pooled_output)        return pooled_output

从上面的源码可以看出,pooler_output 就是[CLS]embedding又经历了一次全连接层的输出。我们可以通过以下代码进行验证:

print("pooler:", model.pooler)my_pooler_output = model.pooler(outputs.last_hidden_state)print(my_pooler_output[0, :5])print(outputs.pooler_output[0, :5])
pooler: BertPooler(  (dense): Linear(in_features=768, out_features=768, bias=True)  (activation): Tanh())tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)

Bert的Pooler_output的由来

我们知道,BERT的训练包含两个任务:MLM和NSP任务(Next Sentence Prediction)。 对这两个任务不熟悉的朋友可以参考:BERT源码实现与解读(Pytorch)【论文阅读】BERT 两篇文章。

其中MLM就是挖空,然后让bert预测这个空是什么。做该任务是使用token embedding进行预测。

而Next Sentence Prediction就是预测bert接受的两句话是否为一对。例如:窗前明月光,疑是地上霜 为 True,窗前明月光,李白打开窗为False。

所以,NSP任务需要句子的语义信息来预测,但是我们看下源码是怎么做的。transformers.models.bert.modeling_bert.BertForNextSentencePrediction的部分源码如下:

class BertForNextSentencePrediction(BertPreTrainedModel):    def __init__(self, config):        super().__init__(config)        self.bert = BertModel(config)        self.cls = BertOnlyNSPHead(config)# 这个就是一个 nn.Linear(config.hidden_size, 2)...def forward(...):...outputs = self.bert(...)pooled_output = outputs[1] # 取pooler_outputseq_relationship_scores = self.cls(pooled_output)# 使用pooler_ouput送给后续的全连接层进行预测...

从上面的源码可以看出,在NSP任务训练时,并不是直接使用[CLS]token的embedding作为句子特征传给后续分类头的,而是使用的是pooler_output。个人原因可能是因为直接使用[CLS]的embedding效果不够好。

但在MLM任务时,是直接使用的是last_hidden_state,有兴趣可以看一下transformers.models.bert.modeling_bert.BertForMaskedLM的源码。

来源地址:https://blog.csdn.net/zhaohongfei_358/article/details/127960742

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Bert的pooler_output是什么?

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Transformer之Bert预训练语言解析的方法是什么

今天小编给大家分享一下Transformer之Bert预训练语言解析的方法是什么的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧
2023-07-05

Prometheus中的TSDB是什么,它的作用是什么

Prometheus中的TSDB是时间序列数据库(Time Series Database)的缩写,它是用来存储和管理时间序列数据的一种数据库系统。TSDB在Prometheus中的作用是存储监控数据的时间序列信息,包括指标数据、标签信息和
Prometheus中的TSDB是什么,它的作用是什么
2024-03-04

web前端是做什么的?优势是什么?

Web前端是指开发Web页面的技术和工具,也称作前端工程师。随着互联网的迅速发展,Web前端在互联网领域中变得越来越重要。这篇文章将讨论Web前端是做什么的,以及为什么Web前端在今天的互联网领域中如此重要。
2023-05-14

什么是软考?软考的全称是什么

  什么是软考?软考全称是什么?对于软考这项考试有些考生并不是很了解,编程学习网小编就来为大家解读什么是软考、软考全称以及软考考试性质。  软考也叫软件水平考试,软考全称为计算机技术与软件专业技术资格(水平)考试,是由国家人力资源和社会保障部(原人事部)、工业和信息化部(原信息产业部)领导的国家级考试,其目的是,科学、公正
什么是软考?软考的全称是什么
2024-04-19

什么是 ipsec?SDN 是什么?

IPsec是一种协议套件,用于确保IP网络通信的安全,提供保密性、完整性和身份验证。SDN是一种网络架构,将网络控制平面与数据平面分离,集中控制和可编程性。两者的结合可增强网络安全性和可编程性:SDN可动态配置IPsec策略,IPsec增强SDN网络安全性,SDN简化IPsec管理。
什么是 ipsec?SDN 是什么?
2024-04-02

java的sdk是什么

SDK是Software Development Kit的缩写,中文意思是“软件开发工具包”。这是一个覆盖面相当广泛的名词,可以这么说:辅助开发某一类软件的相关文档、范例和工具的集合都可以叫做“SDK”。SDK是一系列文件的组合,它为软件的开发提供一个平台(它
java的sdk是什么
2017-08-31

什么是javascript的alert

这篇文章主要介绍“什么是javascript的alert”,在日常操作中,相信很多人在什么是javascript的alert问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”什么是javascript的alert
2023-06-14

PHPWAMP指的是什么

这篇文章将为大家详细讲解有关PHPWAMP指的是什么,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。使用方式:点击相关设置,直接打开IIS站点管理即可使用,如果你电脑没安装IIS,会自动快速安装(右键新标签
2023-06-15

python32指的是什么

这篇文章将为大家详细讲解有关python32指的是什么,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。Python主要用来做什么Python主要应用于:1、Web开发;2、数据科学研究;3、网络爬虫;4、嵌
2023-06-14

Java的ClassLoader是什么

本文小编为大家详细介绍“Java的ClassLoader是什么”,内容详细,步骤清晰,细节处理妥当,希望这篇“Java的ClassLoader是什么”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。ClassLoad
2023-06-16

Android的SplashScreen是什么

本篇内容主要讲解“Android的SplashScreen是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Android的SplashScreen是什么”吧!什么是SplashScreenS
2023-06-29

15PHP指的是什么

本文小编为大家详细介绍“15PHP指的是什么”,内容详细,步骤清晰,细节处理妥当,希望这篇“15PHP指的是什么”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。15PHP是指15菲律宾比索,这里的“php”是菲律宾
2023-07-04

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录