https://www.ichunqiu.com/battalion?t=1&r=70899
+ + + + + + + + + + +
LeNet解析
-
附件flag压缩包中存在若干图像样本,使用 matplotlib
可视化为手写体数字+字母!
-
加载模型参数,并查看网络结构
pt = torch.load(
"./MyLeNet.pt"
, map_location=device)
for
net
in
pt:
(net,pt[net].shape)
# [out]---------------
# conv1.weight torch.Size([6, 1, 5, 5])
# conv1.bias torch.Size([6])
# conv2.weight torch.Size([16, 6, 5, 5])
# conv2.bias torch.Size([16])
# fc1.weight torch.Size([120, 256])
# fc1.bias torch.Size([120])
# fc2.weight torch.Size([84, 120])
# fc2.bias torch.Size([84])
# fc3.weight torch.Size([62, 84])
# fc3.bias torch.Size([62])
网络结构:
-
输入层 28*28 -
卷积层1 -
卷积层2 -
全连接层1——256个神经元(4×4×16) -
全连接层2——120个神经元 -
全连接层3——84个神经元 -
输出层——62个神经元
如果仅经过两次步长为1的卷积,输出的特征图大小为20×20,而该模型中实际却是4×4,存在两种可能:
-
每次的卷积步长为2 / 其中1次卷积步长为4 -
存在池化
按照经典LeNet神经网络的思路,每次卷积过后都会进行2*2步长为2的池化
28×28 --(cov1)--> 24×24 --(pool)--> 12×12 --(cov2)--> 8×8 --(pool)--> 4×4
神经元个数与参数个数一致。
-
池化存在两种方式
暂且按照LeNet神经网络中的最大池化来设计,于是可以写出如下网络结构
-
当采用最大池化时,激活和池化的顺序并不影响计算结果
-
卷积 -> 激活 -> 池化 -
卷积 -> 池化 -> 激活 -
最大池化 -
平均池化
class MyLeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 62)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
# ①
x = self.conv2(x)
x = self.maxpool2(x)
# ②
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
# ③
x = self.fc2(x)
# ④
x = self.fc3(x)
return
x
-
目前仅需考虑激活函数
最多4处需要使用激活函数
常见的激活函数有如下4+1种
-
Sigmoid函数 -
Tanh/双曲正切函数 -
ReLU函数 -
Softmax函数(一般出现在一个网络的最后一层) -
未使用激活函数
class MyLeNet(nn.Module):
def __init__(self, list_func):
super().__init__()
self.idx = list_func
self.func = [None, nn.Sigmoid(), nn.Tanh(), nn.ReLU(), nn.Softmax()]
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 62)
def forward(self, x):
x = self.conv1(x)
if
self.idx[0] > 0:
x = self.func[self.idx[0]](x)
x = self.maxpool1(x)
x = self.conv2(x)
if
self.idx[1] > 0:
x = self.func[self.idx[1]](x)
x = self.maxpool2(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
if
self.idx[2] > 0:
x = self.func[self.idx[2]](x)
x = self.fc2(x)
if
self.idx[3] > 0:
x = self.func[self.idx[3]](x)
x = self.fc3(x)
return
x
尝试损失函数
arrange_list = [[ j, k, l, m]
for
j
in
range(4)
for
k
in
range(4)
for
l
in
range(4)
for
m
in
range(5)]
for
idxs
in
arrange_list[:500]:
model = MyLeNet(idxs)
model.load_state_dict(pt)
tmp =
""
for
i
in
range(56):
npy_0 = np.load(
"./flag/"
+str(i)+
".npy"
).reshape((1,1,28,28))
# 调整样本为输入形状
tmp += chars[int(model(torch.tensor(npy_0).to(device)).argmax())]
(tmp[-1],end=
''
)
()
结果种出现若干字符,类似Base64编码,于是尝试使用Base64解码
# [out]:部分输出结果
huunKUuvWRvvWXWiIW4WvxIhWvxuixvvvh4EhivvvIhvzWEEuWvWKEiv
22nIk2282Rre2le2nn222nI8WrW22r4r42822I8v8ini2862n2nrh82r
HnnIknnvWRrhnlvWMWvWWkkWWnWWDv4R4n8WnFnv8invzvCnnnnnhnDr
4UUvKUEE2axU2kUIzW44IxI6Uaxa2xvzv4E62iavEIavEvEaUUEWKi2E
AFFSFCFFFFFFkkFFFkFFSFAFkFFCFFFFFFFFFFSFFSAFFFFFSFFNFFFF
XWKXKWKKWKWWWXWjXWQQWKXKWWXWj75W5KQWKjKWK7KWKWQKWWXWKKjF
lUUYhUr5U2rU21UD1Vg1UY1hl4Y22Y5Y5lgL2j11rYhYYi5UUUYWl525
AHKwKWKKWAVWWXWWWWVjWKPhKVXWWXFwFhKLhFKKK7hVXiWUWWXWKjW5
lnKYhUhVn2VUWlhDTV9gWYl6UVYWWY545lghhiVVV7hVYhgWhWYWliWF
tnAt8CACnACCAtCDtnACnttCUAtCttAtAtAAAAtnAtnntnCCCn89tAtt
hUEmkUkEwRvWWhWUmW44ukkhU4WuU4646h4EhGkvvwmUzWEuWW4WkvUE
4nnSk2rSnnrnn2UDMn442tIrUrr224646h4U468rrMnn8864nn44hr2r
-
Base64解码
for
idxs
in
arrange_list:
model = MyLeNet(idxs)
model.load_state_dict(pt)
tmp =
""
for
i
in
range(56):
npy_0 = np.load(
"./flag/"
+str(i)+
".npy"
).reshape((1,1,28,28))
tmp += chars[int(model(torch.tensor(npy_0).to(device)).argmax())]
try:
flag = base64.b64decode(tmp).decode(
'utf-8'
)
if
"flag"
in
flag:
(flag)
(idxs)
break
except Exception:
pass
-
得到Flag及损失函数
一血队伍解法
作者:LaoGong队伍
+ + + + + + + + + + +
原文始发于微信公众号(春秋伽玛):官方WP | PyTorch手写识别模型训练赛题LeNet解析
- 左青龙
- 微信扫一扫
-
- 右白虎
- 微信扫一扫
-
评论