-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate.py
48 lines (34 loc) · 1.8 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import jax
import jax.numpy as np
from lib.Generator import Generator
from lib.model import fwd_embedding, fwd_layer_norm, fwd_transformer_encoder
from lib.param_utils.load_params import load_params
from lib.tokeniser import BartTokenizerWithoutOverflowEOS
def fwd_encode(params: dict, src: np.ndarray, mask_enc: np.ndarray) -> np.ndarray:
# params
embedding: dict = params['embedding'] # embedding
encoder_embed_positions: np.ndarray = params['encoder_embed_positions'] # array
encoder_embed_layer_norm: dict = params['encoder_embed_layer_norm'] # layer norm
encoder_layers: list = params['encoder_layers'] # list of transformer encoder
_, width_enc = src.shape
offset = 2
# encoder
src = fwd_embedding(embedding, src)
src = src + encoder_embed_positions[offset:width_enc+offset]
src = fwd_layer_norm(encoder_embed_layer_norm, src)
for encoder_layer in encoder_layers:
src = fwd_transformer_encoder(encoder_layer, src, mask_enc)
return src
tokenizer = BartTokenizerWithoutOverflowEOS.from_pretrained('facebook/bart-base')
sentences = ['Can you see the beautiful flowers <mask> alongside the track?', 'Upon graduation, <mask> of herself.']
batch = tokenizer(sentences, padding=True, return_tensors='jax')
src = batch.input_ids
mask_enc_1d = batch.attention_mask.astype(np.bool_)
mask_enc = np.einsum('bi,bj->bij', mask_enc_1d, mask_enc_1d)[:, None]
params = load_params('params_bart_base_en.dat')
params = jax.tree_map(np.asarray, params)
encoder_last_hidden_output = fwd_encode(params, src, mask_enc)
generator = Generator(params)
generate_ids = generator.generate(encoder_last_hidden_output, mask_enc_1d, num_beams=5)
decoded_sentences = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(decoded_sentences)