2016-08-26 3 views
2

Ich lief das Demo-Programm von Word2Vec, die in TensorFlow enthalten ist, und jetzt versucht, das vortrainierte Modell aus Dateien wiederherzustellen, aber es funktioniert nicht.Problem mit der Wiederherstellung vortrained Modell in Tensorflow

lief ich dieses Skript-Datei: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py

Dann diese Datei Ich habe versucht, laufen:

#!/usr/bin/env python 

import tensorflow as tf 

FILENAME_META = "model.ckpt-70707299.meta" 
FILENAME_CHECKPOINT = "model.ckpt-70707299" 


def main(): 
    with tf.Session() as sess: 
     saver = tf.train.import_meta_graph(FILENAME_META) 
     saver.restore(sess, FILENAME_CHECKPOINT) 


if __name__ == '__main__': 
    main() 

Es ist mit der folgenden Fehlermeldung

Traceback (most recent call last): 
    File "word2vec_restore.py", line 16, in <module> 
    main() 
    File "word2vec_restore.py", line 11, in main 
    saver = tf.train.import_meta_graph(FILENAME_META) 
    File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph 
    return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file)) 
    File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def 
    producer_op_list=producer_op_list) 
    File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def 
    op_def = op_dict[node.op] 
KeyError: 'Skipgram' 

Ich betrachte ich habe verstanden, versagt das API-Dokument von TensorFlow, und ich habe den obigen Code so implementiert, wie er darin geschrieben ist. Verwende ich das Saver-Objekt falsch?

Antwort

0

Versuchen Sie Folgendes:

saver = tf.train.Saver() 
with tf.Session() as sess: 
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 
    if checkpoint and checkpoint.model_checkpoint_path: 
     saver.restore(sess, checkpoint.model_checkpoint_path) 

Wo checkpoint_dir Pfad, enthält Checkpoint-Dateien, nicht vollständigen Pfad zum meta oder Checkpoint-Dateien in Ordner ist. Tensorflow wählt den aktuellsten Checkpoint selbst aus dem angegebenen Ordner aus.

+0

Es schlägt mit "ValueError: Keine Variablen zu speichern" in der ersten Zeile fehl. –

2

Ich löste das alleine. Ich habe mich gefragt, woher der Schlüssel "Skipgram" kommt, und den Quellcode gegraben. Um das Problem zu lösen, fügen Sie einfach die folgenden auf der Oberseite:

from tensorflow.models.embedding import gen_word2vec 

Ich verstehe immer noch nicht genau, was ich tue, aber vielleicht ist dies, weil es notwendig ist, eine ähnliche Bibliothek in C geschrieben laden ++.

Danke.

Verwandte Themen