2017-09-07 2 views
0
import tensorflow as tf 

slim = tf.contrib.slim 


def create_learning_rate(curr_step, lr_config): 
    base_lr = lr_config.get('base_lr', 0.1) 
    decay_steps = lr_config.get('decay_steps', []) 
    decay_rate = lr_config.get('decay_rate', 0.1) 

    scale_rates = [ 
     lambda: tf.constant(decay_rate**i, dtype=tf.float32) 
     for i in range(len(decay_steps) + 1) 
    ] 

    conds = [] 
    prev = -1 
    for decay_step in decay_steps: 
     conds.append(tf.logical_and(curr_step > prev, curr_step <= decay_step)) 
     prev = decay_step 
    conds.append(curr_step > decay_steps[-1]) 

    learning_rate_scale = tf.case(
     list(zip(conds, scale_rates)), lambda: 0.0, exclusive=True) 
    return learning_rate_scale * base_lr 


global_step = slim.create_global_step() 
train_op = tf.assign_add(global_step, 1) 
lr = create_learning_rate(
    global_step, {"base_lr": 0.1, 
       "decay_steps": [10, 20], 
       "decay_rate": 0.1}) 

with tf.Session() as sess: 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    for i in range(30): 
     curr_lr, step, _ = sess.run([lr, global_step, train_op]) 
     print(curr_lr, step) 

Ich möchte die Lernrate zu bestimmten Zeiten zu verfallen. Es ist jedoch immer 0,001. Irgendwelche Ideen? Oder gibt es eine bessere Methode, die Lernrate anzupassen?Lernrate Anpassung in Tensorflow

Danke für Ihre Hilfe.

+0

warum nicht Ihre Lernrate als Platzhalter Satz und bei jeder Iteration Feed jeden Wert, den Sie es wollen? –

+0

Ja, es ist auch eine richtige Lösung. –

Antwort

0

Dies liegt daran, dass die Lambda-Funktion die Variable als Referenz statt als Wert erfasst.

So ist die richtige Art und Weise ist

def create_learning_rate(global_step, lr_config): 
    base_lr = lr_config.get('base_lr', 0.1) 
    decay_steps = lr_config.get('decay_steps', []) 
    decay_rate = lr_config.get('decay_rate', 0.1) 

    prev = -1 
    scale_rate = 1.0 

    cases = [] 
    for decay_step in decay_steps: 
     cases.append((tf.logical_and(global_step > prev, 
            global_step <= decay_step), 
        lambda v=scale_rate: v)) 
     scale_rate *= decay_rate 
     prev = decay_step 
    cases.append((global_step > decay_step, lambda v=scale_rate: v)) 
    learning_rate_scale = tf.case(cases, lambda: 0.0, exclusive=True) 
    return learning_rate_scale * base_lr 
Verwandte Themen