-
Notifications
You must be signed in to change notification settings - Fork 0
/
_temp_query_strategies.py
60 lines (51 loc) · 1.78 KB
/
_temp_query_strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
from random import randint
from spacy.util import fix_random_seed
fix_random_seed(42)
def query_random(examples, exclude, n_instances):
"""Random sampling strategy."""
n_queried = 0
max_idx = len(examples) - 1
_inner_exclude = set(exclude)
while n_queried < n_instances:
idx = randint(0, max_idx)
if idx not in _inner_exclude:
_inner_exclude.add(idx)
n_queried += 1
yield idx, examples[idx]
def query_least_confidence(nlp, included_components, examples,
exclude, n_instances, spans_key):
"""Least confidence sampling strategy for multilabeled data for spaCy's
SpanCategorizer. Based on mean score of all labels for each span.
"""
def _get_least_confident():
for _idx in indexes_by_score:
include = include_arr[_idx]
if include:
return _idx
disabled_comps = set(nlp.pipe_names) - set(included_components)
texts = [
example.text
for i, example in enumerate(examples)
if i not in exclude
]
n_process = os.cpu_count()
predictions = nlp.pipe(texts, disable=disabled_comps, n_process=n_process)
ex_len = len(examples)
include_arr = np.repeat(True, ex_len)
for ex in exclude:
include_arr[ex] = False
scores = np.zeros(ex_len)
for i, pred in enumerate(predictions):
spans = pred.spans[spans_key]
if spans:
scores[i] = spans.attrs["scores"].mean()
indexes_by_score = np.argsort(scores) # ascending order
n_queried = 0
while n_queried < n_instances:
idx = _get_least_confident()
include_arr[idx] = False
example = examples[idx]
n_queried += 1
yield idx, example