2017-04-17 4 views
2

Ich implementiere eigene keras Verlustfunktion. Wie kann ich auf Tensorwerte zugreifen?Debug Keras Tensor Werte

Was ich

def loss_fn(y_true, y_pred): 
    print y_true 

versucht haben, Sie druckt

Tensor("target:0", shape=(?, ?), dtype=float32) 

Gibt es irgendwelche keras y_true Werte zugreifen funktionieren?

Antwort

4

Normalerweise y_true Sie im Voraus wissen, - während der Vorbereitung des Zuges Corpora ...

Allerdings gibt es einen Trick, um die Werte innerhalb y_true und/oder y_pred zu sehen. Keras gibt Ihnen die Möglichkeit, entsprechende callback zum Drucken der neuronalen Netzwerkausgabe zu schreiben. Es wird wie folgt aussehen:

def loss_fn(y_true, y_pred): 
    return y_true # or y_pred 
... 
import keras.callbacks as cbks 
class CustomMetrics(cbks.Callback): 

    def on_epoch_end(self, epoch, logs=None): 
     for k in logs: 
      if k.endswith('loss_fn'): 
       print logs[k] 

Hier ist die loss_fn ist Name Ihrer Verlustfunktion, wenn Sie es in die model.compile(...,metrics=[loss_fn],) Funktion während Modells Kompilierung übergeben.

So, endlich, müssen Sie diese CustomMetrics Rückruf als Argument übergeben in die model.fit():

model.fit(x=train_X, y=train_Y, ... , callbacks=[CustomMetrics()]) 

PS: Wenn Sie Theano (oder TensorFlow) wie hier in Keras Sie ein Python-Programm zu schreiben, und dann kompilieren Sie es und führen es aus. Also, in Ihrem Beispiel y_true - ist nur eine Tensor-Variable, die für weitere Kompilierung und Verlustfunktionszählung verwendet wird.

Es bedeutet, dass es keine Möglichkeit gibt, die Werte darin zu sehen. In Theano zum Beispiel können Sie nach der Ausführung der entsprechenden eval() Funktion die einzige sogenannte shared Variable betrachten. Weitere Informationen finden Sie unter this question.

0

Die Werte der symbolischen Tensorvariablen können nicht direkt abgerufen werden. Sie müssen eine theano-Funktion schreiben, um den Wert zu extrahieren. Vergiss nicht, das Theano als Backend von Keras zu wählen.

Überprüfen Sie den Notebook-Link, um einige grundlegende derano Variablen und Funktionen zu erhalten: get tensor value in call function of own layers

Verwandte Themen