keras模型训练 + .h5转变可部署.pb

核对环境

1tensorflow版本: 2.0.1 2keras版本: 2.3.1 3 4

数据集文件夹结构
在这里插入图片描述
模型训练

1# -*- coding: utf-8 -*- 2""" 3Created on Sat Jun 6 14:28:50 2020 4 5@author: USER 6""" 7from __future__ import absolute_import,division,print_function,unicode_literals 8import tensorflow as tf 9from tensorflow.keras.preprocessing.image import ImageDataGenerator 10from tensorflow.keras import layers 11import os 12import matplotlib.pyplot as plt 13import pathlib 14 15 16data_dir = '~/dataset/my_dataset/' 17PATH = pathlib.Path(data_dir) 18 19train_dir = os.path.join(PATH,'train') 20validation_dir = os.path.join(PATH,'validation') 21train_rose_dir = os.path.join(train_dir,'rose') 22train_sunflowers_dir = os.path.join(train_dir,'sunflowers') 23validation_rose_dir = os.path.join(validation_dir,'rose') 24validation_sunflowers_dir = os.path.join(validation_dir,'sunflowers') 25 26num_rose_tr = len(os.listdir(train_rose_dir)) 27num_sunflowers_tr = len(os.listdir(train_sunflowers_dir)) 28num_rose_val = len(os.listdir(validation_rose_dir)) 29num_sunflowers_val = len(os.listdir(validation_sunflowers_dir)) 30total_train = num_rose_tr + num_sunflowers_tr 31total_val = num_rose_val + num_sunflowers_val 32 33batch_size = 128 #batch数量 34epochs = 5 #训练次数 35IMG_HEIGHT = 150 #图片高 36IMG_WIDTH = 150 #图片宽 37 38train_image_generator = ImageDataGenerator(rescale=1./255, # 归一化 39 horizontal_flip=True, # 图片翻转 40 width_shift_range=.15, # 宽变化 41 height_shift_range=.15, # 高变化 42 rotation_range=45, # 旋转45度 43 zoom_range=0.5 # 缩放0.5倍 44 ) 45valiadation_image_generator = ImageDataGenerator(rescale=1./255) # 验证集不进行augmentation处理 46 47train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size, 48 directory=train_dir, # 训练集路径 49 shuffle=True, # 打乱图片顺序 50 target_size=(IMG_HEIGHT,IMG_WIDTH),# 修改图片尺寸 51 class_mode='binary') 52val_data_gen = valiadation_image_generator.flow_from_directory(batch_size=batch_size, 53 directory=validation_dir, # 验证集路径 54 target_size=(IMG_HEIGHT,IMG_WIDTH), # 修改图片尺寸 55 class_mode='binary') 56 57def plotImages(images_arr): 58 fig,axes = plt.subplots(1,5,figsize=(20,20)) 59 axes = axes.flatten() 60 for img,ax in zip(images_arr,axes): 61 ax.imshow(img) 62 ax.axis('off') 63 plt.tight_layout() 64 plt.show() 65 66augemted_images=[train_data_gen[0][0][0] for i in range(5)] 67plotImages(augemted_images) 68 69model = tf.keras.models.Sequential([ 70 layers.Conv2D(16,3,padding='same',activation='relu',input_shape =(IMG_HEIGHT,IMG_WIDTH,3)), #16为filter个数 3为kernel_size 71 layers.MaxPooling2D(), 72 layers.Dropout(0.2), # 防过拟合 73 layers.Conv2D(32,3,padding='same',activation='relu'), 74 layers.MaxPooling2D(), 75 layers.Conv2D(64,3,padding='same',activation='relu'), 76 layers.MaxPooling2D(), 77 layers.Dropout(0.2), 78 layers.Flatten(), 79 layers.Dense(512,activation='relu'), 80 layers.Dense(1,activation='sigmoid') # 2分类,因此1个神经元 81]) 82 83model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy']) 84model.summary() 85history = model.fit_generator( 86 train_data_gen, # 训练集 87 steps_per_epoch=3, # 每个epoch训练batchsize的个数 88 epochs=epochs, 89 validation_data=val_data_gen,# 验证集 90 validation_steps=3, 91) 92 93def show_train_history(train_history,train,validation): 94 plt.plot(train_history.history[train]) # 绘制训练数据的执行结果 95 plt.plot(train_history.history[validation]) # 绘制验证数据的执行结果 96 plt.title('Train History') # 图标题 97 plt.xlabel('epoch') # x轴标签 98 plt.ylabel(train) # y轴标签 99 plt.legend(['train','validation'],loc='upper left') # 添加左上角图例 100 101model.evaluate(train_data_gen) 102 103model.save("F:/spyder_project/lidongdong/lab/model/my_model.h5") 104 105

模型文件.h5转换.pb

1import tensorflow as tf 2from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 3 4def convert_h5to_pb(): 5 model = tf.keras.models.load_model("F:/spyder_project/lidongdong/lab/model/my_model.h5",compile=False) 6 model.summary() 7 full_model = tf.function(lambda Input: model(Input)) 8 full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) 9 10 # Get frozen ConcreteFunction 11 frozen_func = convert_variables_to_constants_v2(full_model) 12 frozen_func.graph.as_graph_def() 13 14 layers = [op.name for op in frozen_func.graph.get_operations()] 15 print("-" * 50) 16 print("Frozen model layers: ") 17 for layer in layers: 18 print(layer) 19 20 print("-" * 50) 21 print("Frozen model inputs: ") 22 print(frozen_func.inputs) 23 print("Frozen model outputs: ") 24 print(frozen_func.outputs) 25 26 # Save frozen graph from frozen ConcreteFunction to hard drive 27 tf.io.write_graph(graph_or_graph_def=frozen_func.graph, 28 logdir="~/model", 29 name="mnist.pb", 30 as_text=False) 31 32convert_h5to_pb() 33 34

代码交流 2021