-
Notifications
You must be signed in to change notification settings - Fork 5
/
vocabulary.py
130 lines (104 loc) · 3.77 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# coding=utf-8
"""
Vocabulary helper class
"""
import re
import numpy as np
# contains the data structure
class Vocabulary:
"""Stores the tokens and their conversion to vocabulary indexes."""
def __init__(self, tokens=None, starting_id=0):
self._tokens = {}
self._current_id = starting_id
if tokens:
for token, idx in tokens.items():
self._add(token, idx)
self._current_id = max(self._current_id, idx + 1)
def __getitem__(self, token_or_id):
return self._tokens[token_or_id]
def add(self, token):
"""Adds a token."""
if not isinstance(token, str):
raise TypeError("Token is not a string")
if token in self:
return self[token]
self._add(token, self._current_id)
self._current_id += 1
return self._current_id - 1
def update(self, tokens):
"""Adds many tokens."""
return [self.add(token) for token in tokens]
def __delitem__(self, token_or_id):
other_val = self._tokens[token_or_id]
del self._tokens[other_val]
del self._tokens[token_or_id]
def __contains__(self, token_or_id):
return token_or_id in self._tokens
def __eq__(self, other_vocabulary):
return self._tokens == other_vocabulary._tokens # pylint: disable=W0212
def __len__(self):
return len(self._tokens) // 2
def encode(self, tokens):
"""Encodes a list of tokens as vocabulary indexes."""
vocab_index = np.zeros(len(tokens), dtype=np.float32)
for i, token in enumerate(tokens):
vocab_index[i] = self._tokens[token]
return vocab_index
def decode(self, vocab_index):
"""Decodes a vocabulary index matrix to a list of tokens."""
tokens = []
for idx in vocab_index:
tokens.append(self[idx])
return tokens
def _add(self, token, idx):
if idx not in self._tokens:
self._tokens[token] = idx
self._tokens[idx] = token
else:
raise ValueError("IDX already present in vocabulary")
def tokens(self):
"""Returns the tokens from the vocabulary"""
return [t for t in self._tokens if isinstance(t, str)]
class SMILESTokenizer:
"""Deals with the tokenization and untokenization of SMILES."""
REGEXPS = {
"brackets": re.compile(r"(\[[^\]]*\])"),
"2_ring_nums": re.compile(r"(%\d{2})"),
"brcl": re.compile(r"(Br|Cl)")
}
REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"]
def tokenize(self, data, with_begin_and_end=True):
"""Tokenizes a SMILES string."""
def split_by(data, regexps):
if not regexps:
return list(data)
regexp = self.REGEXPS[regexps[0]]
splitted = regexp.split(data)
tokens = []
for i, split in enumerate(splitted):
if i % 2 == 0:
tokens += split_by(split, regexps[1:])
else:
tokens.append(split)
return tokens
tokens = split_by(data, self.REGEXP_ORDER)
if with_begin_and_end:
tokens = ["$"] + tokens + ["^"]
return tokens
def untokenize(self, tokens):
"""Untokenizes a SMILES string."""
smi = ""
for token in tokens:
if token == "^":
break
if token != "$":
smi += token
return smi
def create_vocabulary(smiles_list, tokenizer):
"""Creates a vocabulary for the SMILES syntax."""
tokens = set()
for smi in smiles_list:
tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False))
vocabulary = Vocabulary()
vocabulary.update(["<pad>", "$", "^"] + sorted(tokens))
return vocabulary