diff --git a/src/MedVInT_TD/test.py b/src/MedVInT_TD/test.py index 14ff623..d058c0a 100644 --- a/src/MedVInT_TD/test.py +++ b/src/MedVInT_TD/test.py @@ -18,7 +18,7 @@ @dataclass class ModelArguments: model_path: Optional[str] = field(default="./LLAMA/llama-7b-hf") - ckp: Optional[str] = field(default="./Results/QA_PMC_LLaMA_lora_PMC-CLIP_MLP/choice_training/checkpoint-4146") + ckp: Optional[str] = field(default="./Results/VQA_lora_PMC_LLaMA_PMCCLIP/choice/checkpoint-4000") checkpointing: Optional[bool] = field(default=False) ## Q_former ## N: Optional[int] = field(default=12)