趣味のPython・深層学習

中級者のための実装集

pandas Dataframeの重み付きアンサンブル

重み付きアンサンブルとは?

重み付きアンサンブルは、複数のモデルの予測結果に対して異なる重みを割り当て、それらを組み合わせる手法です。各モデルに与えられる重みは、そのモデルの性能や信頼性に基づいて決定されます。この方法は、アンサンブル全体の性能を向上させるのに寄与します。

サンプルコード

ここでは、kaggleでnotebookを提出するときに便利なデータフレームでのアンサンブルサンプルコードを提供します。

import pandas as pd

def calculate_weighted_average(*dfs, weights=None):
    # dfsが空でないことを確認
    if not dfs:
        raise ValueError("No DataFrames provided.")
    
    # カラムリストを取得
    columns = dfs[0].columns
    
    # 重みが提供されているか確認
    if weights is None:
        weights = [1] * len(dfs)
    elif len(weights) != len(dfs):
        raise ValueError("Number of weights must match the number of DataFrames.")
    
    # それぞれのカラムごとに重みつき平均を計算
    weighted_avg_df = pd.DataFrame({'ID': dfs[0]['ID']})
    
    for column in columns:
        weighted_avg_df[column] = sum(df[column] * weight for df, weight in zip(dfs, weights)) / sum(weights)
    
    return weighted_avg_df

# 三つのダミーデータを作成
df1 = pd.DataFrame({'ID': [1, 2, 3],
                    'seizure_vote': [10, 20, 30],
                    'lpd_vote': [15, 25, 35],
                    'gpd_vote': [12, 22, 32],
                    'lrda_vote': [18, 28, 38],
                    'grda_vote': [14, 24, 34],
                    'other_vote': [16, 26, 36]})

df2 = pd.DataFrame({'ID': [1, 2, 3],
                    'seizure_vote': [12, 22, 32],
                    'lpd_vote': [16, 26, 36],
                    'gpd_vote': [14, 24, 34],
                    'lrda_vote': [20, 30, 40],
                    'grda_vote': [18, 28, 38],
                    'other_vote': [22, 32, 42]})

df3 = pd.DataFrame({'ID': [1, 2, 3],
                    'seizure_vote': [14, 24, 34],
                    'lpd_vote': [18, 28, 38],
                    'gpd_vote': [16, 26, 36],
                    'lrda_vote': [22, 32, 42],
                    'grda_vote': [20, 30, 40],
                    'other_vote': [24, 34, 44]})

# 重みのリストを作成
weights = [0.3, 0.5, 0.2]

# 関数を呼び出して重み付き平均を計算
result_df = calculate_weighted_average(df1, df2, df3, weights=weights)

# 結果を表示
print(result_df)