Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 1, 2024
1 parent cef4afe commit 08af4fb
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions odyssey/models/cehr_mamba/mamba-dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
" get_run_id,\n",
" load_config,\n",
" load_pretrain_data,\n",
" load_finetune_data\n",
" load_finetune_data,\n",
")\n",
"from odyssey.utils.utils import seed_everything\n",
"\n",
Expand All @@ -72,10 +72,10 @@
"outputs": [],
"source": [
"class args:\n",
" data_dir = 'odyssey/data/bigbird_data'\n",
" sequence_file = 'patient_sequences_2048.parquet'\n",
" id_file = 'dataset_2048_multi.pkl'\n",
" vocab_dir = 'odyssey/data/vocab'\n",
" data_dir = \"odyssey/data/bigbird_data\"\n",
" sequence_file = \"patient_sequences_2048.parquet\"\n",
" id_file = \"dataset_2048_multi.pkl\"\n",
" vocab_dir = \"odyssey/data/vocab\"\n",
" max_len = 2048\n",
" mask_prob = 0.15"
]
Expand Down Expand Up @@ -105,16 +105,12 @@
"\n",
"\n",
"_, fine_test = load_finetune_data(\n",
" args.data_dir,\n",
" args.sequence_file,\n",
" args.id_file,\n",
" 'few_shot',\n",
" 'all'\n",
" args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n",
")\n",
"test_dataset = PretrainDatasetDecoder(\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
")"
]
},
Expand Down Expand Up @@ -213,9 +209,9 @@
],
"source": [
"# Load pretrained model\n",
"checkpoint = torch.load('checkpoints/mamba_pretrain/best.ckpt', map_location=device)\n",
"state_dict = checkpoint['state_dict']\n",
"state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}\n",
"checkpoint = torch.load(\"checkpoints/mamba_pretrain/best.ckpt\", map_location=device)\n",
"state_dict = checkpoint[\"state_dict\"]\n",
"state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items()}\n",
"model.load_state_dict(state_dict)\n",
"model"
]
Expand All @@ -239,13 +235,13 @@
],
"source": [
"train_loader = DataLoader(\n",
" test_dataset, #train_dataset\n",
" batch_size=3,\n",
" shuffle=False,\n",
" )\n",
" test_dataset, # train_dataset\n",
" batch_size=3,\n",
" shuffle=False,\n",
")\n",
"\n",
"sample = test_dataset[97] #train_dataset[0]\n",
"sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n",
"sample = test_dataset[97] # train_dataset[0]\n",
"sample = {key: tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n",
"\n",
"# sample = next(iter(train_loader))\n",
"# sample = {key:tensor.to(device) for key, tensor in sample.items()}\n",
Expand All @@ -267,8 +263,8 @@
}
],
"source": [
"input_ids = sample['concept_ids'].squeeze().tolist()\n",
"input_ids = input_ids[:input_ids.index(0)]\n",
"input_ids = sample[\"concept_ids\"].squeeze().tolist()\n",
"input_ids = input_ids[: input_ids.index(0)]\n",
"print(tokenizer.decode(input_ids))"
]
},
Expand Down Expand Up @@ -311,7 +307,7 @@
"source": [
"output = model.generate(\n",
" torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),\n",
" max_new_tokens=10\n",
" max_new_tokens=10,\n",
")\n",
"\n",
"tokenizer.decode(output.squeeze().tolist()[-10:])"
Expand Down

0 comments on commit 08af4fb

Please sign in to comment.