juliaで前処理大全 6.生成

(src=https://pixabay.com/photos/food-salad-raw-carrots-1209503/)

juliaで前処理大全その5です。今回は生成をテーマに扱います。 章のテーマとしては不均衡なデータを取り扱う手法について述べています。 手法としては多すぎるデータを削るアンダーサンプリング、そして少ないデータを水増しするオーバーサンプリングの2つですが、この章では主に後者を取り扱います。

準備

今回は準備として製造レコードproduction.csvを読み込みます。

using DataFrames,CSV,Chain,Downloads
production_url = "https://raw.githubusercontent.com/hanafsky/awesomebook/master/data/production.csv"

production_df = @chain production_url Downloads.download CSV.File DataFrame
first(production_df,10) |> println
10×4 DataFrame
 Row │ type     length    thickness  fault_flg
     │ String1  Float64   Float64    Bool
─────┼─────────────────────────────────────────
   1 │ E        274.027    40.2411       false
   2 │ D         86.3193   16.9067       false
   3 │ E        123.94      1.01846      false
   4 │ B        175.555    16.4149       false
   5 │ B        244.935    29.0611       false
   6 │ B        226.427    39.7638       false
   7 │ C        331.638    16.8356       false
   8 │ A        200.865    12.1843       false
   9 │ C        276.387    29.8996       false
  10 │ E        168.441     1.26592      false

アンダーサンプリングによる不均衡データの調整

ここはとくにコードもないので省略させてもらいます。

オーバーサンプリングによる不均衡データの調整

製造レコードのデータセットを用いて、データの分割を行います。

まずは、代表的なオーバーサンプリングの方法であるSMOTEについてアルゴリズムをおさらいします。[1]

graph TD id1(生成元のデータからランダムに1つのデータを選択) id2(1からkの整数値からランダムに選択しnを設定) id3(選択したデータにn番目に近いデータを新たに選択) id4(2つのデータの間のデータを生成) id5(指定したデータ数に達するまで繰り返す) id1-->id2 id2-->id3 id3-->id4 id4-->id5

前処理大全においては、自分でSMOTEの処理を実装するのではなく、適当なライブラリからインポートしてくることを推奨しています。

juliaでオーバーサンプリングを提供しているパッケージとしては、ClassImbalance.jlとMLUtils.jlがありました。前者は長らくメンテされていないため、依存パッケージのバージョンが整合しない可能性もあります。後者はまだドキュメントが整備されていないようで、SMOTEは実装されていないようです。[2] 自分で実装しても良いのですが、ここではpythonのライブラリを呼び出してみたいと思います。 本のpythonコードが古いせいか、キーワード引数名に違いはあります。 また、DataFrame型をそのまま渡すことができないので、一度マトリックスに変換してオーバーサンプリング後にDataFrame型に戻す操作をしています。 この辺りの作業は止むを得ないですが、"Not Awesome"かもしれません。

using PyCall
imblearn = pyimport("imblearn.over_sampling")
sm = imblearn.SMOTE(sampling_strategy="auto", k_neighbors=5, random_state=71)
imb_data = production_df[!,[:length,:thickness]] |> Matrix
imb_target = production_df.fault_flg
balance_data,balance_target = sm.fit_resample(imb_data, imb_target)
new_df= DataFrame(hcat(balance_data,balance_target),["length","thickness","fault_flg"]);

せっかくなので可視化してみることにしましょう。 まず、オーバーサンプリングする前のデータを散布図にしてみます。

using StatsPlots
p = @chain production_df begin
        filter(:fault_flg=>==(false),_)
        @df scatter(:length, :thickness, label="false")
    end
                                                              
@chain production_df begin
      filter(:fault_flg=>==(true),_)
      @df scatter!(:length, :thickness, label="true")
    end
"/home/runner/work/hanafsky.github.io/hanafsky.github.io/__site/assets/tips/preprocess/generation/code/output/5-1.svg"

真の値のデータが偽のデータに比べて非常に少ないことがわかります。 一方で、オーバーサンプリングされたデータはどうでしょうか。

p2 = @chain new_df begin
        filter(:fault_flg=>==(false),_)
        @df scatter(:length, :thickness, label="false")
    end
                                                              
@chain new_df begin
      filter(:fault_flg=>==(true),_)
      @df scatter!(:length, :thickness, label="true")
    end
"/home/runner/work/hanafsky.github.io/hanafsky.github.io/__site/assets/tips/preprocess/generation/code/output/5-2.svg"

もともとのデータに比べて真のデータが水増しされていることがわかります。

[1] Synthetic Minority Over-sampling Technique
[2] oversampleという関数はある。

つづく