Skip to content

Commit

Permalink
Fixed style merge conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed May 1, 2024
2 parents 4dab179 + 08af4fb commit bd28037
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions odyssey/models/cehr_mamba/mamba-dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
" get_run_id,\n",
" load_config,\n",
" load_pretrain_data,\n",
" load_finetune_data\n",
" load_finetune_data,\n",
")\n",
"from odyssey.utils.utils import seed_everything\n",
"\n",
Expand Down Expand Up @@ -108,16 +108,12 @@
"\n",
"\n",
"_, fine_test = load_finetune_data(\n",
" args.data_dir,\n",
" args.sequence_file,\n",
" args.id_file,\n",
" 'few_shot',\n",
" 'all'\n",
" args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n",
")\n",
"test_dataset = PretrainDatasetDecoder(\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
")"
]
},
Expand Down Expand Up @@ -216,9 +212,9 @@
],
"source": [
"# Load pretrained model\n",
"checkpoint = torch.load('checkpoints/mamba_pretrain/best.ckpt', map_location=device)\n",
"state_dict = checkpoint['state_dict']\n",
"state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}\n",
"checkpoint = torch.load(\"checkpoints/mamba_pretrain/best.ckpt\", map_location=device)\n",
"state_dict = checkpoint[\"state_dict\"]\n",
"state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items()}\n",
"model.load_state_dict(state_dict)\n",
"model"
]
Expand Down Expand Up @@ -273,8 +269,8 @@
}
],
"source": [
"input_ids = sample['concept_ids'].squeeze().tolist()\n",
"input_ids = input_ids[:input_ids.index(0)]\n",
"input_ids = sample[\"concept_ids\"].squeeze().tolist()\n",
"input_ids = input_ids[: input_ids.index(0)]\n",
"print(tokenizer.decode(input_ids))"
]
},
Expand Down Expand Up @@ -317,7 +313,7 @@
"source": [
"output = model.generate(\n",
" torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),\n",
" max_new_tokens=10\n",
" max_new_tokens=10,\n",
")\n",
"\n",
"tokenizer.decode(output.squeeze().tolist()[-10:])"
Expand Down Expand Up @@ -1075,18 +1071,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit bd28037

Please sign in to comment.