PyTorch: Tests mit torchvision.datasets.ImageFolder und DataLoader
Ich bin ein Neuling versuchen, diese PyTorch CNN Arbeit mit der Katzen&Hunde-dataset von kaggle. Da es keine Ziele für die test-Bilder, die ich manuell klassifiziert einige der test-Bilder und legen Sie die Klasse im Dateinamen, in der Lage sein, um zu testen (vielleicht sollte nur verwendet einige der Zug-Bilder).
Benutzte ich die torchvision.datasets.ImageFolder-Klasse zu laden, die Zug-und test-Bilder. Das training scheint zu funktionieren.
Aber was muss ich tun, um die test-routine-arbeiten? Ich weiß nicht, wie meine Verbindung test_data_loader mit der test-Schleife an der Unterseite, über test_x und test_y.
Den Code basiert auf diese MNIST Beispiel CNN. Es, so etwas wie dieses verwendet wird, direkt nach dem Lader erstellt werden. Aber ich konnte nicht umschreiben, dass es für mein dataset:
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]
Code:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import transforms
EPOCHS = 2
BATCH_SIZE = 10
LEARNING_RATE = 0.003
TRAIN_DATA_PATH = "./train_cl/"
TEST_DATA_PATH = "./test_named_cl/"
TRANSFORM_IMG = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225] )
])
train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=TRANSFORM_IMG)
test_data_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
class CNN(nn.Module):
# omitted...
if __name__ == '__main__':
print("Number of train samples: ", len(train_data))
print("Number of test samples: ", len(test_data))
print("Detected Classes are: ", train_data.class_to_idx) # classes are detected by folder structure
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss()
# Training and Testing
for epoch in range(EPOCHS):
for step, (x, y) in enumerate(train_data_loader):
b_x = Variable(x) # batch x (image)
b_y = Variable(y) # batch y (target)
output = model(b_x)[0]
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Test -> this is where I have no clue
if step % 50 == 0:
test_x = Variable(test_data_loader)
test_output, last_layer = model(test_x)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / float(test_y.size(0))
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
InformationsquelleAutor kett | 2018-03-02
Du musst angemeldet sein, um einen Kommentar abzugeben.
Blick auf die Daten von Kaggle und der code, es scheint, dass es Probleme in Ihrem laden von Daten, sowohl für Zug-und test-set. Zunächst sollten die Daten werden in einen anderen Ordner pro Etikett für die Standard-PyTorch
ImageFolder
zu laden richtig. In Ihrem Fall, da alle Trainingsdaten in den gleichen Ordner, PyTorch ist, laden Sie es als eine Klasse und somit das lernen scheint zu funktionieren. Sie können dies korrigieren, indem Sie eine Ordner-Struktur wie -train/dog
, -train/cat
, -test/dog
, -test/cat
und dann vorbei am Bahnhof und den test-Ordner zum trainieren und testenImageFolder
beziehungsweise. Die Ausbildung der code scheint in Ordnung, ändern Sie einfach die Ordner-Struktur und Sie sollte gut sein. Werfen Sie einen Blick auf die offizielle Dokumentation von ImageFolder, die hat ein ähnliches Beispiel.Ich glaube nicht, dass es irgendwelche direkten Weg zu tun, noch (das weiß ich). Aber Sie können erweitern die dataset-Klasse und das tun. Hier ist ein Beispiel aus einer offiziellen PyTorch tutorial. Sie können es ändern, und entfernen Sie das zusätzliche Zeug und es sollte funktionieren. Lassen Sie mich wissen, ob das funktioniert für Sie.
InformationsquelleAutor Monster
Als pro @Monster ' s Kommentar oben, hier ist meine Ordner-Struktur für ImageFolder
- Und dies, wie lade ich das dataset:
InformationsquelleAutor Mo-Gang