我的编程空间,编程开发者的网络收藏夹
学习永远不晚

pytorch实战---IMDB情感分析

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

pytorch实战---IMDB情感分析

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

🥦引言

本文使用IMDB数据集,结合pytorch进行情感分析

🥦完整代码

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom sklearn.metrics import precision_score, recall_score, f1_score, accuracy_scorefrom torch import utilsimport torchtextfrom tqdm import tqdmfrom torchtext.datasets import IMDBfrom torchtext.datasets.imdb import NUM_LINESfrom torchtext.data import get_tokenizerfrom torchtext.vocab import build_vocab_from_iteratorfrom torchtext.data.functional import to_map_style_datasetimport osimport sysimport loggingimport logginglogging.basicConfig(    level=logging.WARN, stream=sys.stdout, format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")VOCAB_SIZE = 15000# step1 编写GCNN模型代码,门(Gate)卷积网络class GCNN(nn.Module):    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):        super(GCNN, self).__init__()        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)        nn.init.xavier_uniform_(self.embedding_table.weight)        # 都是1维卷积        self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)        self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)        self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)        self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)        self.output_linear1 = nn.Linear(64, 128)        self.output_linear2 = nn.Linear(128, num_class)    def forward(self, word_index):        """        定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出        """        # 1. 通过word_index得到word_embedding        # word_index shape: [bs, max_seq_len]        word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]        # 2. 编写第一层1D门卷积模块,通道数在第2维        word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]        A = self.conv_A_1(word_embedding)        B = self.conv_B_1(word_embedding)        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]        A = self.conv_A_2(H)        B = self.conv_B_2(H)        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]        # 3. 池化并经过全连接层        pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]        linear1_output = self.output_linear1(pool_output)        # 最后一层需要设置为隐含层数目        logits = self.output_linear2(linear1_output)  # [bs, 2]        return logits# PyTorch官网的简单模型class TextClassificationModel(nn.Module):    """    简单版embedding.DNN模型    """    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):        super(TextClassificationModel, self).__init__()        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)        self.fc = nn.Linear(embed_dim, num_class)    def forward(self, token_index):        # 词袋        embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]        return self.fc(embedded)# step2 构建IMDB DataloaderBATCH_SIZE = 64def yeild_tokens(train_data_iter, tokenizer):    for i, sample in enumerate(train_data_iter):        label, comment = sample        yield tokenizer(comment)  # 字符串转换为token索引的列表train_data_iter = IMDB(root="./data", split="train")  # Dataset类型的对象tokenizer = get_tokenizer("basic_english")# 只使用出现次数大约20的tokenvocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=[""])vocab.set_default_index(0)  # 特殊索引设置为0print(f'单词表大小: len(vocab)')# 校对函数, batch是dataset返回值,主要是处理batch一组数据def collate_fn(batch):    """    对DataLoader所生成的mini-batch进行后处理    """    target = []    token_index = []    max_length = 0  # 最大的token长度    for i, (label, comment) in enumerate(batch):        tokens = tokenizer(comment)        token_index.append(vocab(tokens))  # 字符列表转换为索引列表        # 确定最大的句子长度        if len(tokens) > max_length:            max_length = len(tokens)        if label == "pos":            target.append(0)        else:            target.append(1)    token_index = [index + [0] * (max_length - len(index)) for index in token_index]    # one-hot接收长整形的数据,所以要转换为int64    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))# step3 编写训练代码def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):    """    此处data_loader是map-style dataset    """    start_epoch = 0    start_step = 0    if resume != "":        # 加载之前训练过的模型的参数文件        logging.warning(f"loading from resume")        checkpoint = torch.load(resume)        model.load_state_dict(checkpoint['model_state_dict'])        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])        start_epoch = checkpoint['epoch']        start_step = checkpoint['step']    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):        ema_loss = 0        total_acc_account = 0        total_account = 0        true_labels = []        predicted_labels = []        num_batches = len(train_data_loader)        for batch_index, (target, token_index) in enumerate(train_data_loader):            optimizer.zero_grad()            step = num_batches * (epoch_index) + batch_index + 1            logits = model(token_index)            # one-hot需要转换float32才可以训练            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均loss            bce_loss.backward()            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定            optimizer.step()  # 更新参数            true_labels.extend(target.tolist())            predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())            if step % log_step_interval == 0:                logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")            if step % save_step_interval == 0:                os.makedirs(save_path, exist_ok=True)                save_file = os.path.join(save_path, f"step_{step}.pt")                torch.save({                    "epoch": epoch_index,                    "step": step,                    "model_state_dict": model.state_dict(),                    'optimizer_state_dict': optimizer.state_dict(),                    'loss': bce_loss                }, save_file)                logging.warning(f"checkpoint has been saved in {save_file}")            if step % save_step_interval == 0:                os.makedirs(save_path, exist_ok=True)                save_file = os.path.join(save_path, f"step_{step}.pt")                torch.save({                    "epoch": epoch_index,                    "step": step,                    "model_state_dict": model.state_dict(),                    'optimizer_state_dict': optimizer.state_dict(),                    'loss': bce_loss,                    'accuracy': accuracy,                    'precision': precision,                    'recall': recall,                    'f1': f1                }, save_file)                logging.warning(f"checkpoint has been saved in {save_file}")            if step % eval_step_interval == 0:                logging.warning("start to do evaluation...")                model.eval()                ema_eval_loss = 0                total_acc_account = 0                total_account = 0                true_labels = []                predicted_labels = []                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):                    total_account += eval_target.shape[0]                    eval_logits = model(eval_token_index)                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),   F.one_hot(eval_target, num_classes=2).to(torch.float32))                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss                    true_labels.extend(eval_target.tolist())                    predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())                accuracy = accuracy_score(true_labels, predicted_labels)                precision = precision_score(true_labels, predicted_labels)                recall = recall_score(true_labels, predicted_labels)                f1 = f1_score(true_labels, predicted_labels)                logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")                logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")                model.train()model = GCNN()# model = TextClassificationModel()print("模型总参数:", sum(p.numel() for p in model.parameters()))optimizer = torch.optim.Adam(model.parameters(), lr=0.001)train_data_iter = IMDB(root="data", split="train")  # Dataset类型的对象train_data_loader = torch.utils.data.DataLoader(    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)eval_data_iter = IMDB(root="data", split="test")  # Dataset类型的对象# collate校对eval_data_loader = utils.data.DataLoader(    to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)# resume = "./data/step_500.pt"resume = ""train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval = 500, eval_step_interval = 300, save_path = "./log_imdb_text_classification2", resume = resume)

🥦代码分析

🥦导库

首先导入需要的库

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom sklearn.metrics import precision_score, recall_score, f1_score, accuracy_scorefrom torch import utilsimport torchtextfrom tqdm import tqdmfrom torchtext.datasets import IMDB
  • torch (PyTorch):
    PyTorch 是一个用于机器学习和深度学习的开源深度学习框架。它提供了张量计算、自动微分、神经网络层和优化器等功能,使用户能够构建和训练深度学习模型。

  • torch.nn:
    torch.nn 模块包含了PyTorch中用于构建神经网络模型的类和函数。它包括各种神经网络层、损失函数和优化器等。

  • torch.nn.functional:
    torch.nn.functional 模块提供了一组函数,用于构建神经网络的非参数化操作,如激活函数、池化和卷积等。这些函数通常与torch.nn一起使用。

  • sklearn.metrics (scikit-learn):
    scikit-learn是一个用于机器学习的Python库,其中包含了一系列用于评估模型性能的度量工具。导入的precision_score、recall_score、f1_score 和 accuracy_score 用于计算分类模型的精确度、召回率、F1分数和准确性。

  • torch.utils:
    torch.utils 包含了一些实用工具和数据加载相关的函数。在这段代码中,它用于构建数据加载器。

  • torchtext:
    torchtext 是一个PyTorch的自然语言处理库,用于文本数据的处理和加载。它提供了用于文本数据预处理和构建数据集的功能。

  • tqdm:
    tqdm 是一个Python库,用于创建进度条,可用于监视循环迭代的进度。在代码中,它用于显示训练和评估的进度。

  • torchtext.datasets.IMDB:
    torchtext.datasets.IMDB 是TorchText库中的一个数据集,包含了IMDb电影评论的数据。这些评论用于情感分析任务,其中评论被标记为积极或消极。

🥦设置日志

logging.basicConfig(    level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")

在代码中设置日志的作用是记录程序的运行状态、调试信息和重要事件,以便在开发和生产环境中更轻松地诊断问题和了解程序的行为。设置日志有以下作用:

  • 问题诊断:当程序出现错误或异常时,日志记录可以提供有关错误发生的位置、原因和上下文的信息。这有助于开发人员快速定位和修复问题。

  • 性能分析:通过记录程序的运行时间和关键操作的时间戳,日志可以用于性能分析,帮助开发人员识别潜在的性能瓶颈。

  • 跟踪进度:在长时间运行的任务中,例如训练深度学习模型,日志记录可以帮助跟踪任务的进度,以便了解训练状态、完成的步骤和剩余时间。

  • 监控和警报:日志可以与监控系统集成,以便在发生关键事件或异常情况时触发警报。这对于及时响应问题非常重要。

  • 审计和合规:在某些应用中,日志记录是合规性的一部分,用于追踪系统的操作和用户的活动。日志可以用于审计和调查。

在上述代码中,设置日志的目的是跟踪训练进度、记录训练损失以及保存检查点。它允许开发人员监视模型训练的进展并在需要时查看详细信息,例如损失值和评估指标。此外,日志还可以用于调试和查看模型性能。

🥦模型定义

代码定义了两个模型:

GCNN:用于文本分类的门控卷积神经网络。TextClassificationModel:使用嵌入和线性层的简单文本分类模型。

🥦GCNN

class GCNN(nn.Module):    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):        super(GCNN, self).__init__()        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)        nn.init.xavier_uniform_(self.embedding_table.weight)        # 都是1维卷积        self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)        self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)        self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)        self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)        self.output_linear1 = nn.Linear(64, 128)        self.output_linear2 = nn.Linear(128, num_class)    def forward(self, word_index):        """        定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出        """        # 1. 通过word_index得到word_embedding        # word_index shape: [bs, max_seq_len]        word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]        # 2. 编写第一层1D门卷积模块,通道数在第2维        word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]        A = self.conv_A_1(word_embedding)        B = self.conv_B_1(word_embedding)        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]        A = self.conv_A_2(H)        B = self.conv_B_2(H)        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]        # 3. 池化并经过全连接层        pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]        linear1_output = self.output_linear1(pool_output)        # 最后一层需要设置为隐含层数目        logits = self.output_linear2(linear1_output)  # [bs, 2]        return logits

🥦TextClassificationModel

class TextClassificationModel(nn.Module):    """    简单版embedding.DNN模型    """    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):        super(TextClassificationModel, self).__init__()        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)        self.fc = nn.Linear(embed_dim, num_class)    def forward(self, token_index):        # 词袋        embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]        return self.fc(embedded)

🥦准备IMDb数据集

这行代码使用TorchText的IMDB数据集对象,导入IMDb数据集的训练集部分。

# 数据集导入train_data_iter = IMDB(root="./data", split="train")

这行代码创建了一个用于将文本分词为单词的分词器。

# 数据预处理tokenizer = get_tokenizer("basic_english")

这里,build_vocab_from_iterator 函数根据文本数据创建了一个词汇表,只包括出现频率大于等于20次的单词。特殊标记用于处理未知单词。然后,set_default_index将特殊标记的索引设置为0。

# 构建词汇表vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=[""])vocab.set_default_index(0)

这是一个自定义的校对函数,用于处理DataLoader返回的批次数据,将文本转换为可以输入模型的张量形式。

def collate_fn(batch):    """    对DataLoader所生成的mini-batch进行后处理    """    target = []    token_index = []    max_length = 0  # 最大的token长度    for i, (label, comment) in enumerate(batch):        tokens = tokenizer(comment)        token_index.append(vocab(tokens))  # 字符列表转换为索引列表        # 确定最大的句子长度        if len(tokens) > max_length:            max_length = len(tokens)        if label == "pos":            target.append(0)        else:            target.append(1)    token_index = [index + [0] * (max_length - len(index)) for index in token_index]    # one-hot接收长整形的数据,所以要转换为int64    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))

这行代码将IMDb训练数据集加载到DataLoader对象中,以便进行模型训练。collate_fn函数用于处理数据的批处理。

train_data_loader = torch.utils.data.DataLoader(    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

上述代码块执行了IMDb数据集的准备工作,包括导入数据、分词、构建词汇表和设置数据加载器。这些步骤是为了使数据集可用于训练文本分类模型。

🥦整理函数

这个 collate_fn 函数用于对 DataLoader 批次中的数据进行处理,确保每个批次中的文本序列具有相同的长度,并将标签转换为适用于模型输入的张量形式。它的工作包括以下几个方面:

提取标签和评论文本。使用分词器将评论文本分词为单词。确定批次中最长评论的长度。根据最长评论的长度,将所有评论的单词索引序列填充到相同的长度。将标签转换为适当的张量形式(这里是将标签转换为长整数型)。返回处理后的批次数据,其中包括标签和填充后的单词索引序列。

这个整理函数确保了模型在训练期间能够处理不同长度的文本序列,并将它们转换为模型可接受的张量输入。

🥦训练函数

def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):    """    此处data_loader是map-style dataset    """    start_epoch = 0    start_step = 0    if resume != "":        # 加载之前训练过的模型的参数文件        logging.warning(f"loading from resume")        checkpoint = torch.load(resume)        model.load_state_dict(checkpoint['model_state_dict'])        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])        start_epoch = checkpoint['epoch']        start_step = checkpoint['step']    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):        ema_loss = 0        total_acc_account = 0        total_account = 0        true_labels = []        predicted_labels = []        num_batches = len(train_data_loader)        for batch_index, (target, token_index) in enumerate(train_data_loader):            optimizer.zero_grad()            step = num_batches * (epoch_index) + batch_index + 1            logits = model(token_index)            # one-hot需要转换float32才可以训练            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均loss            bce_loss.backward()            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定            optimizer.step()  # 更新参数            true_labels.extend(target.tolist())            predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())            if step % log_step_interval == 0:                logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")            if step % save_step_interval == 0:                os.makedirs(save_path, exist_ok=True)                save_file = os.path.join(save_path, f"step_{step}.pt")                torch.save({                    "epoch": epoch_index,                    "step": step,                    "model_state_dict": model.state_dict(),                    'optimizer_state_dict': optimizer.state_dict(),                    'loss': bce_loss                }, save_file)                logging.warning(f"checkpoint has been saved in {save_file}")            if step % save_step_interval == 0:                os.makedirs(save_path, exist_ok=True)                save_file = os.path.join(save_path, f"step_{step}.pt")                torch.save({                    "epoch": epoch_index,                    "step": step,                    "model_state_dict": model.state_dict(),                    'optimizer_state_dict': optimizer.state_dict(),                    'loss': bce_loss,                    'accuracy': accuracy,                    'precision': precision,                    'recall': recall,                    'f1': f1                }, save_file)                logging.warning(f"checkpoint has been saved in {save_file}")            if step % eval_step_interval == 0:                logging.warning("start to do evaluation...")                model.eval()                ema_eval_loss = 0                total_acc_account = 0                total_account = 0                true_labels = []                predicted_labels = []                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):                    total_account += eval_target.shape[0]                    eval_logits = model(eval_token_index)                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),   F.one_hot(eval_target, num_classes=2).to(torch.float32))                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss                    true_labels.extend(eval_target.tolist())                    predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())                accuracy = accuracy_score(true_labels, predicted_labels)                precision = precision_score(true_labels, predicted_labels)                recall = recall_score(true_labels, predicted_labels)                f1 = f1_score(true_labels, predicted_labels)                logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")                logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")                model.train()

这段代码定义了一个名为 train 的函数,用于执行训练过程。下面是该函数的详细说明:

train 函数接受以下参数:    train_data_loader: 训练数据的 DataLoader,用于迭代训练数据。    eval_data_loader: 用于评估的 DataLoader,用于评估模型性能。    model: 要训练的神经网络模型。    optimizer: 用于更新模型参数的优化器。    num_epoch: 训练的总周期数。    log_step_interval: 记录日志的间隔步数。    save_step_interval: 保存模型检查点的间隔步数。    eval_step_interval: 执行评估的间隔步数。    save_path: 保存模型检查点的目录。    resume: 可选的,用于恢复训练的检查点文件路径。训练函数的主要工作如下:    它首先检查是否有恢复训练的检查点文件。如果有,它会加载之前训练的模型参数和优化器状态,以便继续训练。    然后,它开始进行一系列的训练周期(epochs),每个周期内包含多个训练步(batches)。    在每个训练步中,它执行以下操作:        零化梯度,以准备更新模型参数。        计算模型的预测输出(logits)。        计算二进制交叉熵损失(binary cross-entropy loss)。        使用反向传播(backpropagation)计算梯度并更新模型参数。        记录损失、真实标签和预测标签。        如果步数达到了 log_step_interval,则记录损失。        如果步数达到了 save_step_interval,则保存模型检查点。        如果步数达到了 eval_step_interval,则执行评估:    将模型切换到评估模式(model.eval())。    对评估数据集中的每个批次执行以下操作:        计算模型的预测输出。        计算二进制交叉熵损失。        计算准确性、精确度、召回率和F1分数。        记录评估损失和评估指标。    将模型切换回训练模式(model.train())。最后,训练函数返回经过训练的模型。

这个训练函数执行了完整的训练过程,包括了模型的前向传播、损失计算、梯度更新、日志记录、模型检查点的保存和评估。通过调用这个函数,你可以训练模型并监视其性能。

🥦模型初始化和优化器

model = GCNN()# model = TextClassificationModel()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

🥦加载用于训练和评估的数据

在提供的代码中,加载用于训练和评估的数据的部分如下:

train_data_iter = IMDB(root="data", split="train")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的训练集部分。这部分数据将用于模型的训练。

eval_data_iter = IMDB(root="data", split="test")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的测试集部分。这部分数据将用于评估模型的性能。


之后,这些数据集通过以下代码转化为 DataLoader 对象,以便用于模型训练和评估:

# 训练数据 DataLoadertrain_data_loader = torch.utils.data.DataLoader(    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
# 评估数据 DataLoadereval_data_loader = utils.data.DataLoader(     to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

这些 DataLoader 对象将数据加载到内存中,以便训练和评估使用。collate_fn 函数用于处理数据的批次,确保它们具有适当的格式,以便输入到模型中。

这些部分负责加载和准备用于训练和评估的数据,是机器学习模型训练和评估的重要准备步骤。训练数据用于训练模型,而评估数据用于评估模型的性能。

🥦恢复训练

start_epoch = 0start_step = 0if resume != "":    # 加载之前训练过的模型的参数文件    logging.warning(f"loading from resume")    checkpoint = torch.load(resume)    model.load_state_dict(checkpoint['model_state_dict'])    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])    start_epoch = checkpoint['epoch']    start_step = checkpoint['step']

上述代码段位于训练函数中的开头部分,主要用于检查是否有已经训练过的模型的检查点文件,以便继续训练。具体解释如下:

如果 resume 变量不为空(即存在要恢复的检查点文件路径),则执行以下操作:
通过 torch.load 加载之前训练过的模型的检查点文件。
使用 load_state_dict 方法将已保存的模型参数加载到当前的模型中,以便继续训练。
同样,使用 load_state_dict 方法将已保存的优化器状态加载到当前的优化器中,以确保继续从之前的状态开始训练。
获取之前训练的轮数和步数,以便从恢复的状态继续训练。

这部分代码的目的是允许从之前保存的模型检查点继续训练,而不是从头开始。这对于长时间运行的训练任务非常有用,可以在中途中断训练并在之后恢复,而不会丢失之前的训练进度。

🥦调用训练

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification2", resume=resume)

🥦保存文件的读取

import torch# 指定已存在的 .pt 文件路径file_path = "./log_imdb_text_classification/step_3500.pt"  # 替换为实际的文件路径# 使用 torch.load() 加载文件checkpoint = torch.load(file_path)# 查看准确率、精确率、召回率和F1分数accuracy = checkpoint["accuracy"]precision = checkpoint["precision"]recall = checkpoint["recall"]f1 = checkpoint["f1"]print("Accuracy:", accuracy)print("Precision:", precision)print("Recall:", recall)print("F1 Score:", f1)

在这里插入图片描述

🥦扩展 LSTM、GRU

本文原作者使用的是卷积神经网络,但是卷积神经网络的优化模型GCNN,但是这个模型对于图更好,由此我接下来引入两个循环神经网络LSTM和GRU

class LSTMModel(nn.Module):    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):        super(LSTMModel, self).__init__()        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)        self.output_linear = nn.Linear(hidden_dim, num_class)    def forward(self, word_index):        word_embedding = self.embedding_table(word_index)        lstm_out, _ = self.lstm(word_embedding)        lstm_out = lstm_out[:, -1, :]  # 取最后一个时间步的输出        logits = self.output_linear(lstm_out)        return logitsclass GRUModel(nn.Module):    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):        super(GRUModel, self).__init__()        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=1, batch_first=True)        self.output_linear = nn.Linear(hidden_dim, num_class)    def forward(self, word_index):        word_embedding = self.embedding_table(word_index)        gru_out, _ = self.gru(word_embedding)        gru_out = gru_out[:, -1, :]  # 取最后一个时间步的输出        logits = self.output_linear(gru_out)        return logits
# 创建LSTM模型lstm_model = LSTMModel()print("模型总参数:", sum(p.numel() for p in lstm_model.parameters()))lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)# 创建GRU模型# gru_model = GRUModel()# print("模型总参数:", sum(p.numel() for p in gru_model.parameters()))# gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=0.001)
# 训练LSTM模型train(train_data_loader, eval_data_loader, lstm_model, lstm_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_lstm", resume="")# 训练GRU模型# train(train_data_loader, eval_data_loader, gru_model, gru_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_gru", resume="")

感兴趣的小伙伴可以试试,对比一下

🥦总结

本文代码来自网络仅供学习,原文地址

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

来源地址:https://blog.csdn.net/null18/article/details/133984043

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

pytorch实战---IMDB情感分析

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

【Python NLTK】实战案例:情感分析,洞察用户情绪

情感分析是自然语言处理的重要分支,旨在理解和识别文本中的情绪和情感。本文将使用Python NLTK库来实现情感分析,演示如何洞察用户的情绪,并提供演示代码。
【Python NLTK】实战案例:情感分析,洞察用户情绪
2024-02-24

pytorch 实现情感分类问题小结

本文主要介绍了pytorch 实现情感分类问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-02-14

什么是情感分析?

情感分析是一种人工智能技术,通过分析文本或语音,识别、提取和量化其中的情绪。它使用机器学习算法将文本内容分类或评级为积极、消极或中立等情绪。该技术广泛应用于市场研究、社交媒体监测、客服、医疗保健和金融领域,可以通过自动化分析、客观性以及提供深入见解来帮助企业做出明智的决策。
什么是情感分析?
2024-04-02

Python基于jieba分词实现snownlp情感分析

情感分析(sentimentanalysis)是2018年公布的计算机科学技术名词,它可以根据文本内容判断出所代表的含义是积极的还是负面的等。本文将通过jieba分词实现snownlp情感分析,感兴趣的可以了解一下
2023-01-30

Tensorflow2.1实现文本中情感分类实现解析

这篇文章主要为大家介绍了Tensorflow2.1实现文本中情感分类实现解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2022-11-21

AI核心难点之一:情感分析的常见类型与挑战

编程学习网:情感分析或情感人工智能,在商业应用中通常被称为意见挖掘,是自然语言处理(NLP)的一个非常流行的应用。文本处理是该技术最大的分支,但并不是唯一的分支。情绪AI有三种类型及其组合。
AI核心难点之一:情感分析的常见类型与挑战
2024-04-23

怎样用Python代码做情感分析

本篇文章为大家展示了怎样用Python代码做情感分析,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。一台可以上网的电脑基本的python代码阅读能力,用于修改几个模型参数对百度中文NLP最新成果的浓烈
2023-06-16

如何利用python实现简单的情感分析

今天小编给大家分享一下如何利用python实现简单的情感分析的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。1 数据导入及预处
2023-07-02

Python底层技术揭秘:如何实现情感分析

Python底层技术揭秘:如何实现情感分析,需要具体代码示例引言:随着社交媒体的普及和大数据时代的到来,情感分析成为了一个被广泛关注和应用的领域。情感分析可以帮助我们理解和分析用户的情感和意见,从而对产品、服务或市场做出更合理的决策。Pyt
Python底层技术揭秘:如何实现情感分析
2023-11-08

如何利用ChatGPT和Python实现情感分析功能

如何利用ChatGPT和Python实现情感分析功能介绍ChatGPTChatGPT是OpenAI于2021年发布的一种基于强化学习的生成式预训练模型,它采用了强大的语言模型来生成连贯的对话。ChatGPT可以用于各种任务,包括情感分析。导
2023-10-24

PHP 开发中 Elasticsearch 实现文本挖掘与情感分析

近年来,随着互联网的快速发展,海量的文本数据被不断产生。这些文本数据蕴含着丰富的信息,对于企业来说,通过对文本数据的挖掘与分析,可以获取用户需求、产品意见、市场趋势等有价值的信息。而Elasticsearch作为一种分布式搜索引擎,具有擅长
2023-10-21

计算机竞赛 基于GRU的 电影评论情感分析 - python 深度学习 情感分类

文章目录 1 前言1.1 项目介绍 2 情感分类介绍3 数据集4 实现4.1 数据预处理4.2 构建网络4.3 训练模型4.4 模型评估4.5 模型预测 5 最后 1 前言 🔥 优质竞赛项目系列,今天要分
2023-08-30

如何利用ChatGPT和Python实现对话情感分析功能

如何利用ChatGPT和Python实现对话情感分析功能引言:随着人工智能和自然语言处理的快速发展,对话情感分析成为了一个备受关注的研究领域。ChatGPT作为一个先进的生成式对话模型,为我们提供了一个很好的工具来实现对话情感分析。本文将介
2023-10-24

Golang在舆情监测与分析中的实战应用

在舆情监测与分析中,golang 的应用主要体现在:数据采集:从多种线上来源收集数据。数据清洗:去除冗余和错误数据。分析:采用机器学习算法进行情绪分析和主题提取,识别关键影响者。可视化:创建图表和仪表板展示分析结果。优势包括并发处理、高效率
Golang在舆情监测与分析中的实战应用
2024-05-12

自然语言处理NLPTextRNN实现情感分类

这篇文章主要为大家介绍了自然语言处理NLPTextRNN实现情感分类示例解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-05-17

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录