2016-06-18 5 views
2

Ich versuche, Verschachtelung für Schleifen mit der Scan-Funktion zu simulieren, aber das ist langsam. Gibt es eine bessere Möglichkeit, die Verschachtelung von Schleifen mit Tensorflow zu simulieren? Ich mache diese Berechnung nicht nur mit numpy, damit ich automatisch differenzieren kann.Emuliert verschachtelte für Schleifen mit Scan ist langsam

Insbesondere falte ich über ein Bild mit einem bilateralen Filter alle während der Verwendung von Tensorflow-Steuerelement ops. Um dies zu erreichen, habe ich scan() -Funktionen verschachtelt, aber das lässt mich mit einer bemerkenswert schlechten Leistung - die Filterung eines kleinen Bildes dauert mehr als 5 Minuten.

Gibt es einen besseren Weg als die Verschachtelung von Scan-Funktionen und wie stark benutze ich Tensorflow-Kontrollfluss-Operationen? Ich interessiere mich für allgemeine Antworten mehr als eine spezifische für meinen Code. Hier

ist das Original, schneller Code, wenn Sie es sehen wollen:

def bilateralFilter(image, sigma_space=1, sigma_range=None, win_size=None): 

    if sigma_range is None: 
     sigma_range = sigma_space 
    if win_size is None: win_size = max(5, 2 * int(np.ceil(3*sigma_space)) + 1) 

    win_ext = (win_size - 1)/2 
    height = image.shape[0] 
    width = image.shape[1] 

    # pre-calculate spatial_gaussian 
    spatial_gaussian = [] 
    for i in range(-win_ext, win_ext+1): 
     for j in range(-win_ext, win_ext+1): 
      spatial_gaussian.append(np.exp(-0.5*(i**2+j**2)/sigma_space**2)) 

    padded = np.pad(image, win_ext, mode="edge") 

    out_image = np.zeros(image.shape) 
    weight = np.zeros(image.shape) 

    idx = 0 
    for row in xrange(-win_ext, 1+win_ext): 
     for col in xrange(-win_ext, 1+win_ext): 
      slice = padded[win_ext+row:height+win_ext+row, 
              win_ext+col:width+win_ext+col] 
      value = np.exp(-0.5*((image - slice)/sigma_range)**2) \ 
        * spatial_gaussian[idx] 
      out_image += value*slice 
      weight += value 
      idx += 1 

    out_image /= weight 

    return out_image 

Dies ist die Tensorflow Version:

sess = tf.InteractiveSession() 
with sess.as_default(): 
    def bilateralFilter(image, sigma_space, sigma_range): 
     win_size = max(5., 2 * np.ceil(3 * sigma_space) + 1) 

     win_ext = int((win_size - 1)/2) 
     height = tf.shape(image)[0].eval() 
     width = tf.shape(image)[1].eval() 

     spatial_gaussian = [] 
     for i in range(-win_ext, win_ext + 1): 
      for j in range(-win_ext, win_ext + 1): 
       spatial_gaussian.append(np.exp(-0.5 * (i ** 2 +\ 
       j ** 2)/sigma_space ** 2)) 

     # we use "symmetric" as it best approximates "edge" padding 
     padded = tf.pad(image, [[win_ext, win_ext], [win_ext, win_ext]], 
       mode='SYMMETRIC') 
     out_image = tf.zeros(tf.shape(image)) 
     weight = tf.zeros(tf.shape(image)) 

     spatial_index = tf.constant(0) 
     row = tf.constant(-win_ext) 
     col = tf.constant(-win_ext) 

     def cond(padded, row, col, weight, out_image, spatial_index): 
      return tf.less(row, win_ext + 1) 

     def body(padded, row, col, weight, out_image, spatial_index): 
      sub_image = tf.slice(padded, [win_ext + row, win_ext + col], 
         [height, width]) 
      value = tf.exp(-0.5 * 
        (((image - sub_image)/sigma_range) ** 2)) * 
        spatial_gaussian[spatial_index.eval()] 
      out_image += value * sub_image 
      weight += value 
      spatial_index += 1 
      row, col = tf.cond(tf.not_equal(tf.mod(col, 
           tf.constant(2*win_ext + 1)), 0), 
           lambda: (row + 1, tf.constant(-win_ext)), 
           lambda: (row, col)) 
      return padded, row, col, weight, out_image, spatial_index 

     padded, row, col, weight, out_image, spatial_index = 
     tf.while_loop(cond, body, 
     [padded, row, col, weight, out_image, spatial_index]) 
     out_image /= weight 

     return out_image 

    cat = plt.imread("cat.png") # grayscale 
    cat = tf.reshape(tf.constant(cat), [276, 276]) 
    cat_blurred = bilateralFilter(cat, 2., 0.25) 
    cat_blurred = cat_blurred.eval() 
    plt.figure() 
    plt.gray() 
    plt.imshow(cat_blurred) 
    plt.show() 

Antwort

1

Hier ist ein Problem mit Ihrem Code. cols() hat eine Menge Python-Globals, und Sie schienen zu erwarten, dass sie bei jeder Schleifeniteration aktualisiert werden. Sehen Sie sich das TensorFlow-Tutorial zur Konstruktion und Ausführung von Graphen an. Kurz gesagt, diese Python-Globals und ihr zugehöriger Code werden nur zur Entwurfszeit des Graphen ausgeführt, und sie sind nicht einmal in TensorFlows Ausführungsdiagramm. Eine Operation kann nur in das Ausführungsdiagramm eingeschlossen werden, wenn es sich um einen tf-Operator handelt.

Es scheint auch, dass tf.while_loop besser für Ihren Code als Scan geeignet ist.

+0

Ich habe den Code aktualisiert, aber ich nehme an, dass ich soweit gekommen bin, weil ich mit Tensoren indexiere und Tensorflow das nicht erlaubt. "ValueError: Fetch argument ... wurde als nicht erreichbar markiert." – TFUser

Verwandte Themen