Skip to content

Commit 7702ac3

Browse files
committed
Create 2022-09-03-分组卷积.md
1 parent d638416 commit 7702ac3

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

_posts/2022-09-03-分组卷积.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
---
2+
title: 分组卷积详解
3+
---
4+
5+
在看[ACMix](https://github.com/LeapLabTHU/ACmix/blob/main/Swin-Transformer/models/swin_transformer_acmix.py#L130)这篇文章的时候涉及到了分组卷积,由于网上的图只给了简单的情况,没有给出更复杂的情况,这里特地记录一下
6+
7+
有分组卷积,配置如下
8+
9+
```Python
10+
import torch
11+
import torch.nn as nn
12+
13+
dim = 128
14+
nhead = 8
15+
groups = dim // nhead
16+
17+
group_conv = nn.Conv2d(9*dim//nhead, dim, 3, padding=1, groups=groups, bias=False) # 不失一般性,这里不设置bias
18+
19+
x=torch.randn(1, 9*dim//nhead, 224, 224)
20+
y=group_conv(x)
21+
y.shape
22+
```
23+
24+
输出y的shape为
25+
26+
```
27+
torch.Size([1, 128, 224, 224])
28+
```
29+
30+
下面我们只判断第一组
31+
32+
```Python
33+
groups = dim // nhead
34+
x_g0 = x[:, :9, ...]
35+
kernel_g0 = group_conv.weight.data[:nhead]
36+
y_g0=nn.functional.conv2d(x_g0, kernel_g0, padding=1)
37+
38+
(np.around(y[:, :nhead].detach().numpy(), 5) == np.around(y_g0.detach().numpy(), 5)).all() # 由于精度问题需要round一下
39+
```
40+
41+
输出为
42+
43+
```
44+
True
45+
```
46+
47+
以此类推,可以判断第二组到最后一组的相等情况。
48+
49+
# 图解
50+
51+
![image-20220903133948876](https://raw.githubusercontent.com/294coder/blog_img_bed/main/imgs/image-20220903133948876.png)

0 commit comments

Comments
 (0)