-
Notifications
You must be signed in to change notification settings - Fork 4
/
npbcl_not.py
36 lines (30 loc) · 862 Bytes
/
npbcl_not.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="2"
try:
os.mkdir('./saves')
except:
pass
import numpy as np
import matplotlib
matplotlib.use('Agg')
import math
from data_generators import PermutedMnistGenerator, SplitMnistGenerator, NotMnistGenerator, FashionMnistGenerator
from ibpbcl import IBP_BCL
import torch
torch.manual_seed(8)
np.random.seed(10)
hidden_size = [200]
alpha = [30]
no_epochs = 5
no_tasks = 5
coreset_size = 0#200
coreset_method = "kcen"
single_head = False
batch_size = 256
# data_gen = PermutedMnistGenerator(no_tasks)
# data_gen = SplitMnistGenerator()
data_gen = NotMnistGenerator()
# data_gen = FashionMnistGenerator()
model = IBP_BCL(hidden_size, alpha, no_epochs, data_gen, coreset_method, coreset_size, single_head, grow = False)
accs, _ = model.batch_train(batch_size)
np.save('./saves/notmnist_accuracies.npy', accs)