我的编程空间,编程开发者的网络收藏夹
学习永远不晚

numpy中tensordot的用法

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

numpy中tensordot的用法

楔子

在numpy中有一个tensordot方法,尤其在做机器学习的时候会很有用。估计有人看到这个名字,会想到tensorflow,没错tensorflow里面也有tensordot这个函数。这个函数它的作用就是,可以让两个不同维度的数组进行相乘。我们来举个例子:

import numpy as np

a = np.random.randint(0, 9, (3, 4))
b = np.random.randint(0, 9, (4, 5))
try:
    print(a * b)
except Exception as e:
    print(e)  # operands could not be broadcast together with shapes (3,4) (4,5)

# 很明显,a和b两个数组的维度不一样,没办法相乘
# 但是
print(np.tensordot(a, b, 1))
"""
[[32 32 28 28 52]
 [10 25 40 38 78]
 [56  7 28  0 42]]
"""
# 我们看到使用tensordot是可以的

下面我们来看看这个函数的用法

函数原型

@array_function_dispatch(_tensordot_dispatcher)
def tensordot(a, b, axes=2):

我们看到这个函数接收三个参数,前两个就是numpy中数组,最后一个参数则是用于指定收缩的轴。它可以接收一个整型、列表、列表里面嵌套列表,具体代表什么含义我们下面举例说明。

理解axes

axes为整型

如果axes接收的是一个整型:m,那么表示指定数组a的后n个轴和数组b的前n个轴分别进行内积,就是对应位置元素相乘、再整体求和。

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))

# 显然这两个数组不能直接相乘,但是a和后两个轴和b的前两个轴是可以直接相乘的
# 因为它们都是(4, 5), 最后结果的shape为(3, 8)
print(np.tensordot(a, b, 2).shape)  # (3, 8)

而且这个axes默认为2,所以它一般都是针对三维或者三维以上的数组

但是为了具体理解,后面我们会使用一维、二维数据具体举例说明。现在先看axes取不同的值,会得到什么结果,先理解一下axes的含义。

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))

try:
    print(np.tensordot(a, b, 1).shape)
except Exception as e:
    print(e)  # shape-mismatch for sum
# 结果报错了,很好理解,就是形状不匹配嘛
# axes指定为1,表示a的后一个轴和b的前一个轴进行内积
# 但是一个是5一个是4,元素无法一一对应,所以报错,提示shape-mismatch,形状不匹配

# 这里我们把数组b的shape改一下,这样a的后一个轴和b的前一个轴就匹配了,都是5
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((5, 4, 8))
print(np.tensordot(a, b, 1).shape)  # (3, 4, 4, 8)
"""
这样就能够运算了,我们说指定收缩的轴,进行内积运算得到的是一个值
所以这里的(3, 4, 5)和(5, 4, 8)变成了(3, 4, 4, 8)

而上一个例子是(3, 4, 5)和(4, 5, 8),然后axes=2
因为a的后两个轴和b的前两个轴进行内积变成了一个具体的值,所以最终的维度就是(3, 8)
"""

如果axes为0的话,会有什么结果

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))

print(np.tensordot(a, b, 0).shape)  # (3, 4, 5, 4, 5, 8)
print(np.tensordot(b, a, 0).shape)  # (4, 5, 8, 3, 4, 5)
"""
np.tensordot(a, b, 0)等价于将a中的每一个元素都和b相乘
然后再将原来a中的对应元素替换掉
"""

上面的操作也可以使用爱因斯坦求和来实现

axes=0

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))

c1 = np.tensordot(a, b, 0)
c2 = np.einsum("ijk,xyz->ijkxyz", a, b)
print(c1.shape, c2.shape)  # (3, 4, 5, 4, 5, 8) (3, 4, 5, 4, 5, 8)
print(np.all(c1 == c2))  # True
"""
生成的c1和c2是一样的
"""

c3 = np.tensordot(b, a, 0)
c4 = np.einsum("ijk,xyz->xyzijk", a, b)
print(c3.shape, c4.shape)  # (4, 5, 8, 3, 4, 5) (4, 5, 8, 3, 4, 5)
print(np.all(c3 == c4))  # True
"""
生成的c3和c4是一样的
"""

那么它们的效率之间孰优孰劣呢?我们在jupyter上测试一下

>>> %timeit c1 = np.tensordot(a, b, 0)
50.5 µs ± 206 ns per loop
>>> %timeit c2 = np.einsum("ijk,xyz->ijkxyz", a, b)
7.29 µs ± 242 ns per loop

可以看到爱因斯坦求和快了不少

axes=1

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((5, 4, 8))

c1 = np.tensordot(a, b, 1)
c2 = np.einsum("ijk,kyz->ijyz", a, b)
print(c1.shape, c2.shape)  # (3, 4, 4, 8) (3, 4, 4, 8)
print(np.all(c1 == c2))  # True

axes=2

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))

c1 = np.tensordot(a, b, 2)
c2 = np.einsum("ijk,jkz->iz", a, b)
print(c1.shape, c2.shape)  # (3, 8) (3, 8)
print(np.all(c1 == c2))  # True

axes为列表

如果axes接收的是一个列表:[m, n],那么表示让a的第m+1个(索引为m)轴和b的第n+1(索引为n)个轴进行内积。使用列表的方法最大的好处就是,可以指定任意位置的轴。

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
# 我们看到a的第二个维度(或者说轴)和b的第一个维度都是4,所以它们是可以进行内积的
c1 = np.tensordot(a, b, [1, 0])
# 由于内积的结果是一个标量,所以(3, 4, 5)和(4, 5, 8)在tensordot之后的shape是(3, 5, 5, 8)
# 相当于把各自的4给扔掉了(因为变成了标量),然后组合在一起
print(c1.shape)  # (3, 5, 5, 8)

# 同理a的最后一个维度和b的第二个维度也是可以内积的
# 最后一个维度也可以使用-1,等于按照列表的索引来取对应的维度
c2 = np.tensordot(a, b, [-1, 1])
print(c2.shape)  # (3, 4, 4, 8)

上面的操作也可以使用爱因斯坦求和来实现

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
c1 = np.tensordot(a, b, [1, 0])
c2 = np.einsum("ijk,jyz->ikyz", a, b)
print(c1.shape, c2.shape)  # (3, 5, 5, 8) (3, 5, 5, 8)
print(np.all(c1 == c2))  # True

c3 = np.tensordot(a, b, [-1, 1])
c4 = np.einsum("ijk,akz->ijaz", a, b)
print(c3.shape, c4.shape)  # (3, 4, 4, 8) (3, 4, 4, 8)
print(np.all(c3 == c4))  # True

axes为列表嵌套列表

如果axes接收的是一个嵌套列表的列表:[[m], [n]],等于说可以选多个轴

import numpy as np

a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
# 我们想让a的后两个轴和b的前两个轴内积
c1 = np.tensordot(a, b, axes=2)
c2 = np.tensordot(a, b, [[1,2], [0,1]])
print(c1.shape, c2.shape)  # (3, 8) (3, 8)
print(np.all(c1 == c2))  # True

但是使用列表进行筛选还有一个好处,就是可以忽略顺序

import numpy as np

a = np.arange(60).reshape((4, 3, 5))
b = np.arange(160).reshape((4, 5, 8))
# 这个时候就无法给axes传递整型了
c3 = np.tensordot(a, b, [[0, 2], [0, 1]])
print(c3.shape)  # (3, 8)

此外,使用列表筛选还有一个强大的功能,就是可以倒着取值

import numpy as np

a = np.arange(60).reshape((4, 5, 3))
b = np.arange(160).reshape((5, 4, 8))

# 这个时候我们选择前两个轴,但是一个是(4, 5)一个是(5, 4),所以无法相乘
# 因此在选择的时候需要倒着筛选:
# [[0, 1], [1, 0]]-> (4, 5)和(4, 5) 或者 [[1, 0], [0, 1]] -> (5, 4)和(5, 4)
c3 = np.tensordot(a, b, [[0, 1], [1, 0]])
print(c3.shape)  # (3, 8)

最后同样看看如何爱因斯坦求和来实现

import numpy as np

a = np.arange(60).reshape((4, 5, 3))
b = np.arange(160).reshape((4, 5, 8))

c1 = np.tensordot(a, b, [[0, 1], [0, 1]])
c2 = np.einsum("ijk,ijz->kz", a, b)
print(c1.shape, c2.shape)  # (3, 8) (3, 8)
print(np.all(c1 == c2))  # True


a = np.arange(60).reshape((4, 5, 3))
b = np.arange(160).reshape((5, 4, 8))

c1 = np.tensordot(a, b, [[0, 1], [1, 0]])
c2 = np.einsum("ijk,jiz->kz", a, b)
print(c1.shape, c2.shape)  # (3, 8) (3, 8)
print(np.all(c1 == c2))  # True


a = np.arange(60).reshape((4, 3, 5))
b = np.arange(160).reshape((5, 4, 8))

c1 = np.tensordot(a, b, [[0, 2], [1, 0]])
c2 = np.einsum("ijk,kiz->jz", a, b)
print(c1.shape, c2.shape)  # (3, 8) (3, 8)
print(np.all(c1 == c2))  # True

以两个一维数组为例

我们来通过打印具体的数组来看一下tensordot

import numpy as np

a = np.array([1, 2, 3])
b = np.array([2, 3, 4])

print(np.tensordot(a, b, axes=0))
"""
[[ 2  3  4]
 [ 4  6  8]
 [ 6  9 12]]
"""
print(np.einsum("i,j->ij", a, b))
"""
[[ 2  3  4]
 [ 4  6  8]
 [ 6  9 12]]
"""

# 我们axes=0,等于是a的每一个元素和相乘,然后再把原来a对应的元素替换掉
# 所以是a中的1 2 3分别和b相乘,得到[2 3 4] [4 6 8] [6 9 12]、再替换掉1 2 3
# 所以结果是[[2 3 4] [4 6 8] [6 9 12]]

如果axes=1呢?

import numpy as np

a = np.array([1, 2, 3])
b = np.array([2, 3, 4])

print(np.tensordot(a, b, axes=1))  # 20
"""
选取a的前一个轴和b的后一个轴进行内积
而a和b只有一个轴,所以结果是一个标量
"""
print(np.einsum("i,i->", a, b))  # 20

如果axes=2呢?首先我们说axes等于一个整型,表示选取a的后n个轴,b的前n个轴,而一维数组它们只有一个轴

import numpy as np

a = np.array([1, 2, 3])
b = np.array([2, 3, 4])

try:
    print(np.tensordot(a, b, axes=2))  # 20
except Exception as e:
    print(e)  # tuple index out of range

显然索引越界了。

以一个一维数组和一个二维数组为例

我们通过一维数组和二维数组进行tensordot来感受一下

axes=0

import numpy as np

a = np.array([1, 2, 3])
b = np.array([[2, 3, 4]])

print(np.tensordot(a, b, 0))
"""
[[[ 2  3  4]]

 [[ 4  6  8]]
 
 [[ 6  9 12]]]
"""
print(np.einsum("i,jk->ijk", a, b))
"""
[[[ 2  3  4]]

 [[ 4  6  8]]
 
 [[ 6  9 12]]]
"""
# 很好理解,就是1 2 3分别和[[2, 3, 4]]相乘再替换掉 1 2 3
print(np.tensordot(a, b, 0).shape)  # (3, 1, 3)


##########################
print(np.tensordot(b, a, 0))
"""
[[[ 2  4  6]
  [ 3  6  9]
  [ 4  8 12]]]
"""
print(np.einsum("i,jk->jki", a, b))
"""
[[[ 2  4  6]
  [ 3  6  9]
  [ 4  8 12]]]
"""
# 很好理解,就是2 3 4分别和[1 2 3]相乘再替换掉 2 3 4
print(np.tensordot(b, a, 0).shape)  # (1, 3, 3)

axes=1的话呢?

import numpy as np

a = np.array([1, 2, 3])
b = np.array([[2, 3, 4], [4, 5, 6]])
try:
    print(np.tensordot(a, b, 1))
except Exception as e:
    print(e)  # shape-mismatch for sum
# 我们注意到报错了,因为axes=1,表示取a的后一个轴和b的前1个轴
# a的shape是(3, 0),所以它的后一个轴和前一个轴对应的数组长度都是3
# 但是b的前一个轴对应的数组长度是2,不匹配所以报错

print(np.tensordot(b, a, 1))  # [20 32]
# 我们看到这个是可以的,因为这表示b的后一个轴,数组长度为3,是匹配的
# 让后一个轴的[2 3 4]、[4 5 6]分别和[1 2 3]进行内积,最终得到两个标量

try:
    print(np.einsum("i,ij->ij", a, b))
except Exception as e:
    print(e)
    # operands could not be broadcast together with remapped shapes [original->remapped]: (3,)->(3,newaxis) (2,3)->(2,3)

# 同样对于爱因斯坦求和也是无法这么做的,我们需要换个顺序
print(np.einsum("i,ji->j", a, b))  # [20 32]
# 或者
print(np.einsum("j,ij->i", a, b))  # [20 32]

axes=2的话呢?

import numpy as np

a = np.array([1, 2, 3])
b = np.array([[2, 3, 4], [4, 5, 6]])
try:
    print(np.tensordot(a, b, 2))
except Exception as e:
    print(e)  # tuple index out of range
# 我们注意到报错了,因为axes=2,表示取a的后两个轴和b的前两个轴
# 而a总共才1个轴,所以报错了

try:
    print(np.tensordot(b, a, 2))
except Exception as e:
    print(e)  # shape-mismatch for sum
# 我们看到虽然也报错了,但是不是报索引越界。
# 因为上面表示取a的前两个轴,虽然a只有一个,但是此时不会索引越界,只是就取一个。如果是取后两个就会越界了
# 此时b是(2, 3),而a是(3,) 不匹配,可能有人觉得会发生广播,但在这里不会

以两个二维数组为例

我们再通过两个二维数组进行tensordot来感受一下

axes=0

import numpy as np

a = np.array([[1, 2, 3]])
b = np.array([[2, 3, 4], [4, 5, 6]])

# a_shape: (1, 3) b_shape(3, 3)
print(np.tensordot(a, b, 0))
"""
[[[[ 2  3  4]
   [ 4  5  6]]

  [[ 4  6  8]
   [ 8 10 12]]

  [[ 6  9 12]
   [12 15 18]]]]
"""
print(np.einsum("ij,xy->ijxy", a, b))
"""
[[[[ 2  3  4]
   [ 4  5  6]]

  [[ 4  6  8]
   [ 8 10 12]]

  [[ 6  9 12]
   [12 15 18]]]]
"""
print(np.tensordot(a, b, 0).shape)  # (1, 3, 2, 3)

#############
print(np.tensordot(b, a, 0))
"""
[[[[ 2  4  6]]

  [[ 3  6  9]]

  [[ 4  8 12]]]


 [[[ 4  8 12]]

  [[ 5 10 15]]

  [[ 6 12 18]]]]
"""
print(np.einsum("ij,xy->xyij", a, b))
"""
[[[[ 2  4  6]]

  [[ 3  6  9]]

  [[ 4  8 12]]]


 [[[ 4  8 12]]

  [[ 5 10 15]]

  [[ 6 12 18]]]]
"""
print(np.tensordot(b, a, 0).shape)  # (2, 3, 1, 3)

axes=1

import numpy as np

a = np.array([[1, 2], [3, 4]])
b = np.array([[2, 3, 4], [4, 5, 6]])

# a_shape: (2, 2) b_shape(2, 3)
print(np.tensordot(a, b, 1))
"""
[[10 13 16]
 [22 29 36]]
"""
print(np.einsum("ij,jk->ik", a, b))
"""
[[10 13 16]
 [22 29 36]]
"""
# 仔细的你肯定发现了,此时就相当于矩阵的点乘
print(a @ b)
"""
[[10 13 16]
 [22 29 36]]
"""

axes=2

import numpy as np

a = np.array([[1, 2], [3, 4]])
b = np.array([[2, 3, 4], [4, 5, 6]])

# a_shape: (2, 2) b_shape(2, 3)

# 取后两个轴显然不行,因为(2, 2)和(2, 3)不匹配
try:
    print(np.tensordot(a, b, 2))
except Exception as e:
    print(e)  # shape-mismatch for sum
    
a = np.array([[1, 2, 3], [2, 2, 2]])
b = np.array([[2, 3, 4], [4, 5, 6]])
print(np.tensordot(a, b, 2))  # 50
print(np.einsum("ij,ij->", a, b))  # 50    

最后看即个爱因斯坦求和的例子,感受它和主角tensordot的区别,当然如果不熟悉的爱因斯坦求和的话可以不用看

import numpy as np

a = np.random.randint(1, 9, (5, 3, 2, 3))
b = np.random.randint(1, 9, (3, 3, 2))

c1 = a @ b  # 多维数组,默认是对最后两位进行点乘
c2 = np.einsum("ijkm,jmn->ijkn", a, b)
print(np.all(c1 == c2))  # True
print(c2.shape)  # (5, 3, 2, 2)
print(np.einsum("...km,...mn->...kn", a, b).shape)  # (5, 3, 2, 2)

# 但如果是
c3 = np.einsum("ijkm,amn->ijkn", a, b)
print(c3.shape)  # (5, 3, 2, 2)
# 由于符号不一样,所以即使shape一致,但是两个数组不一样
print(np.all(c3 == c1))  # False


a = np.random.randint(1, 9, (5, 3, 3, 2))
b = np.random.randint(1, 9, (1, 3, 2))

print(np.einsum("ijmk,jmn->ijkn", a, b).shape)  # (5, 3, 2, 2)
print(np.einsum("ijkm,jnm->ijkn", a, b).shape)  # (5, 3, 3, 3)

到此这篇关于numpy中tensordot的用法的文章就介绍到这了,更多相关numpy tensordot内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

numpy中tensordot的用法

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

numpy中tensordot的用法

本文主要介绍了numpy中tensordot的用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-02-20

numpy中的tensordot怎么使用

这篇文章主要讲解了“numpy中的tensordot怎么使用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“numpy中的tensordot怎么使用”吧!楔子在numpy中有一个tensord
2023-07-05

python中numpy的用法

numpy是python中用于科学计算的强大库,它提供了以下功能:多维数组处理矩阵运算快速傅里叶变换(fft)线性代数随机数生成NumPy在Python中的强大功能NumPy是Python中用于科学计算的一个强大且灵活的库。它提供了用于处
python中numpy的用法
2024-05-15

python中numpy用法

numpy是python中处理多维数组和矩阵的库,提供丰富的功能包括数组创建、访问、操作、数据类型、广播、线性代数运算、傅里叶变换、随机数生成、文件输入/输出和自定义函数。NumPy 在 Python 中的用法NumPy 是 Python
python中numpy用法
2024-05-15

numpy中np.c_和np.r_的用法解析

本文主要介绍了numpy中np.c_和np.r_的用法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-03-13

关于numpy中np.nonzero()函数用法的详解

np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。一般来说,通过help(np.nonzero)能够查看到该函数的解析与例程。但是,由于例程为英文缩写,阅读起来还是很费劲,因此,本文将其英文解释
2022-06-04

详述numpy中的np.random.random()系列函数用法

本文主要介绍了详述numpy中的np.random.random()系列函数用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
2023-03-14

Numpy的基本用法整理

本篇内容主要讲解“Numpy的基本用法整理”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Numpy的基本用法整理”吧!前言Numpy是一个开源的Python科学计算库,它是python科学计算库
2023-06-04

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录