-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patha2_bleu_score.py
145 lines (115 loc) · 4 KB
/
a2_bleu_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
'''
This code is provided solely for the personal and private use of students
taking the CSC401H/2511H course at the University of Toronto. Copying for
purposes other than this use is expressly prohibited. All forms of
distribution of this code, including but not limited to public repositories on
GitHub, GitLab, Bitbucket, or any other online platform, whether as given or
with any changes, are expressly prohibited.
Authors: Sean Robertson, Jingcheng Niu, Zining Zhu, and Mohamed Abdall
Updated by: Raeid Saqur <raeidsaqur@cs.toronto.edu>
All of the files in this directory and all subdirectories are:
Copyright (c) 2022 University of Toronto
'''
'''Calculate BLEU score for one reference and one hypothesis
You do not need to import anything more than what is here
'''
from math import exp # exp(x) gives e^x
from typing import List, Sequence, Iterable
def grouper(seq:Sequence[str], n:int) -> List:
'''Extract all n-grams from a sequence
An n-gram is a contiguous sub-sequence within `seq` of length `n`. This
function extracts them (in order) from `seq`.
Parameters
----------
seq : sequence
A sequence of words or token ids representing a transcription.
n : int
The size of sub-sequence to extract.
Returns
-------
ngrams : list
'''
ngrams = []
for i in range(0, len(seq) - n + 1):
ngrams.append(" ".join(seq[i: i + n]))
return ngrams
def n_gram_precision(reference:Sequence[str], candidate:Sequence[str], n:int) -> float:
'''Calculate the precision for a given order of n-gram
Parameters
----------
reference : sequence
The reference transcription. A sequence of words or token ids.
candidate : sequence
The candidate transcription. A sequence of words or token ids
(whichever is used by `reference`)
n : int
The order of n-gram precision to calculate
Returns
-------
p_n : float
The n-gram precision. In the case that the candidate has length 0,
`p_n` is 0.
'''
p_n = 0
candidate_ngrams = grouper(candidate, n)
reference_ngrams = grouper(reference, n)
count = 0
if len(candidate) > 0:
for c in candidate_ngrams:
if c in reference_ngrams:
count+=1
if len(candidate_ngrams)>0:
p_n = count/len(candidate_ngrams)
else:
p_n = 0
return p_n
def brevity_penalty(reference:Sequence[str], candidate:Sequence[str]) -> float:
'''Calculate the brevity penalty between a reference and candidate
Parameters
----------
reference : sequence
The reference transcription. A sequence of words or token ids.
candidate : sequence
The candidate transcription. A sequence of words or token ids
(whichever is used by `reference`)
Returns
-------
BP : float
The brevity penalty. In the case that the candidate transcription is
of 0 length, `BP` is 0.
'''
BP = 0
if len(candidate) > 0:
br = len(reference)/len(candidate)
if br < 1:
BP = 1
else:
BP = exp(1-br)
else:
BP = 0
return BP
def BLEU_score(reference:Sequence[str], candidate:Sequence[str], n) -> float:
'''Calculate the BLEU score
Parameters
----------
reference : sequence
The reference transcription. A sequence of words or token ids.
candidate : sequence
The candidate transcription. A sequence of words or token ids
(whichever is used by `reference`)
n : int
The maximum order of n-gram precision to use in the calculations,
inclusive. For example, ``n = 2`` implies both unigram and bigram
precision will be accounted for, but not trigram.
Returns
-------
bleu : float
The BLEU score
'''
p_scores = 1
for i in range(1, n+1):
prec = n_gram_precision(reference, candidate, i)
p_scores *= prec
BP = brevity_penalty(reference, candidate)
bleu = BP * (p_scores ** (1 / n))
return bleu