TransUnet训练自己的数据集
github的源代码https://github.com/Beckschen/TransUNet
一.先对自己的数据集进行处理
1.原始数据集如下图所示分images和labels,若数据集是png/jpg....格式,首先需要将每一张图的image和其对应的label合并转化为一个.npz文件.
# 自建的将图片及对应标签合并为一个npz格式数据的转换代码import cv2import globimport numpy as npdef npz(): # 图像路径 path = r'D:/train/images/*.png' # 项目中存放训练所用的npz文件路径 path2 = r'D:/data/Synapse/train_npz//' for i, img_path in enumerate(glob.glob(path)): # 读入图像 image = cv2.imread(img_path, flags=0) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 读入标签 label_path = img_path.replace('images', 'labels') label = cv2.imread(label_path, flags=0) #flag=0时为灰度图像 #label = cv2.imread(label_path) #现在为彩色图像 #label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 保存npz np.savez(path2 + str(i), image=image, label=label) print('', i)if __name__ == '__main__': npz()
train文件转化为 train_npz文件;val转化为test_vol_h5文件,如图(注意这里训练集与测试集均转化为.npz格式,需修改读取文件的方法,具体是修改datasets/dataset_synapse.py文件中的Synapse_dataset类,修改__getitem__函数和datasets/dataset_synapse.py文件中的RandomGenerator类,修改__call__函数)
npz文件生成完成之后,找到train.txt和test_vol.txt,清空文件夹中的内容,按照原内容的格式对应train_npz文件与test_vol_h5文件中.npz文件的全部文件名分别写入train.txt/test_vol.txt文件,一个名称一行(参考原内容的格式)。
2.若数据集为.npz格式,则直接加载
二.开始训练
1.修改train.py(根据自己需要修改红线部分)
2.根据自己需要,看是否加载预训练权重
通过修改train.py中 args.is_pretrain = True # False 与 net.load_from(weights=np.load(config_vit.pretrained_path)) 实现
若要加载预训练权重则将权重文件放在model中
则可以开始训练,运行train.py
三.训练中我出现花费较长时间的报错,以及解决方法
在生成train_npz文件与test_vol_h5文件时
报错:libpng warning: iCCP: known incorrect sRGB profile
解决:通过别的方法读取再保存即可解决。
path = r"D:/Desktop/car/all/train/labels/" #为文件的路径fileList = os.listdir(path)for i in tqdm(fileList): image = io.imread(path+i) # image = io.imread(os.path.join(path, i)) image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA) cv2.imencode('.png',image)[1].tofile(path+i)
训练过程中
报错:not enough memory:you tried to allocate
因为我用的cpu,应该是内存不够了
解决方法:将batchsize减少,之前我设置的是4,改为2后就可以正常运行了
报错:FileNotFoundError: [Errno 2] No such file or directory
解决方法:
1)检查是否存在此文件。存在没问题
2)检查数据集路径,为绝对路径。路径正确
3) 检查加载数据集方式,正确
4)最后检查出来是train.txt中的训练样本的名称前多打了一个空格(粗心造成浪费了快一天时间),删除空格后正常运行。
来源地址:https://blog.csdn.net/m0_64894570/article/details/128820676
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341