-
Notifications
You must be signed in to change notification settings - Fork 5
/
test.py
35 lines (28 loc) · 886 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pickle
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from model import Unet
import re
import dataloader
bs = 256
model = Unet((256,1,1250)).cuda()
path = 'model/final.pt'
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
pick_path = 'output.p'
test = torch.utils.data.DataLoader(BPdatasetv2(0, train = False, val = False, test = True), batch_size=bs)
temp1 = []
model.eval()
with torch.no_grad():
for idx,(inputs,labels) in tqdm(enumerate(test),total=len(test), disable=True):
inputs = inputs.cuda()
labels = labels.cuda()
outputs_v= model(inputs).cuda()
temp1.extend(outputs_v)
temp1 = torch.stack(temp1)
with open(pick_path,'wb') as f:
pickle.dump(temp1.cpu().detach().numpy(), f)