张量在内存中的存储(reshape/permute操作理解)

在阅读flux代码的时候,看到这段处理latent的代码有些懵逼,追根溯源就是自己对于pytorch Tensor数据组织的方式理解不透彻,因此写下这篇博客开云破雾:

1
2
3
4
5
6
7
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents
PYTHON

在底层实现中,PyTorch 中的张量(Tensor)实际上是以一维的连续内存块存储的,只是通过不同的 stride 来控制数据在内存中的访问顺序。我习惯从右往左去看待张量,因此我说的层的顺序是从右往左的

1
2
3
4
5
6
7
8
9
10
11
Example.1:Tensor A.shape: (2, 3, 4)
视图:
[[[0, 1, 2, 3], // 第一个 batch, 第 1 行
[4, 5, 6, 7], // 第一个 batch, 第 2 行
[8, 9, 10, 11]], // 第一个 batch, 第 3 行
[[12, 13, 14, 15], // 第二个 batch, 第 1 行
[16, 17, 18, 19], // 第二个 batch, 第 2 行
[20, 21, 22, 23]]] // 第二个 batch, 第 3 行

内存中的数据(按行主序排列):
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]
PYTHOH
  • stride[2] = 1(最里层的stride都是1)代表最里层每个元素是紧邻的,例如0马上就接着1然后是2:
    0,1,2,3,…
  • stride[1] = A.shape[2] = 4,代表每相邻4个元素是一组 ,例如0,1,2,3四个元素是一组,此时4,5,6,7又是新的一组:
    [0,1,2,3], [4,5,6,7], …
  • stride[0] = A.shape[2]*A.shape[1] = 3*4,代表每相邻12个元素是一个更大的组,因此0-11这12个元素会形成最外层的大组:
    [ [0,1,2,3], [4,5,6,7], [8,9,10,11] ], …

由此总结stride的规律

  • 对于有N个维度的Tensor A,
    我们可以把print出来的Tensor从上到下,从左往右展平成一维向量, 他的张量视图就是根据上述得到的。

理解了Tensor数据的组织方法,最重要的作用就是可以推导出reshape(),view()这俩操作后新的数据视图(注意没有permute)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Example.2: 对于1中的Tensor A (2,3,4),如果执行A.view(2,2,2,3)会是什么样子?

按照上述分析,首先把A展平成一维:
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]

引用stride规则:
stride[3] = 1
stride[2] = 3,得到[0,1,2], [3,4,5],...
stride[1] = 3*2 = 6,得到[[0,1,2],[3,4,5]], [[6,7,8],[9,10,11]],...
stride[0] = 3*2*2 = 12,得到[[[0,1,2],[3,4,5]],[[6,7,8],[9,10,11]]], [[[12,13,14],[15,16,17]],[[18,19,20],[21,22,23]]]

于是变换后的最终结果为:
[[[[0,1,2],
[3,4,5]],
[[6,7,8],
[9,10,11]]],
[[[12,13,14],
[15,16,17]],
[[18,19,20],
[21,22,23]]]]

在pytorch中代码验证,结果相同:
import torch
A = torch.arange(24).reshape(2,3,4)
print(A)
A = A.reshape(2,2,2,3)
print(A)
tensor([[[[ 0, 1, 2],
[ 3, 4, 5]],

[[ 6, 7, 8],
[ 9, 10, 11]]],


[[[12, 13, 14],
[15, 16, 17]],

[[18, 19, 20],
[21, 22, 23]]]])
PYTHON

上述的讨论其实就是view()和reshape()操作的原理,因此对于只含view和reshape的操作,不管中间维度怎么变换,只要输入相同,且输出的维度相同最后的结果就是一样的

对于permute() 操作,他的原理和前两个不同,举一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
有一个5维Tensor A,A.shape = (5,6,7,8,9),假设原来A中的元素A[3][4][5][6][7] = b,
进行A.permute(0,2,4,1,3)后,A[3][4][5][6][7]会被映射到A[3][5][7][4][6],
所以permute后A[3][5][7][4][6] = b。说到底就是每个元素的索引按着permute的方式对应变换位置。

更重要的理解方式就是stride的变换,具体来说原始A的stride也会按照变换到permute后的A的stride上:

import torch
x = torch.arange(24).reshape(2, 3, 4)
print("x 的步幅:", x.stride()) # 输出: (12, 4, 1)
print(x)

y = x.permute(2, 0, 1) # 新形状为 (4, 2, 3)
print("Permute 后的张量 y 的形状:", y.shape)
print("y 的步幅:", y.stride()) # 输出: (1, 12, 4)
print(y)

输出:
x 的步幅: (12, 4, 1)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
Permute 后的张量 y 的形状: torch.Size([4, 2, 3])
y 的步幅: (1, 12, 4)
tensor([[[ 0, 4, 8],
[12, 16, 20]],

[[ 1, 5, 9],
[13, 17, 21]],

[[ 2, 6, 10],
[14, 18, 22]],

[[ 3, 7, 11],
[15, 19, 23]]])

进一步的对于上面这个例子permute后的Tensor按照最开始讲的,展开成一维不就是[0,4,8,12,...,23]吗,那我能不能用view改变一下形状呢?
例如y.view(2,12)不就返回[[0,4,8,12,16,20,1,5,9,13,17,21], [...]]了吗?
实际上由于y = x.permute()返回的只是原来x的新视图,x在内存中的物理存储没有改变还是[1,2,3,...],想要把y展平后view会报错:

Traceback (most recent call last):
File "/home/ganzhaoxing/RAG-Diffusion/test_reshape.py", line 10, in <module>
y.view(2,12)
RuntimeError: view size is not compatible with input tensors size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

从这个报错信息可以猜测,pytorch实现permute应该采用的是stride变换的观点,例如shape = (2, 3, 4)的stride本来是(12,4,1),
permute(2,0,1)后,相对于x的物理存储,stride变为 (1,12,4) != (12,4,1)因此判断不连续。

解决方案:需要用contiguous()函数把y对应的视图在实际物理存储上变得连续才能用view,或者直接使用reshape函数也可以(相当于contiguous + view)。
PYTHON

现在我们彻底理解了Tensor数据组织和形状变换的原理,让我们重新回到开头flux的源码部分进行解读:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
原始输入的latents.shape = (batch_size, num_channels_latents, height, width)
第一句代码:latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
仅仅改变了最后(height,width)这两个维度,为了演示考虑(C,H,W)维度:
original latent =
[#shape = (2,2,4)
[[1,2,3,4],
[5,6,7,8]],

[[9,10,11,12],
[13,14,15,16]]
]
after view, latent =
[#shape = (2,1,2,2,2)
[
[[[1,2],
[3,4]],

[[5,6],
[7,8]]]
],

[
[[[9,10],
[11,12]],

[[13,14],
[15,16]]]
]
]

第二句代码:latents = latents.permute(0, 2, 4, 1, 3, 5)执行后把latents变成
(batch_size, height // 2, width // 2, num_channels_latents, 2, 2),之所以要这样做是为了把,
不同channel对应的同一个2*2的区域放在一起。

第三句代码:latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 就是最后整合一下,合并一下维度。

用代码可视化这个过程,忽略batch维度:

import torch
x = torch.arange(32).reshape(2,4,4)
print("x 的步幅:", x.stride())
print(x)
y = x.view(2,2,2,2,2)
print("y 的步幅:", y.stride())
print(y)
y = y.permute(1,3,0,2,4)
print("y 的步幅:", y.stride())
print(y)
y = y.reshape(4,2,4)
print("y 的步幅:", y.stride())
print(y)

x 的步幅: (16, 4, 1)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],

[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]])
y 的步幅: (16, 8, 4, 2, 1)
tensor([[[[[ 0, 1],
[ 2, 3]],

[[ 4, 5],
[ 6, 7]]],


[[[ 8, 9],
[10, 11]],

[[12, 13],
[14, 15]]]],



[[[[16, 17],
[18, 19]],

[[20, 21],
[22, 23]]],


[[[24, 25],
[26, 27]],

[[28, 29],
[30, 31]]]]])
y 的步幅: (8, 2, 16, 4, 1)
tensor([[[[[ 0, 1],
[ 4, 5]],

[[16, 17],
[20, 21]]],


[[[ 2, 3],
[ 6, 7]],

[[18, 19],
[22, 23]]]],



[[[[ 8, 9],
[12, 13]],

[[24, 25],
[28, 29]]],


[[[10, 11],
[14, 15]],

[[26, 27],
[30, 31]]]]])
y 的步幅: (8, 4, 1)
tensor([[[ 0, 1, 4, 5],
[16, 17, 20, 21]],

[[ 2, 3, 6, 7],
[18, 19, 22, 23]],

[[ 8, 9, 12, 13],
[24, 25, 28, 29]],

[[10, 11, 14, 15],
[26, 27, 30, 31]]])

注意每一步的逻辑,我们是先想把tensor变换成后面的样子,根据这个样子我们能推导出对应的stride。
于是根据新stride和旧stride的变换关系,推导出permute、view、reshape的关系。注意view和reshape是把输入当做从左到右从上到下的一维连续变量进行变换。
PYTHON

相关博客:https://blog.csdn.net/qq_43391414/article/details/120798955


张量在内存中的存储(reshape/permute操作理解)
https://deadsmither5.github.io/2024/12/28/张量在内存中的存储/
作者
zhaoxing
发布于
2024年12月28日
许可协议