juliaで前処理大全 6.生成
juliaで前処理大全その5です。今回は生成をテーマに扱います。 章のテーマとしては不均衡なデータを取り扱う手法について述べています。 手法としては多すぎるデータを削るアンダーサンプリング、そして少ないデータを水増しするオーバーサンプリングの2つですが、この章では主に後者を取り扱います。
準備
今回は準備として製造レコードproduction.csv
を読み込みます。
using DataFrames,CSV,Chain,Downloads
production_url = "https://raw.githubusercontent.com/hanafsky/awesomebook/master/data/production.csv"
production_df = @chain production_url Downloads.download CSV.File DataFrame
first(production_df,10) |> println
10×4 DataFrame
Row │ type length thickness fault_flg
│ String1 Float64 Float64 Bool
─────┼─────────────────────────────────────────
1 │ E 274.027 40.2411 false
2 │ D 86.3193 16.9067 false
3 │ E 123.94 1.01846 false
4 │ B 175.555 16.4149 false
5 │ B 244.935 29.0611 false
6 │ B 226.427 39.7638 false
7 │ C 331.638 16.8356 false
8 │ A 200.865 12.1843 false
9 │ C 276.387 29.8996 false
10 │ E 168.441 1.26592 false
アンダーサンプリングによる不均衡データの調整
ここはとくにコードもないので省略させてもらいます。
オーバーサンプリングによる不均衡データの調整
製造レコードのデータセットを用いて、データの分割を行います。
まずは、代表的なオーバーサンプリングの方法であるSMOTEについてアルゴリズムをおさらいします。[1]
前処理大全においては、自分でSMOTEの処理を実装するのではなく、適当なライブラリからインポートしてくることを推奨しています。
juliaでオーバーサンプリングを提供しているパッケージとしては、ClassImbalance.jlとMLUtils.jlがありました。前者は長らくメンテされていないため、依存パッケージのバージョンが整合しない可能性もあります。後者はまだドキュメントが整備されていないようで、SMOTEは実装されていないようです。[2] 自分で実装しても良いのですが、ここではpythonのライブラリを呼び出してみたいと思います。 本のpythonコードが古いせいか、キーワード引数名に違いはあります。 また、DataFrame型をそのまま渡すことができないので、一度マトリックスに変換してオーバーサンプリング後にDataFrame型に戻す操作をしています。 この辺りの作業は止むを得ないですが、"Not Awesome"かもしれません。
using PyCall
imblearn = pyimport("imblearn.over_sampling")
sm = imblearn.SMOTE(sampling_strategy="auto", k_neighbors=5, random_state=71)
imb_data = production_df[!,[:length,:thickness]] |> Matrix
imb_target = production_df.fault_flg
balance_data,balance_target = sm.fit_resample(imb_data, imb_target)
new_df= DataFrame(hcat(balance_data,balance_target),["length","thickness","fault_flg"]);
せっかくなので可視化してみることにしましょう。 まず、オーバーサンプリングする前のデータを散布図にしてみます。
using StatsPlots
p = @chain production_df begin
filter(:fault_flg=>==(false),_)
@df scatter(:length, :thickness, label="false")
end
@chain production_df begin
filter(:fault_flg=>==(true),_)
@df scatter!(:length, :thickness, label="true")
end
"/home/runner/work/hanafsky.github.io/hanafsky.github.io/__site/assets/tips/preprocess/generation/code/output/5-1.svg"
真の値のデータが偽のデータに比べて非常に少ないことがわかります。 一方で、オーバーサンプリングされたデータはどうでしょうか。
p2 = @chain new_df begin
filter(:fault_flg=>==(false),_)
@df scatter(:length, :thickness, label="false")
end
@chain new_df begin
filter(:fault_flg=>==(true),_)
@df scatter!(:length, :thickness, label="true")
end
"/home/runner/work/hanafsky.github.io/hanafsky.github.io/__site/assets/tips/preprocess/generation/code/output/5-2.svg"
もともとのデータに比べて真のデータが水増しされていることがわかります。
[1] | Synthetic Minority Over-sampling Technique |
[2] | oversampleという関数はある。 |
つづく