-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
70 lines (56 loc) · 2.47 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data, Dataset
from tqdm import tqdm
class Twibot22(Dataset):
def __init__(self, root=r'src/Data/preprocessed', device='cpu', edge_index_file='edge_index.pt', edge_type_file='edge_type.pt'):
self.root = root
super().__init__(self.root, None, None, None)
self.device = device
path = lambda name: f"{self.root}/{name}"
# load labels
labels = torch.load(path("labels.pt"), map_location=self.device)
# load node features
numerical_features = torch.load(path("num_properties_tensor.pt"), map_location=self.device)
categorical_features = torch.load(path("categorical_properties_tensor.pt"), map_location=self.device)
description_embeddings = torch.load(path("user_description_embedding_tensor.pt"), map_location=self.device)
tweet_embeddings = torch.load(path("user_tweets_tensor.pt"), map_location=self.device)
# load edge index and types
edge_index = torch.load(path(edge_index_file), map_location=self.device)
edge_type = torch.load(path(edge_type_file), map_location=self.device)
# load dataset masks
train_mask = torch.load(path("train_mask.pt"), map_location=self.device)
test_mask = torch.load(path("test_mask.pt"), map_location=self.device)
val_mask = torch.load(path("validation_mask.pt"), map_location=self.device)
# Create data object
self.data = Data(
edge_index=edge_index,
edge_attr=edge_type,
y=labels,
description_embeddings = description_embeddings,
tweet_embeddings = tweet_embeddings,
numerical_features = numerical_features,
categorical_features = categorical_features,
train_mask = train_mask,
test_mask = test_mask,
val_mask = val_mask,
num_nodes = labels.shape[0]
)
assert self.data.validate()
def len(self):
return 1
def get(self, idx):
if idx == 0: return self.data
@property
def num_node_features(self):
return self.data.num_node_features
@property
def num_edge_features(self):
return self.data.num_edge_features
@property
def num_nodes(self):
return self.data.num_nodes
@property
def num_edges(self):
return self.data.num_edges