2017-02-02 3 views
2

Ich schreibe einen einfachen Code, um One-Hot-Codierung für eine Liste von Indizes zu berechnen. Bsp: [1,2,3] => [[0,1,0,0], [0,0,1,0], [0,0,0,1]]Wie wird über die Zeilen einer Theano-Matrix mit der Scan-Funktion in theano iteriert?

Ich schreibe a Funktion das gleiche für einen einzelnen Vektor zu tun:

n_val =4 
def encoding(x_t): 
    z = T.zeros((x_t.shape[0], n_val)) 
    one_hot = T.set_subtensor(z[T.arange(x_t.shape[0]), x_t], 1) 
    return one_hot 

um die gleiche Funktion über die Zeilen einer Matrix zu wiederholen, ich die folgenden,

x = T.imatrix() 
[m],_ = theano.scan(fn = encoding, sequences = x) 

Y = T.stacklists(m) 
f= theano.function([x],Y) 

ich mit jeder Scheibe einen 3D-Tensor erwarte entsprechend der Ein-Hot-Codierung der Zeilen der Matrix.

ich die folgende Störung erhalte, während die Funktion kompilieren,

/Library/Python/2.7/site-packages/theano/tensor/var.pyc in __iter__(self) 
594   except TypeError: 
595    # This prevents accidental iteration via builtin.sum(self) 
--> 596    raise TypeError(('TensorType does not support iteration. ' 
    597        'Maybe you are using builtin.sum instead of ' 
598        'theano.tensor.sum? (Maybe .max?)')) 

TypeError: TensorType does not support iteration. Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?) 

Kann jemand bitte helfen Sie mir zu verstehen, wo ich falsch werde und wie ich den Code ändern zu bekommen, was ich brauche?

Vielen Dank im Voraus.

+0

'fn' Argument' scan' arbeiten, um eine Liste zurückgeben muss – Kh40tiK

Antwort

1

Hier ist der Code, der

# input a matrix, expect scan to work with each row of matrix 
my_matrix = np.asarray([[1,2,3],[1,3,2],[1,1,1]]) 

x = T.imatrix() 

def encoding(idx): 
    z = theano.tensor.zeros((idx.shape[0], 4)) 
    one_hot = theano.tensor.set_subtensor(z[theano.tensor.arange(idx.shape[0]), idx], 1) 
    return one_hot 

m, update = theano.scan(fn=encoding, 
         sequences=x) 


f = theano.function([x], m) 

##########3 
result = f(my_matrix) 
print (result) 
Verwandte Themen