-
Notifications
You must be signed in to change notification settings - Fork 33
/
plot_lora.py
48 lines (41 loc) · 1.83 KB
/
plot_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import logging
import torch
def log_lora(lora_layers, log_weights=True, log_grad=True, log_level=logging.INFO):
if not log_weights and not log_grad:
return
try:
from fewlines import bar
except ImportError:
logging.error('Unable to import fewlines. "pip install fewlines" to use distribution logging')
return
gradients_a = {}
gradients_b = {}
weights_a = {}
weights_b = {}
for i, lora in enumerate(lora_layers):
q = lora['q_lora']
v = lora['v_lora']
if log_grad:
gradients_a[f'Q{i}.A'] = q.A.weight.grad.view(-1).to(torch.float32).tolist()
gradients_b[f'Q{i}.B'] = q.B.weight.grad.view(-1).to(torch.float32).tolist()
gradients_a[f'V{i}.A'] = v.A.weight.grad.view(-1).to(torch.float32).tolist()
gradients_b[f'V{i}.B'] = v.B.weight.grad.view(-1).to(torch.float32).tolist()
if log_weights:
weights_a[f'Q{i}.A'] = q.A.weight.view(-1).to(torch.float32).tolist()
weights_b[f'Q{i}.B'] = q.B.weight.view(-1).to(torch.float32).tolist()
weights_a[f'V{i}.A'] = v.A.weight.view(-1).to(torch.float32).tolist()
weights_b[f'V{i}.B'] = v.B.weight.view(-1).to(torch.float32).tolist()
if log_grad:
logging.log(log_level, f'\n=== GRADIENTS A ===')
for l in bar.bar_histograms(gradients_a):
logging.log(log_level, l)
logging.log(log_level, f'\n=== GRADIENTS B ===')
for l in bar.bar_histograms(gradients_b):
logging.log(log_level, l)
if log_weights:
logging.log(log_level, f'\n=== WEIGHTS A ===')
for l in bar.bar_histograms(weights_a):
logging.log(log_level, l)
logging.log(log_level, f'\n=== WEIGHTS B ===')
for l in bar.bar_histograms(weights_b):
logging.log(log_level, l)