2016-05-05 4 views
3

Ich versuche conditionals mit tensorflow zu verwenden, und ich den Fehler bekommen:Tensorflow bedingte Wurfwertfehler

ValueError: Shapes (1,) and() are not compatible 

Unterhalb der Code verwende ich, dass der Fehler zu werfen. Es sagt, der Fehler in der bedingten

import tensorflow as tf 
import numpy as np 

X = tf.constant([1, 0]) 
Y = tf.constant([0, 1]) 
BOTH = tf.constant([1, 1]) 
WORKING = tf.constant(1) 

def create_mult_func(tf, amount, list): 
    def f1(): 
     return tf.scalar_mul(amount, list) 
    return f1 

def create_no_op_func(tensor): 
    def f1(): 
     return tensor 
    return f1 

def stretch(tf, points, dim, amount): 
    """points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor""" 
    x_list, y_list = tf.split(0, 2, points) 
    x_stretch, y_stretch = tf.split(1, 2, dim) 
    is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x") 
    is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y") 
    x_list_stretched = tf.cond(is_stretch_X, 
           create_mult_func(tf, amount, x_list), create_no_op_func(x_list)) 
    y_list_stretched = tf.cond(is_stretch_Y, 
           create_mult_func(tf, amount, y_list), create_no_op_func(y_list)) 
    return tf.concat(1, [x_list_stretched, y_list_stretched]) 

example_points = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.float32) 
example_point_list = tf.placeholder(tf.float32) 

result = stretch(tf, example_point_list, X, 1) 
sess = tf.Session() 

with tf.Session() as sess: 
    result = sess.run(result, feed_dict={example_point_list: example_points}) 
    print(result) 

Stapelüberwachung:

File "/path/test2.py", line 36, in <module> 
    result = stretch(tf, example_point_list, X, 1) 
    File "/path/test2.py", line 28, in stretch 
    create_mult_func(tf, amount, x_list), create_no_op_func(x_list)) 
    File "/path/tensorflow/python/ops/control_flow_ops.py", line 1142, in cond 
    p_2, p_1 = switch(pred, pred) 
    File "/path/tensorflow/python/ops/control_flow_ops.py", line 203, in switch 
    return gen_control_flow_ops._switch(data, pred, name=name) 
    File "/path/tensorflow/python/ops/gen_control_flow_ops.py", line 297, in _switch 
    return _op_def_lib.apply_op("Switch", data=data, pred=pred, name=name) 
    File "/path/tensorflow/python/ops/op_def_library.py", line 655, in apply_op 
    op_def=op_def) 
    File "/path/tensorflow/python/framework/ops.py", line 2156, in create_op 
    set_shapes_for_outputs(ret) 
    File "/path/tensorflow/python/framework/ops.py", line 1612, in set_shapes_for_outputs 
    shapes = shape_func(op) 
    File "/path/tensorflow/python/ops/control_flow_ops.py", line 2032, in _SwitchShape 
    unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) 
    File "/path/tensorflow/python/framework/tensor_shape.py", line 554, in merge_with 
    (self, other)) 
ValueError: Shapes (1,) and() are not compatible 

ich versucht habe, die WORKING Wechsel ein Array statt eines skalaren zu sein.

Ich glaube, dass das Problem ist, dass tf.equal eine int32 anstelle des Bool zurückgibt, die es soll entsprechend der Dokumentation zu tf.cond

+0

nicht der Fehler nicht in dem bedingten in der Form der Tensor ist, dass Sie versuchen Zum Vergleich: 'x_stretch' hat nicht die gleiche Form von' WORKING'. Scheint x_stretch ist eine Dimension größer als 'WORKING'. Was ist der Inhalt von "Punkten"? Wenn Sie einen ausführbaren Code bereitstellen, kann ich Ihnen mehr helfen. – fabrizioM

+0

Ich machte ein komplettes ausführbares Beispiel, das Sie einfügen können und sehen, was scheitert – dtracers

Antwort

7

Das Problem ist im ersten Argument zurück. Aus der Dokumentation here, über die Art des ersten Arguments zu tf.cond:

pred: A scalar determining whether to return the result of fn1 or fn2. 

Beachten Sie, dass es sich um einen Skalar sein muss. Sie verwenden das Ergebnis des Vergleichens eines Tensors und eines Tensors, was Ihnen einen Skalar gibt, der einen (1,)Tensor, NOT gibt. Sie können es auf einen Skalar wandeln den tf.reshape Operator wie folgt:

t = tf.equal(x_stretch, WORKING, name="is_stretch_x") 
x_list_stretched = tf.cond(tf.reshape(t, []), 
          create_mult_func(tf, amount, x_list), create_no_op_func(x_list)) 

komplettes Arbeitsprogramm:

import tensorflow as tf 
import numpy as np 

X = tf.constant([1, 0]) 
Y = tf.constant([0, 1]) 
BOTH = tf.constant([1, 1]) 
WORKING = tf.constant(1) 

def create_mult_func(tf, amount, list): 
    def f1(): 
     return tf.scalar_mul(amount, list) 
    return f1 

def create_no_op_func(tensor): 
    def f1(): 
     return tensor 
    return f1 

def stretch(tf, points, dim, amount): 
    """points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor""" 
    x_list, y_list = tf.split(0, 2, points) 
    x_stretch, y_stretch = tf.split(0, 2, dim) 
    is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x") 
    is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y") 
    x_list_stretched = tf.cond(tf.reshape(is_stretch_X, []), 
           create_mult_func(tf, amount, x_list), create_no_op_func(x_list)) 
    y_list_stretched = tf.cond(tf.reshape(is_stretch_Y, []), 
           create_mult_func(tf, amount, y_list), create_no_op_func(y_list)) 
    return tf.pack([x_list_stretched, y_list_stretched]) 

example_points = np.array([[1, 1], [2, 2]], dtype=np.float32) 
example_point_list = tf.placeholder(tf.float32) 

result = stretch(tf, example_point_list, X, 1) 
sess = tf.Session() 

with tf.Session() as sess: 
    result = sess.run(result, feed_dict={example_point_list: example_points}) 
    print(result) 
+0

Dies funktioniert nicht. Wirft einen Fehler: "Verwenden eines' tf.Tensor' als Python 'bool' ist nicht erlaubt. Auch ich machte einen vollständigen ausführbaren Code Beispiel. Auch ihre Beispielcode verwendet einen Tensor als erstes Argument als eine Bedingung: https://www.tensorflow.org/versions/r0.8/api_docs/python/control_flow_ops.html#cond – dtracers

+0

Dies könnte einen Fehler in ihrem Code bedeuten? – dtracers

+0

Nicht wirklich. Die Eingabe in 'tf.pack' ist eine Liste. Versuchen Sie 'return tf.pack ([x_list_stretched, y_list_stretched])' – keveman