2017-11-22 1 views
0

Ich habe ein Netzwerk, das ich auf einigen Datensätzen trainieren möchte (als Beispiel, sagen CIFAR10). Ich kann Data Loader Objekt überUntergruppen eines Pytorch-Datensatzes nehmen

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
             download=True, transform=transform) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
              shuffle=True, num_workers=2) 

erstellen Meine Frage lautet wie folgt: Angenommen, ich möchte mehrere verschiedene Trainings-Iterationen machen. Nehmen wir an, ich möchte zuerst das Netzwerk auf allen Bildern in ungeraden Positionen, dann auf allen Bildern in geraden Positionen und so weiter trainieren. Um das zu tun, muss ich auf diese Bilder zugreifen können. Leider scheint trainset solchen Zugriff nicht zuzulassen. Das heißt, der Versuch, trainset[:1000] oder allgemeiner trainset[mask] zu tun, wird einen Fehler auslösen.

ich tun könnte, statt

trainset.train_data=trainset.train_data[mask] 
trainset.train_labels=trainset.train_labels[mask] 

und dann

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
               shuffle=True, num_workers=2) 

jedoch, dass mich zwingen, wird eine neue Kopie des vollständigen Datensatzes in jeder Iteration zu erzeugen (wie ich schon trainset.train_data so änderte ich muss neu definiert werden trainset). Gibt es einen Weg, es zu vermeiden?

Im Idealfall würde Ich mag etwas „gleichwertig“

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4, 
               shuffle=True, num_workers=2) 

Antwort

1

Sie haben können einen benutzerdefinierten Sampler für den Datensatz loader definieren Vermeidung des Datensatzes neu zu erstellen (nur einen neuen Lader für jede unterschiedliche Sampling erstellen).

class YourSampler(Sampler): 
    def __init__(self, mask): 
     self.mask = mask 

    def __iter__(self): 
     return (self.indices[i] for i in torch.nonzero(self.mask)) 

    def __len__(self): 
     return len(self.mask) 

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
             download=True, transform=transform) 

sampler1 = YourSampler(your_mask) 
sampler2 = YourSampler(your_other_mask) 
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4, 
              sampler = sampler1, shuffle=True, num_workers=2) 
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4, 
              sampler = sampler2, shuffle=True, num_workers=2) 

PS: Ich habe den Code nicht überprüft.

PS2: Sie können hier weitere Informationen finden: http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

+1

Dank! Eine kleine Bemerkung: Scheinbar Sampler ist nicht kompatibel mit Shuffle, also um das gleiche Ergebnis zu erreichen, kann man tun: fackel.utils.data.DataLoader (trainset, batch_size = 4, sampler = SubsetRandomSampler (np.where (Maske) 0]), shuffle = Falsch, num_workers = 2) –