2017-01-14 2 views
0

Ich möchte eine benutzerdefinierte Ebene schreiben, wo ich eine Variable im Speicher zwischen Läufen behalten kann. Zum BeispielPersistente Variable in Keras Benutzerdefinierte Ebene

class MyLayer(Layer): 
def __init__(self, out_dim = 51, **kwargs): 
    self.out_dim = out_dim 
    super(MyLayer, self).__init__(**kwargs) 

def build(self, input_shape): 
    a = 0.0 
    self.persistent_variable = K.variable(a) 
    self.built = True 

def get_output_shape_for(self, input_shape): 
    return (input_shape[0], 1) 

def call(self, x, mask=None): 
    a = K.eval(self.persistent_variable) + 1 
    K.set_value(self.persistent_variable, a) 
    return self.persistent_variable 

m = Sequential() 
m.add(MyLayer(input_shape=(1,))) 

Wenn ich laufe m.predict erwarte ich, dass die persistent_variable aktualisiert werden, und der erhöhte Wert drucken. Aber es sieht aus wie es immer druckt 0

# Dummy input 
x = np.zeros(1) 

m.predict(x, batch_size=1) 

Meine Frage ist, wie kann ich das persistent_variable Schritt und speichern nach jedem Lauf von m.predict

Danke, Naveen

Antwort

1

Der Trick ist, dass Sie müssen self.add_update(...) in Ihrer Anruffunktion anrufen, um eine Funktion zu registrieren, die jedes Mal aufgerufen wird, wenn Ihr Modell evaluiert wird (ich fand dies, indem ich in den Quellcode der stateful rnns eingrabe). Wenn Sie self.stateful = True tun, wird es Ihre benutzerdefinierte Update-Funktion für jeden Trainings- und Vorhersageanruf aufrufen, andernfalls wird es nur während des Trainings aufgerufen. Zum Beispiel:

import keras.backend as K 
import numpy as np 
from keras.engine.topology import Layer 

class CounterLayer(Layer): 
    def __init__(self, stateful=False,**kwargs): 
    self.stateful = stateful # True means it will increment counter on predict and train, false means it will only increment counter on train 
    super(CounterLayer, self).__init__(**kwargs) 


    def build(self, input_shape): 
    # Define variables in build 
    self.count = K.variable(0, name="count") 
    super(CounterLayer, self).build(input_shape) 

    def call(self, x, mask=None): 
    updates = [] 
    # The format is (variable, value setting to) 
    # So this says 
    # self.pos = self.pos + 1 
    updates.append((self.count, self.count+1)) 

    # You can append more updates to this list or call add_update more 
    # times if you want 

    # Add our custom update 

    # We stick x here so it calls our update function every time our layer 
    # is given a new x 
    self.add_update(updates, x) 

    # This will be an identity layer but keras gets mad for some reason 
    # if you just output x so we'll multiply it by 1 so it thinks it is a 
    # "new variable" 
    return self.count 
    # in newer keras versions you might need to name this compute_output_shape instead 
    def get_output_shape_for(self, input_shape): 
    # We will just return our count as an array ([[count]]) 
    return (1,1) 

    def reset_states(self): 
    self.count.set_value(0) 

Beispiel Nutzung:

from keras.layers import Input 
from keras.models import Model 
from keras.optimizers import RMSprop 
inputLayer = Input(shape=(10,)) 
counter = CounterLayer() # Don't update on predict 
# counter = CounterLayer(stateful=True) # This will update each time you call predict 
counterLayer = counter(inputLayer) 
model = Model(input=inputLayer, output=counterLayer) 
optimizer = RMSprop(lr=0.001) 
model.compile(loss="mse", optimizer=optimizer) 


# See the value of our counter 
print counter.count.get_value() 

# This won't actually train anything but each epoch will update our counter 

# Note that if you say have a batch size of 5, update will be called 5 times per epoch 
model.fit(np.zeros([1, 10]), np.array([0]), batch_size=1, nb_epoch=5) 

# The value of our counter has now changed 
print counter.count.get_value() 

model.predict(np.zeros([1, 10])) 

# If we did stateful=False, this didn't change, otherwise it did 
print counter.count.get_value() 
+0

Hallo Phylliida, Sieht wie die richtige Lösung. Aber es funktioniert manchmal nicht. Ich lief 'a = model.predict (np.random.rand (100, 10), batch_size = 1) drucken (a)' '[0. 1. 2. 3. 5. 6. 6. 7 9. 10. 10. 11. ....] ' Manchmal fehlt es Update. –

+0

Huh, es könnte eine Art Race Condition sein. Ich weiß es wirklich nicht leid, wir können warten, um zu sehen, ob jemand anderes weiß – Phylliida

+1

Sie haben Recht. In Keras könnte es eine Wettlaufsituation geben. Ich habe eine 'RepeatVector' Schicht nach' CounterLayer' hinzugefügt, und es hat funktioniert. –

Verwandte Themen