2017-12-21 9 views
1

Ich versuche, mehrere torch.utils.data.DataLoader s zu verwenden, um Datensätze zu erstellen, auf die verschiedene Transformationen angewendet werden. Derzeit ist mein Code grobPyTorch DataLoader

d_transforms = [ 
    transforms.RandomHorizontalFlip(), 
    # Some other transforms... 
] 
loaders = [] 
for i in range(len(d_transforms)): 
    dataset = datasets.MNIST('./data', 
      train=train, 
      download=True, 
      transform=d_transforms[i] 
    loaders.append(
     DataLoader(dataset, 
      shuffle=True, 
      pin_memory=True, 
      num_workers=1) 
     ) 

Dies funktioniert, aber es ist extrem langsam. kernprof zeigt, dass fast die ganze Zeit in meinem Code auf den Leitungen wie

x, y = next(iter(train_loaders[i])) 

ausgegeben Ich vermute, dass dies auf die Tatsache zurückzuführen ist, dass ich mehrere Instanzen von DataLoader, jede mit ihren eigenen Arbeitern, bin mit der versucht, um die gleichen Datendateien zu lesen.

Meine Frage ist, was ist ein besserer Weg, dies zu tun? Idealerweise würde ich torch.utils.data.DataSet von der Unterklasse torch.utils.data.DataSet abziehen und die Transformation angeben, die ich anwenden möchte beim Stichprobenverfahren, aber dies scheint nicht möglich zu sein, weil __getitem__ Argumente nicht entgegennehmen kann.

+0

Ja, wenn Sie einen vorschlagen könnten, der großartig wäre. Wie gesagt, ein besserer Weg ist das, wonach ich suche. – Coolness

+0

Ich versuche, mehrere abgeleitete Datensätze von einem zu erstellen und generalisieren über sie. – Coolness

+0

Ich möchte lieber nicht über die Einzelheiten meiner Arbeit sprechen, noch reicht das Kommentarfeld aus, um es vollständig zu erklären. Ich sehe nicht, wie es für die Frage relevant ist. – Coolness

Antwort

0

__getitem__ nimmt ein Argument, das der Index des Inhalts ist, den Sie laden möchten. Für z.

transform = transforms.Compose(
    [transforms.ToTensor(), 
    normalize]) 

class CountDataset(Dataset): 

def __init__(self, file,transform=None): 

    self.transform = transform 
    #self.vocab = vocab 
    with open(file,'rb') as f: 
     self.data = pickle.load(f) 
    self.y = self.data['answers'] 
    self.I = self.data['images'] 


def __len__(self): 
    return len(self.y) 

def __getitem__(self, idx): 
    img_name = self.I[idx] 
    label = self.y[Idx] 
    fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg 
    DIR = '/hdd/manoj/VQA/Images/mscoco/' 
    img_full_path = os.path.join(DIR,fname) 
    img = Image.open(img_full_path).convert("RGB") 
    img_tensor = self.transform(img.resize((224,224))) 
    return img_tensor,label 


testset = CountDataset(file = 'testdat.pkl', 
         transform = transform) 


testloader = DataLoader(testset, batch_size=32, 
         shuffle=False, num_workers=4) 

Sie rufen den Datenlader nicht in Schleife auf.