我的编程空间,编程开发者的网络收藏夹
学习永远不晚

Pytorch中torch.cat()函数解析

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

Pytorch中torch.cat()函数解析

一. torch.cat()函数解析

1. 函数说明

1.1 官网torch.cat(),函数定义及参数说明如下图所示:
函数定义及参数说明
1.2 函数功能
函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐,如下面例子所示。torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

2. 代码举例

2.1 输入两个二维张量(dim=0):dim=0对行进行拼接

a = torch.randn(2,3)b =  torch.randn(3,3)c = torch.cat((a,b),dim=0)a,b,c
输出结果如下:(tensor([[-0.90, -0.37,  1.96],         [-2.65, -0.60,  0.05]]), tensor([[ 1.30,  0.24,  0.27],         [-1.99, -1.09,  1.67],         [-1.62,  1.54, -0.14]]), tensor([[-0.90, -0.37,  1.96],         [-2.65, -0.60,  0.05],         [ 1.30,  0.24,  0.27],         [-1.99, -1.09,  1.67],         [-1.62,  1.54, -0.14]]))

2.2 输入两个二维张量(dim=1): dim=1对列进行拼接

a = torch.randn(2,3)b =  torch.randn(2,4)c = torch.cat((a,b),dim=1)a,b,c
输出结果如下:(tensor([[-0.55, -0.84, -1.60],         [ 0.39, -0.96,  1.02]]), tensor([[-0.83, -0.09,  0.05,  0.17],         [ 0.28, -0.74, -0.27, -0.85]]), tensor([[-0.55, -0.84, -1.60, -0.83, -0.09,  0.05,  0.17],         [ 0.39, -0.96,  1.02,  0.28, -0.74, -0.27, -0.85]]))

2.3 输入两个三维张量:dim=0 对通道进行拼接

a = torch.randn(2,3,4)b =  torch.randn(1,3,4)c = torch.cat((a,b),dim=0)a,b,c
输出结果如下:(tensor([[[ 0.51, -0.72, -0.02,  0.76],          [ 0.72,  1.01,  0.39, -0.13],          [ 0.37, -0.63, -2.69,  0.74]],          [[ 0.72, -0.31, -0.27,  0.10],          [ 1.66, -0.06,  1.91, -0.66],          [ 0.34, -0.23, -0.18, -1.22]]]), tensor([[[ 0.94,  0.77, -0.41, -1.20],          [-0.23, -1.03, -0.25,  1.67],          [-1.00, -0.68, -0.35, -0.50]]]), tensor([[[ 0.51, -0.72, -0.02,  0.76],          [ 0.72,  1.01,  0.39, -0.13],          [ 0.37, -0.63, -2.69,  0.74]],          [[ 0.72, -0.31, -0.27,  0.10],          [ 1.66, -0.06,  1.91, -0.66],          [ 0.34, -0.23, -0.18, -1.22]],          [[ 0.94,  0.77, -0.41, -1.20],          [-0.23, -1.03, -0.25,  1.67],          [-1.00, -0.68, -0.35, -0.50]]]))

2.4 输入两个三维张量:dim=1对行进行拼接

a = torch.randn(2,3,4)b =  torch.randn(2,4,4)c = torch.cat((a,b),dim=1)a,b,c
输出结果如下:(tensor([[[-0.86,  0.00, -1.26,  1.20],          [-0.46, -1.08, -0.82,  2.03],          [-0.89,  0.43,  1.92,  0.49]],          [[ 0.24, -0.02,  0.32,  0.97],          [ 0.33, -1.34,  0.76, -1.55],          [ 0.38,  1.45,  0.27, -0.64]]]), tensor([[[ 0.82,  0.85, -0.30, -0.58],          [-0.09,  0.40,  0.02,  0.75],          [-0.70,  0.67, -0.88, -0.50],          [-0.62, -1.65, -1.10, -1.39]],          [[-0.85, -1.61, -0.35, -0.56],          [ 0.00,  1.40,  0.41,  0.39],          [-0.01,  0.04,  0.80,  0.41],          [-1.21, -0.64,  1.14,  1.64]]]), tensor([[[-0.86,  0.00, -1.26,  1.20],          [-0.46, -1.08, -0.82,  2.03],          [-0.89,  0.43,  1.92,  0.49],          [ 0.82,  0.85, -0.30, -0.58],          [-0.09,  0.40,  0.02,  0.75],          [-0.70,  0.67, -0.88, -0.50],          [-0.62, -1.65, -1.10, -1.39]],          [[ 0.24, -0.02,  0.32,  0.97],          [ 0.33, -1.34,  0.76, -1.55],          [ 0.38,  1.45,  0.27, -0.64],          [-0.85, -1.61, -0.35, -0.56],          [ 0.00,  1.40,  0.41,  0.39],          [-0.01,  0.04,  0.80,  0.41],          [-1.21, -0.64,  1.14,  1.64]]]))

2.5 输入两个三维张量:dim=2对列进行拼接

a = torch.randn(2,3,4)b =  torch.randn(2,3,5)c = torch.cat((a,b),dim=2)a,b,c
输出结果如下:(tensor([[[ 0.13, -0.02,  0.13, -0.25],          [ 1.42, -0.22, -0.87,  0.27],          [-0.07,  1.04, -0.06,  0.91]],          [[ 0.88, -1.46,  0.04,  0.35],          [ 1.36,  0.64,  0.75,  0.39],          [ 0.36,  1.13,  0.83,  0.56]]]), tensor([[[-0.47, -2.30, -0.49, -1.02,  1.74],          [ 0.71,  0.89,  0.80, -0.05, -1.35],          [-0.40,  0.26, -0.78, -1.50, -0.92]],          [[-0.77, -0.01,  1.23,  0.70, -0.66],          [ 0.28, -0.18, -0.91,  2.23,  1.14],          [-1.93, -0.17,  0.15,  0.40,  0.32]]]), tensor([[[ 0.13, -0.02,  0.13, -0.25, -0.47, -2.30, -0.49, -1.02,  1.74],          [ 1.42, -0.22, -0.87,  0.27,  0.71,  0.89,  0.80, -0.05, -1.35],          [-0.07,  1.04, -0.06,  0.91, -0.40,  0.26, -0.78, -1.50, -0.92]],          [[ 0.88, -1.46,  0.04,  0.35, -0.77, -0.01,  1.23,  0.70, -0.66],          [ 1.36,  0.64,  0.75,  0.39,  0.28, -0.18, -0.91,  2.23,  1.14],          [ 0.36,  1.13,  0.83,  0.56, -1.93, -0.17,  0.15,  0.40,  0.32]]]))

来源地址:https://blog.csdn.net/flyingluohaipeng/article/details/125038212

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

Pytorch中torch.cat()函数解析

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Pytorch中torch.cat()函数举例解析

一般torch.cat()是为了把多个tensor进行拼接而存在的,下面这篇文章主要给大家介绍了关于Pytorch中torch.cat()函数举例解析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
2022-12-22

Pytorch中torch.cat()函数的使用及说明

这篇文章主要介绍了Pytorch中torch.cat()函数的使用及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
2023-01-03

PyTorch常用函数torch.cat()中dim参数使用说明

这篇文章主要为大家介绍了PyTorch常用函数torch.cat()中dim参数使用说明,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
2023-05-17

Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析

torch.squeeze()这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,下面这篇文章主要给大家介绍了关于Pytorch中torch.unsqueeze()与torch.squeeze()函数详细的相关资料,需要的朋友可以参考下
2023-02-14

PyTorch中Torch.arange函数详解

PyTorch是由Facebook开发的开源机器学习库,它用于深度神经网络和自然语言处理,下面这篇文章主要给大家介绍了关于PyTorch中Torch.arange函数详解的相关资料,需要的朋友可以参考下
2023-02-03

pytorch中nn.Flatten()函数详解及示例

nn.Flatten是一个类,而torch.flatten()则是一个函数,下面这篇文章主要给大家介绍了关于pytorch中nn.Flatten()函数详解及示例的相关资料,需要的朋友可以参考下
2023-01-06

pytorch中函数tensor.numpy()的数据类型实例分析

这篇文章主要讲解了“pytorch中函数tensor.numpy()的数据类型实例分析”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“pytorch中函数tensor.numpy()的数据类型
2023-07-02

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录