对比tensorflow,从0开始学pytorch(三)--自定义层
上文虽然实现了GMS层的效果,但是前端代码太多,太ugly,也不好复用。今天抽空看了下pytorch中怎么自定义层,很简单,比tensorflow好用。
1. 任意文件夹创建个文件,和所有编程语言一样
2. 一样,集成nn.Module,然后自定义一个形参
这里需要花时间搞明白torch.nn.functional下的函数和torch.nn下的类的区别,一开始有点懵,想着为什么不做高级语言当中的静态函数,想明白了也就简单了。
图中的SPP_Sizes做了类型定义,python中一般情况不需要定义类型,但不定义在后面循环就会报错,看了下pytorch自带conv2d的源码,发现源码中非常严谨,每一个变量都定义了类型。
3. 调用就非常简单了,上一篇笔记中的冗长的代码,就可以一行调用
4. 简化后,代码看过去顺眼多了。附上GMS封装后的源码和训练结果
import torch import torch.nn as nn import torch.nn.functional as F class GMS(nn.Module): def __init__(self, Spp_Sizes:[]): super().__init__() if len(Spp_Sizes) == 0: self.SPP_Sizes = [2, 3, 4] else: self.SPP_Sizes = Spp_Sizes def forward(self, x): x_gap = F.adaptive_avg_pool2d(x, (1, 1)) x_gap = torch.flatten(x_gap, 1) x_gmp = F.adaptive_max_pool2d(x, (1, 1)) x_gmp = torch.flatten(x_gmp, 1) x_gms = torch.cat((x_gap, x_gmp), dim=1) for spp_size in self.SPP_Sizes: x_spp = F.adaptive_max_pool2d(x, (spp_size,spp_size)) x_spp = torch.flatten(x_spp, 1) x_gms = torch.cat((x_gms, x_spp), dim=1) return x_gms