-
Notifications
You must be signed in to change notification settings - Fork 0
/
rf_model.py
57 lines (43 loc) · 1.91 KB
/
rf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import fbeta_score
import numpy as np
from dataset import get_data_sets
from features import summary_stats
def random_forest_clf(features, labels, **kwargs):
'''
Builds a random forest classifier
"features" is a numpy array of shape [num_samples, num_features]
"labels" is a numpy array of shape [num_sampes, num_categories]
"kwargs" are additional key word arguments to supply to the RandomForestClassifier() constructor
'''
clf = RandomForestClassifier(**kwargs)
clf = clf.fit(features, labels)
return clf
def reshape_features(*features):
'''
each "feature" in features is a numpy array of shape [num_samples, num_channels, num_metrics]
Each of these should be flattened so there is one row for each sample and then concatenated
horizontally.
'''
reshaped = []
for feature in features:
shape = feature.shape
reshaped.append(feature.reshape((shape[0], shape[1] * shape[2])))
return np.hstack(reshaped)
def predict():
print('preparing data...')
data = get_data_sets(data_dir='./data/train-tif-v2')
(images, labels) = data['train'].get_image_batch(100)
(rgbn_hists, power_spectra) = summary_stats(images, labels)
train_features = reshape_features(rgbn_hists, power_spectra)
print('training classifier...')
rf_classifier = random_forest_clf(train_features, labels, class_weight='balanced_subsample')
print('making predictions...')
(eval_images, eval_labels) = data['validation'].get_image_batch(100)
(rgbn_hists, power_spectra) = summary_stats(eval_images, eval_labels)
eval_features = reshape_features(rgbn_hists, power_spectra)
predicted_labels = rf_classifier.predict(eval_features)
f2 = fbeta_score(eval_labels, predicted_labels, beta=2, average='samples')
print('F2 score: {}'.format(f2))
if __name__ == "__main__":
predict()