本文使用Keras预训练的模型做图片分类,下载花卉图像数据集的方法参考TensorFlow教程。此数据集共有5种花卉,将花卉图片重新组织为训练集和测试集。
首先利用Keras的ImageDataGenerator函数进行数据增强,然后从文件目录中读取训练集和测试集。
train_aug = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_aug = ImageDataGenerator(rescale=1./255)
train_generator = train_aug.flow_from_directory(
os.path.join(ROOT_DIR, 'flower_photos/train'),
target_size=(224, 224),
batch_size=32)
validation_generator = test_aug.flow_from_directory(
os.path.join(ROOT_DIR, 'flower_photos/test'),
target_size=(224, 224),
batch_size=32)
使用Keras的VGG16模型进行微调,以便用于新数据集。参考官方文档中的代码。
base_model = VGG16(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
# 预测五类
predictions = Dense(5, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结原模型的层,仅训练添加后的层
for layer in base_model.layers:
layer.trainable = False
编译并运行模型。
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(
train_generator,
steps_per_epoch=len(train_generator),
epochs=10,
validation_data=validation_generator,
validation_steps=len(validation_generator))
参考资料: