diff --git a/odyssey/models/cehr_mamba/mamba-dev.ipynb b/odyssey/models/cehr_mamba/mamba-dev.ipynb index a729375..a96af6e 100644 --- a/odyssey/models/cehr_mamba/mamba-dev.ipynb +++ b/odyssey/models/cehr_mamba/mamba-dev.ipynb @@ -59,7 +59,7 @@ " get_run_id,\n", " load_config,\n", " load_pretrain_data,\n", - " load_finetune_data\n", + " load_finetune_data,\n", ")\n", "from odyssey.utils.utils import seed_everything\n", "\n", @@ -108,16 +108,12 @@ "\n", "\n", "_, fine_test = load_finetune_data(\n", - " args.data_dir,\n", - " args.sequence_file,\n", - " args.id_file,\n", - " 'few_shot',\n", - " 'all'\n", + " args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n", ")\n", "test_dataset = PretrainDatasetDecoder(\n", - " data=fine_test,\n", - " tokenizer=tokenizer,\n", - " max_len=args.max_len,\n", + " data=fine_test,\n", + " tokenizer=tokenizer,\n", + " max_len=args.max_len,\n", ")" ] }, @@ -216,9 +212,9 @@ ], "source": [ "# Load pretrained model\n", - "checkpoint = torch.load('checkpoints/mamba_pretrain/best.ckpt', map_location=device)\n", - "state_dict = checkpoint['state_dict']\n", - "state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}\n", + "checkpoint = torch.load(\"checkpoints/mamba_pretrain/best.ckpt\", map_location=device)\n", + "state_dict = checkpoint[\"state_dict\"]\n", + "state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items()}\n", "model.load_state_dict(state_dict)\n", "model" ] @@ -273,8 +269,8 @@ } ], "source": [ - "input_ids = sample['concept_ids'].squeeze().tolist()\n", - "input_ids = input_ids[:input_ids.index(0)]\n", + "input_ids = sample[\"concept_ids\"].squeeze().tolist()\n", + "input_ids = input_ids[: input_ids.index(0)]\n", "print(tokenizer.decode(input_ids))" ] }, @@ -317,7 +313,7 @@ "source": [ "output = model.generate(\n", " torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),\n", - " max_new_tokens=10\n", + " max_new_tokens=10,\n", ")\n", "\n", "tokenizer.decode(output.squeeze().tolist()[-10:])" @@ -1075,18 +1071,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" } }, "nbformat": 4,