2017-02-17 5 views
2

Ein Kollege von mir wies auf die sehr coole Option sample_weight anstelle einer Maskierungsschicht zu verwenden, wenn Sie Eingabe in eine Nicht-RNN in Keras maskieren müssen.Keras: Maskierung zero-padded Eingang für Nicht-RNN

In meinem Fall habe ich 62 Spalten in der Eingabe, mit der 63. ist die Antwort. Über 97% der Nicht-Null-Einträge in den 62 Spalten sind in den ersten 30 Spalten enthalten. Ich versuche nur, dass das funktioniert, also möchte ich die letzten 32 Spalten im Training auf 0 belasten, was im Wesentlichen eine "Maske des armen Mannes" erzeugt.

Dies ist eine Klassifizierungsaufgabe mit 8 Klassen, die ein MLP verwendet. Die Antwortvariable wurde unter Verwendung der to_categorical()-Funktion in Keras transformiert.

Hier ist die Umsetzung:

model = Sequential() 
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu')) 
model.add(Dense(8, init='uniform', activation='sigmoid')) 
hist = model.fit(X, y, 
       validation_data=(X_test, ytest), 
       nb_epoch=epochs_, 
       batch_size=batch_size_, 
       callbacks=callbacks_list, 
       sample_weight = np.array([X.shape[1]-32, 30])) 

ich diesen Fehler:

in standardize_weights 
assert y.shape[:sample_weight.ndim] == sample_weight.shape 

Wie kann ich meine sample_weight zu 'Maske' fixieren die ersten 32 Spalten der Eingang?

Antwort

2

Probengewicht ist nicht wie das funktioniert:

sample_weight : optional array of the same length as x , containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length) , to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile() . source

Mit anderen Worten, legt diese Einstellung unterschiedliche Gewichte an den Proben der Trainingsdaten, nicht auf die Merkmale jeder Probe. Dies wird nur im Trainingsschritt verwendet. Ich denke, Sie sollten Maskierung verwenden, wenn Sie nicht möchten, dass die Ebene diese Features verwendet. Oder entfernen Sie sie einfach aus Ihrem Dataset? Oder, wenn es nicht zu kompliziert ist, lassen Sie das Netzwerk selbst lernen, welche nützlichen Funktionen es gibt.

Hilft das?