04_残差网络
2026/6/8 22:24:23 网站建设 项目流程

描述

残差网络是现代卷积神经网络的一种,有效的抑制了深层神经网络的梯度弥散和梯度爆炸现象,使得深度网络训练不那么困难。

下面以cifar-10-batches-py数据集,实现一个ResNet18的残差网络,通过继承nn.Module实现残差块(Residual Block),网络模型类。

定义Block

ResNetBlock派生至nn.Module,需要自己实现forward函数。

torch.nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法,可以从这个类派生自己的模型类。

nn.Module重要的函数:

  • forward(self,*input):forward函数为前向传播函数,需要自己重写,它用来实现模型的功能,并实现各个层的连接关系;
  • __call__(self, *input, **kwargs): __call__()的作用是使class实例能够像函数一样被调用,以“对象名()”的形式使用;
  • __repr__(self):__repr__函数为Python的一个内置函数,它能把一个对象用字符串的形式表达出来;
  • __init__(self):构造函数,自定义模型的网络层对象一般在这个函数中定义。
classResNetBlock(nn.Module):def__init__(self,input_channels,num_channels,stride=1):''' 构造函数:定义网络层 '''super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=stride)self.btn1=nn.BatchNorm2d(num_channels)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1,stride=1)self.btn2=nn.BatchNorm2d(num_channels)ifstride!=1:self.downsample=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=stride)else:self.downsample=lambdax:xdefforward(self,X):''' 实现反向传播 '''Y=self.btn1(self.conv1(X))Y=nn.functional.relu(Y)Y=self.btn2(self.conv2(Y))Y+=self.downsample(X)returnnn.functional.relu(Y)

定义模型

ResNet同样派生于nn.Module,与ResNetBlock类似,需要实现forward。

torch.nn.Sequential是PyTorch 中一个用于构建顺序神经网络模型的容器类,它将多个神经网络层或模块按顺序组合在一起,简化模型搭建过程。‌Sequential器会严格按照添加的顺序执行内部的子模块,前向传播时自动传递数据,适用于简单神经网络的构建。

classResNet(nn.Module):def__init__(self,layer_dism,num_class=10):''' 构造函数:定义预处理model;构建block层 '''super(ResNet,self).__init__()# 预处理self.stem=nn.Sequential(nn.Conv2d(3,64,3,1),# 3x30x30nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2,2)# 64x15x15)self.layer1=self.build_resblock(64,64,layer_dism[0])self.layer2=self.build_resblock(64,128,layer_dism[1],2)self.layer3=self.build_resblock(128,256,layer_dism[2],2)self.layer4=self.build_resblock(256,512,layer_dism[3],2)self.avgpool=nn.AvgPool2d(1,1)self.btn=nn.Flatten()self.fc=nn.Linear(512,num_class)defbuild_resblock(self,input_channels,num_channels,block,stride=1):res_block=nn.Sequential()res_block.append(ResNetBlock(input_channels,num_channels,stride))for_inrange(1,block):res_block.append(ResNetBlock(num_channels,num_channels,stride))returnres_blockdefforward(self,X):out=self.stem(X)out=self.layer1(out)out=self.layer2(out)out=self.layer3(out)out=self.layer4(out)out=self.avgpool(out)returnself.fc(self.btn(out))

模型训练

加载数据

使用torchvision.datasets加载本地数据,如果本地没有数据,可以设置download=True自动下载。

# 定义数据转换transform=transforms.Compose([transforms.ToTensor(),# 将PIL图像转换为Tensortransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# 归一化])# 加载CIFAR-10训练集trainset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=True,download=False,transform=transform)trainloader=th.utils.data.DataLoader(trainset,batch_size=16,shuffle=False,num_workers=2)# 加载CIFAR-10测试集testset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=False,download=False,transform=transform)testloader=th.utils.data.DataLoader(testset,batch_size=16,shuffle=False,num_workers=2)

模型初始化

模型初始化是确保网络能够有效学习的关键步骤,一个好的初始值,会使模型收敛速度提高,使模型准确率更精确。

torch.nn.init模块提供了一系列的权重初始化函数:

  • torch.nn.init.uniform_ :均匀分布
  • torch.nn.init.normal_ :正态分布
  • torch.nn.init.constant_:初始化为指定常数
  • torch.nn.init.kaiming_uniform_:凯明均匀分布
  • torch.nn.init.kaiming_normal_:凯明正态分布
  • torch.nn.init.xavier_uniform_:Xavier均匀分布
  • torch.nn.init.xavier_normal_:Xavier正态分布

在初始化时,最好不要将模型的参数初始化为0,因为这样会导致梯度消失,进而影响训练效果。可以将模型初始化为一个很小的值,如0.01,0.001等。

definitialize_weight(m):ifisinstance(m,nn.Conv2d)orisinstance(m,nn.Linear):nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')# mode:权重方差计算方式,可选 'fan_in' 或 'fan_out'(输入、输出神经元数量)# nonlinearity:激活函数类型,用于调整计算公式 ,一般是relu、leaky_reluifm.biasisnotNone:nn.init.constant_(m.bias,0)

[2,2,2,2] 参数分别代表四个block的中的残差块数量(可以仔细看一下build_resblock函数)

resnet_18=ResNet([2,2,2,2])resnet_18.apply(initialize_weight)# 初始化模型loss_cross=nn.CrossEntropyLoss()trainer=th.optim.SGD(resnet_18.parameters())

训练

训练过程比较漫长,这里训练只有20轮,测试精度0.51。如果有N卡加持的话,可以适当调高epoch,精度能进一步提高。

forepochinrange(0,20):running_loss=0.0forinputs,labelsintrainloader:trainer.zero_grad()outputs=resnet_18(inputs)loss=loss_cross(outputs,labels)loss.backward()trainer.step()running_loss+=loss.item()print(f'[{epoch+1}] ev loss:{running_loss/3125}')running_loss=0.0

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询