2017-02-03 11 views
1

Ich habe eine (n, m) Tensor X, wo ich alle Werte kleiner als ein Schwellenwert t Null setzen möchte. Das heißt,Entfernen Sie kleine Werte aus der Matrix

X = X * tf.cast(tf.greater(X, t), X.dtype) 

Ich frage mich, ist es eine effizientere Art und Weise, dies zu tun? Weil X in meinem Setup ist riesig und wie ich es verstehe, baut die tf.cast(tf.greater(X, t), X.dtype) einen anderen Tensor, der so viel Speicher wie X benötigt.

+0

Haben Sie versucht, die 'eval' Methode in Tensorflow, um Ihren Tensor in ein numpy Array zu transformieren und dann eine der folgenden Antworten zu verwenden? –

+0

Ich bin hier nur brainstorming, "tf.add" und "tfsubtract" unterstützen Broadcasting, also sollten sie speichereffizient sein. Kannst du vielleicht versuchen, 't' zu subtrahieren, dann' tf.clip_by_value (...) 'und dann' t' zurück? Intuitiv führt es zu mehr Operationen, aber weniger Speicherverbrauch? –

+0

@RobertLacok Danke für die Eingabe, aber das ist nicht das Gleiche, denke ich. durch Hinzufügen von "t" ersetzen Sie Nullen durch "t" – fabian789

Antwort

0

Was mit der guten alten

for i in range(n): 
    for j in range(m): 
     if X[n][m] < t: X[n][m] = 0 
+1

Es ist in Tensorflow – fabian789

0
falsch ist

Wenn X Ihre Matrix (a numpy Array Ich gehe davon aus) Sie können versuchen:

x[x<small_value]=0 

wenn der boolean-Array erstellen Speicher zu viel nimmt Sie können dies durch eine Schleife nach einzelnen Spalten versuchen.

+1

Es ist in Tensorflow. Hätte das aus der Frage deutlicher machen sollen, aber der Tag war da :) – fabian789

+0

Oh Entschuldigung, keine Ahnung, ich habe sie nie benutzt – Numlet

1

Ich bin nicht sicher, ob dies effiziente

x = tf.constant([1, 2, 3, 4, 5, 6, 7]) 
y = tf.where(tf.greater(x, tf.constant(5)), 
      x, # if ture 
      tf.zeros_like(x)) # if false 

with tf.Session() as sess: 
    a = sess.run(y) 
    # a is [0, 0, 0, 0, 0, 6, 7] 
0
foo = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) 
threshold_map = tf.greater(foo, tf.constant(5.)) 
threshold_map_index = tf.reshape(tf.where(threshold_map), [-1]) 
foo_threshold = tf.gather(foo, threshold_map_index) 
# foo_threshold = [6., 7., 8., 9., 10.] 

(dies wird nicht Arbeit mit mehr als einer Dimension)

Verwandte Themen