2016-10-27 2 views
6

Ich versuche, mit TensorFlow in Python zu beginnen, ein einfaches Feed-Forward-NN aufzubauen. Ich habe eine Klasse, die die Netzwerkgewichte enthält (Variablen, die während des Zuges aktualisiert werden und zur Laufzeit konstant bleiben sollen) und ein anderes Skript zum Trainieren des Netzwerks, das die Trainingsdaten erhält, sie in Stapel trennt und das Netzwerk in Stapeln trainiert . Wenn ich versuche, um das Netzwerk zu trainieren, ich erhalte eine Fehlermeldung anzeigt, dass der Daten-Tensor nicht in der gleichen Grafik wie die NN Tensoren ist:TensorFlow: So stellen Sie sicher, Tensoren sind in der gleichen Grafik

ValueError: Tensor("Placeholder:0", shape=(10, 5), dtype=float32) must be from the same graph as Tensor("windows/embedding/Cast:0", shape=(100232, 50), dtype=float32).

die relevanten Teile im Training Skript ist:

def placeholder_inputs(batch_size, ner): 
    windows_placeholder = tf.placeholder(tf.float32, shape=(batch_size, ner.windowsize)) 
    labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) 
    return windows_placeholder, labels_placeholder 

with tf.Session() as sess: 
    windows_placeholder, labels_placeholder = placeholder_inputs(batch_size, ner) 
    logits = ner.inference(windows_placeholder) 

Und die relevant im Netz Klasse sind:

class WindowNER(object): 
def __init__(self, wv, windowsize=3, dims=[None, 100,5], reg=0.01): 
    self.reg=reg 
    self.windowsize=windowsize 
    self.vocab_size = wv.shape[0] 
    self.embedding_dim = wv.shape[1] 
    with tf.name_scope("embedding"): 
     self.L = tf.cast(tf.Variable(wv, trainable=True, name="L"), tf.float32) 
    with tf.name_scope('hidden1'): 
     self.W = tf.Variable(tf.truncated_normal([windowsize * self.embedding_dim, dims[1]], 
      stddev=1.0/math.sqrt(float(windowsize*self.embedding_dim))), 
     name='weights') 
     self.b1 = tf.Variable(tf.zeros([dims[1]]), name='biases') 
    with tf.name_scope('output'): 
     self.U = tf.Variable(tf.truncated_normal([dims[1], dims[2]], stddev = 1.0/math.sqrt(float(dims[1]))), name='weights') 
     self.b2 = tf.Variable(tf.zeros(dims[2], name='biases')) 


def inference(self, windows): 
    with tf.name_scope("embedding"): 
     embedded_words = tf.reshape(tf.nn.embedding_lookup(self.L, windows), [windows.get_shape()[0], self.windowsize * self.embedding_dim]) 
    with tf.name_scope("hidden1"): 
     h = tf.nn.tanh(tf.matmul(embedded_words, self.W) + self.b1) 
    with tf.name_scope('output'): 
     t = tf.matmul(h, self.U) + self.b2 

Warum gibt es zwei grafische Darstellungen in erster Linie, und wie kann ich sicherstellen, dass die Datenplatzhaltertensoren in der gleichen Grafik wie die NN sind?

Danke !!

Antwort

5

Sie sollten in der Lage sein, alle Tensoren unter der gleichen Grafik zu erstellen, indem Sie so etwas wie dies zu tun:

g = tf.Graph() 
with g.as_default(): 
    windows_placeholder, labels_placeholder = placeholder_inputs(batch_size, ner) 
    logits = ner.inference(windows_placeholder) 

with tf.Session(graph=g) as sess: 
    # Run a session etc 

Sie können hier mehr über Graphen in TF lesen: https://www.tensorflow.org/versions/r0.8/api_docs/python/framework.html#Graph

+0

Vielen Dank für die schnelle Antwort! Allerdings habe ich diese Änderung vorgenommen (die Sitzung kommentieren, bis ich das Diagramm korrekt erstellt habe) und ich bekomme immer noch den gleichen Fehler - "Tensor (...) muss aus dem gleichen Graphen wie Tensor (...) stammen". – user616254

+0

Es ist schwer zu sagen, ohne in der Lage zu sein, den vollständigen Code zu sehen. Aber es ist wahrscheinlich, dass Sie entweder Code haben, der Operatoren außerhalb des Geltungsbereichs 'with g.as_default()' konstruiert, oder ein Code, den Sie aufrufen, erstellt ein eigenes Diagramm. Könnten Sie mehr vom Code zeigen? (Um ehrlich zu sein, würde ich als nächstes versuchen, den Tensorflow-Python-Code zu instrumentieren, der Operatoren erstellt und die Identität des Graphen ausgibt, dem jeder Operator hinzugefügt wird.) –

0

Manchmal, wenn man eine bekommen Fehler wie dieser, der Fehler (der oft Verwendung einer falschen Variablen aus einem anderen Graphen sein kann) könnte viel früher geschehen sein, und hat sich auf die Operation verbreitet, die schließlich einen Fehler warf. Daher könnten Sie nur diese Linie untersuchen und schlussfolgern, dass die Tensoren aus demselben Graphen stammen sollten, während der Fehler tatsächlich woanders liegt.

Der einfachste Weg zu überprüfen ist, auszudrucken, welcher Graph für jede Variable/Op in der Grafik verwendet wird. Sie können dies einfach tun, indem Sie:

print(variable_name.graph) 
Verwandte Themen