2017-06-07 4 views
0

Ich bin neu bei PyTorch, probieren Sie es aus, nachdem Sie für eine Weile ein anderes Toolkit verwendet haben.pytorch benutzerdefinierte Ebene "ist keine Modul-Unterklasse"

Ich möchte verstehen, wie benutzerdefinierte Layer und Funktionen programmieren. Und als einfacher Test, schrieb ich folgendes:

class Testme(nn.Module):   ## it _is_ a sublcass of module ## 
    def __init__(self): 
     super(Testme, self).__init__() 

    def forward(self, x): 
     return x/t_.max(x) 

die dazu bestimmt ist, die Daten zu veranlassen, durchquert auf 1 zu summieren nicht wirklich nützlich, nur bei Test.

Dann steck ich es zum Beispiel Code aus dem PyTorch Spielplatz:

def make_layers(cfg, batch_norm=False): 
    layers = [] 
    in_channels = 3 
    for i, v in enumerate(cfg): 
     if v == 'M': 
      layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 
     else: 
      padding = v[1] if isinstance(v, tuple) else 1 
      out_channels = v[0] if isinstance(v, tuple) else v 
      conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding) 
      if batch_norm: 
       layers += [conv2d, nn.BatchNorm2d(out_channels, affine=False), nn.ReLU()] 
      else: 
       layers += [conv2d, nn.ReLU()] 
      layers += [Testme]       # here <------------------ 
      in_channels = out_channels 
    return nn.Sequential(*layers) 

Das Ergebnis ist ein Fehler!

TypeError: model.Testme is not a Module subclass 

Vielleicht muss das eine Funktion und kein Modul sein? Auch nicht klar, was der Unterschied zwischen Funktion, Modul ist.

Zum Beispiel, warum benötigt eine Funktion eine backward(), auch wenn sie vollständig aus Standard-Primitive pytorch konstruiert ist, während ein Modul dies nicht benötigt?

Antwort

0

Das ist ein einfacher. Du hast es fast geschafft, aber du hast vergessen, eine Instanz deiner neuen Klasse Testme zu erstellen. Sie müssen dies tun, auch wenn die Erstellung einer Instanz einer bestimmten Klasse keine Parameter annimmt (wie bei Testme). Aber es ist leichter zu vergessen als für eine Faltungsschicht, auf die Sie typischerweise viele Argumente anwenden.

Ändern Sie die angegebene Zeile wie folgt und Ihr Problem ist behoben.

layers += [Testme()] 
Verwandte Themen