示例 2 训练

训练函数

async function trainModel(model, inputs, labels, surface) {
  const batchSize = 25;
  const epochs = 100;
  const callbacks = tfvis.show.fitCallbacks(surface, ['loss'], {callbacks:['onEpochEnd']})
  return await model.fit(inputs, labels,
    {batchSize, epochs, shuffle:true, callbacks:callbacks}
  );
}

亲自试一试 »

epochs 定义模型将执行多少次迭代(循环)。

model.fit 是运行循环的函数。

callbacks定义了模型要重绘图形时要调用的回调函数。


测试模型

在训练模型时,测试和评估它很重要。

我们通过检查模型对一系列不同输入的预测来做到这一点。

但是,在我们这样做之前,我们必须对数据进行非标准化:

取消标准化

let unX = tf.linspace(0, 1, 100);
let unY = model.predict(unX.reshape([100, 1]));

const unNormunX = unX.mul(inputMax.sub(inputMin)).add(inputMin);
const unNormunY = unY.mul(labelMax.sub(labelMin)).add(labelMin);

unX = unNormunX.dataSync();
unY = unNormunY.dataSync();

然后我们可以看看结果:

绘制结果

const predicted = Array.from(unX).map((val, i) => {
return {x: val, y: unY[i]}
});

// 绘制结果
tfPlot([values, predicted], surface1)

亲自试一试 »