From ca6978d953a84de48779573dd61f12f225d7b083 Mon Sep 17 00:00:00 2001 From: Tianrui Guan <502112826@qq.com> Date: Fri, 1 Dec 2023 23:18:39 -0500 Subject: [PATCH] update cwt training --- configs/_base_/datasets/cwt_group8.py | 64 ++++++++++++++++ .../_base_/models/ours_class_att_group8.py | 75 +++++++++++++++++++ configs/ours/ganav_group8_cwt.py | 29 +++++++ mmseg/datasets/__init__.py | 3 +- mmseg/datasets/cwt.py | 25 +++++++ 5 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 configs/_base_/datasets/cwt_group8.py create mode 100644 configs/_base_/models/ours_class_att_group8.py create mode 100644 configs/ours/ganav_group8_cwt.py create mode 100644 mmseg/datasets/cwt.py diff --git a/configs/_base_/datasets/cwt_group8.py b/configs/_base_/datasets/cwt_group8.py new file mode 100644 index 0000000..f9f1d6d --- /dev/null +++ b/configs/_base_/datasets/cwt_group8.py @@ -0,0 +1,64 @@ +# dataset settings +img_size = (1920, 1080) +crop_size = (375, 600) + +dataset_type = 'CWT_Dataset' +data_root = 'data/CWT/' +# img_norm_cfg = dict( +# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_norm_cfg = dict( + mean=[127.38, 127.96, 128.21], std=[53.941, 54.258, 54.389], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_size, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=crop_size, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=6, + workers_per_gpu=6, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='img', + ann_dir='annotation/grey_mask', + split='train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='img', + ann_dir='annotation/grey_mask', + split='test.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='img', + ann_dir='annotation/grey_mask', + split='test.txt', + pipeline=test_pipeline)) diff --git a/configs/_base_/models/ours_class_att_group8.py b/configs/_base_/models/ours_class_att_group8.py new file mode 100644 index 0000000..37ecb01 --- /dev/null +++ b/configs/_base_/models/ours_class_att_group8.py @@ -0,0 +1,75 @@ +# model settings +img_size = (1920, 1080) +crop_size = (375, 600) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='OursHeadClassAttNew', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=384, + mask_size=(97, 97), + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + dropout_ratio=0.1, + num_classes=8, + input_transform='multiple_select', + norm_cfg=norm_cfg, + align_corners=False, + attn_split=1, + strides=(2,1), + size_index=1, + img_size=crop_size, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, static_weight=False)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=160, + channels=32, + num_convs=1, + num_classes=8, + in_index=-2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + img_size=crop_size, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=64, + channels=32, + num_convs=1, + num_classes=8, + in_index=-3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + img_size=crop_size, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/configs/ours/ganav_group8_cwt.py b/configs/ours/ganav_group8_cwt.py new file mode 100644 index 0000000..bff21e7 --- /dev/null +++ b/configs/ours/ganav_group8_cwt.py @@ -0,0 +1,29 @@ +_base_ = [ + '../_base_/models/ours_class_att_group8.py', '../_base_/datasets/cwt_group8.py', + '../_base_/default_runtime.py' +] + +img_size = (1920, 1080) +crop_size = (375, 600) + +optimizer = dict(type='SGD', lr=0.03, momentum=0.9, weight_decay=4e-5) +optimizer_config = dict() +# learning policy +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=240000) +total_iters = 240000 +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=240000, metric='mIoU') + +# optimizer +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + by_epoch=False) + +data = dict( + samples_per_gpu=6, + workers_per_gpu=6) + + + diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index bdbfa13..52cd17c 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -5,8 +5,9 @@ from .rugd_group6 import RUGDDataset_Group6 from .rugd_group4 import RUGDDataset_Group4 from .rellis_group4 import RELLISDataset_Group4 +from .cwt import CWT_Dataset __all__ = [ 'CustomDataset', 'RUGDDataset', 'RELLISDataset', 'RELLISDataset_Group6', - 'RUGDDataset_Group6', 'RUGDDataset_Group4', 'RELLISDataset_Group4' + 'RUGDDataset_Group6', 'RUGDDataset_Group4', 'RELLISDataset_Group4', 'CWT_Dataset' ] diff --git a/mmseg/datasets/cwt.py b/mmseg/datasets/cwt.py new file mode 100644 index 0000000..ef376e3 --- /dev/null +++ b/mmseg/datasets/cwt.py @@ -0,0 +1,25 @@ +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CWT_Dataset(CustomDataset): + """cwt dataset. + + """ + + + + CLASSES = ("flat", "bumpy", "water", "rock", "mixed", "excavator", "obstacle") + + PALETTE = [[0, 255, 0], [255, 255, 0], [255, 0, 0], [128, 0, 0], [100, 65, 0], [0, 255, 255], [0, 0, 255]] + + def __init__(self, **kwargs): + super(CWT_Dataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) + self.CLASSES = ("flat", "bumpy", "water", "rock", "mixed", "excavator", "obstacle") + self.PALETTE =[[0, 255, 0], [255, 255, 0], [255, 0, 0], [128, 0, 0], [100, 65, 0], [0, 255, 255], [0, 0, 255]] + + # assert osp.exists(self.img_dir)