2017-08-07 1 views
1

Gegeben ein paar symbolische Variablen zum Abrufen, ich muss wissen, welche Platzhalter Abhängigkeit sind.So ermitteln Sie die Platzhalterabhängigkeit in TensorFlow

In Theano, die wir haben:

import theano as th 
import theano.tensor as T 

x, y, z = T.scalars('xyz') 
u, v = x*y, y*z 
w = u + v 

th.gof.graph.inputs([w]) # gives [x, y, z] 
th.gof.graph.inputs([u]) # gives [x, y] 
th.gof.graph.inputs([v]) # gives [y, z] 
th.gof.graph.inputs([u, v]) # gives [x, y, z] 

Wie das Gleiche in TensorFlow zu tun?

Antwort

1

Es gibt keine integrierte Funktion (die ich kenne), aber es ist einfach zu machen:

# Setup a graph 
import tensorflow as tf 
placeholder0 = tf.placeholder(tf.float32, []) 
placeholder1 = tf.placeholder(tf.float32, []) 
constant0 = tf.constant(2.0) 
sum0 = tf.add(placeholder0, constant0) 
sum1 = tf.add(placeholder1, sum0) 

# Function to get *all* dependencies of a tensor. 
def get_dependencies(tensor): 
    dependencies = set() 
    dependencies.update(tensor.op.inputs) 
    for sub_op in tensor.op.inputs: 
     dependencies.update(get_dependencies(sub_op)) 
    return dependencies 

print(get_dependencies(sum0)) 
print(get_dependencies(sum1)) 
# Filter on type to get placeholders. 
print([tensor for tensor in get_dependencies(sum0) if tensor.op.type == 'Placeholder']) 
print([tensor for tensor in get_dependencies(sum1) if tensor.op.type == 'Placeholder']) 

Natürlich können Sie den Platzhalter in die Funktion als auch Filterung werfen könnte.

+0

Diese Rekursion ist wahrscheinlich auf Fibonacci-ähnlichen Graphen ineffizient, aber in der Tat ein guter Startpunkt. – Kh40tiK

Verwandte Themen