pytorch怎么加载自己的图片数据集
本文小编为大家详细介绍“pytorch怎么加载自己的图片数据集”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch怎么加载自己的图片数据集”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。
ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别
导入ImageFolder()包from torchvision.datasets import ImageFolder
在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。
ImageFolder 加载数据集
导入包和设置transform
import torchfrom torchvision import transforms, datasetsimport torch.nn as nnfrom torch.utils.data import DataLoader transforms = transforms.Compose([ transforms.Resize(256), # 将图片短边缩放至256,长宽比保持不变: transforms.CenterCrop(224), #将图片从中心切剪成3*224*224大小的图片 transforms.ToTensor() #把图片进行归一化,并把数据转换成Tensor类型])
加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。
path = r'D:\数据集\Flower_Orig_dataset'data_train = datasets.ImageFolder(path, transform=transforms)data_loader = DataLoader(data_train, batch_size=64, shuffle=True)for i, data in enumerate(data_loader): images, labels = data print(images.shape) print(labels.shape) break
使用pytorch提供的Dataset类创建自己的数据集。
具体步骤:
首先要有一个txt文件, 这个文件格式是: 图片路径 标签. 这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。
有了这个txt文件, 我们就可以在类里面构造我们的数据集.
1 把图片路径和图片标签分割开, 有两个列表, 一个列表是图片路径名, 一个列表是标签号, 有一点就是第 i 个图片列表和 第 i 个标签是对应的
重写__len__方法 和 __getitem__方法
1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可
就可以构建数据集实例和加载数据集.
定义一个用来生成[ 图片路径 标签] 这样的txt文件函数
def make_txt(root, file_name, label): path = os.path.join(root, file_name) data = os.listdir(path) f = open(path+'\\'+'f.txt', 'w') for line in data: f.write(line+' '+str(label)+'\n') f.close()#调用函数生成两个文件夹下的txt文件make_txt(path, file_name='flower_orig', label=0)make_txt(path, file_name='sunflower', label=1)
将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签
def link_txt(file1, file2): txt_list = [] path = r'D:\数据集\Flower_Orig_dataset\data.txt' f = open(path, 'a') f1 = open(file1, 'r') data1 = f1.readlines() for line in data1: txt_list.append(line) f2 = open(file2, 'r') data2 = f2.readlines() for line in data2: txt_list.append(line) for line in txt_list: f.write(line) f.close() f1.close() f2.close() #调用函数, 将两个文件夹下的txt文件合并file1 = r'D:\数据集\Flower_Orig_dataset\flower_orig\f.txt'file2 = r'D:\数据集\Flower_Orig_dataset\sunflower\f.txt'link_txt(file1=file1, file2=file2)
现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息
Dataset加载数据集
我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径 class MyDataset(Dataset): def __init__(self, img_path, transform=None): super(MyDataset, self).__init__() self.root = img_path self.txt_root = self.root + 'data.txt' f = open(self.txt_root, 'r') data = f.readlines() imgs = [] labels = [] for line in data: line = line.rstrip() word = line.split() imgs.append(os.path.join(self.root, word[1], word[0])) labels.append(word[1]) self.img = imgs self.label = labels self.transform = transform def __len__(self): return len(self.label) def __getitem__(self, item): img = self.img[item] label = self.label[item] img = Image.open(img).convert('RGB') #此时img是PIL.Image类型 label是str类型 if transforms is not None: img = self.transform(img) label = np.array(label).astype(np.int64) label = torch.from_numpy(label) return img, label
加载我们的数据集:
path = r'D:\数据集\Flower_Orig_dataset'dataset = MyDataset(path, transform=transform) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下来我们就可以构建我们的网络架构:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3,16,3) self.maxpool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(16,5,3) self.relu = nn.ReLU() self.fc1 = nn.Linear(55*55*5, 1200) self.fc2 = nn.Linear(1200,64) self.fc3 = nn.Linear(64,2) def forward(self,x): x = self.maxpool(self.relu(self.conv1(x))) #113 x = self.maxpool(self.relu(self.conv2(x))) #55 x = x.view(-1, self.num_flat_features(x)) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features
训练我们的网络:
model = Net() criterion = torch.nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01) epochs = 10for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(data_loader): images, label = data out = model(images) loss = criterion(out, label) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() if(i+1)%10 == 0: print('[%d %5d] loss: %.3f'%(epoch+1, i+1, running_loss/100)) running_loss = 0.0 print('finished train')
保存网络模型(这里不止是保存参数,还保存了网络结构)
#保存模型torch.save(net, 'model_name.pth') #保存的是模型, 不止是w和b权重值 # 读取模型model = torch.load('model_name.pth')
读到这里,这篇“pytorch怎么加载自己的图片数据集”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网行业资讯频道。
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341