forked from sorenbouma/keras-oneshot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
59 lines (54 loc) · 2.58 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from keras.layers import Input, Convolution2D, Lambda, merge, Dense, Flatten,MaxPooling2D
from keras.models import Model, Sequential
from keras import backend as K
from keras.optimizers import SGD
import numpy.random as rng
import numpy as np
import dill as pickle
import matplotlib.pyplot as plt
class Siamese_Loader:
#For loading batches and testing tasks to a siamese net
def __init__(self,Xtrain,Xval):
self.Xval = Xval
self.Xtrain = Xtrain
self.n_classes,self.n_examples,self.w,self.h = Xtrain.shape
self.n_val,self.n_ex_val,_,_ = Xval.shape
def get_batch(self,n):
#Create batch of pairs, half same class, half different class
categories = rng.choice(self.n_classes,size=(n,),replace=False)
pairs=[np.zeros((n, self.h, self.w,1)) for i in range(2)]
targets=np.zeros((n,))
targets[n//2:] = 1
for i in range(n):
category = categories[i]
idx_1 = rng.randint(0,self.n_examples)
pairs[0][i,:,:,:] = self.Xtrain[category,idx_1].reshape(self.w,self.h,1)
idx_2 = rng.randint(0,self.n_examples)
#pick images of same class for 1st half, different for 2nd
category_2 = category if i >= n//2 else (category + rng.randint(1,self.n_classes)) % self.n_classes
pairs[1][i,:,:,:] = self.Xtrain[category_2,idx_2].reshape(self.w,self.h,1)
return pairs, targets
def make_oneshot_task(self,N):
#Create pairs of test image, support set for testing N way one-shot learning.
categories = rng.choice(self.n_val,size=(N,),replace=False)
indices = rng.randint(0,self.n_ex_val,size=(N,))
true_category = categories[0]
ex1, ex2 = rng.choice(self.n_examples,replace=False,size=(2,))
test_image = np.asarray([self.Xval[true_category,ex1,:,:]]*N).reshape(N,self.w,self.h,1)
support_set = self.Xval[categories,indices,:,:]
support_set[0,:,:] = self.Xval[true_category,ex2]
support_set = support_set.reshape(N,self.w,self.h,1)
pairs = [test_image,support_set]
targets = np.zeros((N,))
targets[0] = 1
return pairs, targets
def test_oneshot(self,model,N,k,verbose=0):
#Test average N way oneshot learning accuracy of a siamese neural net over k one-shot tasks
n_correct = 0
for i in range(k):
inputs, targets = self.make_oneshot_task(N)
probs = model.predict(inputs)
if np.argmax(probs) == 0:
n_correct+=1
percent_correct = (100.0*n_correct / k)
return percent_correct