2017-10-24 1 views
0

Das neurale Netzwerk, an dem ich gerade arbeite, akzeptiert einen spärlichen Tensor als Eingabe. Ich meine Daten von einem TFRecord wie folgt lautet:Äquivalent von tf.SparseFeature in tf.data

_, examples = tf.TFRecordReader(options=options).read_up_to(
    filename_queue, num_records=batch_size) 

features = tf.parse_example(examples, features={ 
      'input_feat': tf.SparseFeature(index_key='input_feat_idx', 
             value_key='input_feat_values', 
             dtype=tf.int64, 
             size=SIZE_FEATURE)}) 

Es funktioniert wie ein Charme, aber ich war auf der Suche auf der tf.data API, die für eine Vielzahl von Aufgaben bequemer aussieht, und ich bin nicht sicher, wie tf.SparseTensor Objekte lesen wie ich mit dem und tf.parse_example(). Irgendeine Idee?

Antwort

3

TensorFlow 1.5 fügt native Unterstützung für tf.SparseTensor in den Kerntransformationen hinzu. (Dies ist derzeit verfügbar, wenn Sie pip install tf-nightly oder von der Quelle auf dem Master-Zweig von TensorFlow erstellen.) Dies bedeutet, dass Sie Ihre Pipeline wie folgt schreiben können: