2018年6月24日日曜日

SAMSE: Sign-adjusted Mean Squared Error

回帰の損失関数は色々あるのですが、昔からなぜないのだろうと思う計測方法があります。 それがタイトルにも書いたように、符号調整済みの平均二乗誤差。 特定の数値を基準にして、一定の数値以上か以下かで損失の与え方を変える誤差関数です。 まあ平均二乗誤差 (MSE) ではなく、MAPEやMAEベースでも良いのですが、とにかく符号を考慮するとどうなるのかと思ってました。

この損失関数はどこに使えるかというと、 例えばノイズのあるsin波を予測し、0以上かどうかで何らかの処理を判定する時に使えます。 そのような条件の場合、0.1と予測して0.3だった時と、-0.1だった時では意味が異なるのですが、 MSEなどの有名な損失関数ではこれを適切に扱えません。

似たような事は他の手段でできない事もないのですが、 (1)分類誤差と二乗誤差の合計値を損失関数にすると二乗誤差部分は正負を考慮できない問題が残ります。 経験的にも二乗誤差に予測が引きづられ過ぎてしまう印象がある。 (2)クラス分類でサンプルの重みを変えたりすると多少は数値予測っぽさは出せますが、softmaxに縛られ過ぎてしまう問題が残ります。 (3)出力を分割してそれぞれMSEで損失を計測するのは近いと思いますが、これだと全体に損失をかけにくくなる。 そのためクラスの重みをある程度考慮しつつベースは距離関数という損失関数のほうが、 使い勝手が良いケースもあるのではないかと以前から思っていました。

回帰は1を基準にする事のほうが多いと思うので、以下のように1を中心にKerasでカスタム損失関数を実装してみましたが、ones_likeをzeros_likeにすればより符号っぽさが出ます。
def samse(y_true, y_pred):  # sign-adjusted mse
    ones = K.ones_like(y_true)
    p1 = K.cast(K.greater(y_true, ones), K.floatx())
    n1 = K.cast(K.less_equal(y_true, ones), K.floatx())
    p2 = K.cast(K.greater(y_pred, ones), K.floatx())
    n2 = K.cast(K.less_equal(y_pred, ones), K.floatx())
    rate = K.square(y_true - y_pred)
    pp = K.mean(p1 * p2 * rate, axis=-1)
    pn = K.mean(p1 * n2 * rate, axis=-1) * 2
    pn = K.mean(n1 * p2 * rate, axis=-1) * 2
    nn = K.mean(n1 * n2 * rate, axis=-1)
    return pp + pn + pn + nn

重みをどうするかとか、分布が少し歪むとかの問題もあるので一概に使い勝手が良いかはわかりません。 ただpp, nnのように予測が正しい場合には損失を計算しないようにするとか、部分的に絶対誤差にするとか、細かな改良はしやすい関数だと思います。 ちなみに上記の *2 の部分をなくせば普通のMSEと一緒。

まあ損失関数を作った事がないので遊びで作ってみただけですが、軽く確認してみた感じではMSEより正負を気持ち考慮できてるかなあという感じ。 計算上はMSEより損失が大きくなるはずですが、MSEより損失が小さくなるケースが多い。

0 件のコメント: