使用Salience-DETR训练自己数据集

使用Salience-DETR训练自己数据集出现的问题

修改参数

注:自用,仅记录第一次训练的时候出现的问题

在main.py中找到配置文件在configs/train_config.py

在train_config.py中可以修改各种参数,轮数,bs数,num_worker数,以及修改数据集路径

下方还有要使用的模型路径,自己修改

# Commonly changed training configurations
num_epochs = 12   # train epochs
batch_size = 2    # total_batch_size = #GPU x batch_size
num_workers = 4   # workers for pytorch DataLoader
pin_memory = True # whether pin_memory for pytorch DataLoader
print_freq = 50   # frequency to print logs
starting_epoch = 0
max_norm = 0.1    # clip gradient norm

output_dir = None  # path to save checkpoints, default for None: checkpoints/{model_name}
find_unused_parameters = False  # useful for debugging distributed training

# define dataset for train
coco_path = "data/coco"  # /PATH/TO/YOUR/COCODIR
train_transform = presets.detr  # see transforms/presets to choose a transform
train_dataset = CocoDetection(
    img_folder=f"{coco_path}/train2017",
    ann_file=f"{coco_path}/annotations/instances_train2017.json",
    transforms=train_transform,
    train=True,
)
test_dataset = CocoDetection(
    img_folder=f"{coco_path}/val2017",
    ann_file=f"{coco_path}/annotations/instances_val2017.json",
    transforms=None,  # the eval_transform is integrated in the model
)

# model config to train
model_path = "configs/salience_detr/salience_detr_resnet50_800_1333.py"

如果你的数据集的种类数大于coco数据集的种类数

记得在model_path中的文件修改一下num_classes的数量,要大于等于你的数据集的种类数+1,懒得话可以直接设置成365,也是适配coco的

num_classes = 103

预训练权重

如果不想使用预训练权重,代码中没有明确关于预训练权重的东西,

首先要修改模型文件中的

backbone = ResNetBackbone(
    "resnet50", pretrained=False, norm_layer=FrozenBatchNorm2d, return_indices=(1, 2, 3), freeze_indices=(0,)
)

添加一个flase

然后在backbone文件resnet.py中__--new--__中修改

    def __new__(
        self,
        arch: str,
        pretrained: bool = False,	#新增
        weights: Dict = None,
        return_indices: Tuple[int] = (0, 1, 2, 3),
        freeze_indices: Tuple = (),
        **kwargs,
    ):
        # get parameters and instantiate backbone
        model_config = self.get_instantiate_config(self, ResNet, arch, kwargs)
        default_weight = model_config.pop("url", None)
        resnet = instantiate(model_config)

        # load state dict
        # 修改
        if pretrained:
            weights = load_checkpoint(default_weight if weights is None else weights)
            if isinstance(weights, Dict):
                weights = weights["model"] if "model" in weights else weights
            self.load_state_dict(resnet, weights)
        else:
            print("Skip loading pretrained")

就是添加一个参数,看你自己使用不使用预训练权重

大概设置就是这些

posted @ 2025-04-11 12:13  anti1hapi  阅读(96)  评论(0)    收藏  举报