-
Notifications
You must be signed in to change notification settings - Fork 1
/
toy_test.py
executable file
·34 lines (28 loc) · 986 Bytes
/
toy_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/env python3
import numpy as np
from sklearn.datasets import make_s_curve
from sklearn.ensemble import RandomForestRegressor
import matplotlib.pyplot as plt
from generative_rf import FeatureGenerator
s_curve = make_s_curve(n_samples=10000)[0][:, [0, 2]]
s_curve -= s_curve.mean(axis=0)
s_curve /= s_curve.std(axis=0)
noise = np.random.uniform(-2, 2, size=(500, 2))
y = [0] * len(s_curve) + [1] * len(noise)
X = np.concatenate([s_curve, noise], axis=0)
rf = RandomForestRegressor(n_estimators=100, max_depth=15).fit(X, y)
pred = rf.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=pred)
plt.xticks([], [])
plt.yticks([], [])
plt.title("Prediction with RandomForestRegressor")
plt.show()
generator = FeatureGenerator().register(rf).reinforce(X).update_moments(X)
data, _ = generator.generate(50000)
y = rf.predict(data)
plt.clf()
plt.scatter(data[:, 0], data[:, 1], c=y, s=5)
plt.xticks([], [])
plt.yticks([], [])
plt.title("Data sampled with RandomForestRegressor")
plt.show()