-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_distributed_sigmoid_loss.py
148 lines (117 loc) · 4.62 KB
/
test_distributed_sigmoid_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import sys
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from distributed_sigmoid_loss import DDPSigmoidLoss
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try:
torch.manual_seed(s)
except NameError:
pass
try:
torch.cuda.manual_seed_all(s)
except NameError:
pass
try:
np.random.seed(s % (2 ** 32 - 1))
except NameError:
pass
random.seed(s)
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup(rank, world_size):
if sys.platform == 'win32':
# Distributed package only covers collective communications with Gloo
# backend and FileStore on Windows platform. Set init_method parameter
# in init_process_group to a local file.
# Example init_method="file:///f:/libtmp/some_file"
init_method = "file:///{your local file path}"
# initialize the process group
dist.init_process_group("gloo", init_method=init_method, rank=rank, world_size=world_size)
else:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_partition(rank, world_size, gpu_batch_size, emb_dim):
set_seed(42)
# image_inputs = torch.range(0, bz * emb_dim - 1).reshape(bz, emb_dim)
# set_seed(42)
# text_inputs = torch.range(0, bz * emb_dim - 1).reshape(bz, emb_dim)
image_inputs = torch.randn(world_size * gpu_batch_size, emb_dim)
set_seed(40)
text_inputs = torch.randn(world_size * gpu_batch_size, emb_dim)
return (
image_inputs[rank * gpu_batch_size : (rank + 1) * gpu_batch_size],
text_inputs[rank * gpu_batch_size : (rank + 1) * gpu_batch_size],
)
def get_encoders(emb_dim, output_dim=2):
set_seed(42)
image_encoder = nn.Linear(emb_dim, output_dim, bias=False)
set_seed(42)
text_encoder = nn.Linear(emb_dim, output_dim, bias=False)
return image_encoder, text_encoder
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def toy_forward_backward_pass(rank, world_size, bz, emb_dim=2, return_dict=None):
setup(rank, world_size)
assert bz % world_size == 0
gpu_batch_size = bz // world_size
image_inputs, text_inputs = get_partition(rank, world_size, gpu_batch_size, emb_dim)
image_encoder, text_encoder = get_encoders(emb_dim)
# Toy forward (compute embedding)
image_embeddings = image_encoder(image_inputs)
text_embeddings = text_encoder(text_inputs)
# L2 Normalize features
image_embeddings = F.normalize(image_embeddings)
text_embeddings = F.normalize(text_embeddings)
# Compute loss
loss = DDPSigmoidLoss(gpu_batch_size)(image_embeddings, text_embeddings)
# Toy backward (compute gradients)
loss.backward()
average_gradients(text_encoder)
# # # check gradients
# print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}")
# # # average gradient from all devices
average_gradients(image_encoder)
# # # check gradients
# print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}")
if rank == 0:
return_dict['img_grad'] = image_encoder.weight.grad
return_dict['txt_grad'] = text_encoder.weight.grad
def test_same_gradient(emb_dim=2, world_size=2, batch_size=4):
manager = mp.Manager()
w2_return_dict = manager.dict()
mp.spawn(
toy_forward_backward_pass,
args=(world_size, batch_size, emb_dim, w2_return_dict),
nprocs=world_size,
join=True,
)
w1_return_dict = manager.dict()
world_size = 1
mp.spawn(
toy_forward_backward_pass,
args=(world_size, batch_size, emb_dim, w1_return_dict),
nprocs=world_size,
join=True,
)
assert torch.allclose(w2_return_dict['img_grad'], w1_return_dict['img_grad'], rtol=1e-3)
assert torch.allclose(w2_return_dict['txt_grad'], w1_return_dict['txt_grad'], rtol=1e-3)
if __name__ == "__main__":
test_same_gradient(world_size=3, batch_size=3)
test_same_gradient(world_size=2, batch_size=4)
test_same_gradient(world_size=2, batch_size=4, emb_dim=128)
test_same_gradient(world_size=2, batch_size=4, emb_dim=512)