0

In Double DQN (implementiert in CNTK), ich versuche, den Wert der nächsten Staaten (post_state_var) mithilfe des Online-Modells zu berechnen. Um meine Lösung zu vektorisieren, habe ich die one_hot-Operation verwendet. Allerdings bekomme ich den folgenden Fehler, wenn ich versuche zu trainieren:OneHot op von backpropagation ausschließen

Der Knoten "OneHot" kann im Training verwendet werden, aber es nicht in der Gradientenausbreitung teilnehmen.

ich mein Modell und Eingänge wie definiert haben:

state_var = cntk.input_variable(state_shape, name='state') 
action_var = cntk.input_variable(1, name='action') 
reward_var = cntk.input_variable(1, name='reward') 
post_state_var = cntk.input_variable(state_shape, name='post_state') 
terminal_var = cntk.input_variable(1, name='terminal') 

with cntk.default_options(activation=relu): 
    model_fn = Sequential([ 
     Dense(32, name='h1'), 
     Dense(32, name='h2'), 
     Dense(action_shape, name='action') 
    ]) 

model = model_fn(state_var) 
target_model = model.clone(cntk.CloneMethod.freeze) 

ich die Zielwerte dann berechnen und den Verlust wie folgt definieren:

# Value of action selected at state t 
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1) 

# Double Q learning - Value of action selected at state t+1 
online_post_state_model = model_fn(post_state_var) 
online_post_state_best_action = cntk.argmax(online_post_state_model) 
post_state_best_value = cntk.reduce_sum(target_model * 
             one_hot(online_post_state_best_action, num_classes=action_shape)) 

gamma = 0.99 
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value 

# MSE for simplicity 
td_error = state_value - cntk.stop_gradient(target) 
loss = cntk.reduce_mean(cntk.square(td_error)) 

Wenn ich

online_post_state_model = model_fn(post_state_var) 
ersetzen

mit

dann ist der Fehler weg, aber das ist falsch, da es ein altes eingefrorenes Modell verwendet, um das Ziel zu berechnen. Wie kann ich model_fn mit post_state_var auswerten und die Ausgabe von der Rückpropagation ausschließen? Benütze ich stop_gradient nicht richtig?

Antwort

0

Die typische Verwendung von one_hot ist für Eingabedaten, für die Sie normalerweise keine Backpropagate benötigen.

Eine Abhilfe wäre, die Aktionen als ein Hot-Vektoren in Ihrem Diagramm zu halten. Sie können dies tun, indem Sie hardmax anstelle von argmax verwenden.

Verwandte Themen