趣味のPython・深層学習

中級者のための実装集

【1dconv】一次元畳み込みCNNの実装

連続したデータを扱う時、1dCNNを実装したいときがあります。 今回はipynbファイルでデバッグを行いながら、実装する際に便利な関数をつけました。 中間層を取り出すおまけ付きです。

class Conv1d(nn.Module):
    def __init__(self, channel_1, channel_2, channel_3, kernel_size_1, kernel_size_2, kernel_size_3, debug=False):
        super(Conv1d, self).__init__()

        self.debug = debug
        self.conv1 = nn.Conv1d(1, channel_1, kernel_size=kernel_size_1, stride=2)
        self.bn1 = nn.BatchNorm1d(channel_1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=kernel_size_2, stride=2)
        self.conv2 = nn.Conv1d(channel_1, channel_2, kernel_size=kernel_size_2, stride=2)
        self.bn2 = nn.BatchNorm1d(channel_2)
        self.conv3 = nn.Conv1d(channel_2, channel_3, kernel_size=kernel_size_3, stride=1)
        self.gap = nn.AdaptiveAvgPool1d(15)
        self.fc = nn.Linear(30, 1)

    def forward(self, x):
        self.intermediate_outputs = []

        x = self.conv1(x)
        self.print_debug(x, 'conv1')

        x = self.bn1(x)
        self.print_debug(x, 'bn1')

        x = self.relu(x)
        self.print_debug(x, 'relu')

        x = self.maxpool(x)
        self.print_debug(x, 'maxpool')

        x = self.conv2(x)
        self.print_debug(x, 'conv2')

        x = self.bn2(x)
        self.print_debug(x, 'bn2')

        x = self.relu(x)
        self.print_debug(x, 'relu')

        x = self.maxpool(x)
        self.print_debug(x, 'maxpool')

        x = self.conv3(x)
        self.print_debug(x, 'conv3')

        x = self.gap(x)
        self.print_debug(x, 'gap')

        x = x.view(x.size(0), -1)
        self.print_debug(x, 'view')

        self.intermediate_outputs.append(x.clone().detach())
        self.print_debug(self.intermediate_outputs[0], '中間層')

        x = self.fc(x)
        self.print_debug(x, 'fc')

        return x

    def get_intermediate_outputs(self):
        return self.intermediate_outputs

    def print_debug(self, data, message):
        if self.debug:
            print(message, data.shape

下記は実際のテスト使用例になります。 チャンネル数やストライドなどの引数はタスクに合わせてチューニングして下さい。

# モデルのインスタンス化
model = Conv1d(channel_1=16, channel_2=32, channel_3=64, kernel_size_1=3, kernel_size_2=2, kernel_size_3=3, debug=True)

# ダミーの入力データの生成
dummy_input = torch.randn(1, 1, 128)  # サイズ (バッチサイズ, チャネル数, シーケンス長)

# モデルの概要を表示
summary(model, input_size=(1, 128))

# ダミーの入力データをモデルに渡して出力を得る
output = model(dummy_input)

# 中間層の出力を取得
intermediate_outputs = model.get_intermediate_outputs()

# 結果の表示
print("モデルの出力サイズ:", output.shape)
print("中間層の出力サイズ:", intermediate_outputs[0].shape)