PyBrain - 导入数据集数据

在本章中,我们将学习如何获取数据以使用 Pybrain 数据集。

最常用的数据集是 −

  • 使用 sklearn
  • 来自 CSV 文件

使用 sklearn

使用 sklearn

以下是包含来自 sklearn 的数据集详细信息的链接:https://scikit-learn.org/stable/datasets/toy_dataset.html

以下是一些如何使用来自 sklearn 的数据集的示例−

示例 1:load_digits()

from sklearn import datasets
from pybrain.datasets import ClassificationDataSet
digits = datasets.load_digits()
X, y = digits.data, digits.target
ds = ClassificationDataSet(64, 1, nb_classes=10)
for i in range(len(X)):
ds.addSample(ravel(X[i]), y[i])

示例 2: load_iris()

from sklearn import datasets
from pybrain.datasets import ClassificationDataSet
digits = datasets.load_iris()
X, y = digits.data, digits.target
ds = ClassificationDataSet(4, 1, nb_classes=3)
for i in range(len(X)):
ds.addSample(X[i], y[i])

来自 CSV 文件

我们也可以使用来自 csv 文件的数据,如下所示 −

这是 xor 真值表的示例数据:datasettest.csv

CSV 文件

这是从 .csv 文件中读取数据集数据的工作示例。

示例

from pybrain.tools.shortcuts import buildNetwork
from pybrain.structure import TanhLayer
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer
import pandas as pd

print('Read data...')
df = pd.read_csv('data/datasettest.csv',header=0).head(1000)
data = df.values

train_output = data[:,0]
train_data = data[:,1:]

print(train_output)
print(train_data)

# 创建一个具有两个输入、三个隐藏和一个输出的网络
nn = buildNetwork(2, 3, 1, bias=True, hiddenclass=TanhLayer)

# 创建一个与网络输入和输出大小匹配的数据集:
_gate = SupervisedDataSet(2, 1)

# 创建一个用于测试的数据集。
nortrain = SupervisedDataSet(2, 1)

# 将输入和目标值添加到数据集
# NOR 真值表的值
for i in range(0, len(train_output)) :
_gate.addSample(train_data[i], train_output[i])

# 使用数据集 norgate 训练网络。
trainer = BackpropTrainer(nn, _gate)

# 将运行循环 1000 次来训练它。
for epoch in range(1000):
	trainer.train()
trainer.testOnData(dataset=_gate, verbose = True)

如示例所示,使用 Panda 从 csv 文件读取数据。

输出

C:\pybrain\pybrain\src>python testcsv.py
Read data...
[0 1 1 0]
[
   [0 0]
   [0 1]
   [1 0]
   [1 1]
]
Testing on data:
('out: ', '[0.004 ]')
('correct:', '[0 ]')
error: 0.00000795
('out: ', '[0.997 ]')
('correct:', '[1 ]')
error: 0.00000380
('out: ', '[0.996 ]')
('correct:', '[1 ]')
error: 0.00000826
('out: ', '[0.004 ]')
('correct:', '[0 ]')
error: 0.00000829
('All errors:', [7.94733477723902e-06, 3.798267582566822e-06, 8.260969076585322e
-06, 8.286246525558165e-06])
('Average error:', 7.073204490487332e-06)
('Max error:', 8.286246525558165e-06, 'Median error:', 8.260969076585322e-06)