2.3.4 自定义神经网络模型

神经网络层通过Cell的子类(SubClass)实现,同样地,神经网络模型也可以采用SubClass方法自定义神经网络模型;构建时需要在__init__方法中将要使用的神经网络组件实例化,在__call__方法中定义神经网络的计算逻辑。同样地,以2.3.1节的卷积神经网络模型为例,定义接口可用伪代码描述,如代码2.14所示。

代码2.14 自定义神经网络模型

对上述卷积模型进行实例化,其执行过程将从__init__方法开始,第一个是Conv2D,Conv2D也是Cell的子类,会进入Conv2D的__init__方法,此时会将第一个Conv2D的卷积参数收集到self._params中,之后回到Conv2D,将第一个Conv2D收集到self._cells;第二个的组件是MaxPool2D,因为其没有训练参数,因此将MaxPool2D收集到self._cells;以此类推,分别收集第二个卷积参数和卷积层,三个全连接层的参数和全连接层。实例化之后可以调用.parameters_and_names方法返回训练参数;调用conv.cells_and_names方法查看神经网络层列表。