diff --git a/.gitignore b/.gitignore index 1b0cabf..5741639 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,8 @@ data *.sh~ # 추가 +*.csv code/pths code/predictions -code/wandb \ No newline at end of file +code/wandb +code/trained_models diff --git a/code/dataset.py b/code/dataset.py index 042e861..298242f 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -2,6 +2,9 @@ import math import json from PIL import Image +import pickle +import random +import os import torch import numpy as np @@ -336,15 +339,18 @@ def filter_vertices(vertices, labels, ignore_under=0, drop_under=0): class SceneTextDataset(Dataset): def __init__(self, root_dir, split='train', + json_name=None, image_size=2048, crop_size=1024, ignore_tags=[], ignore_under_threshold=10, drop_under_threshold=1, + custom_transform=None, color_jitter=True, normalize=True): - with open(osp.join(root_dir, 'ufo/{}.json'.format(split)), 'r') as f: - anno = json.load(f) + if json_name: + with open(osp.join(root_dir, f'ufo/{json_name}'), 'r') as f: + anno = json.load(f) self.anno = anno self.image_fnames = sorted(anno['images'].keys()) @@ -357,6 +363,8 @@ def __init__(self, root_dir, self.drop_under_threshold = drop_under_threshold self.ignore_under_threshold = ignore_under_threshold + + self.custom_transform = custom_transform def __len__(self): return len(self.image_fnames) @@ -401,6 +409,8 @@ def __getitem__(self, idx): funcs = [] if self.color_jitter: funcs.append(A.ColorJitter(0.5, 0.5, 0.5, 0.25)) + if self.custom_transform: + funcs.append(self.custom_transform) if self.normalize: funcs.append(A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = A.Compose(funcs) @@ -410,3 +420,25 @@ def __getitem__(self, idx): roi_mask = generate_roi_mask(image, vertices, labels) return image, word_bboxes, roi_mask + +class PickleDataset(Dataset): + def __init__(self, datadir, to_tensor=True): + self.datadir = datadir + self.to_tensor = to_tensor + self.datalist = [f for f in os.listdir(datadir) if f.endswith('.pkl')] + + def __getitem__(self, idx): + with open(file=osp.join(self.datadir, f"{idx}.pkl"), mode="rb") as f: + data = pickle.load(f) + + image, score_map, geo_map, roi_mask = data + if self.to_tensor: + image = torch.Tensor(image) + score_map = torch.Tensor(score_map) + geo_map = torch.Tensor(geo_map) + roi_mask = torch.Tensor(roi_mask) + + return image, score_map, geo_map, roi_mask + + def __len__(self): + return len(self.datalist) \ No newline at end of file diff --git a/code/deteval.py b/code/deteval.py index 769bce0..a92547e 100644 --- a/code/deteval.py +++ b/code/deteval.py @@ -203,7 +203,7 @@ def diag(r): for n in range(len(pointsList)): points = pointsList[n] - transcription = transcriptionsList[n] + transcription = transcriptionsList[n] if transcriptionsList is not None else None dontCare = transcription == "###" or len(points) > 4 gtRect = Rectangle(*points) gtRects.append(gtRect) diff --git a/code/inference.py b/code/inference.py index bf2d173..af39024 100644 --- a/code/inference.py +++ b/code/inference.py @@ -21,8 +21,9 @@ def parse_args(): # Conventional args parser.add_argument('--data_dir', default=os.environ.get('SM_CHANNEL_EVAL', '../data/medical')) - parser.add_argument('--model_dir', default=os.environ.get('SM_CHANNEL_MODEL', 'trained_models')) - parser.add_argument('--output_dir', default=os.environ.get('SM_OUTPUT_DATA_DIR', 'predictions')) + # parser.add_argument('--model_dir', default=os.environ.get('SM_CHANNEL_MODEL', 'trained_models')) + parser.add_argument('--model_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/code/trained_models/1000e_adam_cosine_0.001_pickle_is[2048]_cs[1024]_aug['CJ', 'GB', 'N']") + # parser.add_argument('--output_dir', default=os.environ.get('SM_OUTPUT_DATA_DIR', 'predictions')) parser.add_argument('--device', default='cuda' if cuda.is_available() else 'cpu') parser.add_argument('--input_size', type=int, default=2048) @@ -67,10 +68,17 @@ def main(args): model = EAST(pretrained=False).to(args.device) # Get paths to checkpoint files - ckpt_fpath = osp.join(args.model_dir, 'latest.pth') - - if not osp.exists(args.output_dir): - os.makedirs(args.output_dir) + best_checkpoint_fpath = osp.join(args.model_dir, 'best.pth') + if os.path.isfile(best_checkpoint_fpath): + print('best checkpoint found') + ckpt_fpath = best_checkpoint_fpath + else: + print('no best checkpoint found') + ckpt_fpath = osp.join(args.model_dir, 'latest.pth') + + # if not osp.exists(args.output_dir): + # print('no checkpoint found') + # os.makedirs(args.output_dir) print('Inference in progress') @@ -80,7 +88,7 @@ def main(args): ufo_result['images'].update(split_result['images']) output_fname = 'output.csv' - with open(osp.join(args.output_dir, output_fname), 'w') as f: + with open(osp.join(args.model_dir, output_fname), 'w') as f: json.dump(ufo_result, f, indent=4) diff --git a/code/to_pickle.py b/code/to_pickle.py new file mode 100644 index 0000000..92d3e6a --- /dev/null +++ b/code/to_pickle.py @@ -0,0 +1,64 @@ +import pickle +from tqdm import tqdm +import os +import os.path as osp + +from east_dataset import EASTDataset +from dataset import SceneTextDataset + +import albumentations as A + +def main(): + data_dir = '/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical' + ignore_tags = ['masked', 'excluded-region', 'maintable', 'stamp'] + custom_augmentation_dict = { + 'CJ': A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5), + 'GB': A.GaussianBlur(blur_limit=(3, 7), p=0.5), + 'B': A.Blur(blur_limit=7, p=0.5), + 'GN': A.GaussNoise(p=0.5), + 'HSV': A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), + 'RBC': A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + 'N': A.Normalize(mean=(0.7760271717131425, 0.7722186515548635, 0.7670997062399484), + std=(0.17171108542242774, 0.17888224507630185, 0.18678791254805846), p=1.0) + } + + # image_size = [1024] + image_size = [1024, 1536, 2048] + crop_size = [256, 512, 1024, 2048] + aug_select = ['CJ','GB','N'] + + custom_augmentation = [] + for s in aug_select: + custom_augmentation.append(custom_augmentation_dict[s]) + + pkl_dir = f'pickle/{image_size}_cs{crop_size}_aug{aug_select}/train/' + # pkl_dir = f'pickle_is{image_size}_cs{crop_size}_aug{aug_select}/train/' + + # 경로 폴더 생성 + os.makedirs(osp.join(data_dir, pkl_dir), exist_ok=True) + + for i, i_size in enumerate(image_size): + for j, c_size in enumerate(crop_size): + if c_size > i_size: + continue + train_dataset = SceneTextDataset( + root_dir=data_dir, + split='train', + json_name='train_split.json', + image_size=i_size, + crop_size=c_size, + ignore_tags=ignore_tags, + custom_transform=A.Compose(custom_augmentation), + color_jitter=False, + normalize=False + ) + train_dataset = EASTDataset(train_dataset) + + ds = len(train_dataset) + for k in tqdm(range(ds)): + data = train_dataset.__getitem__(k) + with open(file=osp.join(data_dir, pkl_dir, f"{ds*i+ds*j+k}.pkl"), mode="wb") as f: + pickle.dump(data, f) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/code/train.py b/code/train.py index 03166db..7143436 100644 --- a/code/train.py +++ b/code/train.py @@ -2,6 +2,7 @@ import os.path as osp import time import math +import random from datetime import timedelta from argparse import ArgumentParser @@ -13,123 +14,269 @@ import wandb from east_dataset import EASTDataset -from dataset import SceneTextDataset +from dataset import SceneTextDataset, PickleDataset from model import EAST +from deteval import calc_deteval_metrics +from utils import get_gt_bboxes, get_pred_bboxes, seed_everything, AverageMeter +import albumentations as A +import numpy as np def parse_args(): parser = ArgumentParser() - # Conventional args - parser.add_argument('--data_dir', type=str, - default=os.environ.get('SM_CHANNEL_TRAIN', '../data/medical')) - parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', - 'trained_models')) - - parser.add_argument('--device', default='cuda' if cuda.is_available() else 'cpu') + # pkl 데이터셋 경로 + parser.add_argument('--train_dataset_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/pickle/[2048]_cs[1024]_aug['CJ', 'GB', 'N']/train") + parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_TRAIN', '../data/medical')) + parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', 'trained_models')) + parser.add_argument('--seed', type=int, default=137) + parser.add_argument('--val_interval', type=int, default=5) + parser.add_argument('--device', default='cuda:0' if cuda.is_available() else 'cpu') parser.add_argument('--num_workers', type=int, default=8) - parser.add_argument('--image_size', type=int, default=2048) parser.add_argument('--input_size', type=int, default=1024) - parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--max_epoch', type=int, default=150) - parser.add_argument('--save_interval', type=int, default=5) + parser.add_argument('--save_interval', type=int, default=1) parser.add_argument('--ignore_tags', type=list, default=['masked', 'excluded-region', 'maintable', 'stamp']) - parser.add_argument('-m', '--mode', type=str, default='on', help='wandb logging mode(on: online, off: disabled)') parser.add_argument('-p', '--project', type=str, default='datacentric', help='wandb project name') - parser.add_argument('-d', '--data', default='original', type=str, help='description about dataset') - + parser.add_argument('-d', '--data', default='pickle', type=str, help='description about dataset', choices=['original', 'pickle']) + parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'adamW']) + parser.add_argument("--scheduler", type=str, default='cosine', choices=['multistep', 'cosine']) + parser.add_argument("--resume", type=str, default=None, choices=[None, 'resume', 'finetune']) + args = parser.parse_args() + if args.data == 'original': + args.data_name = 'original' + args.save_dir = os.path.join(args.model_dir, f'{args.max_epoch}e_{args.optimizer}_{args.scheduler}_{args.learning_rate}') + elif args.data == 'pickle': + args.data_name = args.train_dataset_dir.split('/')[-2] + args.save_dir = os.path.join(args.model_dir, f'{args.max_epoch}e_{args.optimizer}_{args.scheduler}_{args.learning_rate}_{args.data_name}') + os.makedirs(args.save_dir, exist_ok=True) + if args.input_size % 32 != 0: raise ValueError('`input_size` must be a multiple of 32') return args - -def do_training(args, data_dir, model_dir, device, image_size, input_size, num_workers, batch_size, - learning_rate, max_epoch, save_interval, ignore_tags, mode, project, data): - dataset = SceneTextDataset( - data_dir, - split='train', - image_size=image_size, - crop_size=input_size, - ignore_tags=ignore_tags - ) - dataset = EASTDataset(dataset) - num_batches = math.ceil(len(dataset) / batch_size) +def do_training(args): + + ### Train Loader ### + if args.data == 'original': + train_dataset = SceneTextDataset( + args.data_dir, + split='train_split', + json_name='train_split.json', + image_size=args.image_size, + crop_size=args.input_size, + ignore_tags=args.ignore_tags, + pin_memory=True, + ) + train_dataset = EASTDataset(train_dataset) + + elif args.data == 'pickle': + train_dataset = PickleDataset(args.train_dataset_dir) + + + train_num_batches = math.ceil(len(train_dataset) / args.batch_size) train_loader = DataLoader( - dataset, - batch_size=batch_size, + train_dataset, + batch_size=args.batch_size, shuffle=True, - num_workers=num_workers + num_workers=args.num_workers, + pin_memory=True ) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### Val Loader ### + ''' 아래 코드는 val loss를 위한 코드입니다.''' + ''' + val_dataset = SceneTextDataset( + args.data_dir, + split='valid_split', + train_val='valid_split.json', + image_size=args.image_size, + crop_size=args.image_size, + ignore_tags=args.ignore_tags, + color_jitter=False, + ) + val_dataset = EASTDataset(val_dataset) + val_num_batches = math.ceil(len(val_dataset) / args.batch_size) + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + ''' + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") model = EAST() model.to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) - scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[max_epoch // 2], gamma=0.1) - - # WandB - if mode == 'on': + + ### Resume or finetune ### + if args.resume == "resume": + checkpoint = torch.load(osp.join(args.save_dir, "latest.pth")) + model.load_state_dict(checkpoint) + elif args.resume == "finetune": + checkpoint = torch.load(osp.join(args.save_dir, "best.pth")) + model.load_state_dict(checkpoint) + + ### Optimizer ### + if args.optimizer == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + elif args.optimizer == "adamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + + ### Scheduler ### + if args.scheduler == "multistep": + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[args.max_epoch // 2, args.max_epoch // 2 * 2], gamma=0.1) + elif args.scheduler == "cosine": + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0) + + ### WandB ### + if args.mode == 'on': wandb.init( - project=project, + project=args.project, entity='nae-don-nae-san', - group=data, - name=f'{max_epoch}e_{learning_rate}' + group=args.data, + name=f'{args.max_epoch}e_{args.optimizer}_{args.scheduler}_{args.learning_rate}_{args.data_name}' ) wandb.config.update(args) wandb.watch(model) + + ### Train ### + best_val_loss = np.inf + best_f1_score = 0 + train_loss = AverageMeter() + val_loss = AverageMeter() + model.train() - for epoch in range(max_epoch): - epoch_loss, epoch_start = 0, time.time() - with tqdm(total=num_batches) as pbar: + total_start_time = time.time() + for epoch in range(args.max_epoch): + epoch_start = time.time() + train_loss.reset() + with tqdm(total=train_num_batches) as pbar: for img, gt_score_map, gt_geo_map, roi_mask in train_loader: - pbar.set_description('[Epoch {}]'.format(epoch + 1)) + pbar.set_description(f'[Epoch {epoch + 1}]') loss, extra_info = model.train_step(img, gt_score_map, gt_geo_map, roi_mask) optimizer.zero_grad() loss.backward() optimizer.step() - train_loss = loss.item() - epoch_loss += train_loss + train_loss.update(loss.item()) + pbar.update(1) train_dict = { - 'Cls loss': extra_info['cls_loss'], 'Angle loss': extra_info['angle_loss'], - 'IoU loss': extra_info['iou_loss'] + 'train total loss': train_loss.avg, + 'train cls loss': extra_info['cls_loss'], + 'train angle loss': extra_info['angle_loss'], + 'train iou loss': extra_info['iou_loss'] } pbar.set_postfix(train_dict) - - if mode == 'on': - wandb.log({'train_loss': train_loss, 'cls_loss': extra_info['cls_loss'], - 'angle_loss': extra_info['angle_loss'], 'iou_loss': extra_info['iou_loss']}, step=epoch) + if args.mode == 'on': + wandb.log(train_dict, step=epoch) scheduler.step() + epoch_duration = time.time() - epoch_start + + print('Mean loss: {:.4f} | Elapsed time: {} |'.format( + train_loss.avg, timedelta(seconds=epoch_duration))) + + ### Val ### + # 매 val_interval 에폭마다, 마지막 5에폭 이후 validation 수행 + if (epoch + 1) % args.val_interval == 0 or epoch >= args.max_epoch - 5: + with torch.no_grad(): + + ''' 아래 코드는 val loss를 위한 코드입니다.''' + ''' + model.eval() + epoch_start = time.time() + with tqdm(total=val_num_batches) as pbar: + for img, gt_score_map, gt_geo_map, roi_mask in val_loader: + pbar.set_description('Evaluate') + loss, extra_info = model.train_step( + img, gt_score_map, gt_geo_map, roi_mask + ) + val_loss.update(loss.item()) + + pbar.update(1) + val_dict = { + "val total loss": val_loss.avg, + "val cls loss": extra_info["cls_loss"], + "val angle loss": extra_info["angle_loss"], + "val iou loss": extra_info["iou_loss"], + } + pbar.set_postfix(val_dict) + if args.mode == 'on': + wandb.log(val_dict, step=epoch) + + if val_loss.avg < best_val_loss: + print(f"New best model for val loss : {val_loss.avg}! saving the best model..") + bestpt_fpath = osp.join(args.model_dir, "best.pth") + torch.save(model.state_dict(), bestpt_fpath) + best_val_loss = val_loss.avg + ''' + + ''' 아래 코드는 val f1 score를 위한 코드입니다. ''' + print("Calculating validation results...") + valid_images = [f for f in os.listdir(osp.join(args.data_dir, 'img/valid_split/')) if f.endswith('.jpg')] - print('Mean loss: {:.4f} | Elapsed time: {}'.format( - epoch_loss / num_batches, timedelta(seconds=time.time() - epoch_start))) + pred_bboxes_dict = get_pred_bboxes(model, args.data_dir, valid_images, args.image_size, args.batch_size, split='valid_split') + gt_bboxes_dict = get_gt_bboxes(args.data_dir, json_file='ufo/valid_split.json', valid_images=valid_images) - if (epoch + 1) % save_interval == 0: - if not osp.exists(model_dir): - os.makedirs(model_dir) + result = calc_deteval_metrics(pred_bboxes_dict, gt_bboxes_dict) + total_result = result['total'] + precision, recall = total_result['precision'], total_result['recall'] + f1_score = 2*precision*recall/(precision+recall) + print(f'Precision: {precision} Recall: {recall} F1 Score: {f1_score}') + + val_dict = { + 'val precision': precision, + 'val recall': recall, + 'val f1_score': f1_score + } + if args.mode == 'on': + wandb.log(val_dict, step=epoch) + + ### Save Best Model ### + if best_f1_score < f1_score: + print(f"New best model for f1 score : {f1_score}! saving the best model..") + # bestpt_fpath = osp.join(args.model_dir, 'best.pth') + bestpt_fpath = osp.join(args.save_dir, 'best.pth') + torch.save(model.state_dict(), bestpt_fpath) + best_f1_score = f1_score - ckpt_fpath = osp.join(model_dir, 'latest.pth') + elapsed_time = time.time() - total_start_time + estimated_time_left = elapsed_time / (epoch + 1) * (args.max_epoch - epoch - 1) + # 예상 종료 시간을 현재 시간 기준으로 변환 + + eta = str(timedelta(seconds=estimated_time_left)) + print(f'Epoch {epoch + 1} Validation Finised | Left ETA: {eta}') + + # Save the Lastest Model + if (epoch + 1) % args.save_interval == 0: + ckpt_fpath = osp.join(args.save_dir, 'latest.pth') + # ckpt_fpath = osp.join(args.model_dir, 'latest.pth') torch.save(model.state_dict(), ckpt_fpath) + + total_duration = time.time() - total_start_time + # print('Mean loss: {:.4f} | Elapsed time: {} |'.format( + # train_loss.avg, timedelta(seconds=epoch_duration))) - if mode == 'on': - # wandb.run.summary['best_f1'] = best_f1 - wandb.alert('Training Task Finished', f"TRAIN_LOSS: {train_loss:.4f}") + if args.mode == 'on': + wandb.alert('Training Task Finished', f"TRAIN_LOSS: {train_loss.avg:.4f}") wandb.finish() def main(args): - do_training(args, **args.__dict__) - + do_training(args) if __name__ == '__main__': args = parse_args() + print(args) + seed_everything(args.seed) + main(args) diff --git a/code/utils.py b/code/utils.py new file mode 100644 index 0000000..8918fc3 --- /dev/null +++ b/code/utils.py @@ -0,0 +1,76 @@ +import os +import os.path as osp +import random +import numpy as np +import torch + +from detect import detect +from tqdm import tqdm +import cv2 +import json + +class AverageMeter: + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def seed_everything(seed): + random.seed(seed) + np.random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +def get_gt_bboxes(root_dir, json_file, valid_images) : + + gt_bboxes = dict() + ufo_file_root = osp.join(root_dir, json_file) + + with open(ufo_file_root, 'r') as f: + ufo_file = json.load(f) + + ufo_file_images = ufo_file['images'] + for valid_image in tqdm(valid_images) : + gt_bboxes[valid_image] = [] + for idx in ufo_file_images[valid_image]['words'].keys() : + gt_bboxes[valid_image].append(ufo_file_images[valid_image]['words'][idx]['points']) + + return gt_bboxes + +def get_pred_bboxes(model, data_dir, valid_images, input_size, batch_size, split='valid') : + + image_fnames, by_sample_bboxes = [], [] + + images = [] + for valid_image in tqdm(valid_images) : + image_fpath = osp.join(data_dir,'img/{}/{}'.format(split, valid_image)) + image_fnames.append(osp.basename(image_fpath)) + + images.append(cv2.imread(image_fpath)[:, :, ::-1]) + if len(images) == batch_size: + by_sample_bboxes.extend(detect(model, images, input_size)) + images = [] + + if len(images): + by_sample_bboxes.extend(detect(model, images, input_size)) + + pred_bboxes = dict() + for idx in range(len(image_fnames)) : + image_fname = image_fnames[idx] + sample_bboxes = by_sample_bboxes[idx] + pred_bboxes[image_fname] = sample_bboxes + + return pred_bboxes diff --git a/tools/normalize.ipynb b/tools/normalize.ipynb new file mode 100644 index 0000000..f550e3c --- /dev/null +++ b/tools/normalize.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import glob\n", + "from PIL import Image\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = '/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/img/train'\n", + "img_path = glob.glob(os.path.join(data_dir,'*.jpg'))\n", + "\n", + "img_list = []\n", + "for m in img_path:\n", + " img = Image.open(m)\n", + "\n", + " assert img.mode == 'RGB'\n", + " \n", + " img = np.array(img)\n", + " img_list.append(img)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_norm(img_list):\n", + " # dataset의 axis=1, 2에 대한 평균 산출\n", + " mean_ = np.array([np.mean(x, axis=(0, 1)) for x in tqdm(img_list, ascii=True)])\n", + " # r, g, b 채널에 대한 각각의 평균 산출\n", + " mean_r = mean_[..., 0].mean() / 255.0\n", + " mean_g = mean_[..., 1].mean() / 255.0\n", + " mean_b = mean_[..., 2].mean() / 255.0\n", + "\n", + " # dataset의 axis=1, 2에 대한 표준편차 산출\n", + " std_ = np.array([np.std(x, axis=(0, 1)) for x in tqdm(img_list, ascii=True)])\n", + " # r, g, b 채널에 대한 각각의 표준편차 산출\n", + " std_r = std_[..., 0].mean() / 255.0\n", + " std_g = std_[..., 1].mean() / 255.0\n", + " std_b = std_[..., 2].mean() / 255.0\n", + " \n", + " return (mean_r, mean_g, mean_b), (std_r, std_g, std_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|##########| 100/100 [00:20<00:00, 4.97it/s]\n", + "100%|##########| 100/100 [00:52<00:00, 1.90it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "((0.7760271717131425, 0.7722186515548635, 0.7670997062399484),\n", + " (0.17171108542242774, 0.17888224507630185, 0.18678791254805846))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calculate_norm(img_list)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tools/train_test_split.ipynb b/tools/train_test_split.ipynb new file mode 100644 index 0000000..e726e35 --- /dev/null +++ b/tools/train_test_split.ipynb @@ -0,0 +1,84 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "import os.path as osp\n", + "import shutil\n", + "\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "seed = 137\n", + "\n", + "root_dir = '../data/medical/'\n", + "json_file = osp.join(root_dir, 'ufo/train.json') # json 파일 경로\n", + "train_split_json = osp.join(root_dir, 'ufo/train_split.json')\n", + "valid_split_json = osp.join(root_dir, 'ufo/valid_split.json')\n", + "\n", + "with open(json_file, 'r') as f:\n", + " data = json.load(f)\n", + " \n", + "images = list(data['images'].keys())\n", + "\n", + "# 8 : 2 split\n", + "train, val = train_test_split(images, train_size=0.8, shuffle=True, random_state=seed)\n", + "\n", + "train_images = {img_id: data['images'][img_id] for img_id in train}\n", + "train_anns = {'images': train_images}\n", + "valid_images = {img_id: data['images'][img_id] for img_id in val}\n", + "valid_anns = {'images': valid_images}\n", + "\n", + "# create json\n", + "with open(train_split_json, 'w', encoding='utf-8') as f:\n", + " json.dump(train_anns, f, indent=4, ensure_ascii=False)\n", + "with open(valid_split_json, 'w', encoding='utf-8') as f:\n", + " json.dump(valid_anns, f, indent=4, ensure_ascii=False)\n", + " \n", + "# split folder\n", + "train_split_dir = osp.join(root_dir, 'img/train_split/')\n", + "valid_split_dir = osp.join(root_dir, 'img/valid_split/')\n", + "\n", + "os.makedirs(train_split_dir, exist_ok=True)\n", + "os.makedirs(valid_split_dir, exist_ok=True)\n", + "\n", + "for train_image in train_images :\n", + " shutil.copy(osp.join(root_dir, 'img/train/', train_image), osp.join(train_split_dir, train_image))\n", + "for valid_image in valid_images :\n", + " shutil.copy(osp.join(root_dir, 'img/train/', valid_image), osp.join(valid_split_dir, valid_image))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}