2017-09-05 7 views
1

ich folgenden Code haben:TensorFlow: Wie für jede Zeile Subarray bekommen in Tensor

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int32, shape=[None]) 
useful_series = tf.magic_slice_function(series, series_length) 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 
    input_y = np.array([[3], [2], [1]]) 
    print(sess.run(useful_series, feed_dict={series: input_x, series_length: input_y})) 

Erwartete Ausgabe als

folgende

[[1,2,3], [2,3], [1]]

Ich habe mehrere Funktionen, usw. tf.gather, tf.slice ausprobiert. Sie alle funktionieren nicht. Was ist die magic_slice_function?

+1

Wahrscheinlich müssen Sie diese außerhalb Tensorflow, tun, da, was Sie wollen nicht ein Tensor ist zu erhalten. –

Antwort

1

Es ist ein wenig knifflig:

import numpy as np 
import tensorflow as tf 

series = tf.placeholder(tf.float32, shape=[None, 5]) 
series_length = tf.placeholder(tf.int64) 

def magic_slice_function(input_x, input_y): 
    array = [] 
    for i in range(len(input_x)): 
     temp = [input_x[i][j] for j in range(input_y[i])] 
     array.extend(temp) 
    return [array] 

with tf.Session() as sess: 
    input_x = np.array([[1, 2, 3, 0, 0], 
         [2, 3, 0, 0, 0], 
         [1, 0, 0, 0, 0]]) 

    input_y = np.array([3, 2, 1], dtype=np.int64) 

    merged_series = tf.py_func(magic_slice_function, [series, series_length], tf.float32, name='slice_func') 

    out = tf.split(merged_series, input_y) 
    print(sess.run(out, feed_dict={series: input_x, series_length: input_y})) 

Der Ausgang wird sein:

[array([ 1., 2., 3.], dtype=float32), array([ 2., 3.], dtype=float32), array([ 1.], dtype=float32)]