trainer使用 torch.utils.data 的 Dataset

Dataset适配

在 Transformers实战——使用Trainer类训练和评估自己的数据和模型 中,使用的数据集类别是from datasets import Dataset,但在常用的实现中,pytorch项目会使用torch.utils.data 中的数据集。

为了让pytorch中的Dataset类适配transformers中的Trainer ,需要对类别中的def __getitem__(self, idx) 方法进行修改,如下:

    def __getitem__(self, idx):
		    # 具体实现,可以自由发挥
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            return_token_type_ids=False,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        # 下面的返回要改成字典形式
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

简单来说就是把返回改成字典形式,且字典的键(key)需要与模型的forward对应。

DataCollator

上面的改动可以控制每一个样本的返回值,但模型训练时的样本往往是一批一批的。如果每个样本的输入形状(shape)固定不变的话,到上一步模型就能正常运行了。

有些时候每个样本的输入形状不一样,如文本分类(每篇文章字数不同),就会导致训练不能正常进行。

这种情况一般有两种方法(以文本分类为例):

  1. 修改__getitem__ :添加一个最大长度 max_length 把每个文本扩展(0填充)或截取到固定长度(如512)。
  2. 使用transformes中的DataCollator 类,如文本填充常用的DataCollatorWithPadding 。这种方式不用修改Dataset类的源码。

使用DataCollator 的样例如下:

from transformers import DataCollatorWithPadding

# 这里的tokenizer变量可以传入BertTokenizer类
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors='pt')

    trainer = Trainer(
        model=model,  # the instantiated 🤗 Transformers model to be trained 需要训练的模型
        args=training_args,  # training arguments, defined above 训练参数
        train_dataset=train_dataset,  # training dataset 训练集
        eval_dataset=dev_dataset,  # evaluation dataset 测试集
        compute_metrics=compute_metrics,  # 计算指标方法
        ########
        data_collator=data_collator, # 传入DataCollator 
        ########
    )

这种方法可以动态地根据每一批次的最大文本长度进行补全,可以一定程度节省内存消耗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/766430.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【QT】多元素控件

目录 概述 List Widget 核心属性 核心方法 核心信号 QListWidgetItem核心方法 Table Widget 核心方法 QTableWidgetItem 核心信号 QTableWidgetItem 核心方法 使用示例: Tree Widget 核心方法 核心信号 QTreeWidgetItem核心属性 QTreeWidgetItem核…

(九)绘制彩色三角形

前面的学习中并未涉及到颜色&#xff0c;现在打算写一个例子&#xff0c;在顶点着色器和片元着色器中加入颜色&#xff0c;绘制有颜色的三角形。 #include <glad/glad.h>//glad必须在glfw头文件之前包含 #include <GLFW/glfw3.h> #include <iostream>void …

13-4 GPT-5:博士级AI,人工智能的新时代

图片来源&#xff1a;AI Disruptive 人工智能世界正在迅速发展&#xff0c;新的创新和突破层出不穷。在本文中&#xff0c;我们将深入探讨最新的进展&#xff0c;从即将推出的 GPT-5 模型到 Apple 和 Meta 之间可能的合作。 GPT-5&#xff1a;博士级别的人工智能 虽然尚未正…

GL823K USB 2.0 SD/MSPRO读卡器控制芯片

概述 GL823K是一个USB 2.0单轮读卡器控制芯片&#xff0c;可以支持SD/MMC/MSPRO闪存卡。它支持USB 2.0高速传输&#xff0c;它在一个芯片上可以控制读取诸如安全数字卡&#xff08;SD卡&#xff09;&#xff0c;SDHC卡&#xff0c;迷你SD卡&#xff0c;微SD卡&#xff08;T-Fl…

Upload-Labs靶场闯关

文章目录 Pass-01Pass-02Pass-03Pass-04Pass-05Pass-06Pass-07Pass-08Pass-09Pass-10Pass-11Pass-12Pass-13Pass-14Pass-15Pass-16Pass-17Pass-18Pass-19Pass-20 以下是文件上传绕过的各种思路&#xff0c;不过是鄙人做题记下来的一些思路笔记罢了。 GitHub靶场环境下载&#x…

带电池监控功能的恒流直流负载组

EAK的交流和直流工业电池负载组测试仪对于测试和验证关键电力系统的能力至关重要&#xff0c;旨在实现最佳精度。作为一家客户至上的公司&#xff0c;我们继续尽我们所能应对供应链挑战&#xff0c;以提供出色的交货时间&#xff0c;大约是行业其他公司的一半。 交流负载组 我…

嵌入式c语言2——预处理

在c语言中&#xff0c;头部内容&#xff0c;如include与define是不参与编译而直接预先处理的 如include相当于把头文件扩展&#xff0c;define相当于做了替换 c语言大型工程创建时&#xff0c;会有调试版本与发行版本&#xff0c;发行时不希望看到调试部分内容&#xff0c;此时…

如何使用 Builder 设计模式和 DataFaker 库在自动化测试中生成测试数据

在自动化 API/Web 或移动应用程序时&#xff0c;您可能会遇到这样的情况&#xff1a;在注册用户时&#xff0c;您可能正在设置用于在测试自动化的端到端用户旅程中签出产品的地址。 那么&#xff0c;你是怎么做到的呢&#xff1f; 通常&#xff0c;我们在 Java 中创建一个 POJO…

鸿蒙开发设备管理:【@ohos.distributedHardware.deviceManager (设备管理)】

设备管理 本模块提供分布式设备管理能力。 系统应用可调用接口实现如下功能&#xff1a; 注册和解除注册设备上下线变化监听发现周边不可信设备认证和取消认证设备查询可信设备列表查询本地设备信息&#xff0c;包括设备名称&#xff0c;设备类型和设备标识 说明&#xff1a…

检索增强生成RAG系列5--RAG提升之路由(routing)

在系列3和系列4我讲了关于一个基本流程下&#xff0c;RAG的提高准确率的关键点&#xff0c;那么接下来&#xff0c;我们再次讲解2个方面&#xff0c;这2个方面可能与RAG的准确率有关系&#xff0c;但是更多的它们是有其它用途。本期先来讲解RAG路由。 目录 1 基本思想2 Logica…

Linux基础 - LNMP 架构部署动态网站环境

目录 零. 简介 一. 部署 二. 安装 Nginx 三. 安装MySQL 四. 安装PHP 五. 配置网站目录 六. 测试环境 零. 简介 LNMP 是指 Linux Nginx MySQL PHP 这一组合架构&#xff0c;用于部署动态网站环境。 Linux 作为操作系统&#xff0c;提供了稳定、安全和高效的基础平台。…

Swift 中强大的 Key Paths(键路径)机制趣谈(上)

概览 小伙伴们可能不知道&#xff1a;在 Swift 语言中隐藏着大量看似“其貌不扬”实则却让秃头码农们“高世骇俗”&#xff0c;堪称卧虎藏龙的各种秘技。 其中&#xff0c;有一枚“不起眼”的小家伙称之为键路径&#xff08;Key Paths&#xff09;。如若将其善加利用&#xff…

MYSQL函数进阶详解:案例解析(第19天)

系列文章目录 一、MySQL的函数&#xff08;重点&#xff09; 二、MySQL的窗口函数&#xff08;重点&#xff09; 三、MySQL的视图&#xff08;熟悉&#xff09; 四、MySQL的事务&#xff08;熟悉&#xff09; 文章目录 系列文章目录前言一、MySQL的函数1. 聚合函数2. group_c…

Linux基础 - MariaDB 数据库管理系统

目录 零. 简介 一. 安装 二. 基本使用 1. 设置root密码 2. 创建库 3. 创建表 4.添加数据 5. 查看数据 三. 管理表单及数据 四. 数据库的备份及恢复 零. 简介 MariaDB 是一种流行的开源数据库管理系统&#xff0c;它是 MySQL 的一个分支。 MariaDB 保留了与 MySQL 的…

HarmonyOS APP应用开发项目- MCA助手(Day01持续更新中~)

简言&#xff1a; gitee地址&#xff1a;https://gitee.com/whltaoin_admin/money-controller-app.git端云一体化开发在线文档&#xff1a;https://developer.huawei.com/consumer/cn/doc/harmonyos-guides-V5/agc-harmonyos-clouddev-view-0000001700053733-V5 注&#xff1…

激光粒度分析仪校准步骤详解:提升测量精度的秘诀

在材料科学、环境监测、医药研发等众多领域&#xff0c;激光粒度分析仪以其高精度、高效率的测量性能&#xff0c;成为了不可或缺的测试工具。然而&#xff0c;为了保持其测量结果的准确性和可靠性&#xff0c;定期校准是不可或缺的步骤。 接下来&#xff0c;佰德将为您详细介…

可视化低代码平台之:RayData光启元的震撼作品。

RayData家的可视化作品&#xff0c;贝格前端工场是经常碰到&#xff0c;制作十分的精良&#xff0c;业内很有影响力。他们也有自己的低代码平台&#xff0c;分为了桌面版和网页版&#xff0c;本期分享一下他们的作品。

【单片机毕业设计选题24043】-可旋转式电视支架控制系统设计与实现

系统功能: 系统操作说明&#xff1a; 上电后OLED显示 “欢迎使用电视支架系统请稍后”&#xff0c;两秒后进入正常界面显示 第一页面第一行显示 Mode:Key&#xff0c; 第二行显示 TV:Middle 短按B5按键可控制步进电机左转&#xff0c; 第二行显示 TV:Left 后正常显示 TV:…

六、资产安全—信息分级资产管理与隐私保护练习题(CISSP)

六、资产安全—信息分级资产管理与隐私保护(CISSP): 六、资产安全—信息分级资产管理与隐私保护(C

语义检索-BAAI Embedding语义向量模型深度解析:微调Cross-Encoder以提升语义检索精度

语义检索-BAAI Embedding语义向量模型深度解析:微调Cross-Encoder以提升语义检索精度 语义向量模型(Embedding Model)已经被广泛应用于搜索、推荐、数据挖掘等重要领域。在大模型时代,它更是用于解决幻觉问题、知识时效问题、超长文本问题等各种大模型本身制约或不足的必要…
最新文章