-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic_bigramLM.py
128 lines (90 loc) · 3.51 KB
/
basic_bigramLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
torch.manual_seed(1357)
@dataclass
class Config:
block_size = 8 # context-length
batch_size = 32 # mini-batch size
with open('./dataset/shakespeare.txt','r',encoding='utf-8') as f:
data = f.read()
class CharacterLevelTokenizer:
def __init__(self,data):
self.data = data
self.vocab = sorted(list(set(self.data)))
self.VOCAB_SIZE = len(self.vocab)
self.i_s = {i:s for i,s in enumerate(self.vocab)}
self.s_i = {s:i for i,s in self.i_s.items()}
def encode(self,s):
return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)
def decode(self,s):
return ''.join([self.i_s[i.item()] for i in s])
tokenizer = CharacterLevelTokenizer(data)
class ShakespeareDataset:
def __init__(self,block_size:int, is_test=False) -> None:
self.tokenizer = CharacterLevelTokenizer(data)
self.is_test = is_test
self.full_data = self.tokenizer.encode(self.tokenizer.data)
if self.is_test:
self.data = self.full_data[int(0.9*len(self.full_data)):]
else:
self.data = self.full_data[:int(0.9*len(self.full_data))]
self.block_size = block_size
def __len__(self) -> int:
return len(self.data)
def get_block_size(self) -> int:
return self.block_size
def get_vocab_size(self) -> int:
return self.tokenizer.VOCAB_SIZE
def __getitem__(self,idx):
item = self.data[idx:idx+self.block_size+1]
x = item[:-1]
y = item[1:]
return x,y
train_ds = ShakespeareDataset(Config.block_size)
val_ds = ShakespeareDataset(Config.block_size,is_test=True)
class BigramLM(nn.Module):
def __init__(self,vocab_size):
super(BigramLM,self).__init__()
self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
def forward(self,idx,targets=None):
logits = self.token_embedding_table(idx) # (B,T,C:vocab_size)
if targets is None:
loss = None
else:
# torch cross entropy expects B,C,T instead of B,T,C
# and for targets, we need B*T instead of B,T
B,T,C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits,targets)
return logits,loss
def generate(self,idx,total):
# idx (B,T) in current context
for _ in range(total):
logits,loss = self(idx)
# since the last element is the next character, we pluck out -1 from T
logits = logits[:,-1,:] # (B*T,C) -> (B,C)
probs = F.softmax(logits,dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx,idx_next],dim=1) # (B, T+=1)
return idx
bglm = BigramLM(tokenizer.VOCAB_SIZE)
optim = torch.optim.AdamW(bglm.parameters(),lr=1e-3)
bglm_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=Config.batch_size)
it = iter(bglm_dl)
for steps in range(25_000):
inputs,targets = next(it)
logits,loss=bglm(inputs,targets)
optim.zero_grad()
loss.backward()
optim.step()
if steps%2500==0:
print(f'step: {steps} loss: {loss.item()}')
generated = bglm.generate(
torch.zeros((1,1),dtype=torch.long), # initial context 0
total=500
)
generated = tokenizer.decode(generated[0])
print('generated (500 tokens) >>>\n',generated)