模型训练—-加载自己的数据集,train.txt包含“路径,类别”

今天想训练一个restnet,之前我是会用文件夹的方式加载数据集,但是师兄给我的数据是用.txt文件划分的训练集和验证集,知道是要把路径和标签提取出来,搞了半天没搞成,后边找到咋搞了。

我的数据形式是每个类别一个文件夹,路径和类别存到.txt里

file

首先解决中文类别的问题,将中文类别单独新建一个txt,按顺序给他们赋予数字,在对应起来

# 读取类别文件,获取类别列表
with open(class_txt_path, 'r', encoding='utf-8') as f:
    classes = [line.strip() for line in f.readlines()]

# 构建类别到索引的映射字典
class_to_idx = {class_name: idx for idx, class_name in enumerate(classes)}

构建dataloader是我们只需要传入路径和标签,重写dataset,自己来构建加载的方式,输入txt路径,即可获得所有的图片和标签

class MyDataset(Dataset):
    def __init__(self, text_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = text_path
        f = open(self.root, 'r', encoding='utf-8')
        data = f.readlines()

        imgs = []
        labels = []

        for line in data:
            img_path, class_name = line.strip().split(',')
            # 将图像路径添加到列表中
            imgs.append(os.path.join("./", img_path))
            # 将中文类别转换为数字类别并添加到列表中
            labels.append(class_to_idx[class_name])
        self.img = imgs
        self.label = labels
        self.transform = transform

    def __len__(self):
        """
        返回数据集的长度
        """
        return len(self.label)

    def __getitem__(self, item):
        """
        根据索引获取数据集中的样本
        """
        img = self.img[item]
        label = self.label[item]
        img = Image.open(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

然后就可以建立dataloader,开始训练了

train_dataset = MyDataset(train_txt_path,transform=data_transform["train"])
# 创建训练数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size, shuffle=True,
                                            num_workers=nw)

完整训练代码

import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

from model import resnet34
from PIL import Image
from torch.utils.data import Dataset

# 定义数据集路径
train_txt_path = "./train.txt"
val_txt_path = "./val.txt"
image_path = "./"  # 图片存放路径
class_txt_path = "./classes.txt"

# 读取类别文件,获取类别列表
with open(class_txt_path, 'r', encoding='utf-8') as f:
    classes = [line.strip() for line in f.readlines()]

# 构建类别到索引的映射字典
class_to_idx = {class_name: idx for idx, class_name in enumerate(classes)}

#
class MyDataset(Dataset):
    def __init__(self, text_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = text_path
        f = open(self.root, 'r', encoding='utf-8')
        data = f.readlines()

        imgs = []
        labels = []

        for line in data:
            img_path, class_name = line.strip().split(',')
            # 将图像路径添加到列表中
            imgs.append(os.path.join("./", img_path))
            # 将中文类别转换为数字类别并添加到列表中
            labels.append(class_to_idx[class_name])
        self.img = imgs
        self.label = labels
        self.transform = transform

    def __len__(self):
        """
        返回数据集的长度
        """
        return len(self.label)

    def __getitem__(self, item):
        """
        根据索引获取数据集中的样本
        """
        img = self.img[item]
        label = self.label[item]
        img = Image.open(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    batch_size = 128
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_dataset = MyDataset(train_txt_path,transform=data_transform["train"])
    val_dataset = MyDataset(val_txt_path, transform=data_transform["val"])

    # 创建训练数据加载器
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size, shuffle=True,
                                            num_workers=nw)

    # 创建验证数据加载器
    validate_loader = torch.utils.data.DataLoader(val_dataset,
                                                batch_size=batch_size, shuffle=False,
                                                num_workers=nw)

    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 14)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 100
    best_acc = 0.0
    save_path = './resNet34-bird.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_num = len(validate_loader.dataset)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracyccuracy: %.3f' %
            (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()
如果觉得本文对您有所帮助,可以支持下博主,—分也是缘。
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇

超多性价比流量卡,扫码查看

这将关闭于 20