2017-08-25 1 views
1

Ich muss eine Variable epsilon_n erstellen, die die Definition (und den Wert) basierend auf dem aktuellen step ändert. Da ich mehr als zwei Fälle habe, scheint es, dass ich tf.cond nicht verwenden kann. Ich versuche, tf.case wie folgt zu verwenden:Tesnorflow: Kann tf.case nicht mit Eingabeargument verwenden

import tensorflow as tf 

#### 
EPSILON_DELTA_PHASE1 = 33e-4 
EPSILON_DELTA_PHASE2 = 2.5 
#### 
step = tf.placeholder(dtype=tf.float32, shape=None) 


def fn1(step): 
    return tf.constant([1.]) 

def fn2(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), lambda step: fn1(step)), 
      (tf.less(step, 6e4), lambda step: fn2(step)), 
      (tf.less(step, 1e5), lambda step: fn3(step))], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

Aber ich erhalte immer diese Fehlermeldung:

TypeError: <lambda>() missing 1 required positional argument: 'step' 

Ich habe versucht, die folgenden:

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), fn1), 
      (tf.less(step, 6e4), fn2), 
      (tf.less(step, 1e5), fn3)], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

Noch würde ich die gleichen Fehler . Die Beispiele in der Tensorflow-Dokumentation berücksichtigen Fälle, in denen kein Eingabeargument an die aufrufbaren Funktionen übergeben wird. Ich konnte nicht genug Informationen über tf.case im Internet finden! Bitte irgendeine Hilfe?

Antwort

2

Hier sind einige Änderungen, die Sie vornehmen müssen. Aus Gründen der Konsistenz können Sie alle Rückgabewerte als Variable festlegen.

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=()) 


def fn1(step): 
    return tf.constant([1.]) 

# Here you need to use Variable not constant, since you are modifying the value using placeholder 
def fn2(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
    pred_fn_pairs=[ 
     (tf.less(step, 3e4), lambda : fn1(step)), 
     (tf.less(step, 6e4), lambda : fn2(step)), 
     (tf.less(step, 1e5), lambda : fn3(step))], 
     default=lambda: tf.constant([1e5]), 
    exclusive=False) 
+0

kleinere Schreibfehler behoben –

Verwandte Themen