-
Notifications
You must be signed in to change notification settings - Fork 837
/
dlrm_data_pytorch.py
1284 lines (1127 loc) · 45.3 KB
/
dlrm_data_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Description: generate inputs and targets for the dlrm benchmark
# The inputs and outputs are generated according to the following three option(s)
# 1) random distribution
# 2) synthetic distribution, based on unique accesses and distances between them
# i) R. Hassan, A. Harris, N. Topham and A. Efthymiou "Synthetic Trace-Driven
# Simulation of Cache Memory", IEEE AINAM'07
# 3) public data set
# i) Criteo Kaggle Display Advertising Challenge Dataset
# https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset
# ii) Criteo Terabyte Dataset
# https://labs.criteo.com/2013/12/download-terabyte-click-logs
from __future__ import absolute_import, division, print_function, unicode_literals
import bisect
import collections
import sys
from collections import deque
# others
from os import path
import data_loader_terabyte
import data_utils
import mlperf_logger
# numpy
import numpy as np
# pytorch
import torch
from numpy import random as ra
from torch.utils.data import Dataset, RandomSampler
# Kaggle Display Advertising Challenge Dataset
# dataset (str): name of dataset (Kaggle or Terabyte)
# randomize (str): determines randomization scheme
# "none": no randomization
# "day": randomizes each day"s data (only works if split = True)
# "total": randomizes total dataset
# split (bool) : to split into train, test, validation data-sets
class CriteoDataset(Dataset):
def __init__(
self,
dataset,
max_ind_range,
sub_sample_rate,
randomize,
split="train",
raw_path="",
pro_data="",
memory_map=False,
dataset_multiprocessing=False,
):
# dataset
# tar_fea = 1 # single target
den_fea = 13 # 13 dense features
# spa_fea = 26 # 26 sparse features
# tad_fea = tar_fea + den_fea
# tot_fea = tad_fea + spa_fea
if dataset == "kaggle":
days = 7
out_file = "kaggleAdDisplayChallenge_processed"
elif dataset == "terabyte":
days = 24
out_file = "terabyte_processed"
else:
raise (ValueError("Data set option is not supported"))
self.max_ind_range = max_ind_range
self.memory_map = memory_map
# split the datafile into path and filename
lstr = raw_path.split("/")
self.d_path = "/".join(lstr[0:-1]) + "/"
self.d_file = lstr[-1].split(".")[0] if dataset == "kaggle" else lstr[-1]
self.npzfile = self.d_path + (
(self.d_file + "_day") if dataset == "kaggle" else self.d_file
)
self.trafile = self.d_path + (
(self.d_file + "_fea") if dataset == "kaggle" else "fea"
)
# check if pre-processed data is available
data_ready = True
if memory_map:
for i in range(days):
reo_data = self.npzfile + "_{0}_reordered.npz".format(i)
if not path.exists(str(reo_data)):
data_ready = False
else:
if not path.exists(str(pro_data)):
data_ready = False
# pre-process data if needed
# WARNNING: when memory mapping is used we get a collection of files
if data_ready:
print("Reading pre-processed data=%s" % (str(pro_data)))
file = str(pro_data)
else:
print("Reading raw data=%s" % (str(raw_path)))
file = data_utils.getCriteoAdData(
raw_path,
out_file,
max_ind_range,
sub_sample_rate,
days,
split,
randomize,
dataset == "kaggle",
memory_map,
dataset_multiprocessing,
)
# get a number of samples per day
total_file = self.d_path + self.d_file + "_day_count.npz"
with np.load(total_file) as data:
total_per_file = data["total_per_file"]
# compute offsets per file
self.offset_per_file = np.array([0] + [x for x in total_per_file])
for i in range(days):
self.offset_per_file[i + 1] += self.offset_per_file[i]
# print(self.offset_per_file)
# setup data
if memory_map:
# setup the training/testing split
self.split = split
if split == "none" or split == "train":
self.day = 0
self.max_day_range = days if split == "none" else days - 1
elif split == "test" or split == "val":
self.day = days - 1
num_samples = (
self.offset_per_file[days] - self.offset_per_file[days - 1]
)
self.test_size = int(np.ceil(num_samples / 2.0))
self.val_size = num_samples - self.test_size
else:
sys.exit("ERROR: dataset split is neither none, nor train or test.")
"""
# text
print("text")
for i in range(days):
fi = self.npzfile + "_{0}".format(i)
with open(fi) as data:
ttt = 0; nnn = 0
for _j, line in enumerate(data):
ttt +=1
if np.int32(line[0]) > 0:
nnn +=1
print("day=" + str(i) + " total=" + str(ttt) + " non-zeros="
+ str(nnn) + " ratio=" +str((nnn * 100.) / ttt) + "%")
# processed
print("processed")
for i in range(days):
fi = self.npzfile + "_{0}_processed.npz".format(i)
with np.load(fi) as data:
yyy = data["y"]
ttt = len(yyy)
nnn = np.count_nonzero(yyy)
print("day=" + str(i) + " total=" + str(ttt) + " non-zeros="
+ str(nnn) + " ratio=" +str((nnn * 100.) / ttt) + "%")
# reordered
print("reordered")
for i in range(days):
fi = self.npzfile + "_{0}_reordered.npz".format(i)
with np.load(fi) as data:
yyy = data["y"]
ttt = len(yyy)
nnn = np.count_nonzero(yyy)
print("day=" + str(i) + " total=" + str(ttt) + " non-zeros="
+ str(nnn) + " ratio=" +str((nnn * 100.) / ttt) + "%")
"""
# load unique counts
with np.load(self.d_path + self.d_file + "_fea_count.npz") as data:
self.counts = data["counts"]
self.m_den = den_fea # X_int.shape[1]
self.n_emb = len(self.counts)
print("Sparse features= %d, Dense features= %d" % (self.n_emb, self.m_den))
# Load the test data
# Only a single day is used for testing
if self.split == "test" or self.split == "val":
# only a single day is used for testing
fi = self.npzfile + "_{0}_reordered.npz".format(self.day)
with np.load(fi) as data:
self.X_int = data["X_int"] # continuous feature
self.X_cat = data["X_cat"] # categorical feature
self.y = data["y"] # target
else:
# load and preprocess data
with np.load(file) as data:
X_int = data["X_int"] # continuous feature
X_cat = data["X_cat"] # categorical feature
y = data["y"] # target
self.counts = data["counts"]
self.m_den = X_int.shape[1] # den_fea
self.n_emb = len(self.counts)
print("Sparse fea = %d, Dense fea = %d" % (self.n_emb, self.m_den))
# create reordering
indices = np.arange(len(y))
if split == "none":
# randomize all data
if randomize == "total":
indices = np.random.permutation(indices)
print("Randomized indices...")
X_int[indices] = X_int
X_cat[indices] = X_cat
y[indices] = y
else:
indices = np.array_split(indices, self.offset_per_file[1:-1])
# randomize train data (per day)
if randomize == "day": # or randomize == "total":
for i in range(len(indices) - 1):
indices[i] = np.random.permutation(indices[i])
print("Randomized indices per day ...")
train_indices = np.concatenate(indices[:-1])
test_indices = indices[-1]
test_indices, val_indices = np.array_split(test_indices, 2)
print("Defined %s indices..." % (split))
# randomize train data (across days)
if randomize == "total":
train_indices = np.random.permutation(train_indices)
print("Randomized indices across days ...")
# create training, validation, and test sets
if split == "train":
self.X_int = [X_int[i] for i in train_indices]
self.X_cat = [X_cat[i] for i in train_indices]
self.y = [y[i] for i in train_indices]
elif split == "val":
self.X_int = [X_int[i] for i in val_indices]
self.X_cat = [X_cat[i] for i in val_indices]
self.y = [y[i] for i in val_indices]
elif split == "test":
self.X_int = [X_int[i] for i in test_indices]
self.X_cat = [X_cat[i] for i in test_indices]
self.y = [y[i] for i in test_indices]
print("Split data according to indices...")
def __getitem__(self, index):
if isinstance(index, slice):
return [
self[idx]
for idx in range(
index.start or 0, index.stop or len(self), index.step or 1
)
]
if self.memory_map:
if self.split == "none" or self.split == "train":
# check if need to swicth to next day and load data
if index == self.offset_per_file[self.day]:
# print("day_boundary switch", index)
self.day_boundary = self.offset_per_file[self.day]
fi = self.npzfile + "_{0}_reordered.npz".format(self.day)
# print('Loading file: ', fi)
with np.load(fi) as data:
self.X_int = data["X_int"] # continuous feature
self.X_cat = data["X_cat"] # categorical feature
self.y = data["y"] # target
self.day = (self.day + 1) % self.max_day_range
i = index - self.day_boundary
elif self.split == "test" or self.split == "val":
# only a single day is used for testing
i = index + (0 if self.split == "test" else self.test_size)
else:
sys.exit("ERROR: dataset split is neither none, nor train or test.")
else:
i = index
if self.max_ind_range > 0:
return self.X_int[i], self.X_cat[i] % self.max_ind_range, self.y[i]
else:
return self.X_int[i], self.X_cat[i], self.y[i]
def _default_preprocess(self, X_int, X_cat, y):
X_int = torch.log(torch.tensor(X_int, dtype=torch.float) + 1)
if self.max_ind_range > 0:
X_cat = torch.tensor(X_cat % self.max_ind_range, dtype=torch.long)
else:
X_cat = torch.tensor(X_cat, dtype=torch.long)
y = torch.tensor(y.astype(np.float32))
return X_int, X_cat, y
def __len__(self):
if self.memory_map:
if self.split == "none":
return self.offset_per_file[-1]
elif self.split == "train":
return self.offset_per_file[-2]
elif self.split == "test":
return self.test_size
elif self.split == "val":
return self.val_size
else:
sys.exit("ERROR: dataset split is neither none, nor train nor test.")
else:
return len(self.y)
def collate_wrapper_criteo_offset(list_of_tuples):
# where each tuple is (X_int, X_cat, y)
transposed_data = list(zip(*list_of_tuples))
X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)
batchSize = X_cat.shape[0]
featureCnt = X_cat.shape[1]
lS_i = [X_cat[:, i] for i in range(featureCnt)]
lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
return X_int, torch.stack(lS_o), torch.stack(lS_i), T
def ensure_dataset_preprocessed(args, d_path):
_ = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
_ = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
for split in ["train", "val", "test"]:
print("Running preprocessing for split =", split)
train_files = [
"{}_{}_reordered.npz".format(args.raw_data_file, day)
for day in range(0, 23)
]
test_valid_file = args.raw_data_file + "_23_reordered.npz"
output_file = d_path + "_{}.bin".format(split)
input_files = train_files if split == "train" else [test_valid_file]
data_loader_terabyte.numpy_to_binary(
input_files=input_files, output_file_path=output_file, split=split
)
# Conversion from offset to length
def offset_to_length_converter(lS_o, lS_i):
def diff(tensor):
return tensor[1:] - tensor[:-1]
return torch.stack(
[
diff(torch.cat((S_o, torch.tensor(lS_i[ind].shape))).int())
for ind, S_o in enumerate(lS_o)
]
)
def collate_wrapper_criteo_length(list_of_tuples):
# where each tuple is (X_int, X_cat, y)
transposed_data = list(zip(*list_of_tuples))
X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)
batchSize = X_cat.shape[0]
featureCnt = X_cat.shape[1]
lS_i = torch.stack([X_cat[:, i] for i in range(featureCnt)])
lS_o = torch.stack([torch.tensor(range(batchSize)) for _ in range(featureCnt)])
lS_l = offset_to_length_converter(lS_o, lS_i)
return X_int, lS_l, lS_i, T
def make_criteo_data_and_loaders(args, offset_to_length_converter=False):
if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
# more efficient for larger batches
data_directory = path.dirname(args.raw_data_file)
if args.mlperf_bin_loader:
lstr = args.processed_data_file.split("/")
d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0]
train_file = d_path + "_train.bin"
test_file = d_path + "_test.bin"
# val_file = d_path + "_val.bin"
counts_file = args.raw_data_file + "_fea_count.npz"
if any(not path.exists(p) for p in [train_file, test_file, counts_file]):
ensure_dataset_preprocessed(args, d_path)
train_data = data_loader_terabyte.CriteoBinDataset(
data_file=train_file,
counts_file=counts_file,
batch_size=args.mini_batch_size,
max_ind_range=args.max_ind_range,
)
mlperf_logger.log_event(
key=mlperf_logger.constants.TRAIN_SAMPLES, value=train_data.num_samples
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=None,
batch_sampler=None,
shuffle=False,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
sampler=RandomSampler(train_data) if args.mlperf_bin_shuffle else None,
)
test_data = data_loader_terabyte.CriteoBinDataset(
data_file=test_file,
counts_file=counts_file,
batch_size=args.test_mini_batch_size,
max_ind_range=args.max_ind_range,
)
mlperf_logger.log_event(
key=mlperf_logger.constants.EVAL_SAMPLES, value=test_data.num_samples
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=None,
batch_sampler=None,
shuffle=False,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
else:
data_filename = args.raw_data_file.split("/")[-1]
train_data = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
test_data = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
train_loader = data_loader_terabyte.DataLoader(
data_directory=data_directory,
data_filename=data_filename,
days=list(range(23)),
batch_size=args.mini_batch_size,
max_ind_range=args.max_ind_range,
split="train",
)
test_loader = data_loader_terabyte.DataLoader(
data_directory=data_directory,
data_filename=data_filename,
days=[23],
batch_size=args.test_mini_batch_size,
max_ind_range=args.max_ind_range,
split="test",
)
else:
train_data = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"train",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
test_data = CriteoDataset(
args.data_set,
args.max_ind_range,
args.data_sub_sample_rate,
args.data_randomize,
"test",
args.raw_data_file,
args.processed_data_file,
args.memory_map,
args.dataset_multiprocessing,
)
collate_wrapper_criteo = collate_wrapper_criteo_offset
if offset_to_length_converter:
collate_wrapper_criteo = collate_wrapper_criteo_length
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.mini_batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
drop_last=False, # True
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=args.test_mini_batch_size,
shuffle=False,
num_workers=args.test_num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
drop_last=False, # True
)
return train_data, train_loader, test_data, test_loader
# uniform ditribution (input data)
class RandomDataset(Dataset):
def __init__(
self,
m_den,
ln_emb,
data_size,
num_batches,
mini_batch_size,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
num_targets=1,
round_targets=False,
data_generation="random",
trace_file="",
enable_padding=False,
reset_seed_on_access=False,
rand_data_dist="uniform",
rand_data_min=1,
rand_data_max=1,
rand_data_mu=-1,
rand_data_sigma=1,
rand_seed=0,
):
# compute batch size
nbatches = int(np.ceil((data_size * 1.0) / mini_batch_size))
if num_batches != 0:
nbatches = num_batches
data_size = nbatches * mini_batch_size
# print("Total number of batches %d" % nbatches)
# save args (recompute data_size if needed)
self.m_den = m_den
self.ln_emb = ln_emb
self.data_size = data_size
self.num_batches = nbatches
self.mini_batch_size = mini_batch_size
self.num_indices_per_lookup = num_indices_per_lookup
self.num_indices_per_lookup_fixed = num_indices_per_lookup_fixed
self.num_targets = num_targets
self.round_targets = round_targets
self.data_generation = data_generation
self.trace_file = trace_file
self.enable_padding = enable_padding
self.reset_seed_on_access = reset_seed_on_access
self.rand_seed = rand_seed
self.rand_data_dist = rand_data_dist
self.rand_data_min = rand_data_min
self.rand_data_max = rand_data_max
self.rand_data_mu = rand_data_mu
self.rand_data_sigma = rand_data_sigma
def reset_numpy_seed(self, numpy_rand_seed):
np.random.seed(numpy_rand_seed)
# torch.manual_seed(numpy_rand_seed)
def __getitem__(self, index):
if isinstance(index, slice):
return [
self[idx]
for idx in range(
index.start or 0, index.stop or len(self), index.step or 1
)
]
# WARNING: reset seed on access to first element
# (e.g. if same random samples needed across epochs)
if self.reset_seed_on_access and index == 0:
self.reset_numpy_seed(self.rand_seed)
# number of data points in a batch
n = min(self.mini_batch_size, self.data_size - (index * self.mini_batch_size))
# generate a batch of dense and sparse features
if self.data_generation == "random":
(X, lS_o, lS_i) = generate_dist_input_batch(
self.m_den,
self.ln_emb,
n,
self.num_indices_per_lookup,
self.num_indices_per_lookup_fixed,
rand_data_dist=self.rand_data_dist,
rand_data_min=self.rand_data_min,
rand_data_max=self.rand_data_max,
rand_data_mu=self.rand_data_mu,
rand_data_sigma=self.rand_data_sigma,
)
elif self.data_generation == "synthetic":
(X, lS_o, lS_i) = generate_synthetic_input_batch(
self.m_den,
self.ln_emb,
n,
self.num_indices_per_lookup,
self.num_indices_per_lookup_fixed,
self.trace_file,
self.enable_padding,
)
else:
sys.exit(
"ERROR: --data-generation=" + self.data_generation + " is not supported"
)
# generate a batch of target (probability of a click)
T = generate_random_output_batch(n, self.num_targets, self.round_targets)
return (X, lS_o, lS_i, T)
def __len__(self):
# WARNING: note that we produce bacthes of outputs in __getitem__
# therefore we should use num_batches rather than data_size below
return self.num_batches
def collate_wrapper_random_offset(list_of_tuples):
# where each tuple is (X, lS_o, lS_i, T)
(X, lS_o, lS_i, T) = list_of_tuples[0]
return (X, torch.stack(lS_o), lS_i, T)
def collate_wrapper_random_length(list_of_tuples):
# where each tuple is (X, lS_o, lS_i, T)
(X, lS_o, lS_i, T) = list_of_tuples[0]
return (X, offset_to_length_converter(torch.stack(lS_o), lS_i), lS_i, T)
def make_random_data_and_loader(
args,
ln_emb,
m_den,
offset_to_length_converter=False,
):
train_data = RandomDataset(
m_den,
ln_emb,
args.data_size,
args.num_batches,
args.mini_batch_size,
args.num_indices_per_lookup,
args.num_indices_per_lookup_fixed,
1, # num_targets
args.round_targets,
args.data_generation,
args.data_trace_file,
args.data_trace_enable_padding,
reset_seed_on_access=True,
rand_data_dist=args.rand_data_dist,
rand_data_min=args.rand_data_min,
rand_data_max=args.rand_data_max,
rand_data_mu=args.rand_data_mu,
rand_data_sigma=args.rand_data_sigma,
rand_seed=args.numpy_rand_seed,
) # WARNING: generates a batch of lookups at once
test_data = RandomDataset(
m_den,
ln_emb,
args.data_size,
args.num_batches,
args.mini_batch_size,
args.num_indices_per_lookup,
args.num_indices_per_lookup_fixed,
1, # num_targets
args.round_targets,
args.data_generation,
args.data_trace_file,
args.data_trace_enable_padding,
reset_seed_on_access=True,
rand_data_dist=args.rand_data_dist,
rand_data_min=args.rand_data_min,
rand_data_max=args.rand_data_max,
rand_data_mu=args.rand_data_mu,
rand_data_sigma=args.rand_data_sigma,
rand_seed=args.numpy_rand_seed,
)
collate_wrapper_random = collate_wrapper_random_offset
if offset_to_length_converter:
collate_wrapper_random = collate_wrapper_random_length
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_wrapper_random,
pin_memory=False,
drop_last=False, # True
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_wrapper_random,
pin_memory=False,
drop_last=False, # True
)
return train_data, train_loader, test_data, test_loader
def generate_random_data(
m_den,
ln_emb,
data_size,
num_batches,
mini_batch_size,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
num_targets=1,
round_targets=False,
data_generation="random",
trace_file="",
enable_padding=False,
length=False, # length for caffe2 version (except dlrm_s_caffe2)
):
nbatches = int(np.ceil((data_size * 1.0) / mini_batch_size))
if num_batches != 0:
nbatches = num_batches
data_size = nbatches * mini_batch_size
# print("Total number of batches %d" % nbatches)
# inputs
lT = []
lX = []
lS_offsets = []
lS_indices = []
for j in range(0, nbatches):
# number of data points in a batch
n = min(mini_batch_size, data_size - (j * mini_batch_size))
# generate a batch of dense and sparse features
if data_generation == "random":
(Xt, lS_emb_offsets, lS_emb_indices) = generate_uniform_input_batch(
m_den,
ln_emb,
n,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
length,
)
elif data_generation == "synthetic":
(Xt, lS_emb_offsets, lS_emb_indices) = generate_synthetic_input_batch(
m_den,
ln_emb,
n,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
trace_file,
enable_padding,
)
else:
sys.exit(
"ERROR: --data-generation=" + data_generation + " is not supported"
)
# dense feature
lX.append(Xt)
# sparse feature (sparse indices)
lS_offsets.append(lS_emb_offsets)
lS_indices.append(lS_emb_indices)
# generate a batch of target (probability of a click)
P = generate_random_output_batch(n, num_targets, round_targets)
lT.append(P)
return (nbatches, lX, lS_offsets, lS_indices, lT)
def generate_random_output_batch(n, num_targets, round_targets=False):
# target (probability of a click)
if round_targets:
P = np.round(ra.rand(n, num_targets).astype(np.float32)).astype(np.float32)
else:
P = ra.rand(n, num_targets).astype(np.float32)
return torch.tensor(P)
# uniform ditribution (input data)
def generate_uniform_input_batch(
m_den,
ln_emb,
n,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
length,
):
# dense feature
Xt = torch.tensor(ra.rand(n, m_den).astype(np.float32))
# sparse feature (sparse indices)
lS_emb_offsets = []
lS_emb_indices = []
# for each embedding generate a list of n lookups,
# where each lookup is composed of multiple sparse indices
for size in ln_emb:
lS_batch_offsets = []
lS_batch_indices = []
offset = 0
for _ in range(n):
# num of sparse indices to be used per embedding (between
if num_indices_per_lookup_fixed:
sparse_group_size = np.int64(num_indices_per_lookup)
else:
# random between [1,num_indices_per_lookup])
r = ra.random(1)
sparse_group_size = np.int64(
np.round(max([1.0], r * min(size, num_indices_per_lookup)))
)
# sparse indices to be used per embedding
r = ra.random(sparse_group_size)
sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
# reset sparse_group_size in case some index duplicates were removed
sparse_group_size = np.int32(sparse_group.size)
# store lengths and indices
if length: # for caffe2 version
lS_batch_offsets += [sparse_group_size]
else:
lS_batch_offsets += [offset]
lS_batch_indices += sparse_group.tolist()
# update offset for next iteration
offset += sparse_group_size
lS_emb_offsets.append(torch.tensor(lS_batch_offsets))
lS_emb_indices.append(torch.tensor(lS_batch_indices))
return (Xt, lS_emb_offsets, lS_emb_indices)
# random data from uniform or gaussian ditribution (input data)
def generate_dist_input_batch(
m_den,
ln_emb,
n,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
rand_data_dist,
rand_data_min,
rand_data_max,
rand_data_mu,
rand_data_sigma,
):
# dense feature
Xt = torch.tensor(ra.rand(n, m_den).astype(np.float32))
# sparse feature (sparse indices)
lS_emb_offsets = []
lS_emb_indices = []
# for each embedding generate a list of n lookups,
# where each lookup is composed of multiple sparse indices
for size in ln_emb:
lS_batch_offsets = []
lS_batch_indices = []
offset = 0
for _ in range(n):
# num of sparse indices to be used per embedding (between
if num_indices_per_lookup_fixed:
sparse_group_size = np.int64(num_indices_per_lookup)
else:
# random between [1,num_indices_per_lookup])
r = ra.random(1)
sparse_group_size = np.int64(
np.round(max([1.0], r * min(size, num_indices_per_lookup)))
)
# sparse indices to be used per embedding
if rand_data_dist == "gaussian":
if rand_data_mu == -1:
rand_data_mu = (rand_data_max + rand_data_min) / 2.0
r = ra.normal(rand_data_mu, rand_data_sigma, sparse_group_size)
sparse_group = np.clip(r, rand_data_min, rand_data_max)
sparse_group = np.unique(sparse_group).astype(np.int64)
elif rand_data_dist == "uniform":
r = ra.random(sparse_group_size)
sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
else:
raise (
rand_data_dist,
"distribution is not supported. \
please select uniform or gaussian",
)
# reset sparse_group_size in case some index duplicates were removed
sparse_group_size = np.int64(sparse_group.size)
# store lengths and indices
lS_batch_offsets += [offset]
lS_batch_indices += sparse_group.tolist()
# update offset for next iteration
offset += sparse_group_size
lS_emb_offsets.append(torch.tensor(lS_batch_offsets))
lS_emb_indices.append(torch.tensor(lS_batch_indices))
return (Xt, lS_emb_offsets, lS_emb_indices)
# synthetic distribution (input data)
def generate_synthetic_input_batch(
m_den,
ln_emb,
n,
num_indices_per_lookup,
num_indices_per_lookup_fixed,
trace_file,
enable_padding=False,
):
# dense feature
Xt = torch.tensor(ra.rand(n, m_den).astype(np.float32))
# sparse feature (sparse indices)
lS_emb_offsets = []
lS_emb_indices = []
# for each embedding generate a list of n lookups,
# where each lookup is composed of multiple sparse indices
for i, size in enumerate(ln_emb):
lS_batch_offsets = []
lS_batch_indices = []
offset = 0
for _ in range(n):
# num of sparse indices to be used per embedding (between
if num_indices_per_lookup_fixed:
sparse_group_size = np.int64(num_indices_per_lookup)
else:
# random between [1,num_indices_per_lookup])
r = ra.random(1)
sparse_group_size = np.int64(
max(1, np.round(r * min(size, num_indices_per_lookup))[0])
)
# sparse indices to be used per embedding
file_path = trace_file
line_accesses, list_sd, cumm_sd = read_dist_from_file(
file_path.replace("j", str(i))
)
# debug prints