趣味のPython・深層学習

中級者のための実装集

PyTorchのメモリ節約小技

PyTorchでメモリ使用量を最小限に抑える

PyTorchを使ってディープラーニングのモデルを構築する際、メモリ使用量を最小限に抑えることが非常に重要です。メモリの効率的な利用によって、モデルの学習がスムーズに進み、また推論時のパフォーマンスも向上します。 今回は、バッチサイズBの2D テンソル (B x H x W) を3チャンネルのイメージテンソル (B x 3 x H x W) に変換する際の、非効率的な方法と効率的な方法を比較してみましょう。

非効率的な方法

x = x.unsqueeze(-1) 
x = torch.cat([x, x, x], dim=3).permute(0, 3, 1, 2)

この方法では、まずxをBxHxWx1の形状に変形し、そのテンソルを3回結合してBxHxWx3のテンソルを作成しています。最後にpermute()を使って、チャンネル次元を移動させています。この一連の操作は、多くのデータ移動を伴うため、非効率的です。

効率的な方法

x = x.unsqueeze(1)
x = x.expand(-1, 3, -1, -1)

一方、この方法ではデータ移動がほとんど発生しません。最初にx.unsqueeze(1)でBxHxWを BxIxHxW に変形し、次にx.expand()を使ってチャンネル次元を3に拡張しています。 expand()は新しいメモリを確保せずに、既存のテンソルを拡張します。そのため、非常に軽量な操作となり、メモリ使用量を最小限に抑えることができます。 まとめると、PyTorchではできる限りデータ移動を避け、expand()などのinplace操作を活用することで、メモリ使用量を最適化することが重要です。特に大規模なモデルを扱う際には、このようなテクニックを意識的に取り入れる必要があります。