如何使用 Tensorflow 配置花卉数据集以提高性能?
pythonserver side programmingprogrammingtensorflow
创建模型时,花卉数据集会给出一定百分比的准确度。如果需要配置模型以提高性能,则定义一个函数,该函数第二次执行缓冲区预取,然后对其进行混洗。在训练数据集上调用此函数以提高模型的性能。
阅读更多: 什么是 TensorFlow,以及 Keras 如何与 TensorFlow 配合使用以创建神经网络?
我们将使用花卉数据集,其中包含数千朵花的图像。它包含 5 个子目录,每个类都有一个子目录。
我们使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 有助于在浏览器上运行 Python 代码,并且不需要任何配置,并且可以免费访问 GPU(图形处理单元)。Colaboratory 是在 Jupyter Notebook 之上构建的。
print("A function is defined that configures the dataset for perfromance") def configure_for_performance(ds): ds = ds.cache() ds = ds.shuffle(buffer_size=1000) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=AUTOTUNE) return ds print("The function is called on training dataset") train_ds = configure_for_performance(train_ds) print("The function is called on validation dataset") val_ds = configure_for_performance(val_ds)
代码来源:https://www.tensorflow.org/tutorials/load_data/images
输出
A function is defined that configures the dataset for perfromance The function is called on training dataset The function is called on validation dataset
解释
- 需要使用数据集训练模型。
- 首先对模型进行良好的打乱,然后进行分批,然后提供这些批次。
- 这些功能是使用"tf.data"API添加的。