-
Notifications
You must be signed in to change notification settings - Fork 765
/
evaluate_pfid.py
40 lines (28 loc) · 1.2 KB
/
evaluate_pfid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Evaluate P-FID between two batches of point clouds.
The point cloud batches should be saved to two npz files, where there
is an arr_0 key of shape [N x K x 3], where K is the dimensionality of
each point cloud and N is the number of clouds.
"""
import argparse
from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
from point_e.evals.fid_is import compute_statistics
from point_e.evals.npz_stream import NpzStreamer
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("batch_1", type=str)
parser.add_argument("batch_2", type=str)
args = parser.parse_args()
print("creating classifier...")
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
print("computing first batch activations")
features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1))
stats_1 = compute_statistics(features_1)
del features_1
features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2))
stats_2 = compute_statistics(features_2)
del features_2
print(f"P-FID: {stats_1.frechet_distance(stats_2)}")
if __name__ == "__main__":
main()