最新要闻

广告

手机

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

家电

环球视讯!21、现有网络模型的使用以及修改

来源:博客园


(资料图片仅供参考)

1、网络模型在pytorch里面的torchvision里面torchvision.models,是关于图像类的网络模型

2、简单以一个分类模型为例子: VGG(最常用的是VGG16和VGG19)

pretrained:

如果是true的话,说明在ImageNet数据集上,模型的参数是都训练好的; 如果是False的话,说明模型的参数是初始化的,没有训练好。

vgg16_false=torchvision.models.vgg16(pretrained=False)      #当pretrained为 False的时候只是加载网络模型。是不需要对网络模型的参数进行下载的vgg16_true=torchvision.models.vgg16(pretrained=True)   #pretrained=True时,需要下载网络模型,下载模型里的参数print(vgg16_true)

progress:

如果是True,显示下载进度条; False则不显示

3、ImageNet数据集:

4、修改现有模型

train_data=torchvision.datasets.CIFAR10("../../dataset/CIFAR10",train=False,                                        transform=torchvision.transforms.ToTensor(),download=True)"""如何利用现有的网络模型,去改动它的结构;比如说想让VGG是10分类任务,也就是让输出特征是10;可以有两种"""#1、再添加一个线性层vgg16_true.add_module("add_linear",nn.Linear(in_features=1000,out_features=10))#add_module()里面两个参数,一个是字符串型,给要加的模块起个名字,第二个是要加的模块,可以直接是一层网络,也可以是一个序列
print(vgg16_true)

输出:

# 2、如果想在序列里面添加可以这样网络模型.想要加的位置.add_moudle()vgg16_true.classifier.add_module("add_linear",nn.Linear(in_features=1000,out_features=10))print(vgg16_true)
# 3、不想添加的话,可以进行修改#对模型中的classifier中的第6层进行修改vgg16_false.classifier[6]=nn.Linear(4096,10)

关键词: 网络模型 进行修改 可以进行