2017-12-28 6 views
0

Ich versuche eine Gradient-Methode für meine benutzerdefinierte TF-Operation zu definieren. Die meisten Lösungen, die ich online gefunden habe, scheinen auf einem gist von harpone basieren. Ich zögere, diesen Ansatz zu verwenden, da es py_func verwendet, die auf GPU nicht ausgeführt werden. Ich fand eine andere Lösung here, die tf.identity() verwendet, die eleganter aussieht und ich denke, dass auf GPU ausgeführt wird. Allerdings habe ich Probleme beim Zugriff auf die Ops in meiner benutzerdefinierten Gradientenfunktion. Hier ist mein Code:Verwenden von Op-Eingängen beim Definieren benutzerdefinierter Verläufe in TensorFlow

@tf.RegisterGradient('MyCustomGradient') 
def _custom_gradient(op, gradients): 
    x = op.inputs[0] 
    return(x) 

def my_op(w): 
    return tf.pow(w,3) 


var_foo = tf.Variable(5, dtype=tf.float32) 
bar = my_op(var_foo) 


g = tf.get_default_graph() 
with g.gradient_override_map({'Identity': 'MyCustomGradient'}): 
    bar = tf.identity(bar) 
    g = tf.gradients(bar, var_foo) 

with tf.Session() as sess: 

    sess.run(tf.global_variables_initializer()) 
    print(sess.run(g)) 

ich _custom_gradient() erwartet den Eingang des op (5 in diesem Beispiel) zurück, sondern scheint es op output x gradient zurückzukehren. Meine benutzerdefinierte my_op wird nicht unterscheidbare Operationen wie tf.sign haben und ich möchte meinen benutzerdefinierten Gradienten basierend auf den Eingaben definieren. Was mache ich falsch?

+0

Ich denke, was passiert ist, dass der benutzerdefinierte Farbverlauf an die 'identity()' op und nicht die 'my_op()' Funktion, wie ich gehofft hatte, angehängt ist. – Milad

Antwort

2

Es gibt kein Problem mit Ihrem Code ist:

erste Lassen Sie uns tun die Vorwärts-Pass:

var_foo = 5 ->bar = 125 ->tf.identity(bar) = 125

Jetzt backpropagate lasst uns:

Der Gradient tf.identity(bar) in Bezug auf sein Argument bar entspricht (nach Ihrer Definition) zu bar, das heißt 125. Der Gradient von bar in Bezug auf var_foo entspricht 3 mal dem Quadrat von var_foo, was 75 ist. Multiplizieren Sie und erhalten Sie 9375, die tatsächlich die Ausgabe Ihres Codes ist.

op.inputs[0] enthält den Durchlasswert des op. In diesem Fall ist der Vorwärtsdurchlauf von identity op 125.

+0

Danke das macht total Sinn. Ich hatte gehofft, einen benutzerdefinierten Farbverlauf für die 'my_op'-Funktion zu definieren, aber stattdessen habe ich es für die Identität definiert. Ich denke, ich muss auf py_func zurückgreifen. – Milad

Verwandte Themen