本文使用Keras预训练的模型做图片分类,下载花卉图像数据集的方法参考TensorFlow教程。此数据集共有5种花卉,将花卉图片重新组织为训练集和测试集。

首先利用Keras的ImageDataGenerator函数进行数据增强,然后从文件目录中读取训练集和测试集。

train_aug = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_aug = ImageDataGenerator(rescale=1./255)

train_generator = train_aug.flow_from_directory(
        os.path.join(ROOT_DIR, 'flower_photos/train'),
        target_size=(224, 224),
        batch_size=32)

validation_generator = test_aug.flow_from_directory(
        os.path.join(ROOT_DIR, 'flower_photos/test'),
        target_size=(224, 224),
        batch_size=32)

使用Keras的VGG16模型进行微调,以便用于新数据集。参考官方文档中的代码。

base_model = VGG16(weights='imagenet', include_top=False)

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
# 预测五类
predictions = Dense(5, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)
# 冻结原模型的层,仅训练添加后的层
for layer in base_model.layers:
    layer.trainable = False

编译并运行模型。

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(
        train_generator,
        steps_per_epoch=len(train_generator),
        epochs=10,
        validation_data=validation_generator,
        validation_steps=len(validation_generator))

参考资料:

如何使用Keras fit和fit_generator(动手教程)