趣味のPython・深層学習

中級者のための実装集

ONNXによるC++ベース機械学習推論

ONNXを使ってPyTorchモデルの推論を高速化する

機械学習モデルの実運用においては、推論時間の高速化が非常に重要な課題となります。特に大規模モデルを使う場合、推論に時間がかかり過ぎると実用的ではなくなってしまいます。幸いPyTorchには、ONNXを使ってモデルの推論を高速化する機能が用意されています。

ここでは、ONNXを使った推論高速化の手順と、実際のコード例を見ていきましょう。

手順1: モデルの定義

まず通常通り、PyTorchを使ってモデルを定義します。今回はシンプルな線形回帰モデルを例に使用します。

import torch.nn as nn

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        out = self.linear(x)
        return out 

手順2: モデルの学習

次にモデルを学習させます。ここでは詳細は省略します。

手順3: ONNXへの変換

学習済みのモデルをONNX形式に変換します。

import torch.onnx

# 入力データのダミーを作成
dummy_input = torch.randn(1, 1)

# モデルをONNXに変換
torch.onnx.export(model, dummy_input, "linear_model.onnx", opset_version=11)

手順4: ONNXランタイムでの推論

ONNXランタイムを使って、ONNXモデルから推論を行います。

import onnxruntime

# ONNXセッションを作成
ort_session = onnxruntime.InferenceSession("linear_model.onnx")

# 入力データ
X = [[0.5]]

# 推論実行
input_name = ort_session.get_inputs()[0].name 
output_name = ort_session.get_outputs()[0].name
outputs = ort_session.run([output_name], {input_name: X})

# 結果の表示
print(f"Output: {outputs[0]}")

以上のように、ONNXを介すことでPyTorchモデルの推論を大幅に高速化できます。ONNXランタイムは最適化されたC++ベースのエンジンを利用しているため、PyTorch自体のPythonベースの実装より高速に動作します。 特に大規模モデルを使う場合、ONNXを使った推論高速化の恩恵は大きくなります。実運用時の推論パフォーマンスを確保したい場合は、ぜひONNXの活用を検討してみてください