如何使用 Tensorflow 配置花卉数据集以提高性能?

pythonserver side programmingprogrammingtensorflow

创建模型时,花卉数据集会给出一定百分比的准确度。如果需要配置模型以提高性能,则定义一个函数,该函数第二次执行缓冲区预取,然后对其进行混洗。在训练数据集上调用此函数以提高模型的性能。

阅读更多: 什么是 TensorFlow,以及 Keras 如何与 TensorFlow 配合使用以创建神经网络?

我们将使用花卉数据集,其中包含数千朵花的图像。它包含 5 个子目录,每个类都有一个子目录。

我们使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 有助于在浏览器上运行 Python 代码,并且不需要任何配置,并且可以免费访问 GPU(图形处理单元)。Colaboratory 是在 Jupyter Notebook 之上构建的。

print("A function is defined that configures the dataset for perfromance")
def configure_for_performance(ds):
   ds = ds.cache()
   ds = ds.shuffle(buffer_size=1000)
   ds = ds.batch(batch_size)
   ds = ds.prefetch(buffer_size=AUTOTUNE)
   return ds

print("The function is called on training dataset")
train_ds = configure_for_performance(train_ds)
print("The function is called on validation dataset")
val_ds = configure_for_performance(val_ds)

代码来源:https://www.tensorflow.org/tutorials/load_data/images

输出

A function is defined that configures the dataset for perfromance
The function is called on training dataset
The function is called on validation dataset

解释

  • 需要使用数据集训练模型。
  • 首先对模型进行良好的打乱,然后进行分批,然后提供这些批次。
  • 这些功能是使用"tf.data"API添加的。

相关文章