我的编程空间,编程开发者的网络收藏夹
学习永远不晚

python之tensorflow手把手实例讲解猫狗识别实现

短信预约 -IT技能 免费直播动态提醒
省份

北京

  • 北京
  • 上海
  • 天津
  • 重庆
  • 河北
  • 山东
  • 辽宁
  • 黑龙江
  • 吉林
  • 甘肃
  • 青海
  • 河南
  • 江苏
  • 湖北
  • 湖南
  • 江西
  • 浙江
  • 广东
  • 云南
  • 福建
  • 海南
  • 山西
  • 四川
  • 陕西
  • 贵州
  • 安徽
  • 广西
  • 内蒙
  • 西藏
  • 新疆
  • 宁夏
  • 兵团
手机号立即预约

请填写图片验证码后获取短信验证码

看不清楚,换张图片

免费获取短信验证码

python之tensorflow手把手实例讲解猫狗识别实现

作为tensorflow初学的大三学生,本次课程作业的使用猫狗数据集做一个二分类模型。

一,猫狗数据集数目构成

train cats:1000 ,dogs:1000
test cats: 500,dogs:500
validation cats:500,dogs:500

二,数据导入


train_dir = 'Data/train'
test_dir = 'Data/test'
validation_dir = 'Data/validation'
train_datagen = ImageDataGenerator(rescale=1/255,
                                   rotation_range=10,
                                   width_shift_range=0.2,  #图片水平偏移的角度
                                   height_shift_range=0.2,  #图片数值偏移的角度
                                   shear_range=0.2,  #剪切强度 
                                   zoom_range=0.2,   #随机缩放的幅度
                                   horizontal_flip=True,   #是否进行随机水平翻转
#                                    fill_mode='nearest'
                                  )
train_generator = train_datagen.flow_from_directory(train_dir,
                 (224,224),batch_size=1,class_mode='binary',shuffle=False)
test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir,
                 (224,224),batch_size=1,class_mode='binary',shuffle=True)
validation_datagen = ImageDataGenerator(rescale=1/255)
validation_generator = validation_datagen.flow_from_directory(
                validation_dir,(224,224),batch_size=1,class_mode='binary')
print(train_datagen)
print(test_datagen)
print(train_datagen)

三,数据集构建

我这里是将ImageDataGenerator类里的数据提取出来,将数据与标签分别存放在两个列表,后面在转为np.array,也可以使用model.fit_generator,我将数据放在内存为了后续调参数时模型训练能更快读取到数据,不用每次训练一整轮都去读一次数据(应该是这样的…我是这样理解…)
注意我这里的数据集构建后,三种数据都是存放在内存中的,我电脑内存是16g的可以存放下。


train_data=[]
train_labels=[]
a=0
for data_train, labels_train in train_generator:
    train_data.append(data_train)
    train_labels.append(labels_train)
    a=a+1
    if a>1999:
        break
x_train=np.array(train_data)
y_train=np.array(train_labels)
x_train=x_train.reshape(2000,224,224,3)

test_data=[]
test_labels=[]
a=0
for data_test, labels_test in test_generator:
    test_data.append(data_test)
    test_labels.append(labels_test)
    a=a+1
    if a>999:
        break
x_test=np.array(test_data)
y_test=np.array(test_labels)
x_test=x_test.reshape(1000,224,224,3)

validation_data=[]
validation_labels=[]
a=0
for data_validation, labels_validation in validation_generator:
    validation_data.append(data_validation)
    validation_labels.append(labels_validation)
    a=a+1
    if a>999:
        break
x_validation=np.array(validation_data)
y_validation=np.array(validation_labels)
x_validation=x_validation.reshape(1000,224,224,3)

四,模型搭建


model1 = tf.keras.models.Sequential([
    # 第一层卷积,卷积核为,共16个,输入为150*150*1
    tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(224,224,3)),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 第二层卷积,卷积核为3*3,共32个,
    tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 第三层卷积,卷积核为3*3,共64个,
    tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),
    tf.keras.layers.MaxPooling2D((2,2)),
    
    # 数据铺平
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64,activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(1,activation='sigmoid')
])
print(model1.summary())

模型summary:

在这里插入图片描述

五,模型训练


model1.compile(optimize=tf.keras.optimizers.SGD(0.00001),
             loss=tf.keras.losses.binary_crossentropy,
             metrics=['acc'])
history1=model1.fit(x_train,y_train,
# 					validation_split=(0~1)   选择一定的比例用于验证集,可被validation_data覆盖
                  validation_data=(x_validation,y_validation),
                  batch_size=10,
                  shuffle=True,
                  epochs=10)
model1.save('cats_and_dogs_plain1.h5')
print(history1)

在这里插入图片描述


plt.plot(history1.epoch,history1.history.get('acc'),label='acc')
plt.plot(history1.epoch,history1.history.get('val_acc'),label='val_acc')
plt.title('正确率')
plt.legend()

在这里插入图片描述

可以看到我们的模型泛化能力还是有点差,测试集的acc能达到0.85以上,验证集却在0.65~0.70之前跳动。

六,模型测试


model1.evaluate(x_validation,y_validation)

在这里插入图片描述

最后我们的模型在测试集上的正确率为0.67,可以说还不够好,有点过拟合,可能是训练数据不够多,后续可以数据增广或者从验证集、测试集中调取一部分数据用于训练模型,可能效果好一些。

到此这篇关于python之tensorflow手把手实例讲解猫狗识别实现的文章就介绍到这了,更多相关python tensorflow 猫狗识别内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

python之tensorflow手把手实例讲解猫狗识别实现

下载Word文档到电脑,方便收藏和打印~

下载Word文档

猜你喜欢

Python实现OCR识别之pytesseract案例详解

Python实现OCR识别:pytesseract Python常用pytesseract进行图片上的文字识别,即OCR识别,完整的代码比较简单,只要下面一行即可,但是实际使用时环境配置上容易出错。from PIL import Image
2022-06-02

Java实现BP神经网络MNIST手写数字识别的示例详解

这篇文章主要为大家详细介绍了Java实现BP神经网络MNIST手写数字识别的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起了解一下
2023-01-31

编程热搜

  • Python 学习之路 - Python
    一、安装Python34Windows在Python官网(https://www.python.org/downloads/)下载安装包并安装。Python的默认安装路径是:C:\Python34配置环境变量:【右键计算机】--》【属性】-
    Python 学习之路 - Python
  • chatgpt的中文全称是什么
    chatgpt的中文全称是生成型预训练变换模型。ChatGPT是什么ChatGPT是美国人工智能研究实验室OpenAI开发的一种全新聊天机器人模型,它能够通过学习和理解人类的语言来进行对话,还能根据聊天的上下文进行互动,并协助人类完成一系列
    chatgpt的中文全称是什么
  • C/C++中extern函数使用详解
  • C/C++可变参数的使用
    可变参数的使用方法远远不止以下几种,不过在C,C++中使用可变参数时要小心,在使用printf()等函数时传入的参数个数一定不能比前面的格式化字符串中的’%’符号个数少,否则会产生访问越界,运气不好的话还会导致程序崩溃
    C/C++可变参数的使用
  • css样式文件该放在哪里
  • php中数组下标必须是连续的吗
  • Python 3 教程
    Python 3 教程 Python 的 3.0 版本,常被称为 Python 3000,或简称 Py3k。相对于 Python 的早期版本,这是一个较大的升级。为了不带入过多的累赘,Python 3.0 在设计的时候没有考虑向下兼容。 Python
    Python 3 教程
  • Python pip包管理
    一、前言    在Python中, 安装第三方模块是通过 setuptools 这个工具完成的。 Python有两个封装了 setuptools的包管理工具: easy_install  和  pip , 目前官方推荐使用 pip。    
    Python pip包管理
  • ubuntu如何重新编译内核
  • 改善Java代码之慎用java动态编译

目录