diff --git a/odyssey/models/cehr_mamba/mamba-dev.ipynb b/odyssey/models/cehr_mamba/mamba-dev.ipynb index 9919aef..70a1352 100644 --- a/odyssey/models/cehr_mamba/mamba-dev.ipynb +++ b/odyssey/models/cehr_mamba/mamba-dev.ipynb @@ -58,7 +58,7 @@ " get_run_id,\n", " load_config,\n", " load_pretrain_data,\n", - " load_finetune_data\n", + " load_finetune_data,\n", ")\n", "from odyssey.utils.utils import seed_everything\n", "\n", @@ -72,10 +72,10 @@ "outputs": [], "source": [ "class args:\n", - " data_dir = 'odyssey/data/bigbird_data'\n", - " sequence_file = 'patient_sequences_2048.parquet'\n", - " id_file = 'dataset_2048_multi.pkl'\n", - " vocab_dir = 'odyssey/data/vocab'\n", + " data_dir = \"odyssey/data/bigbird_data\"\n", + " sequence_file = \"patient_sequences_2048.parquet\"\n", + " id_file = \"dataset_2048_multi.pkl\"\n", + " vocab_dir = \"odyssey/data/vocab\"\n", " max_len = 2048\n", " mask_prob = 0.15" ] @@ -105,16 +105,12 @@ "\n", "\n", "_, fine_test = load_finetune_data(\n", - " args.data_dir,\n", - " args.sequence_file,\n", - " args.id_file,\n", - " 'few_shot',\n", - " 'all'\n", + " args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n", ")\n", "test_dataset = PretrainDatasetDecoder(\n", - " data=fine_test,\n", - " tokenizer=tokenizer,\n", - " max_len=args.max_len,\n", + " data=fine_test,\n", + " tokenizer=tokenizer,\n", + " max_len=args.max_len,\n", ")" ] }, @@ -213,9 +209,9 @@ ], "source": [ "# Load pretrained model\n", - "checkpoint = torch.load('checkpoints/mamba_pretrain/best.ckpt', map_location=device)\n", - "state_dict = checkpoint['state_dict']\n", - "state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}\n", + "checkpoint = torch.load(\"checkpoints/mamba_pretrain/best.ckpt\", map_location=device)\n", + "state_dict = checkpoint[\"state_dict\"]\n", + "state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items()}\n", "model.load_state_dict(state_dict)\n", "model" ] @@ -239,13 +235,13 @@ ], "source": [ "train_loader = DataLoader(\n", - " test_dataset, #train_dataset\n", - " batch_size=3,\n", - " shuffle=False,\n", - " )\n", + " test_dataset, # train_dataset\n", + " batch_size=3,\n", + " shuffle=False,\n", + ")\n", "\n", - "sample = test_dataset[97] #train_dataset[0]\n", - "sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n", + "sample = test_dataset[97] # train_dataset[0]\n", + "sample = {key: tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n", "\n", "# sample = next(iter(train_loader))\n", "# sample = {key:tensor.to(device) for key, tensor in sample.items()}\n", @@ -267,8 +263,8 @@ } ], "source": [ - "input_ids = sample['concept_ids'].squeeze().tolist()\n", - "input_ids = input_ids[:input_ids.index(0)]\n", + "input_ids = sample[\"concept_ids\"].squeeze().tolist()\n", + "input_ids = input_ids[: input_ids.index(0)]\n", "print(tokenizer.decode(input_ids))" ] }, @@ -311,7 +307,7 @@ "source": [ "output = model.generate(\n", " torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),\n", - " max_new_tokens=10\n", + " max_new_tokens=10,\n", ")\n", "\n", "tokenizer.decode(output.squeeze().tolist()[-10:])"