1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
| 原始输入的latents.shape = (batch_size, num_channels_latents, height, width) 第一句代码:latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 仅仅改变了最后(height,width)这两个维度,为了演示考虑(C,H,W)维度: original latent = [ [[1,2,3,4], [5,6,7,8]],
[[9,10,11,12], [13,14,15,16]] ] after view, latent = [ [ [[[1,2], [3,4]],
[[5,6], [7,8]]] ],
[ [[[9,10], [11,12]],
[[13,14], [15,16]]] ] ]
第二句代码:latents = latents.permute(0, 2, 4, 1, 3, 5)执行后把latents变成 (batch_size, height // 2, width // 2, num_channels_latents, 2, 2),之所以要这样做是为了把, 不同channel对应的同一个2*2的区域放在一起。
第三句代码:latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 就是最后整合一下,合并一下维度。
用代码可视化这个过程,忽略batch维度:
import torch x = torch.arange(32).reshape(2,4,4) print("x 的步幅:", x.stride()) print(x) y = x.view(2,2,2,2,2) print("y 的步幅:", y.stride()) print(y) y = y.permute(1,3,0,2,4) print("y 的步幅:", y.stride()) print(y) y = y.reshape(4,2,4) print("y 的步幅:", y.stride()) print(y)
x 的步幅: (16, 4, 1) tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]],
[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]]) y 的步幅: (16, 8, 4, 2, 1) tensor([[[[[ 0, 1], [ 2, 3]],
[[ 4, 5], [ 6, 7]]],
[[[ 8, 9], [10, 11]],
[[12, 13], [14, 15]]]],
[[[[16, 17], [18, 19]],
[[20, 21], [22, 23]]],
[[[24, 25], [26, 27]],
[[28, 29], [30, 31]]]]]) y 的步幅: (8, 2, 16, 4, 1) tensor([[[[[ 0, 1], [ 4, 5]],
[[16, 17], [20, 21]]],
[[[ 2, 3], [ 6, 7]],
[[18, 19], [22, 23]]]],
[[[[ 8, 9], [12, 13]],
[[24, 25], [28, 29]]],
[[[10, 11], [14, 15]],
[[26, 27], [30, 31]]]]]) y 的步幅: (8, 4, 1) tensor([[[ 0, 1, 4, 5], [16, 17, 20, 21]],
[[ 2, 3, 6, 7], [18, 19, 22, 23]],
[[ 8, 9, 12, 13], [24, 25, 28, 29]],
[[10, 11, 14, 15], [26, 27, 30, 31]]])
注意每一步的逻辑,我们是先想把tensor变换成后面的样子,根据这个样子我们能推导出对应的stride。 于是根据新stride和旧stride的变换关系,推导出permute、view、reshape的关系。注意view和reshape是把输入当做从左到右从上到下的一维连续变量进行变换。
PYTHON
|