-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_clippings.py
830 lines (609 loc) · 37.9 KB
/
train_clippings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
####Continue pretrain clip
###This script is to continue pretraining clip on the japanese dataset. Our dataset is a df of image-text pairs
import pandas as pd
import numpy as np
import wandb
from utils.datasets_utils import *
from tqdm.autonotebook import trange
import faiss
from tqdm import tqdm
import json
import argparse
from sklearn.model_selection import train_test_split
import torch
import japanese_clip as ja_clip
from pytorch_metric_learning import losses
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau, CosineAnnealingLR, StepLR
from torch import nn
import wandb
import datasets.clippings_data_loaders as data_loaders
import models.encoders as encoders
def convert_to_text(unicode_string):
return unicode_string.encode('ascii','ignore').decode('ascii')
def get_tk_universe(tk_data_path='/path/to/data/record_linkage_clean_dataset/ocr_json/effocr_tk_title_dup_68352_clean_path.json'):
"""Get the universe of vertical japanese text crops + ocr text from the tk dataset"""
with open(tk_data_path) as f:
tk_data = json.load(f)
ground_truth_target_text=[sublist[0] for sublist in tk_data if sublist[1] ]
ground_truth_target_images=[sublist[1] for sublist in tk_data if sublist[1] ]
tk_data_df=pd.DataFrame({'image_path':ground_truth_target_images,'text':ground_truth_target_text})
tk_data_df=tk_data_df.drop_duplicates(subset=['image_path','text'])
tk_data_df=tk_data_df.dropna()
tk_data_df=tk_data_df.reset_index(drop=True)
return tk_data_df
def get_pr_title_universe(pr_title_path="/path/to/data/record_linkage_clean_dataset/ocr_json/effocr_pr_title_updated.json"):
"""Get the universe of japanese text crops + ocr text from the pr dataset"""
with open(pr_title_path) as f:
pr_title_data = json.load(f)
text_list=[convert_to_text(sublist[0]) for sublist in pr_title_data]
image_list=[sublist[1] for sublist in pr_title_data]
pr_title_data_df=pd.DataFrame({'image_path':image_list,'text':text_list})
pr_title_data_df=pr_title_data_df.drop_duplicates(subset=['image_path','text'])
pr_title_data_df=pr_title_data_df.dropna()
pr_title_data_df=pr_title_data_df.reset_index(drop=True)
return pr_title_data_df
def get_pr_partner_universe(pr_partner_path="/path/to/data/record_linkage_clean_dataset/ocr_json/effocr_pr_partner.json"):
"""Get the horizontal of vertical japanese text crops from the pr dataset"""
with open(pr_partner_path) as f:
pr_partner_data = json.load(f)
text_list= [subdict['partner_text'] for subdict in pr_partner_data]
image_list=[subdict['partner_path'] for subdict in pr_partner_data]
pr_partner_data_df=pd.DataFrame({'image_path':image_list,'text':text_list})
pr_partner_data_df=pr_partner_data_df.drop_duplicates(subset=['image_path','text'])
pr_partner_data_df=pr_partner_data_df.dropna()
pr_partner_data_df=pr_partner_data_df.reset_index(drop=True)
return pr_partner_data_df
def get_pr_partner_test_data(test_data_path="/path/to/data/mm_dir/PR_TK_matched_ocr_only_test.csv"):
"""Get the test data for the pr partner - tk dataset"""
test_data = pd.read_csv(test_data_path)
###Drop duplicates by image_path and text
test_data=test_data.drop_duplicates(subset=['image_path','text'])
test_data=test_data.reset_index(drop=True)
return test_data
def eval_clip(val_loader,model,tokenizer):
"""Evaluate the clip model
Args:
val_loader: The dataloader for the validation set
model: The clip model
tokenizer: The clip tokenizer (japanese tokenizer from the rinna clip repo)
"""
print("Evaluating the model - clip loss")
model.eval()
loss_list=[]
with torch.no_grad():
for batch_idx, (text, image_data, labels, image_path) in enumerate(val_loader):
labels = labels.to(device)
labels= torch.arange((labels.shape[0])).to(device)
###Unsquueze the image data
image_data = image_data.to(device)
### text is a tuple of strings, we need to convert it to a tensor
text=list(text)
text_features = ja_clip.tokenize(text,tokenizer=tokenizer)
for key in text_features.keys():
text_features[key]=text_features[key].to(device)
model_output=model.forward(input_ids=text_features["input_ids"],pixel_values=image_data,
attention_mask=text_features["attention_mask"], position_ids=text_features["position_ids"])
logits_per_image, logits_per_text = model_output["logits_per_image"], model_output["logits_per_text"]
loss = (img_loss(logits_per_image, labels) + text_loss(logits_per_text, labels))/2
loss_list.append(loss.item())
mean_loss= np.mean(loss_list)
wandb.log({"val_loss":mean_loss})
return mean_loss
def prep_labelled_data(data_path="/path/to/data/mm_dir/PR_TK_matched_ocr_only_train.csv",
val_data="/path/to/data/mm_dir/PR_TK_matched_ocr_only_val.csv"):
"""Prepare the labelled data for the PR-TK dataset. The data was preprocessed outside of this script."""
####Load the image+text paired data (CSV)
train_data=pd.read_csv(data_path)
print("original size of data: {}".format(len(train_data)))
###drop duplicates if image path and text are the same
train_data=train_data.drop_duplicates(subset=['image_path','text'])
##Val data
val_data=pd.read_csv(val_data)
print("original size of val data: {}".format(len(val_data)))
###drop duplicates if image path and text are the same
val_data=val_data.drop_duplicates(subset=['image_path','text'])
data=pd.concat([train_data,val_data],axis=0)
return train_data,val_data,data
def prep_synth_data(data_path="/path/to/datamultimodal_data/multimodal_synth_train_data.csv"):
"""Prepare the synthetic data. The data was preprocessed outside of this script."""
####Load the image+text paired data (CSV)
data=pd.read_csv(data_path)
print("original size of data: {}".format(len(data)))
###drop duplicates if image path and text are the same
data=data.drop_duplicates(subset=['image_path','text'])
###Drop if na
data=data.dropna()
print("post processing size of data: {}".format(len(data)))
###Conver labels to negative - subtract from 0. This will help distinguish synthetic data from real data
data['label']=data['label'].apply(lambda x: 0-x)
# train_data=data.sample(500)
###Split the data into train and val just using train test split on label
all_unique_labels=data['label'].unique()
###Take 80% of the labels for train and 20% for val
train_labels=np.random.choice(all_unique_labels,int(len(all_unique_labels)*0.8),replace=False)
val_labels=[x for x in all_unique_labels if x not in train_labels]
###Get the train and val data
train_data=data[data['label'].isin(train_labels)]
val_data=data[data['label'].isin(val_labels)]
return train_data,val_data
###Fir val data, we don't want to use the paths in pr - just keep them for tk!
###Drop rows where image_path does not contain tk_title_img_1025_v2
def prep_unlabelled_data(tk_path="/path/to/data/ocr_dataframes_multimodal/EffOCR/tk_image_ocr.csv",
pr_title_path='/path/to/data/ocr_dataframes_multimodal/EffOCR/pr_image_ocr.csv',
pr_partner_path='/path/to/data/ocr_dataframes_multimodal/EffOCR/partners_image_ocr.csv'):
"""Prep the unlabelled data for training. The data was preprocessed outside of this script."""
##Load the data
tk_data=pd.read_csv(tk_path)
pr_title_data=pd.read_csv(pr_title_path)
pr_partner_data=pd.read_csv(pr_partner_path)
###Concat the data
data=pd.concat([tk_data,pr_title_data,pr_partner_data])
print(len(data))
###Make a label column (not needed, but plays nicely with dataloader)
data['label']=1
###Drop duplicates
data=data.drop_duplicates(subset=['image_path','text'])
##Drop if na
data=data.dropna()
###Split the data into train and val just using train test split
train_data,val_data=train_test_split(data,test_size=0.2,random_state=42)
return train_data,val_data
def pretrain_clip(train_loader,model,device,img_loss,text_loss,epoch,optimizer,tokenizer,scheduler=None,epochviz=None):
"""Pretrain CLIP on the multimodal data. Logging integreated with wandb.
Args:
train_loader: Dataloader for the training data
model: CLIP model
device: Device to run the model on
img_loss: Image loss function
text_loss: Text loss function
epoch: Current epoch
optimizer: Optimizer
scheduler: Scheduler
epochviz: Path to store some example training images (augmented)
"""
print("Pretraining CLIP")
model.train()
loss_list=[]
for batch_idx, (text, image_data, labels, image_path) in tqdm(enumerate(train_loader)):
labels = labels.to(device)
####Unsquueze the image data
image_data = image_data.to(device)
### text is a tuple of strings, we need to convert it to a tensor
text=list(text)
text_features = ja_clip.tokenize(text,tokenizer=tokenizer)
for key in text_features.keys():
text_features[key]=text_features[key].to(device)
optimizer.zero_grad()
model_output=model.forward(input_ids=text_features["input_ids"],pixel_values=image_data,
attention_mask=text_features["attention_mask"], position_ids=text_features["position_ids"])
logits_per_image, logits_per_text = model_output["logits_per_image"], model_output["logits_per_text"]
del model_output
###The clip objective asks us to maximize the similarity between the logits of the image and text. Labels aren't needed here.
###Giving them a diff label for each image and text pair
labels=torch.arange((labels.shape[0]))
labels=labels.to(device)
loss = (img_loss(logits_per_image, labels) + text_loss(logits_per_text, labels))/2
loss.backward()
optimizer.step()
##For ReduceLROnPlateau scheduler, we need to pass the loss value
if scheduler!=None:
scheduler.step()
# scheduler.step(loss.item())
if batch_idx % 50 == 0:
print("Current LR: {}".format(scheduler.get_lr()[0]))
wandb.log({"train/lr": scheduler.get_lr()[0]})
wandb.log({"train/loss": loss.item()})
if batch_idx % 50 == 0:
print("Epoch {} Iteration {}: Loss = {}".format(
str(epoch).zfill(3), str(batch_idx).zfill(4), loss))
if not epochviz is None:
for i in range(30):
image = T.ToPILImage()(INV_NORMALIZE(image_data[i].cpu()))
image.save(os.path.join(epochviz, f"train_sample_{epoch}_{i}.png"))
loss_list.append(loss.item())
####Mean epoch loss
mean_epoch_loss=np.mean(loss_list)
return mean_epoch_loss
def train_bienc_clip(train_loader,clip_model,device,loss_func,epoch,clip_optimizer,tokenizer,clip_scheduler=None,epochviz=None,mlp_model=None,mlp_optimizer=None,mlp_scheduler=None,freeze_clip=False):
"""A version where we contrastively train pooled clip embeddings - The CLIPPINGS model. Logging integreated with wandb.
Args:
train_loader: Dataloader for the training data
clip_model: CLIP model
device: Device to run the model on
loss_func: Loss function
epoch: Current epoch
clip_optimizer: Optimizer for CLIP
tokenizer: CLIP tokenizer (japanese tokenizer from the japanese clip repo)
clip_scheduler: Scheduler for CLIP
epochviz: Path to store some example training images (augmented)
mlp_model: MLP model (optional)
mlp_optimizer: Optimizer for MLP (optional)
mlp_scheduler: Scheduler for MLP (optional)
freeze_clip: Whether to freeze CLIP or not (optional) - only for the MLP Case
"""
clip_model.train()
if not mlp_model is None:
mlp_model.train()
for batch_idx, (text, image_data, labels, image_path) in tqdm(enumerate(train_loader)):
labels = labels.to(device)
####Unsquueze the image data
image_data = image_data.to(device)
### text is a tuple of strings, we need to convert it to a tensor
text=list(text)
text_features = ja_clip.tokenize(text,tokenizer=tokenizer)
for key in text_features.keys():
text_features[key]=text_features[key].to(device)
clip_optimizer.zero_grad()
if not mlp_model is None:
mlp_optimizer.zero_grad()
if freeze_clip and mlp_model is not None:
with torch.no_grad():
model_output=clip_model.forward(input_ids=text_features["input_ids"],pixel_values=image_data,
attention_mask=text_features["attention_mask"], position_ids=text_features["position_ids"])
else:
model_output=clip_model.forward(input_ids=text_features["input_ids"],pixel_values=image_data,
attention_mask=text_features["attention_mask"], position_ids=text_features["position_ids"])
image_embeds, text_embeds = model_output["image_embeds"], model_output["text_embeds"]
del model_output
if args.pooling_type=="mean":
final_embeds= args.im_wt*image_embeds + (1-args.im_wt)*text_embeds
elif args.pooling_type=="mlp":
###Use an MLP to combine the image and text embeddings
###concat the image and text embeddings
final_embeds=torch.cat([image_embeds,text_embeds],dim=1)
###Pass through an MLP
final_embeds=mlp_model.forward(final_embeds)
else:
raise ValueError("Pooling type not supported")
##L2 normalize the embeddings
final_embeds=torch.nn.functional.normalize(final_embeds,p=2,dim=1)
loss=loss_func(final_embeds,labels)
loss.backward()
clip_optimizer.step()
if not mlp_optimizer is None:
mlp_optimizer.step()
##For ReduceLROnPlateau scheduler, we need to pass the loss value
if clip_scheduler!=None:
clip_scheduler.step()
# scheduler.step(loss.item())
if batch_idx % 50 == 0:
print("Current LR: {}".format(clip_scheduler.get_lr()[0]))
wandb.log({"train/clip_lr": clip_scheduler.get_lr()[0]})
if mlp_scheduler!=None:
mlp_scheduler.step()
# scheduler.step(loss.item())
if batch_idx % 50 == 0:
print("Current LR: {}".format(mlp_scheduler.get_lr()[0]))
wandb.log({"train/mlp_lr": mlp_scheduler.get_lr()[0]})
wandb.log({"train/loss": loss.item()})
if batch_idx % 50 == 0:
print("Epoch {} Iteration {}: Loss = {}".format(
str(epoch).zfill(3), str(batch_idx).zfill(4), loss))
if not epochviz is None:
for i in range(10):
image = T.ToPILImage()(INV_NORMALIZE(image_data[i].cpu()))
image.save(os.path.join(epochviz, f"train_sample_{epoch}_{i}.png"))
def get_image_text_embeddings(data_loader,clip_model,mlp_model,device):
"""Get the image and text embeddings for a given dataset using its dataloader"""
clip_model.eval()
if not mlp_model is None:
mlp_model.eval()
for batch_idx, (text, image_data, labels, image_path) in tqdm(enumerate(data_loader)):
labels = labels.to(device)
####Unsquueze the image data
image_data = image_data.to(device)
### text is a tuple of strings, we need to convert it to a tensor
text=list(text)
text_features = ja_clip.tokenize(text,tokenizer=tokenizer)
for key in text_features.keys():
text_features[key]=text_features[key].to(device)
with torch.no_grad():
model_output=clip_model.forward(input_ids=text_features["input_ids"],pixel_values=image_data,
attention_mask=text_features["attention_mask"], position_ids=text_features["position_ids"])
image_embeds, text_embeds = model_output["image_embeds"], model_output["text_embeds"]
# final_embeds=torch.cat((image_embeds,text_embeds),dim=1)
###MEan of the two embeddings
if args.pooling_type=="mean":
final_embeds= args.im_wt*image_embeds + (1-args.im_wt)*text_embeds
elif args.pooling_type=="mlp":
###Use an MLP to combine the image and text embeddings
###concat the image and text embeddings
final_embeds=torch.cat([image_embeds,text_embeds],dim=1)
###Pass through an MLP
final_embeds=mlp_model.forward(final_embeds)
# final_embeds=text_embeds
final_embeds=final_embeds/torch.norm(final_embeds, dim=1, keepdim=True)
####
if batch_idx == 0:
all_embeddings = final_embeds
all_labels = labels
all_text=text
all_paths=image_path
else:
all_embeddings = torch.cat((all_embeddings, final_embeds), dim=0)
all_labels = torch.cat((all_labels, labels), dim=0)
all_text=all_text+text
all_paths=all_paths+image_path
return all_embeddings, all_labels, all_text, all_paths
def tester_bienc_clip(test_loader,ref_loader,clip_model,mlp_model,split='val',log=True):
"""This is the base tester for CLIPPINGS-Record linkage application. We use horizontal japanese text and try matching it with vertical japanese text.
top-1 accuracy is what we use for evaluation."""
print("Testing using pooled embeddings")
test_embeddings, test_labels, test_text, test_paths = get_image_text_embeddings(test_loader,clip_model,mlp_model, device)
print("total test embeddings: ",test_embeddings.shape)
ref_embeddings, ref_labels, ref_text, ref_paths = get_image_text_embeddings(ref_loader,clip_model,mlp_model, device)
print("total ref embeddings: ",ref_embeddings.shape)
###Make an index
index = faiss.IndexFlatIP(test_embeddings.shape[1])
index.add(ref_embeddings.cpu().numpy())
###Get the nearest neighbours
D, I = index.search(test_embeddings.cpu().numpy(), 1)
acc=0
for i in range(len(test_labels)):
if test_labels[i]==ref_labels[I[i][0]]:
acc+=1
acc=acc/len(test_labels)
print("CUSTOM ACCURACY: ",acc)
if log:
wandb.log({f"{split}/precision_1": acc})
###Print a sample of predictions (text)
for i in range(10):
print(f"Text: {test_text[i]}")
print(f"Nearest neighbour: {ref_text[I[i][0]]}")
print(f"Nearest neighbour label: {ref_labels[I[i][0]]}")
print(f"Test label: {test_labels[i]}")
print("")
print(acc)
return acc
if __name__ == "__main__":
##parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--clip_lr", type=float, default=5e-7)
parser.add_argument("--mlp_lr", type=float, default=5e-5)
parser.add_argument("--clip_weight_decay",type=float,default=0.001)
parser.add_argument("--mlp_weight_decay",type=float,default=0.001)
parser.add_argument("--batch_size",type=int,default=153)
parser.add_argument("--m",type=int,default=1)
parser.add_argument("--k",type=int,default=3)
parser.add_argument("--train_data_type",type=str,default="labelled")
parser.add_argument("--wandb_name",type=str,default="clip_pretrain_labelled_m1")
parser.add_argument("--training_type",type=str,default="pretrain")
parser.add_argument("--supcon_temp",type=float,default=0.1)
parser.add_argument("--im_wt",type=float,default=0.3)
parser.add_argument("--pooling_type",type=str,default="mean",help="mean or mlp. MLP pooling is still in beta")
parser.add_argument("--freeze_clip_epochs",type=int,default=20)
parser.add_argument("--mlp_layers",type=int,default=3)
parser.add_argument("--augmented_crops",action="store_true")
parser.add_argument("--train_hardneg",action="store_true")
parser.add_argument("--checkpoint_path",type=str,default=None)
parser.add_argument("--num_epochs",type=int,default=200)
parser.add_argument("--start_epoch",type=int,default=0)
args = parser.parse_args()
# def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
clip_model = ja_clip.clip.CLIPModel.from_pretrained("rinna/japanese-clip-vit-b-16",cache_dir="/tmp/japanese_clip")
###Load checkpoint
if args.checkpoint_path is not None:
clip_model.load_state_dict(torch.load(args.checkpoint_path, map_location=torch.device(device)))
clip_model.to(device)
if args.pooling_type=="mlp":
mlp_model=encoders.MLP(2 * 512, 1024, 512, args.mlp_layers, 0.1)
mlp_model.to(device)
else:
mlp_model=None
tokenizer = ja_clip.load_tokenizer()
###DATAparallel - not supported yet
# model = torch.nn.DataParallel(model)
# model.to(device)
if args.train_data_type == "labelled": #The data for supervised training
train_data,val_data,full_labelled_data=prep_labelled_data()
elif args.train_data_type == "unlabelled":
train_data,val_data=prep_unlabelled_data() ###The in-domain data for self-supervised training
elif args.train_data_type == "synth":
train_data,val_data=prep_synth_data() ###The out-of-domain data for self-supervised training
elif args.train_data_type == "synth_unlabelled": ###The out-of-domain data for self-supervised training + the in-domain data for self-supervised training
train_data_synth,val_data_synth=prep_synth_data()
train_data_unlabelled,val_data_unlabelled=prep_unlabelled_data()
print("train_data_unlabelled.shape: ",train_data_unlabelled.shape)
print("train_data_synth.shape: ",train_data_synth.shape)
##GEt unique synth data (by label)
train_data_synth=train_data_synth.drop_duplicates(subset="label")
val_data_synth=val_data_synth.drop_duplicates(subset="label")
####Add the synth data to the unlabelled data
train_data=pd.concat([train_data_synth,train_data_unlabelled])
val_data=pd.concat([val_data_synth,val_data_unlabelled])
del train_data_synth,val_data_synth,train_data_unlabelled,val_data_unlabelled
else:
raise ValueError("labelled_data must be either labelled, unlabelled or synth")
###Full labelled data is needed anyway
_,_,full_labelled_data=prep_labelled_data()
###Remove any unnamed columns
train_data=train_data.loc[:, ~train_data.columns.str.contains('^Unnamed')]
val_data=val_data.loc[:, ~val_data.columns.str.contains('^Unnamed')]
####In the image_path columns, replace /data01/ with 122a7683-fa4b-45dd-9f13-b18cc4f4a187
train_data['image_path']=train_data['image_path'].apply(lambda x: x.replace("/data01/","/122a7683-fa4b-45dd-9f13-b18cc4f4a187/"))
val_data['image_path']=val_data['image_path'].apply(lambda x: x.replace("/data01/","/122a7683-fa4b-45dd-9f13-b18cc4f4a187/"))
full_labelled_data['image_path']=full_labelled_data['image_path'].apply(lambda x: x.replace("/data01/","/122a7683-fa4b-45dd-9f13-b18cc4f4a187/"))
###We wil drop duplicates in the train data if pretraining
if args.training_type == "pretrain":
###Shuffle first
print("Lenth of train data before dropping duplicates: ",len(train_data))
train_data=train_data.sample(frac=1,random_state=42)
train_data=train_data.drop_duplicates(subset=['text'],keep='first')
print("Lenth of train data after dropping duplicates: ",len(train_data))
###Get test data
test_data=get_pr_partner_test_data()
##Make reference data
###Create a reference dataset - for training bienc_clip
tk_universe_df=get_tk_universe()
###Get labels for the reference dataset
ref_data=tk_universe_df[['image_path','text']]
##Merge the reference data with the data to get the labels
ref_data=pd.merge(ref_data,full_labelled_data[['image_path','label']],on='image_path',how='left')
###Wherever the label is nan, set it to -1
ref_data['label']=ref_data['label'].apply(lambda x: -1 if pd.isna(x) else x)
###Make the ref set more tractable. Get 1000 from label that has -1 and keep all the rest! - we make different versions to make testing progressively harder.
###The larger the ref set, the more difficult the task. All unalabelled data is given a value of -1, so we use that to get the unlabelled data in the required proportions
small_ref_data=ref_data[ref_data['label']==-1].sample(1).append(ref_data[ref_data['label']!=-1]).drop_duplicates(subset=['image_path','text'])
med_ref_data=ref_data[ref_data['label']==-1].sample(1000).append(ref_data[ref_data['label']!=-1]).drop_duplicates(subset=['image_path','text'])
large_ref_data=ref_data[ref_data['label']==-1].sample(7000).append(ref_data[ref_data['label']!=-1]).drop_duplicates(subset=['image_path','text'])
huge_ref_data=ref_data[ref_data['label']==-1].sample(15000).append(ref_data[ref_data['label']!=-1]).drop_duplicates(subset=['image_path','text'])
all_tk_ref_data=ref_data[ref_data['label']==-1].append(ref_data[ref_data['label']!=-1]).drop_duplicates(subset=['image_path','text'])
###Create the data datsets
if args.augmented_crops: ###If we are using augmented crops, we need to use a different transform
train_image_transform=create_clip_random_doc_transform()
else:
train_image_transform=CLIP_BASE_TRANSFORM ###If we are not using augmented crops, we use the base transform
if args.train_hardneg:
print("Setting up dataset with hardnegatives")
dedup_train_data=train_data.drop_duplicates(subset=['label'],keep='first')
print("Total number of unique labels in train data: ",len(dedup_train_data))
k_hardneg_df = data_loaders.make_hard_neg_df(dedup_train_data,k=args.k,clip_model=clip_model,mlp_model=mlp_model,device=device,tokenizer=tokenizer,pooling_type=args.pooling_type,im_wt=args.im_wt)
train_dataset=data_loaders.TextImageDatasetWithHardNegs(train_data,k_hardneg_df,img_transform= train_image_transform ,text_transform=None,batch_size=126,k=args.k,m=args.m)
print("Done setting up dataset with hardnegatives")
else:
train_dataset=data_loaders.TextImageDataset(train_data, img_transform=train_image_transform)
###Make the neccesary loaders
val_dataset=data_loaders.TextImageDataset(val_data,img_transform=CLIP_BASE_TRANSFORM)
small_ref_dataset=data_loaders.TextImageDataset(small_ref_data,img_transform=CLIP_BASE_TRANSFORM)
med_ref_dataset=data_loaders.TextImageDataset(med_ref_data,img_transform=CLIP_BASE_TRANSFORM)
large_ref_dataset=data_loaders.TextImageDataset(large_ref_data,img_transform=CLIP_BASE_TRANSFORM)
huge_ref_dataset=data_loaders.TextImageDataset(huge_ref_data,img_transform=CLIP_BASE_TRANSFORM)
all_tk_ref_dataset=data_loaders.TextImageDataset(all_tk_ref_data,img_transform=CLIP_BASE_TRANSFORM)
test_dataset=data_loaders.TextImageDataset(test_data,img_transform=CLIP_BASE_TRANSFORM)
###Create the data loaders
if args.train_data_type == "labelled" or args.train_data_type == "synth":
if args.train_hardneg:
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=126,shuffle=False,num_workers=4)
else:
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,sampler=data_loaders.NoReplacementMPerClassSampler(train_dataset, m=args.m,batch_size=args.batch_size,num_passes=1))
val_loader=torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
elif args.train_data_type == "unlabelled":
if args.train_hardneg:
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=126,shuffle=False,num_workers=4)
else:
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader=torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
elif args.train_data_type == "synth_unlabelled":
if args.train_hardneg:
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=126,shuffle=False,num_workers=16)
else:
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,shuffle=True)
val_loader=torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
else:
raise ValueError("labelled_data must be either labelled, unlabelled or synth")
small_ref_loader=torch.utils.data.DataLoader(small_ref_dataset, batch_size=args.batch_size, shuffle=False)
med_ref_loader=torch.utils.data.DataLoader(med_ref_dataset, batch_size=args.batch_size, shuffle=False)
large_ref_loader=torch.utils.data.DataLoader(large_ref_dataset, batch_size=args.batch_size, shuffle=False)
huge_ref_loader=torch.utils.data.DataLoader(huge_ref_dataset, batch_size=args.batch_size, shuffle=False)
all_tk_ref_loader=torch.utils.data.DataLoader(all_tk_ref_dataset, batch_size=args.batch_size, shuffle=False)
test_loader=torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
###ADditonally, if training biencoder with synthetic data, create a reference dataset
if args.train_data_type == "synth":
# synth_ref_data=pd.concat([train_data,val_data])
synth_ref_data=pd.concat([val_data])
##Shuffle the data
synth_ref_data=synth_ref_data.sample(frac=1)
##Drop duplicates
synth_ref_data=synth_ref_data.drop_duplicates(subset=['label'])
synth_ref_dataset=data_loaders.TextImageDataset(synth_ref_data, img_transform=CLIP_BASE_TRANSFORM)
synth_ref_dataloader=torch.utils.data.DataLoader(synth_ref_dataset, batch_size=args.batch_size, shuffle=False)
large_synth_ref_data=pd.concat([train_data.sample(20000),val_data])
large_synth_ref_data=synth_ref_data.sample(frac=1)
large_synth_ref_data=synth_ref_data.drop_duplicates(subset=['label'])
large_synth_ref_dataset=data_loaders.TextImageDataset(large_synth_ref_data, img_transform=CLIP_BASE_TRANSFORM)
large_synth_ref_dataloader=torch.utils.data.DataLoader(large_synth_ref_dataset, batch_size=args.batch_size, shuffle=False)
val_data=val_data.drop_duplicates(subset=['label'])
val_dataset=data_loaders.TextImageDataset(val_data, img_transform=CLIP_BASE_TRANSFORM)
val_loader=torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
###Set up device
# setup
###The CLIP loss functions are defined here
img_loss=nn.CrossEntropyLoss()
text_loss=nn.CrossEntropyLoss()
###Optimizer for both clip and mlp
clip_optimizer = torch.optim.AdamW(clip_model.parameters(), lr=args.clip_lr,weight_decay=args.clip_weight_decay, betas=(0.9,0.98),
eps=1e-06)
clip_scheduler = CosineAnnealingWarmRestarts(clip_optimizer, 10, 2)
if args.pooling_type=="mlp":
mlp_optimizer = torch.optim.AdamW(mlp_model.parameters(), lr=args.mlp_lr,weight_decay=args.mlp_weight_decay, betas=(0.9,0.98),
eps=1e-06)
mlp_scheduler = CosineAnnealingWarmRestarts(mlp_optimizer, 10, 2)
else :
mlp_optimizer=None
mlp_scheduler=None
###Set up the trainer. WANDB is used for logging
wandb.init(project="multimodal_record_linkage", name=args.wandb_name)
if args.training_type=="pretrain":
zero_shot_loss=eval_clip(val_loader,clip_model,tokenizer)
if args.training_type=="pretrain":
for epoch in (range(args.start_epoch, args.num_epochs+args.start_epoch)):
train_loss=pretrain_clip(train_loader,clip_model,device,img_loss,text_loss,epoch,clip_optimizer,tokenizer,clip_scheduler,epochviz=None)
val_loss=eval_clip(val_loader,clip_model,tokenizer)
# print("val Accuracy: {}".format(acc))
# acc=tester_bienc_clip(val_loader,val_loader,model,split="val_small",log=True)
print("Val loss: {}".format(val_loss))
print("Train loss: {}".format(train_loss))
if val_loss<zero_shot_loss:
zero_shot_loss=val_loss
torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",args.wandb_name+".pt"))
print("Model saved at epoch {}".format(epoch))
print("Path of the saved model: {}".format(os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",args.wandb_name+".pt")))
print("Path of the saved model: {}".format(os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("epoch_"+str(epoch)+"_"+args.wandb_name+".pt"))))
print("Val loss: {}".format(val_loss))
if val_loss<0.1:
###Look at final acc on tk
final_acc=tester_bienc_clip(test_loader,small_ref_loader,clip_model,split="test",log=True)
print("Final acc on test set: {}".format(final_acc))
torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("epoch_"+str(epoch)+args.wandb_name+".pt")))
elif args.training_type=="train_bienc" and args.train_data_type=="labelled":
best_acc=tester_bienc_clip(val_loader,huge_ref_loader,clip_model,mlp_model,split="val_large",log=True)
loss_func=losses.SupConLoss(temperature=args.supcon_temp)
for epoch in (range(args.start_epoch, args.num_epochs+args.start_epoch)):
if epoch<= args.freeze_clip_epochs:
if args.pooling_type=="mlp":
freeze_clip=True
else:
freeze_clip=False
epoch_loss=train_bienc_clip(train_loader,clip_model,device,loss_func,epoch,clip_optimizer,clip_scheduler=clip_scheduler,epochviz="/path/to/japan_viz/epoch_viz",tokenizer=tokenizer,mlp_model=mlp_model,mlp_optimizer=mlp_optimizer,mlp_scheduler=mlp_scheduler,freeze_clip=freeze_clip)
else:
freeze_clip=False
epoch_loss=train_bienc_clip(train_loader,clip_model,device,loss_func,epoch,clip_optimizer,clip_scheduler=clip_scheduler,epochviz="/path/to/japan_viz/epoch_viz",tokenizer=tokenizer,mlp_model=mlp_model,mlp_optimizer=mlp_optimizer,mlp_scheduler=mlp_scheduler,freeze_clip=freeze_clip)
if epoch>15:
acc=tester_bienc_clip(val_loader,huge_ref_loader,clip_model,mlp_model,split="val_huge",log=True)
else:
acc=tester_bienc_clip(val_loader,large_ref_loader,clip_model,mlp_model,split="val_large",log=True)
if acc>best_acc:
best_acc=acc
torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("clip_imwt_"+str(args.im_wt)[2]+args.wandb_name+".pt")))
print("Model saved at epoch {}".format(epoch))
print("Path of the saved model: {}".format(os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("clip_imwt_"+str(args.im_wt)[2]+args.wandb_name+".pt"))))
if args.pooling_type=="mlp":
torch.save(mlp_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("mlp_imwt_"+str(args.im_wt)[2]+args.wandb_name+".pt")))
print("Model saved at epoch {}".format(epoch))
###Look at final acc on tk
final_acc=tester_bienc_clip(test_loader,small_ref_loader,clip_model,mlp_model,split="test",log=True)
###SAve at every epoch
# torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("clip_imwt_"+str(args.im_wt)[2]+"epoch_"+str(epoch)+args.wandb_name+".pt")))
elif args.training_type=="train_bienc" and args.train_data_type!="labelled":
best_acc=tester_bienc_clip(val_loader,synth_ref_dataloader,clip_model,mlp_model,split="val_small",log=True)
loss_func=losses.SupConLoss(temperature=args.supcon_temp)
for epoch in (range(args.start_epoch, args.num_epochs+args.start_epoch)):
epoch_loss=train_bienc_clip(train_loader,clip_model,device,loss_func,epoch,clip_optimizer,clip_scheduler=clip_scheduler,epochviz=None,tokenizer=tokenizer,mlp_model=mlp_model,freeze_clip=False)
acc=tester_bienc_clip(val_loader,large_synth_ref_dataloader,clip_model,mlp_model,split="val_large",log=True)
if acc>best_acc:
best_acc=acc
torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",args.wandb_name+".pt"))
print("Model saved at epoch {}".format(epoch))
print("Path of the saved model: {}".format(os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",args.wandb_name+".pt")))
if args.pooling_type=="mlp":
torch.save(mlp_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("mlp_imwt_"+str(args.im_wt)[2]+args.wandb_name+".pt")))
print("Model saved at epoch {}".format(epoch))
###Look at final acc on tk
final_acc=tester_bienc_clip(test_loader,all_tk_ref_loader,clip_model,mlp_model,split="test",log=True)
###SAve at every epoch
# torch.save(clip_model.state_dict(), os.path.join("/path/to/save/modelmultimodal_record_linkage/models/",("clip_imwt_"+str(args.im_wt)[2]+"epoch_"+str(epoch)+args.wandb_name+".pt")))
else :
print("Training type not recognised")
raise ValueError