Sie bekommen das endgültige dreidimensionale Ergebnis E
ohne das Erzeugen des großen Zwischenarrays unter Verwendung von batched_dot
:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (D, N, H)
B = tt.tensor3('B') # B.shape = (D, H, K)
E = tt.batched_dot(A, B) # E.shape = (D, N, K)
Leider erfordert dies, dass Sie die Dimensionen auf Ihren Eingabe- und Ausgabearrays permutieren. Obwohl dies kann mit dimshuffle
in Theano geschehen scheint es batched_dot
nicht willkürlich strided Arrays bewältigen können und so stellt sich die folgenden ein ValueError: Some matrix has no unit stride
wenn E
ausgewertet:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_perm = A.dimshuffle((2, 0, 1)) # A_perm.shape = (D, N, H)
B_perm = B.dimshuffle((2, 1, 0)) # B_perm.shape = (D, H, K)
E_perm = tt.batched_dot(A_perm, B_perm) # E_perm.shape = (D, N, K)
E = E_perm.dimshuffle((1, 2, 0)) # E.shape = (N, K, D)
batched_dot
verwendet scan
entlang der ersten (Größe D
) Dimension . Da scan
sequentiell durchgeführt wird, könnte dies rechnerisch weniger effizient sein als die parallele Berechnung aller Produkte, wenn diese auf einer GPU laufen.
Sie können zwischen der Speichereffizienz des batched_dot
Ansatzes und der Parallelität im Broadcast-Ansatz unter Verwendung von scan
explizit abwägen. Idee würde das vollständige Produkt C
für Chargen von Größe M
parallel zu berechnen sein (unter der Annahme, M
ist ein exakter Faktor D
), iterieren Chargen mit scan
:
import theano as th
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_batched = A.reshape((N, H, M, D/M))
B_batched = B.reshape((K, H, M, D/M))
E_batched, _ = th.scan(
lambda a, b: (a[:, :, None, :] * b[:, :, :, None]).sum(1),
sequences=[A_batched.T, B_batched.T]
)
E = E_batched.reshape((D, K, N)).T # E.shape = (N, K, D)
Welcher Dimension wollen Sie über zusammenzuzufassen? Die erste, 0? oder 'H', das ist das zweitletzte in den ursprünglichen Arrays? – hpaulj
In 'numpy' könnte dies ausgedrückt werden als 'np.einsum (' nhd, khd-> nkd ', A, B)' – hpaulj
Ich würde es gerne über H machen. Das sollte Summe sein (1) unter der Annahme, dass der Tensor hat vor der Addition die Form (1, H, D). – Theo