趣味のPython・深層学習

中級者のための実装集

深層学習内部の行列の内積を行うクラス(Matmul)の実装

はじめに

行列の乗算(Matrix Multiplication)は、深層学習などで頻繁に利用される基本的な演算です。その中でも、逆伝播(Backpropagation)における勾配計算の一環として行われることがあります。この記事では、NumPyを用いて行列の乗算を行うMatMulクラスに焦点を当て、その役割や逆伝播の概念について解説します。 ※この記事のコードは書籍「ゼロから作るDeep learning2」を参考にしております。

実装内容

import numpy as np

class MatMul:
    def __init__(self, W: np.ndarray) -> None:
        # 重み行列 W を使用して MatMul インスタンスを初期化します。
        self.params: list[np.ndarray] = [W]
        self.grads: list[np.ndarray] = [np.zeros_like(W)]
        self.x: np.ndarray | None = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        # 行列の乗算の順伝播を実行します。
        W, = self.params
        out: np.ndarray = np.dot(x, W)
        self.x = x
        return out

    def backward(self, dout: np.ndarray) -> np.ndarray:
        # 勾配を計算するための逆伝播を実行します。
        W, = self.params
        dx: np.ndarray = np.dot(dout, W.T)
        dW: np.ndarray = np.dot(self.x.T, dout)
        
        # 勾配を更新します。grads[0] は dW を格納しています。
        self.grads[0][...] = dW
        return dx

初期化部分

def __init__(self, W: np.ndarray) -> None:
    # 重み行列 W を使用して MatMul インスタンスを初期化します。
    self.params: list[np.ndarray] = [W]
    self.grads: list[np.ndarray] = [np.zeros_like(W)]
    self.x: np.ndarray | None = None

MatMulクラスは、行列の乗算を担当するためのインスタンスを初期化します。重要な属性には、重み行列 W、その勾配 dW、および順伝播時の入力 x があります。これらは逆伝播で使用されます。

順伝播部分

def forward(self, x: np.ndarray) -> np.ndarray:
    # 行列の乗算の順伝播を実行します。
    W, = self.params
    out: np.ndarray = np.dot(x, W)
    self.x = x
    return out

forward メソッドは、行列の乗算の順伝播を行います。入力行列 x と重み行列 W の積を計算し、その結果を返します。また、self.x に順伝播時の入力を保存しておくことで、後の逆伝播で使用します。

逆伝播部分

def backward(self, dout: np.ndarray) -> np.ndarray:
    # 勾配を計算するための逆伝播を実行します。
    W, = self.params
    dx: np.ndarray = np.dot(dout, W.T)
    dW: np.ndarray = np.dot(self.x.T, dout)
    
    # 勾配を更新します。grads[0] は dW を格納しています。
    self.grads[0][...] = dW
    return dx

backward メソッドは逆伝播を担当し、連鎖律を用いて入力に関する勾配 dx とパラメータに関する勾配 dW を計算します。これらの勾配は後の層への逆伝播やパラメータの更新に使用されます。self.grads[0] にはパラメータ W に関する勾配が格納され、これがモデルの学習において重要な役割を果たします。