2016-07-24 17 views
1

Ich habe eine Klasse in Tensorflow, die Gewichte und Dokumenteinbettungen hat. Ich werde es für meine Ausbildung und Validierung verwenden. Meine Abfrage ist, dass es in Tensorflow-Sitzung für Validierung möglich ist, nur Gewichte aus meinem Training und nicht die Einbettungen wiederzuverwenden und lassen Sie neue Dokumenteneinbettungen für gültigen Satz lernen. CodeschnipselWie kann man nur einige Variablen im Tensorfluss wiederverwenden?

Class NewModel(Object): 
    def __init__(self, is_training, vocabuary_size, embedding_size): 
    self.X = tf.placeholder("float", [None, 300]) 
    self.doc_int = tf.placeholder(tf.int32, shape=[None]) 

    self.embeddings=tf.get_variable("embedding", [vocabulary_size ,embedding_size],initializer=tf.random_uniform_initializer(-0.1, 0.1)) 
    self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
    self.weights = tf.get_variable("weights",weight_shapeinitializer=tf.random_normal_initializer()) 
    biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
    # Some neural network with optimiser and loss that will train weight and embeddings.. 

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    with tf.variable_scope("foo", reuse=None, initializer=initializer): 
    train = NewModel(is_training=True, vocabulary_size=4000,\ 
    embedding_size =50) 
    with tf.variable_scope("foo", reuse=True, initializer=initializer): 
     valid = NewModel(is_training=False, vocabulary_size= 1000, embedding_size = 50) 
# Here is where I am confused. I want to use trained variable of weight but not embeddings and 
want new embeddings to be trained for valid set. 
    tf.initialize_all_variables().run() 
# will call some function to run epochs and stuff 

Vielleicht könnte die Verwendung anderer Scope-Namen hilfreich sein, würde aber dennoch einige Ratschläge zum weiteren Vorgehen benötigen. Oder ist es möglich, nur irgendwo zu erwähnen, welche Variablen wiederverwendet werden sollen.

Antwort

0

Ich würde vielleicht die NewModel-Klasse reorganisieren.

Class NewModel(Object): 
    def __init__(self, vocabuary_size, embedding_size, initializer): 
     self.X = tf.placeholder("float", [None, 300]) 
     self.doc_int = tf.placeholder(tf.int32, shape=[None]) 
     self.vocabuary_size = vocabuary_size 
     self.embedding_size = embedding_size 
     self.initializer = initializer 

    def initialize_embeddings(self): 
     with tf.variable_scope("embed",initializer=initializer) as scope: 
      self.embeddings=tf.get_variable("embedding", [self.vocabulary_size ,self.embedding_size],initializer=self.initializer) 
      self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
      scope.reuse_variable() 

    def initialize_weights(self, weight_shape, biase_shape, initializer=initializer): 
     with tf.variable_scope("weight", initializer=initializer) as scope: 
      self.weights = tf.get_variable("weights",weight_shapeinitializer=self.initializer) 
      biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
      scope.reuse_variable() 

    def train_network(self): 
     # Some neural network with optimiser and loss that will train weight and embeddings.. 

    def validate_network(self): 
     # A function for the validation process 

Auf diese Weise haben Sie die Einbettungsinitialisierung durch die Initialisierung der Gewichte und Vorspannungen getrennt. Die Verwendung dieser neuen Klasse wäre wie ...

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    model = NewModel(vocabulary_size=4000, embedding_size =50, initializer=initializer) # construct a model instance 
    model.initialize_weights(weight_shape, biase_shape) # initialize the weights and biases 
    model.initialize_embeddings() # initialize embeddings 
    model.train_network() # train the network 
    # Before start validation process, re-initialize embeddings 
    model.initialize_embeddings() 
    model.validate_network() 
Verwandte Themen