PyBrain - 导入数据集数据
在本章中,我们将学习如何获取数据以使用 Pybrain 数据集。
最常用的数据集是 −
- 使用 sklearn
- 来自 CSV 文件
使用 sklearn
使用 sklearn
以下是包含来自 sklearn 的数据集详细信息的链接:https://scikit-learn.org/stable/datasets/toy_dataset.html
以下是一些如何使用来自 sklearn 的数据集的示例−
示例 1:load_digits()
from sklearn import datasets from pybrain.datasets import ClassificationDataSet digits = datasets.load_digits() X, y = digits.data, digits.target ds = ClassificationDataSet(64, 1, nb_classes=10) for i in range(len(X)): ds.addSample(ravel(X[i]), y[i])
示例 2: load_iris()
from sklearn import datasets from pybrain.datasets import ClassificationDataSet digits = datasets.load_iris() X, y = digits.data, digits.target ds = ClassificationDataSet(4, 1, nb_classes=3) for i in range(len(X)): ds.addSample(X[i], y[i])
来自 CSV 文件
我们也可以使用来自 csv 文件的数据,如下所示 −
这是 xor 真值表的示例数据:datasettest.csv
这是从 .csv 文件中读取数据集数据的工作示例。
示例
from pybrain.tools.shortcuts import buildNetwork from pybrain.structure import TanhLayer from pybrain.datasets import SupervisedDataSet from pybrain.supervised.trainers import BackpropTrainer import pandas as pd print('Read data...') df = pd.read_csv('data/datasettest.csv',header=0).head(1000) data = df.values train_output = data[:,0] train_data = data[:,1:] print(train_output) print(train_data) # 创建一个具有两个输入、三个隐藏和一个输出的网络 nn = buildNetwork(2, 3, 1, bias=True, hiddenclass=TanhLayer) # 创建一个与网络输入和输出大小匹配的数据集: _gate = SupervisedDataSet(2, 1) # 创建一个用于测试的数据集。 nortrain = SupervisedDataSet(2, 1) # 将输入和目标值添加到数据集 # NOR 真值表的值 for i in range(0, len(train_output)) : _gate.addSample(train_data[i], train_output[i]) # 使用数据集 norgate 训练网络。 trainer = BackpropTrainer(nn, _gate) # 将运行循环 1000 次来训练它。 for epoch in range(1000): trainer.train() trainer.testOnData(dataset=_gate, verbose = True)
如示例所示,使用 Panda 从 csv 文件读取数据。
输出
C:\pybrain\pybrain\src>python testcsv.py Read data... [0 1 1 0] [ [0 0] [0 1] [1 0] [1 1] ] Testing on data: ('out: ', '[0.004 ]') ('correct:', '[0 ]') error: 0.00000795 ('out: ', '[0.997 ]') ('correct:', '[1 ]') error: 0.00000380 ('out: ', '[0.996 ]') ('correct:', '[1 ]') error: 0.00000826 ('out: ', '[0.004 ]') ('correct:', '[0 ]') error: 0.00000829 ('All errors:', [7.94733477723902e-06, 3.798267582566822e-06, 8.260969076585322e -06, 8.286246525558165e-06]) ('Average error:', 7.073204490487332e-06) ('Max error:', 8.286246525558165e-06, 'Median error:', 8.260969076585322e-06)