class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
BasicBlock
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
2.3 使用ResNet模块进行迁移学习
import torchvision.models as models
import torch.nn as nn
class RES18(nn.Module):
def __init__(self):
super(RES18, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet18(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES34(nn.Module):
def __init__(self):
super(RES34, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet34(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES50(nn.Module):
def __init__(self):
super(RES50, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet50(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES101(nn.Module):
def __init__(self):
super(RES101, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet101(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES152(nn.Module):
def __init__(self):
super(RES152, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet152(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out