2017-06-06 5 views
-1

Ich habe drei Arrays, X, Y und Z. Ich möchte res und Element von X setzen, falls das entsprechende Element von Z wahr ist; Ansonsten werde ich ein Element aus Y setzen.wo() von 1 bis 2 Positionsargumente nimmt, aber 3 wurden gegeben

ich es wie folgt umgesetzt:

X = tf.constant([[1, 2], [3, 4]]) 
Y = tf.constant([[5, 6], [7, 8]]) 
Z = tf.constant([[True, False], [False, True]], tf.bool) 
res = tf.where(Z, X, Y) 
print(res.eval()) 

Allerdings habe ich diese Störung erhalte:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given 

ich an der Definiton von tf.where sah von here und meine Nutzung scheint in Ordnung.

Jede Idee, was könnte das Problem sein?

+0

können Sie versuchen, 'tf.where (Z, x = X, y = Y)' – pramod

+0

Ihr Code funktioniert gut mit TensorFlow 1.0.1, also bin ich neugierig: die TF Version verwendest du? – npf

Antwort

1

Ich vermute, dass Sie eine alte Version von TensorFlow verwenden:

z.B. in r0.10 tf.where verwendet, um nur 2 Argumente zu nehmen.

tf.where(input, name=None)

https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops/sequence_comparison_and_indexing#where

+0

Ich benutze '0.8.0', wahrscheinlich weil ich es mit' pip' installiert habe. – octavian

+0

Das macht dann Sinn. Sie sollten wahrscheinlich die neueste Version installieren: https://www.tensorflow.org/install/ – npf

Verwandte Themen