2017-05-15 5 views
0

Ich versuche, eine einfache lineare Regression in Keras mit einer benutzerdefinierten Verlustfunktion zu implementieren. Ich berechne Chi2 unter der Annahme, dass der Fehler 1% des Funktionswerts ist. Ich füge 1% Gaußsches Rauschen zu einem linearen Modell hinzu. Wenn ich die mittlere quadratische Fehlerverlustfunktion ('mse') verwende, kann ich die Funktion custom_loss() als Metrik hinzufügen, und ich sehe, dass sie sehr nahe bei 1 konvergiert (chi2/ndf). Wenn ich den custom_loss() einfach als Verlustfunktion verwende, wie unten im Ausschnitt gezeigt, bewegt sich das Neuronengewicht überhaupt nicht.Keras benutzerdefinierte Verlust (Chi2) lineare Regression

Was mache ich falsch?

from keras.models import Sequential 
from keras.layers import Dense 
from keras import optimizers 
from keras import backend as K 
import numpy as np 

def custom_loss(y_true, y_pred): 
    return K.mean(K.square(y_pred-y_true)/K.clip(K.square(0.01*y_pred), K.epsilon(), None), axis=-1) 

def build_model(): 
    model = Sequential() 
    model.add(Dense(1, input_dim=1, activation='linear')) 

    sgd = optimizers.SGD(lr=0.01, momentum=0.01, nesterov=False) 
    model.compile(optimizer=sgd, loss=custom_loss, metrics=['mape', 'mse', custom_loss]) 
    return model 

if __name__ == '__main__': 

    model = build_model() 

    x_train = np.linspace(0.0, 1.0, 500) 
    y_train = np.array(map(lambda x: x + 0.01*x*np.random.randn(1), x_train)) 

    model.fit(x_train, y_train, shuffle=True, epochs=1000, batch_size=10) 

Antwort

1

Ich bin nicht sicher genau was Sie versuchen, mit diesem Verlust zu berechnen, aber hier ist das, was ich sehe, wenn ich in [-1, 1] y_pred plotten und y_true = 0,5:

Custom Loss

Es hat deutlich 2 Minima, eine sehr steile Steigung um 0 (undefiniert, wenn y_pred = 0) und niedrige Steigung woanders. Beachten Sie, dass es für das Netzwerk sehr einfach ist, hier den falschen "Arm" der Verlustfunktion zu verwenden.

vergleichen diesen Verlust gegen den quadratischen Fehler:

MSE

Ich würde Ihre Verlustfunktion überdenken: Was Sie auf das Netzwerk zu haben versuchen, hier lernen?

+0

Vielen Dank für Ihre Antwort. Ich versuche, das Netzwerk zu trainieren, um Punkte entlang der Linie (y = x) mit einer Genauigkeit von 1% vorherzusagen. Ich habe hier versucht, die Chi2-Funktion zu konstruieren, und "von Hand" hat die Unsicherheit auf 1% des wahren Wertes (des) gesetzt. Ich fügte der Linie auch ein Geräusch von 1% hinzu. Wenn das Netzwerk einen Fehler von 1% voraussagt, ergibt dieser custom_loss ungefähr 1. Der quadratische Fehler funktioniert in diesem Fall gut, aber in meinem echten Code werden die Werte der durch y dargestellten Funktion sehr klein. Das Netzwerk passt also nicht gut in diese Bereiche. Ich versuche, dieses Problem zu vermeiden. Vielen Dank – dmriser