2017-02-09 2 views
1

Verwendung TensorFlow die HashTable-Lookup-Implementierung i meinen SparseTensor wieder mit dem Standardwert geliefert. Ich möchte, dass reinigen und eine endgültige SparseTensor ohne den Standardwert zu erhalten.Wie nur Nicht-Null-Werte von Sparse Tensor bekommen

Wie kann ich den Standardwert reinigen? Ich habe nichts dagegen, was der Standardwert, um sein wird, dies zu verwirklichen. 0 ist in Ordnung und so ist -1.

Antwort

0

tf.sparse_retain sollte funktionieren:

def sparse_remove(sparse_tensor, remove_value=0.): 
    return tf.sparse_retain(sparse_tensor, tf.not_equal(a.values, remove_value)) 

Als Beispiel:

import tensorflow as tf 

a = tf.SparseTensor(indices=[[1, 2], [2, 2]], values=[0., 1.], shape=[3, 3]) 
with tf.Session() as session: 
    print(session.run([a, sparse_remove(a)])) 

Drucke (ich es leicht umformatiert haben):

[SparseTensorValue(indices=array([[1, 2], [2, 2]]), values=array([ 0., 1.], dtype=float32), shape=array([3, 3])), 
SparseTensorValue(indices=array([[2, 2]]), values=array([ 1.], dtype=float32), shape=array([3, 3]))]