2017-02-10 2 views
1

TL; DR
Ich möchte die Funktionalität von numpy.matmul in theano replizieren. Was ist der beste Weg, dies zu tun?numpy.matmul in Theano

Zu kurz; Ich habe nicht verstanden
Betrachtet man theano.tensor.dot und theano.tensor.tensordot, sehe ich keinen einfachen Weg, um eine einfache Batch-Matrix-Multiplikation zu tun. h., behandeln die letzten zwei Dimensionen von N-dimensionalen Tensoren als Matrizen und multiplizieren sie. Muss ich auf etwas doof Verwendung von theano.tensor.batched_dot zurückgreifen? Oder * schauder * schleife sie selbst ohne Sendung !?

+0

Sieht so aus, als ob sie [noch daran arbeiten] (https://github.com/Theano/Theano/pull/3769). – user2357112

+0

Richtig, ich sah eine Reihe von Pull-Anfragen. Aber das scheint ein ziemlich universelles mathematisches Bedürfnis zu sein? Ich frage mich nur, was die aktuelle "Best Practice" ist; Ich bin mir sicher, dass es jemand tut. – azane

+0

Für jetzt habe ich gerade die 'Matmul'-Funktion aus diesem [PR] (https://github.com/Theano/Theano/pull/3769/files#diff-73defb19c53e8c96044c9d15c8a9d064) geschnappt. – azane

Antwort

1

Die aktuellen Pull-Requests unterstützen kein Broadcasting, daher habe ich das vorerst gemacht. Ich kann es aufräumen, ein wenig mehr Funktionalität hinzufügen und meine eigene PR als temporäre Lösung einreichen. Bis dahin hoffe ich, dass dies jemandem hilft! Ich fügte den Test hinzu, um zu zeigen, dass er numpy.matmul repliziert, vorausgesetzt, dass die Eingabe meinen strengeren (temporären) Behauptungen entspricht.

Auch, .Scan stoppt die Iteration der Sequenzen bei argmin(*sequencelengths) Iterationen. Also glaube ich, dass nicht übereinstimmende Array Shapes keine Ausnahmen auslösen werden.

import theano as th 
import theano.tensor as tt 
import numpy as np 


def matmul(a: tt.TensorType, b: tt.TensorType, _left=False): 
    """Replicates the functionality of numpy.matmul, except that 
    the two tensors must have the same number of dimensions, and their ndim must exceed 1.""" 

    # TODO ensure that broadcastability is maintained if both a and b are broadcastable on a dim. 

    assert a.ndim == b.ndim # TODO support broadcasting for differing ndims. 
    ndim = a.ndim 
    assert ndim >= 2 

    # If we should left multiply, just swap references. 
    if _left: 
     tmp = a 
     a = b 
     b = tmp 

    # If a and b are 2 dimensional, compute their matrix product. 
    if ndim == 2: 
     return tt.dot(a, b) 
    # If they are larger... 
    else: 
     # If a is broadcastable but b is not. 
     if a.broadcastable[0] and not b.broadcastable[0]: 
      # Scan b, but hold a steady. 
      # Because b will be passed in as a, we need to left multiply to maintain 
      # matrix orientation. 
      output, _ = th.scan(matmul, sequences=[b], non_sequences=[a[0], 1]) 
     # If b is broadcastable but a is not. 
     elif b.broadcastable[0] and not a.broadcastable[0]: 
      # Scan a, but hold b steady. 
      output, _ = th.scan(matmul, sequences=[a], non_sequences=[b[0]]) 
     # If neither dimension is broadcastable or they both are. 
     else: 
      # Scan through the sequences, assuming the shape for this dimension is equal. 
      output, _ = th.scan(matmul, sequences=[a, b]) 
     return output 


def matmul_test() -> bool: 
    vlist = [] 
    flist = [] 
    ndlist = [] 
    for i in range(2, 30): 
     dims = int(np.random.random() * 4 + 2) 

     # Create a tuple of tensors with potentially different broadcastability. 
     vs = tuple(
      tt.TensorVariable(
       tt.TensorType('float64', 
           tuple((p < .3) for p in np.random.ranf(dims-2)) 
           # Make full matrices 
           + (False, False) 
       ) 
      ) 
      for _ in range(2) 
     ) 
     vs = tuple(tt.swapaxes(v, -2, -1) if j % 2 == 0 else v for j, v in enumerate(vs)) 

     f = th.function([*vs], [matmul(*vs)]) 

     # Create the default shape for the test ndarrays 
     defshape = tuple(int(np.random.random() * 5 + 1) for _ in range(dims)) 
     # Create a test array matching the broadcastability of each v, for each v. 
     nds = tuple(
      np.random.ranf(
       tuple(s if not v.broadcastable[j] else 1 for j, s in enumerate(defshape)) 
      ) 
      for v in vs 
     ) 
     nds = tuple(np.swapaxes(nd, -2, -1) if j % 2 == 0 else nd for j, nd in enumerate(nds)) 

     ndlist.append(nds) 
     vlist.append(vs) 
     flist.append(f) 

    for i in range(len(ndlist)): 
     assert np.allclose(flist[i](*ndlist[i]), np.matmul(*ndlist[i])) 

    return True 


if __name__ == "__main__": 
    print("matmul_test -> " + str(matmul_test())) 
+0

Auch die zufällige Natur des Tests ist wahrscheinlich nicht ideal. ;) Aber ich war nicht in der Stimmung, Array-Shapes manuell zu hacken. – azane