-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
116 lines (95 loc) · 4.22 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import re
import nltk
from transformers import AutoTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
import os
from data import DATA_DIR
from nltk.corpus import stopwords
def initialize_transform(transform_name, config):
if transform_name is None:
return None
if transform_name == 'hier-bert':
return initialize_hierbert_transform(config)
elif transform_name == 'tf-idf':
return initialize_tfidf(config)
else:
raise ValueError(f"{transform_name} not recognized")
def initialize_bert_transform(config):
assert 'longformer' in config.model
assert config.max_token_length is not None
tokenizer = AutoTokenizer.from_pretrained(config.model)
def transform(text):
tokens = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=config.max_token_length,
return_tensors='pt')
global_attention_mask = torch.zeros_like(tokens['input_ids'])
# global attention on cls token
global_attention_mask[:, 0] = 1
# global attention to sep tokens
global_attention_mask += (tokens['input_ids'] == tokenizer.sep_token_id).int()
x = torch.stack(
(tokens['input_ids'],
tokens['attention_mask'],
global_attention_mask),
dim=2)
x = torch.squeeze(x, dim=0)
return x
return transform
def initialize_hierbert_transform(config):
assert 'bert' in config.model or 'minilm' in config.model
assert config.max_segment_length is not None
assert config.max_segments is not None
tokenizer = AutoTokenizer.from_pretrained(config.model)
def transform(text):
paragraphs = []
paragraphs_tokens = {'input_ids': torch.zeros(config.max_segments, config.max_segment_length,
dtype=torch.int32),
'attention_mask': torch.zeros(config.max_segments, config.max_segment_length,
dtype=torch.int32)}
for idx, paragraph in enumerate(text.split('</s>')[:config.max_segments]):
paragraphs.append(paragraph)
tokens = tokenizer(
paragraphs,
padding='max_length',
truncation=True,
max_length=config.max_segment_length,
return_tensors='pt')
paragraphs_tokens['input_ids'][:len(paragraphs)] = tokens['input_ids']
paragraphs_tokens['attention_mask'][:len(paragraphs)] = tokens['attention_mask']
x = torch.stack(
(paragraphs_tokens['input_ids'],
paragraphs_tokens['attention_mask']),
dim=2)
# x = torch.squeeze(x, dim=0)
return x.to(torch.long)
return transform
def initialize_tfidf(config):
def preprocess_text(text: str):
return re.sub('[0-9]+', ' ', text)
def tokenize(text: str):
if config.dataset in ['ecthr', 'scotus']:
return nltk.word_tokenize(text)
elif config.dataset == 'cail':
return nltk.word_tokenize(text)
elif config.dataset == 'fscs':
return nltk.word_tokenize(text, language='german')
if config.dataset in ['ecthr', 'scotus']:
stop_words = set(stopwords.words('english'))
elif config.dataset == 'cail':
stop_words = None
elif config.dataset == 'fscs':
stop_words = set(stopwords.words('german') + stopwords.words('french') + stopwords.words('italian'))
vectorizer = TfidfVectorizer(ngram_range=(1, 3), max_features=5000, stop_words=stop_words,
preprocessor=preprocess_text, lowercase=False, tokenizer=tokenize, min_df=5)
with open(os.path.join(DATA_DIR, 'datasets', f'{config.dataset}_v1.0', f'{config.dataset}_dump.txt')) as file:
vectorizer.fit(file.readlines())
def transform(text):
text = ' '.join(text.replace('</s>', ' ').split()[:config.max_token_length])
x = torch.as_tensor(vectorizer.transform([text]).todense()).float()
x = torch.squeeze(x, dim=0)
return x
return transform