Ich versuche conditionals mit tensorflow zu verwenden, und ich den Fehler bekommen:Tensorflow bedingte Wurfwertfehler
ValueError: Shapes (1,) and() are not compatible
Unterhalb der Code verwende ich, dass der Fehler zu werfen. Es sagt, der Fehler in der bedingten
import tensorflow as tf
import numpy as np
X = tf.constant([1, 0])
Y = tf.constant([0, 1])
BOTH = tf.constant([1, 1])
WORKING = tf.constant(1)
def create_mult_func(tf, amount, list):
def f1():
return tf.scalar_mul(amount, list)
return f1
def create_no_op_func(tensor):
def f1():
return tensor
return f1
def stretch(tf, points, dim, amount):
"""points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor"""
x_list, y_list = tf.split(0, 2, points)
x_stretch, y_stretch = tf.split(1, 2, dim)
is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x")
is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y")
x_list_stretched = tf.cond(is_stretch_X,
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
y_list_stretched = tf.cond(is_stretch_Y,
create_mult_func(tf, amount, y_list), create_no_op_func(y_list))
return tf.concat(1, [x_list_stretched, y_list_stretched])
example_points = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)
result = stretch(tf, example_point_list, X, 1)
sess = tf.Session()
with tf.Session() as sess:
result = sess.run(result, feed_dict={example_point_list: example_points})
print(result)
Stapelüberwachung:
File "/path/test2.py", line 36, in <module>
result = stretch(tf, example_point_list, X, 1)
File "/path/test2.py", line 28, in stretch
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
File "/path/tensorflow/python/ops/control_flow_ops.py", line 1142, in cond
p_2, p_1 = switch(pred, pred)
File "/path/tensorflow/python/ops/control_flow_ops.py", line 203, in switch
return gen_control_flow_ops._switch(data, pred, name=name)
File "/path/tensorflow/python/ops/gen_control_flow_ops.py", line 297, in _switch
return _op_def_lib.apply_op("Switch", data=data, pred=pred, name=name)
File "/path/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
op_def=op_def)
File "/path/tensorflow/python/framework/ops.py", line 2156, in create_op
set_shapes_for_outputs(ret)
File "/path/tensorflow/python/framework/ops.py", line 1612, in set_shapes_for_outputs
shapes = shape_func(op)
File "/path/tensorflow/python/ops/control_flow_ops.py", line 2032, in _SwitchShape
unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
File "/path/tensorflow/python/framework/tensor_shape.py", line 554, in merge_with
(self, other))
ValueError: Shapes (1,) and() are not compatible
ich versucht habe, die WORKING Wechsel ein Array statt eines skalaren zu sein.
Ich glaube, dass das Problem ist, dass tf.equal
eine int32
anstelle des Bool zurückgibt, die es soll entsprechend der Dokumentation zu tf.cond
nicht der Fehler nicht in dem bedingten in der Form der Tensor ist, dass Sie versuchen Zum Vergleich: 'x_stretch' hat nicht die gleiche Form von' WORKING'. Scheint x_stretch ist eine Dimension größer als 'WORKING'. Was ist der Inhalt von "Punkten"? Wenn Sie einen ausführbaren Code bereitstellen, kann ich Ihnen mehr helfen. – fabrizioM
Ich machte ein komplettes ausführbares Beispiel, das Sie einfügen können und sehen, was scheitert – dtracers