PyBrain - 使用循环网络

循环网络与前馈网络相同,唯一的区别是您需要记住每一步的数据。必须保存每一步的历史记录。

我们将学习如何−

  • 创建循环网络
  • 添加模块和连接

创建循环网络

要创建循环网络,我们将使用 RecurrentNetwork 类,如下所示 −

rn.py

from pybrain.structure import RecurrentNetwork
recurrentn = RecurrentNetwork()
print(recurrentn)

python rn.py

C:\pybrain\pybrain\src>python rn.py
RecurrentNetwork-0
Modules:
[]
Connections:
[]
Recurrent Connections:
[]

我们可以看到循环网络有一个名为"循环连接"的新连接。目前没有可用数据。

现在让我们创建层并添加到模块并创建连接。

添加模块和连接

我们将创建层,即输入、隐藏和输出。这些层将添加到输入和输出模块。接下来,我们将创建输入到隐藏、隐藏到输出的连接以及隐藏到隐藏之间的循环连接。

以下是带有模块和连接的循环网络的代码。

rn.py

from pybrain.structure import RecurrentNetwork
from pybrain.structure import LinearLayer, SigmoidLayer
from pybrain.structure import FullConnection
recurrentn = RecurrentNetwork()

#为输入创建层 => 2 , hidden=> 3 和输出=>1
inputLayer = LinearLayer(2, 'rn_in')
hiddenLayer = SigmoidLayer(3, 'rn_hidden')
outputLayer = LinearLayer(1, 'rn_output')

#将层添加到前馈网络
recurrentn.addInputModule(inputLayer)
recurrentn.addModule(hiddenLayer)
recurrentn.addOutputModule(outputLayer)

#在输入、隐藏和输出之间创建连接
input_to_hidden = FullConnection(inputLayer, hiddenLayer)
hidden_​​to_output = FullConnection(hiddenLayer, outputLayer)
hidden_​​to_hidden = FullConnection(hiddenLayer, hiddenLayer)

#将连接添加到网络
recurrentn.addConnection(input_to_hidden)
recurrentn.addConnection(hidden_​​to_output)
recurrentn.addRecurrentConnection(hidden_​​to_hidden)
recurrentn.sortModules()

print(recurrentn)

python rn.py

C:\pybrain\pybrain\src>python rn.py
RecurrentNetwork-6
Modules:
[<LinearLayer 'rn_in'>, <SigmoidLayer 'rn_hidden'>, 
   <LinearLayer 'rn_output'>]
Connections:
[<FullConnection 'FullConnection-4': 'rn_hidden' -> 'rn_output'>, 
   <FullConnection 'FullConnection-5': 'rn_in' -> 'rn_hidden'>]
Recurrent Connections:
[<FullConnection 'FullConnection-3': 'rn_hidden' -> 'rn_hidden'>]

在上面的输出中,我们可以看到模块、连接和循环连接。

现在让我们使用 activate 方法激活网络,如下所示 −

rn.py

将以下代码添加到之前创建的代码中 −

#使用 activate() 方法激活网络
act1 = recurrentn.activate((2, 2))
print(act1)

act2 = recurrentn.activate((2, 2))
print(act2)

python rn.py

C:\pybrain\pybrain\src>python rn.py
[-1.24317586]
[-0.54117783]