-
Notifications
You must be signed in to change notification settings - Fork 0
/
table_cell_merge_model.py
107 lines (89 loc) · 4.68 KB
/
table_cell_merge_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import inspect
import torch
import torch.nn as nn
from model import GPTConfig, LayerNorm, Block
class TableCellMergeModel(nn.Module):
def __init__(self, config: GPTConfig, ckpt_state_dict=None):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.flatten = nn.Flatten()
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=LayerNorm(config.n_embd, bias=config.bias),
))
# binary classification
self.class_head = nn.Linear(config.n_embd * config.block_size, 1)
self.sigmoid = nn.Sigmoid()
# initialize model for transfer learning / finetuning
if ckpt_state_dict is not None:
unwanted_prefix = '_orig_mod.'
for k, v in list(ckpt_state_dict.items()):
if k.startswith(unwanted_prefix):
ckpt_state_dict[k[len(unwanted_prefix):]] = ckpt_state_dict.pop(k)
ckpt_state_dict.pop('lm_head.weight', None)
torch.nn.init.normal_(self.class_head.weight, mean=0.0, std=0.02)
torch.nn.init.zeros_(self.class_head.bias)
model_dict = self.state_dict()
model_dict.update(ckpt_state_dict)
self.load_state_dict(model_dict)
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
# print(f"forward x: {x.shape}")
x = self.flatten(x)
# print(f"forward flatten x: {x.shape}")
logits = self.class_head(x)
# print(f"logits: {logits.shape}")
y = self.sigmoid(logits)
return y
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer