Skip to content

Commit

Permalink
update bge eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 4, 2023
1 parent 221b2b0 commit e39924e
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 24 deletions.
3 changes: 3 additions & 0 deletions examples/training_bge_model_mydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def main():
choices=list(EncoderType), help='Encoder type, string name of EncoderType')
parser.add_argument("--bf16", action="store_true", help="Whether to use bfloat16 amp training.")
parser.add_argument("--data_parallel", action="store_true", help="Whether to use multi-gpu data parallel.")
parser.add_argument("--normalize_embeddings", action="store_true",
help="Whether to normalize embeddings. set True if temperature < 1.0")
args = parser.parse_args()
logger.info(args)

Expand All @@ -82,6 +84,7 @@ def main():
data_parallel=args.data_parallel,
train_group_size=args.train_group_size,
temperature=args.temperature,
normalize_embeddings=args.normalize_embeddings,
)
logger.info(f"Model saved to {args.output_dir}")
if args.do_predict:
Expand Down
2 changes: 1 addition & 1 deletion tests/flag_dres_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
pooling_method: str = 'cls',
normalize_embeddings: bool = True,
query_instruction_for_retrieval: str = None,
batch_size: int = 64,
batch_size: int = 128,
**kwargs
) -> None:

Expand Down
5 changes: 2 additions & 3 deletions tests/summarize_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@author:XuMing(xuming624@qq.com)
@description: Evaluate MTEB benchmark
pip install mteb
pip install C_MTEB
code modified from https://github.com/FlagOpen/FlagEmbedding
"""
Expand Down Expand Up @@ -108,7 +108,6 @@ def output_markdown(tasks_results, model_names, save_file):
write_line += f" {round(sum(cqa_res) / len(cqa_res), 2)} |"
all_res.append(round(sum(cqa_res) / len(cqa_res), 2))

# if len(all_res) == len(type_results.keys()):
if len(all_res) == task_cnt:
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
task_type_res[t_type][model] = all_res
Expand Down Expand Up @@ -139,7 +138,7 @@ def output_markdown(tasks_results, model_names, save_file):
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"

f.write(write_line + ' \n')

print(f"Save results to {save_file}")

def get_args():
parser = argparse.ArgumentParser()
Expand Down
13 changes: 0 additions & 13 deletions tests/test_model_spearman.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,19 +434,6 @@ def test_bge_large_zh_noinstruct_bge_model(self):
# add sohu avg: 0.4947
pass

def test_bge_large_zh_noinstruct_my_impl_bge_model(self):
# BAAI/bge-large-zh-noinstruct with bge finetuned v3
# STS-B spearman corr: 0.8093
# ATEC spearman corr: 0.45839
# BQ spearman corr: 0.56505
# LCQMC spearman corr: 0.742664
# PAWSX spearman corr: 0.11136
# avg: 0.53736
# V100 QPS: 605
# sohu-dd spearman corr: 0.566741
# sohu-dc spearman corr: 0.2098
# add sohu avg: 0.4947
pass

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion text2vec/bge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizer,
data_file_or_name: str,
query_max_len: int = 64,
query_max_len: int = 32,
passage_max_len: int = 128,
train_group_size: int = 8
):
Expand Down
10 changes: 4 additions & 6 deletions text2vec/bge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
encoder_type: str = "MEAN",
max_seq_length: int = 32,
passage_max_len: int = 128,
num_classes: int = 1,
device: str = None,
):
"""
Expand All @@ -47,7 +46,6 @@ def __init__(
super().__init__(model_name_or_path, encoder_type, max_seq_length, device)
self.query_max_len = max_seq_length
self.passage_max_len = passage_max_len
self.classifier = nn.Linear(self.bert.config.hidden_size * 3, num_classes).to(self.device)

def __str__(self):
return f"<BgeModel: {self.model_name_or_path}, encoder_type: {self.encoder_type}, " \
Expand All @@ -61,7 +59,7 @@ def train_model(
verbose: bool = True,
batch_size: int = 32,
num_epochs: int = 1,
weight_decay: float = 0.01,
weight_decay: float = 0.0,
seed: int = 42,
warmup_ratio: float = 0.05,
lr: float = 1e-5,
Expand All @@ -72,7 +70,7 @@ def train_model(
use_hf_dataset: bool = False,
hf_dataset_name: str = "",
save_model_every_epoch: bool = True,
bf16: bool = True,
bf16: bool = False,
data_parallel: bool = False,
train_group_size: int = 8,
temperature: float = 1.0,
Expand Down Expand Up @@ -175,7 +173,7 @@ def train(
verbose: bool = True,
batch_size: int = 8,
num_epochs: int = 1,
weight_decay: float = 0.01,
weight_decay: float = 0.0,
seed: int = 42,
warmup_ratio: float = 0.05,
lr: float = 1e-5,
Expand All @@ -184,7 +182,7 @@ def train(
max_grad_norm: float = 1.0,
max_steps: int = -1,
save_model_every_epoch: bool = True,
bf16: bool = True,
bf16: bool = False,
data_parallel: bool = False,
temperature: float = 1.0,
normalize_embeddings: bool = False,
Expand Down

0 comments on commit e39924e

Please sign in to comment.