Skip to content

Commit

Permalink
refactor(ngram): implement with a better way (#33)
Browse files Browse the repository at this point in the history
* refactor(ngram): implement with a better way
  • Loading branch information
shenxiangzhuang committed Apr 28, 2024
1 parent 27ac861 commit c5b85f8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 30 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Implement ngram counts with a better way (#33)

## [0.1.1] - 2024-04-26
### Changed
- Upgrade `cached` version to `0.50.0`
Expand Down
14 changes: 6 additions & 8 deletions src/bleu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ pub fn compute_score(
// ngram count
let translation_ngram_counts = get_token_ngram_counter(&translation_tokens, max_order);
let mut merged_ref_ngram_counts = HashMap::new();
for reference_tokens in references_tokens {
let reference_ngram_counts = get_token_ngram_counter(&reference_tokens, max_order);
for reference_tokens in references_tokens.iter() {
let reference_ngram_counts = get_token_ngram_counter(reference_tokens, max_order);
for (key, value) in reference_ngram_counts {
merged_ref_ngram_counts
.entry(key)
Expand All @@ -54,16 +54,14 @@ pub fn compute_score(
// overlap count
let mut overlap_counts = HashMap::new();
for (k, v) in translation_ngram_counts {
let key = k.clone();
if merged_ref_ngram_counts.contains_key(&key) {
overlap_counts.insert(k, min(merged_ref_ngram_counts[&key], v));
if merged_ref_ngram_counts.contains_key(k) {
overlap_counts.insert(k, min(merged_ref_ngram_counts[k], v));
} else {
continue;
}
}
for key in overlap_counts.keys() {
let (_, order) = key;
matches_by_order[order - 1] += overlap_counts[&key];
for &key in overlap_counts.keys() {
matches_by_order[key.len() - 1] += overlap_counts[key];
}

// possible match
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and [sacrebleu](https://github.com/mjpost/sacrebleu)
# Basic usage:
```rust
use bleuscore::compute_score;
// get the references and prediction data:
let references: Vec<Vec<String>> = vec![vec!["Hello, World!".to_string()]];
let predictions: Vec<String> = vec!["Yellow, World!".to_string()];
Expand Down
38 changes: 16 additions & 22 deletions src/ngram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@ use std::collections::HashMap;

/// Here the tokens' type is `&[String]` rather than `&Vec<String>`
/// to fix `clippy::not_unsafe_ptr_arg_deref` error.
pub fn get_token_ngram_counter(
tokens: &[String],
max_order: usize,
) -> HashMap<(String, usize), usize> {
let mut count_map: HashMap<(String, usize), usize> = HashMap::new();
pub fn get_token_ngram_counter(tokens: &[String], max_order: usize) -> HashMap<&[String], usize> {
let mut count_map: HashMap<&[String], usize> = HashMap::new();
for order in 1..=max_order {
for start_index in 0..(tokens.len().saturating_sub(order - 1)) {
// note: can not join with "", which will make 2-gram ('000', '00') = ('0000', '0')
let ngram = tokens[start_index..(start_index + order)].join(" ");
let ngram = &tokens[start_index..(start_index + order)];
count_map
.entry((ngram, order))
.entry(ngram)
.and_modify(|counter| *counter += 1)
.or_insert(1);
}
Expand Down Expand Up @@ -44,9 +40,9 @@ mod test {
fn test_get_token_ngram_short() {
let tokens = vec!["a".to_string(), "b".to_string()];
let counter = get_token_ngram_counter(&tokens, 4);
assert_eq!(counter[&("a".to_string(), 1)], 1);
assert_eq!(counter[&("b".to_string(), 1)], 1);
assert_eq!(counter[&("a b".to_string(), 2)], 1);
assert_eq!(counter[&tokens[0..=0]], 1);
assert_eq!(counter[&tokens[1..=1]], 1);
assert_eq!(counter[&tokens[0..=1]], 1);
}

#[test]
Expand All @@ -59,19 +55,17 @@ mod test {
"c".to_string(),
];
let counter = get_token_ngram_counter(&tokens, 4);
assert_eq!(counter[&("a".to_string(), 1)], 2);
assert_eq!(counter[&("b".to_string(), 1)], 1);
assert_eq!(counter[&("c".to_string(), 1)], 1);
assert_eq!(counter.get(&("d".to_string(), 1)), None);
assert_eq!(counter[&tokens[0..=0]], 2); // 'a': 2
assert_eq!(counter[&tokens[2..=2]], 1); // 'b': 1
assert_eq!(counter[&tokens[3..=3]], 1); // 'c': 1

assert_eq!(counter[&("a a".to_string(), 2)], 1);
assert_eq!(counter[&("a b".to_string(), 2)], 1);
assert_eq!(counter[&("b c".to_string(), 2)], 1);
assert_eq!(counter.get(&("a c".to_string(), 2)), None);
assert_eq!(counter[&tokens[0..=1]], 1); // 'aa': 1
assert_eq!(counter[&tokens[1..=2]], 1); // 'ab': 1
assert_eq!(counter[&tokens[2..=3]], 1); // 'bc': 1

assert_eq!(counter[&("a a b".to_string(), 3)], 1);
assert_eq!(counter[&("a b c".to_string(), 3)], 1);
assert_eq!(counter[&("a a b c".to_string(), 4)], 1);
assert_eq!(counter[&tokens[0..=2]], 1); // 'aab': 1
assert_eq!(counter[&tokens[1..=3]], 1); // 'abc': 1
assert_eq!(counter[&tokens[0..=3]], 1); // 'abcd': 1

assert_eq!(counter.len(), 9);
}
Expand Down

0 comments on commit c5b85f8

Please sign in to comment.