2017-10-19 3 views
0

Ist es möglich, dies elegant zu tun?Schreiben und Lesen SparseTensor zu und von einer tfrecord Datei

Momentan kann ich nur daran denken, die Indizes (tf.int64), values ​​(tf.float32) und shape (tf.int64) des SparseTensors in 3 separaten Features zu speichern (die ersten beiden sind VarLenFeature) und die letzte ist FixedLenFeature). Das scheint wirklich beschwerlich zu sein.

Jede Beratung ist willkommen!

Update 1

unter Meine Antwort für den Aufbau einer Berechnung Graph nicht geeignet ist (b/c den Inhalt des spärlichen Tensor über sess.run werden müssen extrahiert(), die viel Zeit kostet, wenn genannt mit tf.deserialize_many_sparse wiederholt.)

Inspiriert von mrry's answer, ich denke, vielleicht können wir die Bytes von tf.serialize_sparse erzeugt erhalten so später können wir die SparseTensor erholen. Aber tf.serialize_sparse ist nicht in reinem Python implementiert (es ruft die externe Funktion SerializeSparse), was bedeutet, dass wir immer noch sess.run() verwenden müssen, um die Bytes zu erhalten. Wie kann ich eine reine Python-Version von SerializeSparse bekommen? Vielen Dank.

Antwort

1

Da Tensorflow derzeit nur drei Arten in tfrecord unterstützt: Float, Int64 und Bytes und ein SparseTensor hat in der Regel mehr als 1 Art, meine Lösung ist, die SparseTensor zu Bytes mit Pickle zu konvertieren.

Hier ist ein Beispielcode:

import tensorflow as tf 
import pickle 
import numpy as np 
from scipy.sparse import csr_matrix 

#---------------------------------# 
# Write to a tfrecord file 

# create two sparse matrices (simulate the values from .eval() of SparseTensor) 
a = csr_matrix(np.arange(12).reshape((4,3))) 
b = csr_matrix(np.random.rand(20).reshape((5,4))) 

# convert them to pickle bytes 
p_a = pickle.dumps(a) 
p_b = pickle.dumps(b) 

# put the bytes in context_list and feature_list 
## save p_a in context_lists 
context_lists = tf.train.Features(feature={ 
    'context_a': tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_a])) 
    }) 
## save p_b as a one element sequence in feature_lists 
p_b_features = [tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_b]))] 
feature_lists = tf.train.FeatureLists(feature_list={ 
    'features_b': tf.train.FeatureList(feature=p_b_features) 
    }) 

# create the SequenceExample 
SeqEx = tf.train.SequenceExample(
    context = context_lists, 
    feature_lists = feature_lists 
    ) 
SeqEx_serialized = SeqEx.SerializeToString() 

# write to a tfrecord file 
tf_FWN = 'test_pickle1.tfrecord' 
tf_writer1 = tf.python_io.TFRecordWriter(tf_FWN) 
tf_writer1.write(SeqEx_serialized) 
tf_writer1.close() 

#---------------------------------# 
# Read from the tfrecord file 

# first, define the parse function 
def _parse_SE_test_pickle1(in_example_proto): 
    context_features = { 
     'context_a': tf.FixedLenFeature([], dtype=tf.string) 
     } 
    sequence_features = { 
     'features_b': tf.FixedLenSequenceFeature([1], dtype=tf.string) 
     } 
    context, sequence = tf.parse_single_sequence_example(
     in_example_proto, 
     context_features=context_features, 
     sequence_features=sequence_features 
    ) 
    p_a_tf = context['context_a'] 
    p_b_tf = sequence['features_b'] 

    return tf.tuple([p_a_tf, p_b_tf]) 

# use the Dataset API to read 
dataset = tf.data.TFRecordDataset(tf_FWN) 
dataset = dataset.map(_parse_SE_test_pickle1) 
dataset = dataset.batch(1) 
iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next() 

sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer()) 
sess.run(iterator.initializer) 

[p_a_bat, p_b_bat] = sess.run(next_element) 

# 1st index refers to batch, 2nd and 3rd indices refers to the sequence position (only for b) 
rec_a = pickle.loads(p_a_bat[0]) 
rec_b = pickle.loads(p_b_bat[0][0][0]) 

# check whether the recovered the same as the original ones. 
assert((rec_a - a).nnz == 0) 
assert((rec_b - b).nnz == 0) 

# print the contents 
print("\n------ a -------") 
print(a.todense()) 
print("\n------ rec_a -------") 
print(rec_a.todense()) 
print("\n------ b -------") 
print(b.todense()) 
print("\n------ rec_b -------") 
print(rec_b.todense()) 

Hier ist, was ich habe:

------ a ------- 
[[ 0 1 2] 
[ 3 4 5] 
[ 6 7 8] 
[ 9 10 11]] 

------ rec_a ------- 
[[ 0 1 2] 
[ 3 4 5] 
[ 6 7 8] 
[ 9 10 11]] 

------ b ------- 
[[ 0.88612402 0.51438017 0.20077887 0.20969243] 
[ 0.41762425 0.47394715 0.35596051 0.96074408] 
[ 0.35491739 0.0761953 0.86217511 0.45796474] 
[ 0.81253723 0.57032448 0.94959189 0.10139615] 
[ 0.92177499 0.83519464 0.96679833 0.41397829]] 

------ rec_b ------- 
[[ 0.88612402 0.51438017 0.20077887 0.20969243] 
[ 0.41762425 0.47394715 0.35596051 0.96074408] 
[ 0.35491739 0.0761953 0.86217511 0.45796474] 
[ 0.81253723 0.57032448 0.94959189 0.10139615] 
[ 0.92177499 0.83519464 0.96679833 0.41397829]]