我的编程空间,编程开发者的网络收藏夹
学习永远不晚

ML工程师一次微调七个模型,击败OpenAI GPT-4

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

ML工程师一次微调七个模型,击败OpenAI GPT-4

在大模型时代,提示词工程(Prompt Engineering)、模型微调和检索增强生成(RAG)都是非常重要的能力,对于大多数个人使用者来说,掌握提示词工程就够用了,但如果想要在自己的服务中接入大模型,模型微调是必由之路,也因为其对于技能的更高要求,成为了ML工程师高手过招之地。

今日,一位ML工程师的微调模型就登上了HN热榜,「我的微调模型击败了OpenAI的GPT-4」!

工程师Alex Strick van Linschoten在博客中指出,就他的测试数据而言,Mistral、Llama3和Solar LLM的微调模型都要比OpenAI的模型更准确。

文章地址:https://mlops.systems/posts/2024-07-01-full-finetuned-model-evaluation.html#start-date-accuracy

对此,作者戏称道,这篇文章的可以直接改为:微调模型击败OpenAI,但评估过程实在痛苦。

他坦言,代码量很多,运行速度也较慢。在这次微调中,他第一次为微调选择而纠结权衡。如果没有某种系统来处理这个问题,维护这一切的复杂性就会开始增加。

现在,我们来看一下这位工程师是如何做的。(详细代码请参见原文)

加载数据集

所有数据都存储在Hugging Face Hub的一个公共仓库中。

为了进行这些评估,作者选择使用数据集的测试部分,因为模型之前没有接触过这些数据,这样可以更好地评估模型在新数据上的表现。

test_dataset
Dataset({
    features: ['name', 'eventrefnumber', 'text', 'StartDate', 'eventtype', 'province', 'citydistrict', 'village', 'targetgroup', 'commander', 'position', 'minkilled', 'mincaptured', 'capturedcharacterisation', 'killedcharacterisation', 'killq', 'captureq', 'killcaptureraid', 'airstrike', 'noshotsfired', 'dataprocessed', 'flagged', 'glossarymeta', 'minleaderskilled', 'minfacilitatorskilled', 'minleaderscaptured', 'minfacilitatorscaptured', 'leaderq'],
    num_rows: 724
})

首先在DataFrame中添加一个额外的列,然后对数据集中的每一行进行预测。将预测的副本存储到这一列中,以避免重复执行这个计算密集的步骤。

但首先需要将数据组装成Pydantic对象,以便处理数据验证。

[
    IsafEvent(
        name='5',
        text='2013-01-S-025\n\nKABUL, Afghanistan (Jan. 25, 2013)\nDuring a security operation in Andar district, Ghazni province, yesterday, an Afghan and coalition force killed the Taliban leader, Alaudin. Alaudin oversaw a group of insurgents responsible for conducting remote-controlled improvised explosive device and small-arms fire attacks against Afghan and coalition forces. Prior to his death, Alaudin was planning attacks against Afghan National Police in Ghazni province.',
        start_date=datetime.date(2013, 1, 24),
        event_type={'insurgentskilled'},
        province={'ghazni'},
        target_group={'taliban'},
        min_killed=1,
        min_captured=0,
        killq=True,
        captureq=False,
        killcaptureraid=False,
        airstrike=False,
        noshotsfired=False,
        min_leaders_killed=1,
        min_leaders_captured=0,
        predictions={}
    ),
    IsafEvent(
        name='2',
        text='2011-11-S-034\nISAF Joint Command - Afghanistan\nFor Immediate Release\n\nKABUL, Afghanistan (Nov. 20, 2011)\nA coalition security force detained numerous suspected insurgents during an operation in Marjeh district, Helmand province, yesterday.  The force conducted the operation after receiving information that a group of insurgents were at a compound in the area.  After calling for the men inside to come out peacefully, the insurgents emerged and were detained without incident.',
        start_date=datetime.date(2011, 11, 19),
        event_type={'detention'},
        province={'helmand'},
        target_group={''},
        min_killed=0,
        min_captured=4,
        killq=False,
        captureq=True,
        killcaptureraid=True,
        airstrike=False,
        noshotsfired=False,
        min_leaders_killed=0,
        min_leaders_captured=0,
        predictions={}
     )
]

因此,当进行预测时,我们希望从模型中得到一个类似这样的JSON字符串:

json_str = events[0].model_dump_json(exclude={"text", "predictions"})
print(json_str)
{"name":"5","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["tali
ban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":false,"nosh
otsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}

从使用GPT模型进行完整评估开始,需要一个更复杂的提示词才能获得理想的结果。

由于GPT模型没有经过训练或微调来响应微调模型的特定提示词,因此我们不能直接使用相同的提示词。

这带来了一个有趣的问题:我们需要花多少精力在设计GPT提示词上,才能达到微调模型的准确度?换句话说,是否真的有办法在接受不同提示词的模型之间进行公平的比较?

尝试OpenAI的GPT-4和GPT-4 Turbo可以看到,为了让GPT模型有机会与微调模型竞争,提示词需要多长。

理想情况下,作者会在上下文中加入更多的示例,但他也不希望增加token的使用量。

from openai import OpenAI
from rich import print
import json
import os


def query_openai(article_text: str, model: str) -> str:
    query = (
        f"The following is a press release issued by ISAF (formerly operating in Afghanistan):\n{article_text}\n\n"
        "## Extraction request\n"
        "Please extract the following information from the press release:\n"
        "- The name of the event (summarising the event / text as a headline)\n"
        "- The start date of the event\n"
        "- The event type(s)\n"
        "- The province(s) in which the event occurred\n"
        "- The target group(s) of the event\n"
        "- The minimum number of people killed during the event\n"
        "- The minimum number of people captured during the event\n"
        "- Whether someone was killed or not during the event\n"
        "- Whether someone was captured or not during the event\n"
        "- Whether the event was a so-called 'kill-capture raid'\n"
        "- Whether an airstrike was used during the event\n"
        "- Whether no shots were fired during the event\n"
        "- The minimum number of leaders killed during the event\n"
        "- The minimum number of leaders captured during the event\n\n"
        "## Annotation notes:\n"
        "- A 'faciliator' is not a leader.\n"
        "- If a press release states that 'insurgents' were detained without further "
        "details, assign a minimum number of two detained. Interpret 'a couple' as "
        "two. Interpret 'several' as at least three, even though it may sometimes "
        "refer to seven or eight. Classify the terms 'a few', 'some', 'a group', 'a "
        "small group', and 'multiple' as denoting at least three, even if they "
        "sometimes refer to larger numbers. Choose the smaller number if no other "
        "information is available in the press release to come up with a minimally "
        "acceptable figure. Interpret 'numerous' and 'a handful' as at least four, "
        "and 'a large number' as at least five.\n\n"
        "## Example:\n"
        "Article text: 'ISAF Joint Command Evening Operational Update Feb. 19, 2011\nISAF Joint Command - "
        "Afghanistan\u20282011-02-S-143\u2028For Immediate Release \u2028\u2028KABUL, Afghanistan (Feb. 19)\u2028\u2028ISAF "
        "service members at a compound in Sangin district, Helmand province observed numerous insurgents north and south of "
        "their position talking on radios today. After gaining positive identification of the insurgent positions, the "
        "coalition troops engaged, killing several insurgents. Later, the ISAF troops observed more insurgents positioning "
        "in the area with weapons. After positive identification, coalition forces continued firing on the various insurgent "
        "positions, resulting in several more insurgents being killed.'\n\n"
        'Output: `{"name":"Several insurgents killed in '
        'Helmand","start_date":"2011-02-18","event_type":["insurgentskilled"],"province":["helmand"],"target_group":[""],"mi'
        'n_killed":6,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":false,"noshotsfired"'
        ':false,"min_leaders_killed":0,"min_leaders_captured":0}`'
    )


    # set up the prediction harness
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))


    response = client.chat.completions.create(
        model=model,
        response_format={"type": "json_object"},
        messages=[
            {
                "role": "system",
                "content": "You are an expert at identifying events in a press release. You are precise "
                "and always make sure you are correct, drawing inference from the text of the "
                "press release.\n\n You always return a JSON string with the following schema: "
                "## JSON Schema details\n"
                "Here is some of the schema for the JSON output string you "
                "should make use of: event_types = ['airstrike', 'detention', "
                "'captureandkill', 'insurgentskilled', 'exchangeoffire', 'civiliancasualty'], "
                "provinces = ['badakhshan', 'badghis', 'baghlan', 'balkh', 'bamyan', "
                "'day_kundi', 'farah', 'faryab', 'ghazni', 'ghor', 'helmand', 'herat', "
                "'jowzjan', 'kabul', 'kandahar', 'kapisa', 'khost', 'kunar', 'kunduz', "
                "'laghman', 'logar', 'nangarhar', 'nimroz', 'nuristan', 'paktya', 'paktika', "
                "'panjshir', 'parwan', 'samangan', 'sar_e_pul', 'takhar', 'uruzgan', "
                "'wardak', 'zabul'], target_groups = ['taliban', 'haqqani', 'criminals', "
                "'aq', 'hig', 'let', 'imu', 'judq', 'iju', 'hik', 'ttp', 'other']\n\n",
            },
            {"role": "user", "content": query},
        ],
        temperature=1,
    )


    return response.choices[0].message.content

可以通过一个简单的示例来验证这个函数是否正常工作:

json_str = query_openai(events[0].text, "gpt-4o")
print(json.loads(json_str))
{
    'name': 'Taliban leader Alaudin killed in Ghazni',
    'start_date': '2013-01-24',
    'event_type': ['insurgentskilled'],
    'province': ['ghazni'],
    'target_group': ['taliban'],
    'min_killed': 1,
    'min_captured': 0,
    'killq': True,
    'captureq': False,
    'killcaptureraid': True,
    'airstrike': False,
    'noshotsfired': False,
    'min_leaders_killed': 1,
    'min_leaders_captured': 0
}

模型正常工作(如预期的那样),并且我们得到了一个JSON字符串。

接下来,构建一个程序来遍历所有测试数据,获取预测结果,并将这些预测结果存储在Pydantic对象中。

对于批量预测,要确保以异步方式进行,因为有大量的事件,所以不希望耗费太多时间。

此外,作者还在函数中添加了一些重试机制,以应对GPT-3.5-turbo模型的速率限制。

正如我们现在所看到的,作者对每个事件都附加了三个预测结果。

print(events[0])
IsafEvent(
    name='5',
    text='2013-01-S-025\n\nKABUL, Afghanistan (Jan. 25, 2013)\nDuring a security operation in Andar district, Ghazni province, yesterday, an Afghan and coalition force killed the Taliban leader, Alaudin. Alaudin oversaw a group of insurgents responsible for conducting remote-controlled improvised explosive device and small-arms fire attacks against Afghan and coalition forces. Prior to his death, Alaudin was planning attacks against Afghan National Police in Ghazni province.',
    start_date=datetime.date(2013, 1, 24),
    event_type={'insurgentskilled'},
    province={'ghazni'},
    target_group={'taliban'},
    min_killed=1,
    min_captured=0,
    killq=True,
    captureq=False,
    killcaptureraid=False,
    airstrike=False,
    noshotsfired=False,
    min_leaders_killed=1,
    min_leaders_captured=0,
    predictinotallow={'gpt-4o': '{\n  "name": "Taliban leader Alaudin killed in Ghazni",\n  "start_date": "2013-01-24",\n  "event_type": ["insurgentskilled", "captureandkill"],\n  "province": ["ghazni"],\n  "target_group": ["taliban"],\n "min_killed": 1,\n  "min_captured": 0,\n  "killq": true,\n  "captureq": false,\n  "killcaptureraid": true,\n  "airstrike": false,\n  "noshotsfired": false,\n  "min_leaders_killed": 1,\n  "min_leaders_captured": 0\n}',
        'gpt-4-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": true,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}',
        'gpt-3.5-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni province",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": false,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}'
    }
)

目前,已经将所有预测结果都存储在内存中,所以现在是时候将它们提交到数据集,并推送到Hugging Face Hub了,以防笔记本崩溃、本地计算机关闭或其他意外情况发生。

作者创建了一个函数来处理这个过程,因为还需要对其他模型重复这个步骤。虽然过程有点冗长,但这样做更好,这样方便我们可以清楚地看到每一步的操作。

一个更简洁和抽象的 convert_to_dataset 函数可能如下所示:

def convert_to_dataset(data: List[BaseModel]) -> Dataset:
    dataset_dict = {}


    for field_name, field_value in data[0].__fields__.items():
        field_type = field_value.outer_type_
        if field_type in [str, int, float, bool, date]:
            dataset_dict[field_name] = [getattr(item, field_name) for item in data]
        elif field_type == set:
            dataset_dict[field_name] = [list(getattr(item, field_name)) for item in data]
        elif issubclass(field_type, BaseModel):
            dataset_dict[field_name] = [getattr(item, field_name).dict() for item in data]
        else:
            dataset_dict[field_name] = [getattr(item, field_name) for item in data]


    dataset = Dataset.from_dict(dataset_dict)
    return dataset

不过现在,要先把数据推送到Hub上。

convert_and_push_dataset(events, "isafpressreleases_with_preds", split_name="test")

添加来自微调模型的预测

在添加完基线OpenAI模型之后,现在再来添加一些之前微调过的模型,包括本地模型以及由一键微调服务商托管的模型。

重新加载预测数据集

先加载数据集,然后再添加一些本地模型的预测结果:

from datasets import load_dataset

preds_test_data = load_dataset("strickvl/isafpressreleases_with_preds")[
    "test"
].to_list()

微调TinyLlama的预测

现在,如果我们检查数据集,就会发现新的模型预测结果已经保存进去了:

from rich import print

print(preds_test_data[0])
{
    'name': '5',
    'text': '2013-01-S-025\n\nKABUL, Afghanistan (Jan. 25, 2013)\nDuring a security operation in Andar district, Ghazni province, yesterday, an Afghan and coalition force killed the Taliban leader, Alaudin. Alaudin oversaw a group of insurgents responsible for conducting remote-controlled improvised explosive device and small-arms fire attacks against Afghan and coalition forces. Prior to his death, Alaudin was planning attacks against Afghan National Police in Ghazni province.',
    'predictions': {'gpt-3.5-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni province",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": false,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}',
        'gpt-4-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": true,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}',
        'gpt-4o': '{\n  "name": "Taliban leader Alaudin killed in Ghazni",\n  "start_date": "2013-01-24",\n  "event_type": ["insurgentskilled", "captureandkill"],\n  "province": ["ghazni"],\n  "target_group": ["taliban"],\n "min_killed": 1,\n  "min_captured": 0,\n  "killq": true,\n  "captureq": false,\n  "killcaptureraid": true,\n  "airstrike": false,\n  "noshotsfired": false,\n  "min_leaders_killed": 1,\n  "min_leaders_captured": 0\n}',
        'tinyllama-templatefree': '\n{"name":"Taliban leader killed in Ghazni","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'tinyllama-sharegpt': 
'{"name":"2","start_date":"2013-01-24","event_type":["airstrike"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":true,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}'
    },
    'start_date': datetime.date(2013, 1, 24),
    'province': ['ghazni'],
    'target_group': ['taliban'],
    'event_type': ['insurgentskilled'],
    'min_killed': 1,
    'min_captured': 0,
    'killq': True,
    'captureq': False,
    'killcaptureraid': False,
    'airstrike': False,
    'noshotsfired': False,
    'min_leaders_killed': 1,
    'min_leaders_captured': 0
}

微调Mistral的预测

正如作者之前提到的,微调后的Mistral模型无法在本地运行,所以他在Modal上进行推理,在那里他可以使用强大的A100显卡来进行预测。

我们会看到,模型的表现并不理想,几乎所有评估都失败了。在图表中数值为零的就是mistral-lora-templatefree模型。

微调OpenAI的预测

使用OpenAI的一键微调服务对gpt-3.5-turbo-1106模型进行了微调。通过OpenAI SDK 遍历数据集,生成了这个微调模型的预测结果。

微调Mistral模型(通过OpenPipe)

使用OpenPipe微调了Mistral 7B和Mistral 8x7B模型,以便有一个合理的基准来比较其他模型。

微调Solar LLM(通过Predibase)

大约一周前,Predibase宣布了一个新的顶级微调模型,即Upstage的Solar LLM,所以作者决定试试。

这个模型的优势在于它被训练得非常擅长人们常常微调模型的任务,例如结构化数据提取。正如在图表中呈现的,它表现得相当不错!

微调Llama3的预测(通过OpenPipe)

作者本地微调的Llama3模型表现并不好,但在OpenPipe上的输出看起来还可以,所以他使用这些预测进行最终评估。

from rich import print

print(preds_test_data[0])
{
    'name': '5',
    'text': '2013-01-S-025\n\nKABUL, Afghanistan (Jan. 25, 2013)\nDuring a security operation in Andar district, Ghazni province, yesterday, an Afghan and coalition force killed the Taliban leader, Alaudin. Alaudin oversaw a group of insurgents responsible for conducting remote-controlled improvised explosive device and small-arms fire attacks against Afghan and coalition forces. Prior to his death, Alaudin was planning attacks against Afghan National Police in Ghazni province.',
    'predictions': {'finetuned-llama3-7b-32k-openpipe': 
'{"name":"1","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":true,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'finetuned-mistral-7b-optimised-openpipe': 
'{"name":"1","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":true,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'finetuned-openai-gpt-3.5-turbo-1106': 
'{"name":"4","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":true,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'gpt-3.5-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni province",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": false,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}',
        'gpt-4-turbo': '{\n    "name": "Taliban leader Alaudin killed in Ghazni",\n    "start_date": "2013-01-24",\n    "event_type": ["captureandkill"],\n    "province": ["ghazni"],\n    "target_group": ["taliban"],\n    "min_killed": 1,\n    "min_captured": 0,\n    "killq": true,\n    "captureq": false,\n    "killcaptureraid": true,\n    "airstrike": false,\n    "noshotsfired": false,\n    "min_leaders_killed": 1,\n    "min_leaders_captured": 0\n}',
        'gpt-4o': '{\n  "name": "Taliban leader Alaudin killed in Ghazni",\n  "start_date": "2013-01-24",\n  "event_type": ["insurgentskilled", "captureandkill"],\n  "province": ["ghazni"],\n  "target_group": ["taliban"],\n "min_killed": 1,\n  "min_captured": 0,\n  "killq": true,\n  "captureq": false,\n  "killcaptureraid": true,\n  "airstrike": false,\n  "noshotsfired": false,\n  "min_leaders_killed": 1,\n  "min_leaders_captured": 0\n}',
        'mistral-lora-templatefree': '1',
        'tinyllama-sharegpt': 
'{"name":"2","start_date":"2013-01-24","event_type":["airstrike"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":true,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'tinyllama-templatefree': '\n{"name":"Taliban leader killed in Ghazni","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}',
        'ft-solar-1-mini-chat-240612-predibase': 
'\n\n{"name":"2","start_date":"2013-01-24","event_type":["insurgentskilled"],"province":["ghazni"],"target_group":["taliban"],"min_killed":1,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":true,"airstrike":false,"noshotsfired":false,"min_leaders_killed":1,"min_leaders_captured":0}'
    },
    'start_date': datetime.date(2013, 1, 24),
    'province': ['ghazni'],
    'target_group': ['taliban'],
    'event_type': ['insurgentskilled'],
    'min_killed': 1,
    'min_captured': 0,
    'killq': True,
    'captureq': False,
    'killcaptureraid': False,
    'airstrike': False,
    'noshotsfired': False,
    'min_leaders_killed': 1,
    'min_leaders_captured': 0
}

现在我们有了来自七个微调模型和三个OpenAI模型的预测结果(用于比较),可以开始进行评估了。

不过在此之前,还需要先进行一个简单的检查,看看这些预测中有多少是有效的JSON格式。

JSON有效性测试

from datasets import load_dataset

dataset_with_preds = load_dataset("strickvl/isafpressreleases_test_predictions")[
    "train"
].to_list()

通过对比templatefree和sharegpt生成有效JSON的能力,我们已经能看到它们在TinyLlama微调中的差异,这非常具有指导意义。

OpenAI模型每次都能生成有效的JSON,微调后的Mistral和Llama3模型也是如此。

在编写评估模型的代码时,作者注意到有些条目是空白的或者根本没有预测结果,所以他接下来对这个问题进行了调查。

# find out how many of the predictions are None values or empty strings
missing_values = {
    "gpt-4o": 0,
    "gpt-4-turbo": 0,
    "gpt-3.5-turbo": 0,
    "tinyllama-templatefree": 0,
    "tinyllama-sharegpt": 0,
    "finetuned-openai-gpt-3.5-turbo-1106": 0,
    "finetuned-llama3-7b-32k-openpipe": 0,
    "mistral-lora-templatefree": 0,
    "finetuned-mistral-7b-optimised-openpipe": 0,
    "ft-solar-1-mini-chat-240612-predibase": 0,
}


for row in dataset_with_preds:
    for model in row["predictions"]:
        if row["predictions"][model] is None or row["predictions"][model] == "":
            missing_values[model] += 1


print(missing_values)
{
    'gpt-4o': 0,
    'gpt-4-turbo': 0,
    'gpt-3.5-turbo': 0,
    'tinyllama-templatefree': 0,
    'tinyllama-sharegpt': 38,
    'finetuned-openai-gpt-3.5-turbo-1106': 0,
    'finetuned-llama3-7b-32k-openpipe': 0,
    'mistral-lora-templatefree': 0,
    'finetuned-mistral-7b-optimised-openpipe': 0,
    'ft-solar-1-mini-chat-240612-predibase': 0
}

如果没有缺失值,tinyllama-sharegpt模型将会有全部724个预测结果,并且都是有效的JSON。

现在我们可以进入我们真正感兴趣的部分:准确性。作者将计算所有有意义的属性的分数,然后展示模型比较的结果。

这些属性包括:

  • start_date
  • province
  • target_group
  • event_type
  • min_killed
  • min_captured
  • killq
  • captureq
  • killcaptureraid
  • airstrike
  • noshotsfired
  • min_leaders_killed
  • min_leaders_captured

重要提示,对于接下来的所有图表,总任务数为724,所以这些数字是从724中得出的。

测试结果

开始日期准确性

{
    'gpt-4o': 527,
    'gpt-4-turbo': 522,
    'gpt-3.5-turbo': 492,
    'tinyllama-templatefree': 231,
    'tinyllama-sharegpt': 479,
    'finetuned-openai-gpt-3.5-turbo-1106': 646,
    'finetuned-llama3-7b-32k-openpipe': 585,
    'mistral-lora-templatefree': 0,
    'finetuned-mistral-7b-optimised-openpipe': 636,
    'ft-solar-1-mini-chat-240612-predibase': 649
}

Solar和微调的GPT-3.5模型在预测事件发生日期方面表现最佳。

省份准确性

{
    'gpt-4o': 649,
    'gpt-4-turbo': 645,
    'gpt-3.5-turbo': 595,
    'tinyllama-templatefree': 335,
    'tinyllama-sharegpt': 660,
    'finetuned-openai-gpt-3.5-turbo-1106': 704,
    'finetuned-llama3-7b-32k-openpipe': 707,
    'mistral-lora-templatefree': 0,
    'finetuned-mistral-7b-optimised-openpipe': 711,
    'ft-solar-1-mini-chat-240612-predibase': 704
}

分析发现,微调后的模型实际上比OpenAI模型表现更好,只犯了少量错误。

目标群体准确性

在这里,可能会提到多个目标群体,因此作者会根据模型预测的群体中有多少是正确的来给出一个满分为1的分数。

微调后的模型在目标群体识别方面明显优于OpenAI。

不过,作者怀疑如果添加一些训练数据中没有的新群体,模型的表现可能会下降。

事件类型准确性

事件类型实际上是最难的类别之一,因为有些类别在语义上存在重叠,有时甚至连人工标注者也难以区分。

再一次,微调后的模型在这方面表现得相当好。

min_killed准确性

在这些数字估计任务中,微调模型和OpenAI模型的表现差距突然缩小了。

虽然Mistral依然表现最佳,但优势并不明显!而且OpenAI模型在这方面的表现非常出色,令人印象深刻。

作者猜测这是因为提示中有一整段解释了用于标注示例的标准:

标注说明:‘facilitator’ 不是领导者。如果新闻稿中提到‘叛乱分子’被拘留而没有进一步细节,则分配至少两名被拘留者。将‘a couple’解释为两人。将‘several’解释为至少三人,尽管有时可能指七或八人。将‘a few’、‘some’、‘a group’、‘a small group’和‘multiple’解释为至少三人,即使有时它们指代更多人数。如果新闻稿中没有其他信息来提供一个最低可接受的数字,请选择较小的数字。将‘numerous’和‘a handful’解释为至少四人,而‘a large number’解释为至少五人。

min_captured准确性

killq准确性

作者期望这些布尔属性的准确性非常高,基本上几乎所有模型都能达到这一点。

不过,微调后的Mistral仍然击败了GPT-4o的最佳成绩。

captureq准确性

killcaptureraid准确性

「kill-capture raid」是一种特定术语,在标注时以特定方式使用。

OpenAI对作者如何进行这些调用一无所知,这也解释了他们在这里表现不佳的原因。

airstrike准确性

noshotsfired准确性

「noshootsfired」属性指的是新闻稿中是否提到在某次突袭/事件中没有开枪。(在某段时间内,新闻稿特别喜欢提到这一点。)

作者不太确定为什么OpenAI模型的表现与预期相反。

对此,可以想到的一些半拟人化解释方式是,比如GPT-4类模型可能「过度思考」了这个标签,但需要更多的调查才能真正理解这一点。

min_leaders_killed准确性

我们经常听说LLM在处理数字时表现不佳,或者会默认使用某些值等等。但在这个任务中,所有模型都得到了如此高的分数。

这应该是大家一直在努力改进的地方,并且成效显著。不过,微调模型仍然表现最好。

min_leaders_captured准确性

最终得分

让我们将这些单项能力分数相加,取平均值,得到我们模型在准确性方面的最终分数。

甚至连作者自己也感到惊讶,微调模型竟然超过了OpenAI的GPT类模型。甚至,连TinyLlama都比GPT-3.5 Turbo表现更好!

其中,表现最好的是Mistral-7B(在OpenPipe上微调),紧随其后的是Solar LLM和Llama3-7B。

仅从分数来看,任何想为结构化数据提取微调模型的人都可以先从Mistral-7B、Solar 7B或Llama3-7B入手。它们在准确性方面可能都差不多,而在模型服务、效率和延迟方面可能有不同的权衡。

作者认为,虽然在提示中再加入一些示例(以及更多的解释和规则),就可以让OpenAI的模型表现得更好,但拥有自己的微调模型也会有其他好处:

  • 数据隐私(不会将你的信息发送给 OpenAI)
  • 更小的模型很可能意味着更好的性能(尽管我仍需测试和证明这一点)
  • 整体上更多的控制
  • 成本改进

在成本方面,现在进行比较或声明有点困难,尤其是考虑到大型云服务供应商可以依赖的规模经济。

但在一个需要长期重复推理构建这个模型的实际用例中,更有可能让成本论点成立。

特别是,因为唯一能让OpenAI推理调用变得更好的方法是填充大量示例和额外解释,这显著增加了每次查询的成本。

话虽如此,微调模型时确实会出现一些真实的权衡。

微调效果显著,但是……

首先,作者非常高兴地发现,常说的「微调你的模型,获得比GPT-4更好的性能」实际上是真的!不仅是真的,而且只需要相对较少的调整和适应。

要知道,上述所有模型都是用现有数据进行的第一次微调。基本上只是使用了所有默认设置,因此它们开箱即用。

对于下一步的工作,作者表示,他将专注于表现最好的Solar、Llama3和Mistral 7B模型。

评估过程很痛苦

作者有一些在本地工作的模型,还有一些部署在不同环境和不同服务中的其他模型。

让模型分布在不同地方确实很麻烦。在理想情况下,你会希望所有模型的推理都有一个标准接口,尤其是当它们用于相同的用例或项目时。很方便的是,作者的微调GPT3.5自动由OpenAI部署和服务,Llama3、Solar和Mistral也是如此,但作者希望有一个地方可以看到它们全部。

当你有多个模型在运行,并且你在微调和更新它们,数据也在不断变化时,你就需要一种管理这一切的方法。

这也是微调大语言模型的主要挑战之一,你必须管理所有这些东西,以确保它们可靠且可重复地工作。即使在项目的早期阶段,也需要一种方法来保持一切井然有序不出错。

但可以了解是否在进步

尽管评估的实施过程有些痛苦,但它们给了作者一个重要的工具,那就是他现在有了一种特定任务的方法,来判断训练数据或模型的任何改进或优化是否在帮助自己前进。没有这个工具,那基本上就是在盲目摸索。

下一步是什么

作者最初想要训练多个模型,让它们在各自的领域都成为超级专家,例如有一个模型非常擅长估算某个特定事件中有多少人被捕。

然而,看到模型的表现,作者不太确定下一步还应不应该这样做,或者说,不太确定自己是否真的能够通过这种方法显著提高准确性。

接下来要做的,是进行一些与准确性无关的测试评估。比如,看看模型在域外数据(即完全虚构的关于完全不同的事情的数据)上的表现如何。

另一件事,是深入研究模型服务的一些细节。作者想拿出他前三名表现最佳的模型,深入研究大语言模型服务是如何完成的。

但在进入下一步之前,还要等他的兴奋劲退却一些——「目前,我只是很高兴微调后的模型击败了GPT-4!」

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

ML工程师一次微调七个模型,击败OpenAI GPT-4

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

ML工程师一次微调七个模型,击败OpenAI GPT-4

「微调你的模型,获得比GPT-4更好的性能」不只是说说而已,而是真的可操作。最近,一位愿意动手的ML工程师就把几个开源LLM调教成了自己想要的样子。
AI训练2024-11-29

微软仅凭「提示工程」让GPT-4成医学专家!超过一众高度微调模型,专业测试准确率首次超90%

在MedQA数据集(美国医师执照考试题)上,Medprompt让GPT-4的准确率首次超过90%,超越BioGPT和Med-PaLM等一众微调方法。
模型数据2024-11-30

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录