2017-01-24 30 views
2

Nach Tensorflow LSTM Regularization Ich versuche, Regulationization Term der Kostenfunktion hinzufügen, wenn Trainingsparameter von LSTM-Zellen.TensorFlow: Hinzufügen von Regularisierung zu LSTM

einige Konstanten Läßt man einmal beiseite ich habe:

def RegularizationCost(trainable_variables): 
    cost = 0 
    for v in trainable_variables: 
     cost += r(tf.reduce_sum(tf.pow(r(v.name),2))) 
    return cost 

... 

regularization_cost = tf.placeholder(tf.float32, shape =()) 
cost = tf.reduce_sum(tf.pow(pred - y, 2)) + regularization_cost 
optimizer = tf.train.AdamOptimizer(learning_rate = 0.01).minimize(cost) 

... 

tv = tf.trainable_variables() 
s = tf.Session() 
r = s.run 

... 

while (...): 
    ... 

    reg_cost = RegularizationCost(tv) 
    r(optimizer, feed_dict = {x: x_b, y: y_b, regularization_cost: reg_cost}) 

Das Problem, das ich habe ist, dass der Regularisierungsterm enorm Hinzufügen des Lernprozess verlangsamt und tatsächlich die Regularisierungsterm reg_cost wird sichtbar bei jeder Iteration zu erhöhen, wenn der Begriff im Zusammenhang mit pred - y ziemlich stagniert, dh die reg_cost scheint nicht berücksichtigt zu werden.

Wie ich vermute, füge ich diesen Begriff auf völlig falsche Weise hinzu. Ich wusste nicht, wie man diesen Begriff in der Kostenfunktion selbst hinzufügt, also benutzte ich einen Workaround mit skalarem tf.placeholder und "manuell" berechnete die Regularisierungskosten. Wie man es richtig macht?

Antwort

2

den L2-Verlust nur einmal berechnen:

tv = tf.trainable_variables() 
regularization_cost = tf.reduce_sum([ tf.nn.l2_loss(v) for v in tv ]) 
cost = tf.reduce_sum(tf.pow(pred - y, 2)) + regularization_cost 
optimizer = tf.train.AdamOptimizer(learning_rate = 0.01).minimize(cost) 

Sie die Variablen entfernen möchten, die bias wie die nicht legalisiert werden sollte.

0

Es verlangsamt, weil Ihr Code neue Knoten in jeder Iteration erstellt. So kodieren Sie nicht mit TF. Zuerst erstellen Sie Ihr gesamtes Diagramm, einschließlich Regularisierungsbedingungen, und führen sie dann in der while-Schleife nur aus. Jede "tf.XXX" -Operation erstellt neue Knoten.

Verwandte Themen