Skip to content

Commit

Permalink
feat: custom augmentation 실험 코드 공유 (#19)
Browse files Browse the repository at this point in the history
* feat: k_fold split

#7

* feat: for train and valid

#10

* fix: deteval, train

#10

* feat: train with pickle data

#14

* fix: path issue in to_pickle file

#14

* chore: update train_val to json_name

#10

* chore: change arg name

#14

* feat: normalize, custom_augmentation

#16

* feat: custom augmentation

#16

* feat: modify to_pickle.py, train.py, inference.py codes

#20

* chore: update code

#16

---------

Co-authored-by: Eddie-JUB <bjonguk@gmail.com>
  • Loading branch information
FinalCold and Eddie-JUB authored Jan 29, 2024
1 parent 1d9d218 commit 5f2ccf7
Show file tree
Hide file tree
Showing 9 changed files with 596 additions and 73 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ data
*.sh~

# 추가
*.csv
code/pths
code/predictions
code/wandb
code/wandb
code/trained_models
36 changes: 34 additions & 2 deletions code/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import math
import json
from PIL import Image
import pickle
import random
import os

import torch
import numpy as np
Expand Down Expand Up @@ -336,15 +339,18 @@ def filter_vertices(vertices, labels, ignore_under=0, drop_under=0):
class SceneTextDataset(Dataset):
def __init__(self, root_dir,
split='train',
json_name=None,
image_size=2048,
crop_size=1024,
ignore_tags=[],
ignore_under_threshold=10,
drop_under_threshold=1,
custom_transform=None,
color_jitter=True,
normalize=True):
with open(osp.join(root_dir, 'ufo/{}.json'.format(split)), 'r') as f:
anno = json.load(f)
if json_name:
with open(osp.join(root_dir, f'ufo/{json_name}'), 'r') as f:
anno = json.load(f)

self.anno = anno
self.image_fnames = sorted(anno['images'].keys())
Expand All @@ -357,6 +363,8 @@ def __init__(self, root_dir,

self.drop_under_threshold = drop_under_threshold
self.ignore_under_threshold = ignore_under_threshold

self.custom_transform = custom_transform

def __len__(self):
return len(self.image_fnames)
Expand Down Expand Up @@ -401,6 +409,8 @@ def __getitem__(self, idx):
funcs = []
if self.color_jitter:
funcs.append(A.ColorJitter(0.5, 0.5, 0.5, 0.25))
if self.custom_transform:
funcs.append(self.custom_transform)
if self.normalize:
funcs.append(A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform = A.Compose(funcs)
Expand All @@ -410,3 +420,25 @@ def __getitem__(self, idx):
roi_mask = generate_roi_mask(image, vertices, labels)

return image, word_bboxes, roi_mask

class PickleDataset(Dataset):
def __init__(self, datadir, to_tensor=True):
self.datadir = datadir
self.to_tensor = to_tensor
self.datalist = [f for f in os.listdir(datadir) if f.endswith('.pkl')]

def __getitem__(self, idx):
with open(file=osp.join(self.datadir, f"{idx}.pkl"), mode="rb") as f:
data = pickle.load(f)

image, score_map, geo_map, roi_mask = data
if self.to_tensor:
image = torch.Tensor(image)
score_map = torch.Tensor(score_map)
geo_map = torch.Tensor(geo_map)
roi_mask = torch.Tensor(roi_mask)

return image, score_map, geo_map, roi_mask

def __len__(self):
return len(self.datalist)
2 changes: 1 addition & 1 deletion code/deteval.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def diag(r):

for n in range(len(pointsList)):
points = pointsList[n]
transcription = transcriptionsList[n]
transcription = transcriptionsList[n] if transcriptionsList is not None else None
dontCare = transcription == "###" or len(points) > 4
gtRect = Rectangle(*points)
gtRects.append(gtRect)
Expand Down
22 changes: 15 additions & 7 deletions code/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def parse_args():

# Conventional args
parser.add_argument('--data_dir', default=os.environ.get('SM_CHANNEL_EVAL', '../data/medical'))
parser.add_argument('--model_dir', default=os.environ.get('SM_CHANNEL_MODEL', 'trained_models'))
parser.add_argument('--output_dir', default=os.environ.get('SM_OUTPUT_DATA_DIR', 'predictions'))
# parser.add_argument('--model_dir', default=os.environ.get('SM_CHANNEL_MODEL', 'trained_models'))
parser.add_argument('--model_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/code/trained_models/1000e_adam_cosine_0.001_pickle_is[2048]_cs[1024]_aug['CJ', 'GB', 'N']")
# parser.add_argument('--output_dir', default=os.environ.get('SM_OUTPUT_DATA_DIR', 'predictions'))

parser.add_argument('--device', default='cuda' if cuda.is_available() else 'cpu')
parser.add_argument('--input_size', type=int, default=2048)
Expand Down Expand Up @@ -67,10 +68,17 @@ def main(args):
model = EAST(pretrained=False).to(args.device)

# Get paths to checkpoint files
ckpt_fpath = osp.join(args.model_dir, 'latest.pth')

if not osp.exists(args.output_dir):
os.makedirs(args.output_dir)
best_checkpoint_fpath = osp.join(args.model_dir, 'best.pth')
if os.path.isfile(best_checkpoint_fpath):
print('best checkpoint found')
ckpt_fpath = best_checkpoint_fpath
else:
print('no best checkpoint found')
ckpt_fpath = osp.join(args.model_dir, 'latest.pth')

# if not osp.exists(args.output_dir):
# print('no checkpoint found')
# os.makedirs(args.output_dir)

print('Inference in progress')

Expand All @@ -80,7 +88,7 @@ def main(args):
ufo_result['images'].update(split_result['images'])

output_fname = 'output.csv'
with open(osp.join(args.output_dir, output_fname), 'w') as f:
with open(osp.join(args.model_dir, output_fname), 'w') as f:
json.dump(ufo_result, f, indent=4)


Expand Down
64 changes: 64 additions & 0 deletions code/to_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pickle
from tqdm import tqdm
import os
import os.path as osp

from east_dataset import EASTDataset
from dataset import SceneTextDataset

import albumentations as A

def main():
data_dir = '/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical'
ignore_tags = ['masked', 'excluded-region', 'maintable', 'stamp']
custom_augmentation_dict = {
'CJ': A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
'GB': A.GaussianBlur(blur_limit=(3, 7), p=0.5),
'B': A.Blur(blur_limit=7, p=0.5),
'GN': A.GaussNoise(p=0.5),
'HSV': A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
'RBC': A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
'N': A.Normalize(mean=(0.7760271717131425, 0.7722186515548635, 0.7670997062399484),
std=(0.17171108542242774, 0.17888224507630185, 0.18678791254805846), p=1.0)
}

# image_size = [1024]
image_size = [1024, 1536, 2048]
crop_size = [256, 512, 1024, 2048]
aug_select = ['CJ','GB','N']

custom_augmentation = []
for s in aug_select:
custom_augmentation.append(custom_augmentation_dict[s])

pkl_dir = f'pickle/{image_size}_cs{crop_size}_aug{aug_select}/train/'
# pkl_dir = f'pickle_is{image_size}_cs{crop_size}_aug{aug_select}/train/'

# 경로 폴더 생성
os.makedirs(osp.join(data_dir, pkl_dir), exist_ok=True)

for i, i_size in enumerate(image_size):
for j, c_size in enumerate(crop_size):
if c_size > i_size:
continue
train_dataset = SceneTextDataset(
root_dir=data_dir,
split='train',
json_name='train_split.json',
image_size=i_size,
crop_size=c_size,
ignore_tags=ignore_tags,
custom_transform=A.Compose(custom_augmentation),
color_jitter=False,
normalize=False
)
train_dataset = EASTDataset(train_dataset)

ds = len(train_dataset)
for k in tqdm(range(ds)):
data = train_dataset.__getitem__(k)
with open(file=osp.join(data_dir, pkl_dir, f"{ds*i+ds*j+k}.pkl"), mode="wb") as f:
pickle.dump(data, f)

if __name__ == '__main__':
main()
Loading

0 comments on commit 5f2ccf7

Please sign in to comment.