我的编程空间,编程开发者的网络收藏夹
学习永远不晚

sentence-transformers(SBert)中文文本相似度预测(附代码)

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

sentence-transformers(SBert)中文文本相似度预测(附代码)

在这里插入图片描述

前言

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util# 后端接口from flask import Flask, jsonify, requestimport re# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容app = Flask(__name__)# 使通过jsonify返回的中文显示正常,否则显示为ASCII码app.config["JSON_AS_ASCII"] = Falsemodel_path = 'D:/xxx模型路径/'model = SentenceTransformer(model_path)@app.route("/evaluate",methods=['POST'])def evalute_sentence():    s1 = request.json.get("s1")    s2 = request.json.get("s2")    if s1 and s2:        embedding1 = model.encode(s1, convert_to_tensor=True)        embedding2 = model.encode(s2, convert_to_tensor=True)        similarity = util.cos_sim(embedding1, embedding2).tolist()        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})    else:        return jsonify({"code": 400, "msg": "缺少字段"})if __name__ == '__main__':    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>  <div class="recommend">    <el-card class="box">      <h2 class="title">中文文本相似度预测</h2>      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">        <el-form-item prop="s1">          <el-input            placeholder="请输入句子一"            maxlength="32"            show-word-limit            v-model="evaluateForm.s1"            autocomplete="false"            prefix-icon="el-icon-edit-outline"          ></el-input>        </el-form-item>        <el-form-item prop="s2">          <el-input            maxlength="32"            placeholder="请输入句子二"            v-model="evaluateForm.s2"            show-word-limit            autocomplete="false"            prefix-icon="el-icon-edit-outline"          ></el-input>        </el-form-item>        <el-form-item class="btn-container">          <el-button            type="primary"            @click="submitForm('evaluateForm')"            class="btn"            id="queryButton"          >开始预测</el-button>        </el-form-item>      </el-form>      <div v-show="result" style="margin-top: 20px">        <el-progress          :text-inside="true"          :stroke-width="26"          :percentage="result*100 ? result*100 : 0"          class="el-bg-inner-running"        ></el-progress>        <p>预测结果:{{result}}</p>      </div>    </el-card>  </div></template><script>import api from "@/api/index"export default {  data () {    return {      evaluateForm: {        s1: "",        s2: ""      },      evaluateRules: { // 评估表单校验规则        s1: [          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },        ],        s2: [          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },        ],      },      result: undefined,    }  },  methods: {    postEvaluate () { // 调用接口      api.postEvaluate(this.evaluateForm)        .then((res) => {          if (!res) {            return          }          console.log("res", res)          if (res.data.code !== 200) {            this.$message({              message: "请求失败",              type: "error"            })            return          }          let data = res.data.data[0]          this.result = data[0]          console.log("this.result", this.result)          this.$message({            message: "预测成功!",            type: "success"          })        })        .catch((error) => {          this.$message.error('资源获取错误!')        })    },    submitForm (formName) { // 提交表单      this.$refs[formName].validate((valid) => {        if (valid) {          this.postEvaluate()        } else {          this.$message({            message: "请按要求填写",            type: "warning"          })          console.log('error in submit form')          return false        }      })      document.getElementById("queryButton").blur()    },  }}</script><style lang="scss" scoped>.recommend {  width: 100%;  height: 100%;  text-align: center;  display: flex;  text-align: center;  flex-direction: column;  align-items: center;  justify-content: center;  overflow: hidden;  background: #00416a 0 / cover fixed;   background: -webkit-linear-gradient(    to right,    #00416a,    #e4e5e6  );   background: linear-gradient(    to right,    #00416a,    #e4e5e6  );   .box {    width: 48%;    height: 60%;    position: relative;    background: hsla(0, 0%, 100%, 0.3);    z-index: 5;    padding: 10px 20px;    // display: flex;    // flex-direction: column;    // justify-content: center;    box-sizing: border-box;    &::before {      content: '';      position: absolute;      top: 0;      right: 0;      bottom: 0;      left: 0;      filter: blur(20px);    }    .title {      color: #143b54;    }    .btn-container {      margin: 10px auto;      .btn {        width: 100%;        border-radius: 20px;      }    }  }}::v-deep .el-card {  border: 0;  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);}::v-deep .el-progress-bar__outer {  border: 0;  background-color: transparent;  // background-color: #abcbe0;}::v-deep .el-bg-inner-running .el-progress-bar__inner {  background: #9cecfb;   background: -webkit-linear-gradient(    to left,    #0052d4,    #65c7f7,    #9cecfb  );   background: linear-gradient(    to left,    #0052d4,    #65c7f7,    #9cecfb  ); }</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

在这里插入图片描述

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

在这里插入图片描述

最后

代码已上传至sbert中文文本相似度预测,欢迎star!

来源地址:https://blog.csdn.net/weixin_54218079/article/details/128687878

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

sentence-transformers(SBert)中文文本相似度预测(附代码)

下载Word文档到电脑,方便收藏和打印~

下载Word文档

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录