Skip to content

Commit

Permalink
update cwt training
Browse files Browse the repository at this point in the history
  • Loading branch information
rayguan97 committed Dec 2, 2023
1 parent 06797a2 commit ca6978d
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 1 deletion.
64 changes: 64 additions & 0 deletions configs/_base_/datasets/cwt_group8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# dataset settings
img_size = (1920, 1080)
crop_size = (375, 600)

dataset_type = 'CWT_Dataset'
data_root = 'data/CWT/'
# img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_norm_cfg = dict(
mean=[127.38, 127.96, 128.21], std=[53.941, 54.258, 54.389], to_rgb=True)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_size, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=crop_size,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=6,
workers_per_gpu=6,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img',
ann_dir='annotation/grey_mask',
split='train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img',
ann_dir='annotation/grey_mask',
split='test.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img',
ann_dir='annotation/grey_mask',
split='test.txt',
pipeline=test_pipeline))
75 changes: 75 additions & 0 deletions configs/_base_/models/ours_class_att_group8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# model settings
img_size = (1920, 1080)
crop_size = (375, 600)

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='MixVisionTransformer',
in_channels=3,
embed_dims=32,
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 5, 8],
patch_sizes=[7, 3, 3, 3],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1),
decode_head=dict(
type='OursHeadClassAttNew',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
channels=384,
mask_size=(97, 97),
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
dropout_ratio=0.1,
num_classes=8,
input_transform='multiple_select',
norm_cfg=norm_cfg,
align_corners=False,
attn_split=1,
strides=(2,1),
size_index=1,
img_size=crop_size,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, static_weight=False)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=160,
channels=32,
num_convs=1,
num_classes=8,
in_index=-2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
img_size=crop_size,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=64,
channels=32,
num_convs=1,
num_classes=8,
in_index=-3,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
img_size=crop_size,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
29 changes: 29 additions & 0 deletions configs/ours/ganav_group8_cwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
_base_ = [
'../_base_/models/ours_class_att_group8.py', '../_base_/datasets/cwt_group8.py',
'../_base_/default_runtime.py'
]

img_size = (1920, 1080)
crop_size = (375, 600)

optimizer = dict(type='SGD', lr=0.03, momentum=0.9, weight_decay=4e-5)
optimizer_config = dict()
# learning policy
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=240000)
total_iters = 240000
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=240000, metric='mIoU')

# optimizer
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
by_epoch=False)

data = dict(
samples_per_gpu=6,
workers_per_gpu=6)



3 changes: 2 additions & 1 deletion mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from .rugd_group6 import RUGDDataset_Group6
from .rugd_group4 import RUGDDataset_Group4
from .rellis_group4 import RELLISDataset_Group4
from .cwt import CWT_Dataset

__all__ = [
'CustomDataset', 'RUGDDataset', 'RELLISDataset', 'RELLISDataset_Group6',
'RUGDDataset_Group6', 'RUGDDataset_Group4', 'RELLISDataset_Group4'
'RUGDDataset_Group6', 'RUGDDataset_Group4', 'RELLISDataset_Group4', 'CWT_Dataset'
]
25 changes: 25 additions & 0 deletions mmseg/datasets/cwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class CWT_Dataset(CustomDataset):
"""cwt dataset.
"""



CLASSES = ("flat", "bumpy", "water", "rock", "mixed", "excavator", "obstacle")

PALETTE = [[0, 255, 0], [255, 255, 0], [255, 0, 0], [128, 0, 0], [100, 65, 0], [0, 255, 255], [0, 0, 255]]

def __init__(self, **kwargs):
super(CWT_Dataset, self).__init__(
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs)
self.CLASSES = ("flat", "bumpy", "water", "rock", "mixed", "excavator", "obstacle")
self.PALETTE =[[0, 255, 0], [255, 255, 0], [255, 0, 0], [128, 0, 0], [100, 65, 0], [0, 255, 255], [0, 0, 255]]

# assert osp.exists(self.img_dir)

0 comments on commit ca6978d

Please sign in to comment.