In Double DQN (implementiert in CNTK), ich versuche, den Wert der nächsten Staaten (post_state_var) mithilfe des Online-Modells zu berechnen. Um meine Lösung zu vektorisieren, habe ich die one_hot-Operation verwendet. Allerdings bekomme ich den folgenden Fehler, wenn ich versuche zu trainieren:OneHot op von backpropagation ausschließen
Der Knoten "OneHot" kann im Training verwendet werden, aber es nicht in der Gradientenausbreitung teilnehmen.
ich mein Modell und Eingänge wie definiert haben:
state_var = cntk.input_variable(state_shape, name='state')
action_var = cntk.input_variable(1, name='action')
reward_var = cntk.input_variable(1, name='reward')
post_state_var = cntk.input_variable(state_shape, name='post_state')
terminal_var = cntk.input_variable(1, name='terminal')
with cntk.default_options(activation=relu):
model_fn = Sequential([
Dense(32, name='h1'),
Dense(32, name='h2'),
Dense(action_shape, name='action')
])
model = model_fn(state_var)
target_model = model.clone(cntk.CloneMethod.freeze)
ich die Zielwerte dann berechnen und den Verlust wie folgt definieren:
# Value of action selected at state t
state_value = cntk.reduce_sum(model * one_hot(action_var, num_classes=action_shape), axis=1)
# Double Q learning - Value of action selected at state t+1
online_post_state_model = model_fn(post_state_var)
online_post_state_best_action = cntk.argmax(online_post_state_model)
post_state_best_value = cntk.reduce_sum(target_model *
one_hot(online_post_state_best_action, num_classes=action_shape))
gamma = 0.99
target = reward_var + (1.0 - terminal_var) * gamma * post_state_best_value
# MSE for simplicity
td_error = state_value - cntk.stop_gradient(target)
loss = cntk.reduce_mean(cntk.square(td_error))
Wenn ich
online_post_state_model = model_fn(post_state_var)
ersetzen
mit
dann ist der Fehler weg, aber das ist falsch, da es ein altes eingefrorenes Modell verwendet, um das Ziel zu berechnen. Wie kann ich model_fn
mit post_state_var
auswerten und die Ausgabe von der Rückpropagation ausschließen? Benütze ich stop_gradient
nicht richtig?