趣味のPython・深層学習

中級者のための実装集

nn.ModuleListを解説してみる

PyTorchのnn.ModuleListとは?

nn.ModuleList は、PyTorchのニューラルネットワークモジュールの一部であり、複数の nn.Module オブジェクトをまとめて保持するためのコンテナです。これを使うことで、モデル内で複数のサブモデルを簡潔に管理できます。

なぜnn.ModuleListを使用するのか?

柔軟性と再利用性: nn.ModuleList を使うと、動的なサブモデルの追加や取り外しが可能になります。これにより、モデルを構築する際により柔軟で再利用可能なコードを書くことができます。 パラメータ管理: nn.ModuleList は、リスト内の各モジュールが持つパラメータを自動的にトラッキングします。これにより、モデル全体のパラメータ管理が簡単になります。

例: nn.ModuleListの使用

以下は、nn.ModuleList を使用していくつかのサブモデルを保持する例です。

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        # nn.ModuleListでサブモデルを保持
        self.submodules = nn.ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])

    def forward(self, x):
        # サブモデルの順伝播
        for layer in self.submodules:
            x = layer(x)
        return x

この例では、nn.ModuleList には nn.Linear と nn.ReLU の層が含まれています。これにより、モデル内でこれらの層を管理し、効果的に利用することができます。

下記ではSelf attentionのクラスを実装する例です

# 自己注意機構(Self-Attention)のヘッドを複数組み合わせるためのクラス
class SelfAttention_MultiHeads(nn.Module):

    # 初期化メソッド
    def __init__(self, n_mbed, num_heads, head_size, block_size):
        super().__init__()

        # Self-Attentionのヘッドをリストで保持するModuleListを作成
        self.heads = nn.ModuleList((SelfAttention_Head(n_mbed, head_size, block_size) for _ in range(num_heads)))

    # 順伝播メソッド
    def forward(self, x):
        # 各ヘッドに入力を渡し、結果を横方向(最後の次元)に結合して返す
        return torch.cat([h(x) for h in self.heads], dim=-1)

内包表記と組み合わせています。 内包表記を使用しない例を参考にあげておきます。

class SelfAttention_MultiHeads(nn.Module):

    # 初期化メソッド
    def __init__(self, n_mbed, num_heads, head_size, block_size):
        super().__init__()

        # Self-Attentionのヘッドをリストで保持するModuleListを作成
        self.heads = nn.ModuleList()
        for _ in range(num_heads):
            self.heads.append(SelfAttention_Head(n_mbed, head_size, block_size))

    # 順伝播メソッド
    def forward(self, x):
        # 各ヘッドに入力を渡し、結果を横方向(最後の次元)に結合して返す
        results = []
        for head in self.heads:
            results.append(head(x))
        return torch.cat(results, dim=-1)