diff --git a/evaluation/AttentionVisualization.ipynb b/evaluation/AttentionVisualization.ipynb deleted file mode 100644 index dc0a91c..0000000 --- a/evaluation/AttentionVisualization.ipynb +++ /dev/null @@ -1,3412 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 13, - "id": "initial_id", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:39:47.301592Z", - "start_time": "2024-03-15T15:39:47.293763Z" - }, - "collapsed": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\nFile: AttentionVisualization.ipynb\\n---------------------------------\\nVisualize the attention layers of transformer models for interpretability.\\n'" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\"\"\"\n", - "File: AttentionVisualization.ipynb\n", - "---------------------------------\n", - "Visualize the attention layers of transformer models for interpretability.\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "e24876a0c6020df2", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:49:16.880004Z", - "start_time": "2024-03-15T15:49:16.866092Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "import numpy as np\n", - "import plotly.graph_objects as go\n", - "import torch\n", - "from bertviz import head_view, model_view\n", - "from bertviz.neuron_view import show\n", - "from torch.utils.data import DataLoader, Subset\n", - "from transformers import utils\n", - "\n", - "\n", - "utils.logging.set_verbosity_error() # Suppress standard warnings\n", - "\n", - "\n", - "ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n", - "os.chdir(ROOT)\n", - "\n", - "from odyssey.data.dataset import FinetuneDataset\n", - "from odyssey.data.tokenizer import ConceptTokenizer\n", - "from odyssey.models.prediction import load_finetuned_model, predict_patient_outcomes\n", - "from odyssey.models.utils import (\n", - " load_finetune_data,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "dac09785d1fdc0cc", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:39:54.875569Z", - "start_time": "2024-03-15T15:39:54.852027Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "class args:\n", - " \"\"\"Save the configuration arguments.\"\"\"\n", - "\n", - " model_path = (\n", - " \"checkpoints/bigbird_finetune/mortality_1month_20000_patients/best-v3.ckpt\"\n", - " )\n", - " vocab_dir = \"data/vocab\"\n", - " data_dir = \"data/bigbird_data\"\n", - " sequence_file = (\n", - " \"old_data/patient_sequences/patient_sequences_2048_mortality.parquet\"\n", - " )\n", - " id_file = \"old_data/patient_id_dict/dataset_2048_mortality_1month.pkl\"\n", - " valid_scheme = \"few_shot\"\n", - " num_finetune_patients = \"20000\"\n", - " label_name = \"label_mortality_1month\"\n", - "\n", - " max_len = 2048\n", - " batch_size = 1\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "6b5e21258ced06d2", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:39:54.921944Z", - "start_time": "2024-03-15T15:39:54.877586Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "tokenizer = ConceptTokenizer(data_dir=args.vocab_dir)\n", - "tokenizer.fit_on_vocab()" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "8bf63dc4242a163e", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:40:00.629410Z", - "start_time": "2024-03-15T15:39:54.929550Z" - }, - "collapsed": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "BigBirdFinetune(\n", - " (model): BigBirdForSequenceClassification(\n", - " (bert): BigBirdModel(\n", - " (embeddings): BigBirdEmbeddingsForCEHR(\n", - " (word_embeddings): Embedding(20592, 768, padding_idx=0)\n", - " (position_embeddings): Embedding(2048, 768)\n", - " (token_type_embeddings): Embedding(9, 768)\n", - " (time_embeddings): TimeEmbeddingLayer()\n", - " (age_embeddings): TimeEmbeddingLayer()\n", - " (visit_segment_embeddings): VisitEmbedding(\n", - " (embedding): Embedding(3, 768)\n", - " )\n", - " (scale_back_concat_layer): Linear(in_features=832, out_features=768, bias=True)\n", - " (tanh): Tanh()\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (encoder): BigBirdEncoder(\n", - " (layer): ModuleList(\n", - " (0-5): 6 x BigBirdLayer(\n", - " (attention): BigBirdAttention(\n", - " (self): BigBirdBlockSparseAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " )\n", - " (output): BigBirdSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BigBirdIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " (intermediate_act_fn): NewGELUActivation()\n", - " )\n", - " (output): BigBirdOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (pooler): Linear(in_features=768, out_features=768, bias=True)\n", - " (activation): Tanh()\n", - " )\n", - " (classifier): BigBirdClassificationHead(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (out_proj): Linear(in_features=768, out_features=2, bias=True)\n", - " )\n", - " )\n", - " (pretrained_model): BigBirdPretrain(\n", - " (embeddings): BigBirdEmbeddingsForCEHR(\n", - " (word_embeddings): Embedding(20592, 768, padding_idx=0)\n", - " (position_embeddings): Embedding(2048, 768)\n", - " (token_type_embeddings): Embedding(9, 768)\n", - " (time_embeddings): TimeEmbeddingLayer()\n", - " (age_embeddings): TimeEmbeddingLayer()\n", - " (visit_segment_embeddings): VisitEmbedding(\n", - " (embedding): Embedding(3, 768)\n", - " )\n", - " (scale_back_concat_layer): Linear(in_features=832, out_features=768, bias=True)\n", - " (tanh): Tanh()\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (model): BigBirdForMaskedLM(\n", - " (bert): BigBirdModel(\n", - " (embeddings): BigBirdEmbeddingsForCEHR(\n", - " (word_embeddings): Embedding(20592, 768, padding_idx=0)\n", - " (position_embeddings): Embedding(2048, 768)\n", - " (token_type_embeddings): Embedding(9, 768)\n", - " (time_embeddings): TimeEmbeddingLayer()\n", - " (age_embeddings): TimeEmbeddingLayer()\n", - " (visit_segment_embeddings): VisitEmbedding(\n", - " (embedding): Embedding(3, 768)\n", - " )\n", - " (scale_back_concat_layer): Linear(in_features=832, out_features=768, bias=True)\n", - " (tanh): Tanh()\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (encoder): BigBirdEncoder(\n", - " (layer): ModuleList(\n", - " (0-5): 6 x BigBirdLayer(\n", - " (attention): BigBirdAttention(\n", - " (self): BigBirdBlockSparseAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " )\n", - " (output): BigBirdSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BigBirdIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " (intermediate_act_fn): NewGELUActivation()\n", - " )\n", - " (output): BigBirdOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (pooler): Linear(in_features=768, out_features=768, bias=True)\n", - " (activation): Tanh()\n", - " )\n", - " (cls): BigBirdOnlyMLMHead(\n", - " (predictions): BigBirdLMPredictionHead(\n", - " (transform): BigBirdPredictionHeadTransform(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (transform_act_fn): NewGELUActivation()\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " )\n", - " (decoder): Linear(in_features=768, out_features=20592, bias=True)\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = load_finetuned_model(args.model_path, tokenizer)\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "45bb295710f64bc0", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:52:23.262096Z", - "start_time": "2024-03-15T15:52:23.252429Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "fine_tune, fine_test = load_finetune_data(\n", - " args.data_dir,\n", - " args.sequence_file,\n", - " args.id_file,\n", - " args.valid_scheme,\n", - " args.num_finetune_patients,\n", - ")\n", - "\n", - "fine_tune.rename(columns={args.label_name: \"label\"}, inplace=True)\n", - "fine_test.rename(columns={args.label_name: \"label\"}, inplace=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "6c2f9f1f", - "metadata": {}, - "outputs": [], - "source": [ - "test_dataset = FinetuneDataset(\n", - " data=fine_test,\n", - " tokenizer=tokenizer,\n", - " max_len=args.max_len,\n", - ")\n", - "\n", - "test_loader = DataLoader(\n", - " Subset(test_dataset, [89, 90]), # 85 and 88 are small\n", - " batch_size=args.batch_size,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ead6f3658dda5274", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:52:24.058853Z", - "start_time": "2024-03-15T15:52:24.027868Z" - }, - "collapsed": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'concept_ids': tensor([[ 5, 3, 18065, ..., 0, 0, 0]]),\n", - " 'type_ids': tensor([[1, 2, 6, ..., 0, 0, 0]]),\n", - " 'ages': tensor([[ 0, 67, 67, ..., 0, 0, 0]]),\n", - " 'time_stamps': tensor([[ 0, 5773, 5773, ..., 0, 0, 0]]),\n", - " 'visit_orders': tensor([[0, 1, 1, ..., 0, 0, 0]]),\n", - " 'visit_segments': tensor([[0, 2, 2, ..., 0, 0, 0]]),\n", - " 'labels': tensor([0]),\n", - " 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0]])}" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "patient = next(iter(test_loader))\n", - "patient" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "b025c033efc9baab", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:52:25.797704Z", - "start_time": "2024-03-15T15:52:25.476580Z" - }, - "collapsed": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "SequenceClassifierOutput(loss=tensor(0.1007, device='cuda:0', grad_fn=), logits=tensor([[ 1.1001, -1.1447]], device='cuda:0', grad_fn=), hidden_states=(tensor([[[-0.6737, -0.4033, -1.0631, ..., -0.0571, -0.5688, -2.8107],\n", - " [-0.2583, 1.4081, 0.0247, ..., 0.9723, -0.1994, -0.6298],\n", - " [-0.3349, -0.7147, -0.2382, ..., 0.6685, -2.0221, 0.3603],\n", - " ...,\n", - " [-1.1487, -0.3301, -1.2087, ..., -1.4240, 0.6589, -3.1762],\n", - " [-0.5165, 0.4454, -1.0153, ..., -2.3917, -0.0144, -3.3358],\n", - " [-0.1893, -0.1348, 0.2863, ..., -2.4897, 0.7176, -2.4138]]],\n", - " device='cuda:0', grad_fn=), tensor([[[-9.1614e-01, 6.7605e-01, 6.2693e-03, ..., -3.1620e-01,\n", - " -4.6282e-01, -2.0648e+00],\n", - " [-6.0621e-01, 8.6801e-01, 2.6213e-01, ..., 3.4474e-01,\n", - " 1.3363e-01, -7.1667e-01],\n", - " [-7.6982e-01, 2.3106e-03, -5.9440e-01, ..., -4.7166e-01,\n", - " -1.8369e+00, 1.3552e-01],\n", - " ...,\n", - " [-8.6433e-01, 7.3538e-01, -3.3437e-01, ..., -1.2050e+00,\n", - " 1.3254e+00, -2.4991e+00],\n", - " [-7.3595e-01, 1.2425e+00, 5.3519e-03, ..., -1.8429e+00,\n", - " 9.7859e-01, -2.8092e+00],\n", - " [-5.8273e-01, 8.9714e-01, 8.5189e-01, ..., -1.7958e+00,\n", - " 1.6784e+00, -2.0120e+00]]], device='cuda:0',\n", - " grad_fn=), tensor([[[-0.5963, 1.0668, 0.6960, ..., -0.5670, -1.8491, -0.7775],\n", - " [-0.8661, 0.9432, 1.3912, ..., -0.1141, -0.0171, -0.0536],\n", - " [-0.4142, 0.1496, 0.0763, ..., -1.3690, -2.2675, 0.8359],\n", - " ...,\n", - " [-0.5217, 1.3100, 0.3550, ..., -0.8422, 0.7881, -2.2435],\n", - " [-0.6279, 1.6811, 0.6428, ..., -1.1621, 0.4885, -2.3759],\n", - " [-0.5811, 1.5146, 1.1337, ..., -1.0819, 1.0710, -2.0434]]],\n", - " device='cuda:0', grad_fn=), tensor([[[-0.4564, 0.8573, 0.4777, ..., -1.0131, -1.2389, -0.7232],\n", - " [-0.7503, 0.8020, 2.2105, ..., 0.3817, -0.5328, -1.1600],\n", - " [-0.2785, 0.3193, 0.7508, ..., -0.8207, -2.1549, -0.0479],\n", - " ...,\n", - " [-0.5749, 0.9360, 0.3542, ..., -1.2162, 0.5259, -1.9892],\n", - " [-0.8316, 1.2188, 0.4210, ..., -1.5618, 0.3291, -2.1654],\n", - " [-0.7359, 1.2650, 0.9190, ..., -1.4486, 0.7956, -1.8708]]],\n", - " device='cuda:0', grad_fn=), tensor([[[-0.6359, 1.1296, 0.3645, ..., -1.0155, -1.4793, -0.0216],\n", - " [-0.2409, 0.9339, 1.1703, ..., 1.1403, -1.0845, -0.8589],\n", - " [-0.2531, 0.2159, 0.2502, ..., -0.1171, -2.6449, -0.1982],\n", - " ...,\n", - " [-0.4492, 0.9411, 0.7911, ..., -0.9499, 0.3742, -1.3986],\n", - " [-0.7147, 1.2283, 0.7656, ..., -1.0737, 0.1806, -1.6471],\n", - " [-0.6436, 1.0595, 1.1752, ..., -1.1340, 0.4983, -1.3118]]],\n", - " device='cuda:0', grad_fn=), tensor([[[-0.0605, 0.4757, 1.2138, ..., -1.2107, -0.7777, 0.4597],\n", - " [-0.0597, 0.6793, 0.6923, ..., 0.5595, -0.6465, -0.4138],\n", - " [-0.2956, 0.0119, 0.0849, ..., -0.4024, -1.7548, 0.2857],\n", - " ...,\n", - " [-0.0730, 1.2001, 0.8387, ..., -1.5022, 0.3392, -0.9774],\n", - " [-0.2355, 1.2640, 0.8785, ..., -1.6132, 0.2851, -1.2461],\n", - " [-0.2264, 1.1501, 1.2094, ..., -1.7025, 0.5599, -0.9631]]],\n", - " device='cuda:0', grad_fn=), tensor([[[-0.4529, 0.7745, 0.9677, ..., -0.7731, 0.3017, 0.6295],\n", - " [-0.2527, 0.7169, 0.3931, ..., 0.3904, -0.2026, 0.2191],\n", - " [-0.7347, 0.0648, 0.1225, ..., -0.4706, -0.9406, 0.7601],\n", - " ...,\n", - " [-0.1587, 1.8421, 0.7164, ..., -1.1803, 0.3151, -1.0949],\n", - " [-0.4175, 1.9637, 0.7340, ..., -1.2465, 0.3226, -1.4348],\n", - " [-0.4056, 1.7997, 1.0610, ..., -1.3635, 0.5409, -1.1068]]],\n", - " device='cuda:0', grad_fn=)), attentions=(tensor([[[[9.2084e-01, 3.4662e-05, 1.6030e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0152e-04, 1.0245e-03, 5.0193e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1340e-05, 3.4795e-04, 1.1249e-01, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [3.7594e-01, 6.5887e-04, 2.1206e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.4787e-01, 1.3629e-03, 7.0707e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9477e-01, 1.2923e-03, 4.1574e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.2569e-05, 3.1873e-03, 1.4330e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0942e-04, 7.4366e-03, 1.0470e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.2275e-07, 8.8712e-03, 2.0909e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [3.3820e-04, 5.4235e-03, 3.9131e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.1625e-05, 6.0716e-03, 9.2111e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0256e-05, 6.8238e-03, 1.3608e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.1552e-05, 3.2614e-03, 5.1249e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.0069e-02, 1.7049e-03, 7.4966e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6985e-04, 7.9826e-05, 1.0739e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [7.2891e-05, 2.2444e-03, 8.0701e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.6524e-04, 2.5353e-03, 1.1610e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.4759e-05, 3.7951e-03, 9.3934e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[9.5957e-07, 1.8916e-05, 1.7706e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.4017e-03, 6.3007e-04, 8.2404e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.7625e-05, 2.4848e-05, 1.9090e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.4380e-05, 1.2709e-04, 2.4316e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.0306e-05, 1.8831e-04, 2.8286e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.8297e-05, 1.9117e-04, 2.4082e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.7399e-03, 1.2051e-03, 9.5375e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.7488e-03, 2.7973e-03, 2.2183e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6722e-03, 7.9718e-04, 2.1421e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.6540e-03, 2.6361e-03, 1.2919e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.7157e-03, 2.6202e-03, 1.2555e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6167e-03, 1.7538e-03, 1.1150e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.5929e-03, 3.2053e-03, 2.2594e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.3881e-03, 1.7186e-03, 4.7067e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.7379e-03, 3.9178e-03, 5.7301e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [3.5827e-03, 2.9827e-03, 2.1182e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.3835e-03, 3.9488e-03, 2.3409e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.8265e-03, 3.3581e-03, 2.8842e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=), tensor([[[[3.4325e-03, 3.6487e-03, 2.1270e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.6434e-03, 2.1695e-02, 6.8258e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.8946e-03, 5.1628e-02, 9.1872e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [6.2476e-03, 3.5812e-03, 3.8917e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.5847e-03, 5.8187e-03, 4.7901e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.7512e-03, 7.3790e-03, 5.0303e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.9891e-05, 4.4143e-06, 7.2098e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.8795e-03, 2.2040e-02, 1.3027e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.2756e-02, 1.8536e-02, 1.8427e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.0230e-04, 1.0431e-04, 3.6400e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0183e-04, 1.0241e-04, 4.4847e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2058e-04, 2.5562e-04, 6.5510e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.7265e-03, 6.2247e-06, 3.0965e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.5288e-05, 1.7324e-02, 6.0294e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.6073e-04, 1.7897e-03, 2.7479e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [4.5983e-03, 2.8453e-04, 4.3261e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.3822e-03, 2.4970e-04, 7.7571e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.9758e-03, 2.7208e-04, 6.9291e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.2307e-04, 1.6154e-07, 1.1893e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5087e-04, 7.7591e-02, 4.4232e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.9669e-04, 7.1647e-03, 1.2542e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.6157e-03, 1.9115e-04, 8.7496e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5545e-03, 6.9437e-04, 2.2060e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.5639e-03, 4.7817e-04, 1.0144e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[9.0331e-02, 2.0968e-04, 1.5666e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.2050e-05, 4.9400e-03, 2.9039e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.8098e-04, 7.9459e-03, 1.0302e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [8.0900e-02, 5.0021e-03, 3.6893e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.7062e-02, 5.0122e-03, 5.1272e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.2155e-02, 4.1310e-03, 4.8125e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.2098e-04, 4.5407e-08, 1.1781e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.4916e-05, 1.8071e-03, 2.9574e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9490e-03, 3.8689e-03, 1.6880e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [4.2944e-03, 1.2684e-05, 3.9500e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5202e-03, 9.5840e-06, 3.9704e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.4998e-03, 1.0348e-05, 4.0015e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=), tensor([[[[9.0947e-03, 1.7104e-04, 7.2082e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0905e-03, 8.2947e-03, 4.7806e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2774e-02, 4.1130e-03, 1.0427e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [7.4360e-03, 2.5052e-03, 3.5916e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.5451e-03, 2.0411e-03, 4.1614e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0480e-02, 1.5095e-03, 3.9783e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.0953e-04, 4.8485e-07, 1.1093e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.7053e-06, 7.0000e-03, 1.3357e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2624e-03, 4.3490e-03, 8.2277e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [8.3735e-03, 9.6488e-05, 1.6069e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.6108e-03, 6.0472e-05, 1.0010e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [8.4534e-03, 9.8078e-05, 1.4517e-02, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.1851e-03, 3.6355e-09, 4.5110e-06, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.1155e-05, 1.3445e-03, 3.1936e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.9026e-03, 3.4348e-04, 3.4485e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [4.4469e-03, 1.2559e-05, 5.1731e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.1810e-03, 6.3667e-06, 5.4800e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.0399e-03, 6.1432e-06, 4.1856e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[1.9304e-02, 3.4070e-05, 1.5798e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.7734e-05, 1.5453e-02, 6.9360e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.5957e-04, 3.6940e-03, 5.8250e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.6903e-03, 1.8801e-03, 4.4015e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2805e-03, 8.8522e-04, 5.1555e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.4006e-03, 1.2400e-03, 5.7507e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.1703e-04, 4.1161e-05, 7.9415e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.8420e-04, 7.7026e-03, 7.9656e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.5139e-03, 2.4419e-03, 2.6276e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.7707e-03, 6.9963e-04, 1.6709e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.0923e-03, 7.9106e-04, 2.5133e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.4068e-03, 2.6501e-04, 1.1517e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[6.0330e-05, 1.8324e-06, 1.0670e-06, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.0295e-04, 6.8225e-03, 4.1048e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.6344e-03, 2.1405e-03, 1.9475e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.0302e-03, 2.2379e-05, 8.3888e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0699e-03, 3.1538e-05, 9.3927e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1386e-03, 4.4208e-05, 1.2655e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=), tensor([[[[1.0968e-03, 1.1158e-05, 2.8874e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.5882e-04, 1.9889e-02, 2.9087e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.0075e-03, 2.4018e-03, 5.6122e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.2534e-03, 5.0390e-05, 2.8556e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.8193e-04, 5.5810e-05, 3.2226e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.8834e-03, 1.9344e-04, 8.2757e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.1752e-04, 1.0848e-03, 5.2472e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.3445e-05, 3.1222e-02, 7.3892e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.9519e-04, 2.3304e-02, 4.0696e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [7.4697e-04, 2.2354e-04, 5.6568e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.7230e-04, 4.3862e-04, 9.0070e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.0728e-04, 3.7051e-04, 8.7622e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[6.4030e-04, 1.0463e-03, 1.2372e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.3409e-04, 2.6108e-02, 1.0480e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.3558e-03, 4.1270e-03, 3.1541e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [8.7033e-04, 1.3328e-04, 4.4750e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.1584e-03, 5.6131e-04, 1.9963e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3833e-03, 1.7700e-04, 1.2080e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.0260e-04, 1.0307e-03, 5.4334e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.5736e-04, 4.4732e-03, 3.1108e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0399e-03, 2.5841e-03, 3.4921e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.2584e-03, 3.3595e-04, 1.9949e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.4056e-03, 7.4162e-04, 3.0422e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.4776e-03, 3.4124e-04, 2.3692e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.3945e-05, 2.8038e-06, 7.7244e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.1178e-03, 5.9071e-01, 3.8243e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1613e-03, 9.0446e-03, 1.3656e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [6.1137e-06, 8.0224e-07, 2.9675e-06, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.1845e-05, 4.3952e-06, 1.5539e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.0935e-05, 7.2282e-06, 1.6119e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[7.7009e-03, 9.1021e-04, 4.4472e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0289e-04, 1.2443e-02, 8.9628e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.7645e-04, 5.7296e-03, 3.5397e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.8769e-04, 7.9107e-06, 6.3012e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [7.2969e-04, 4.7156e-05, 2.6878e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.0934e-04, 3.7384e-05, 2.2858e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=), tensor([[[[6.7381e-04, 1.1823e-02, 7.4050e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.4863e-04, 4.0161e-02, 3.6514e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.6004e-03, 1.9051e-02, 5.0602e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [3.5047e-03, 1.9051e-03, 1.5601e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.6917e-03, 3.4615e-03, 1.6968e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.7805e-03, 3.3552e-03, 2.2134e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.1542e-04, 1.2458e-05, 3.5918e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5336e-03, 2.6060e-03, 9.9650e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.3570e-03, 4.4454e-04, 2.1211e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [3.8644e-05, 1.1354e-06, 2.5000e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.0465e-05, 9.6188e-07, 2.3662e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.5490e-05, 1.7402e-06, 4.0543e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[4.4482e-03, 4.3810e-07, 2.4999e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.1627e-04, 5.7318e-02, 2.5052e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.0123e-03, 1.0834e-03, 3.6955e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.2926e-03, 2.7957e-07, 6.0144e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3540e-03, 2.3200e-07, 7.3312e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2753e-03, 5.4792e-07, 9.2235e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[1.8663e-05, 6.3300e-07, 1.4228e-06, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2953e-03, 1.2528e-02, 3.5547e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.9098e-03, 1.1120e-03, 1.3087e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.1750e-04, 1.4822e-05, 2.3723e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.7225e-05, 8.3810e-06, 1.6070e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6516e-04, 2.6572e-05, 4.2221e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[4.5335e-05, 4.0905e-04, 2.3216e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.7303e-04, 5.1438e-02, 3.8481e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [6.3572e-04, 1.6108e-02, 4.7815e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.6582e-03, 9.6089e-04, 3.1044e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1655e-03, 1.2231e-03, 3.3471e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5900e-03, 1.6183e-03, 5.6334e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.3592e-03, 9.1240e-05, 1.9648e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.9554e-04, 3.4671e-02, 1.2307e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.1695e-03, 3.3091e-03, 4.3312e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [4.5615e-04, 4.8380e-07, 7.2619e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2036e-03, 3.6028e-06, 3.0705e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2138e-03, 3.5186e-06, 2.6554e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=), tensor([[[[4.0090e-03, 6.7263e-04, 7.5136e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.9121e-04, 1.2857e-02, 3.9387e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5724e-03, 4.5150e-03, 2.8897e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [4.5598e-03, 1.4921e-04, 2.2929e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.6446e-03, 1.8594e-04, 2.9366e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.2847e-03, 2.3700e-04, 3.6638e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.9915e-02, 4.0394e-05, 5.0094e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.3356e-04, 3.0291e-03, 1.2489e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.3452e-03, 8.8680e-04, 1.3824e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.0090e-01, 1.4063e-05, 1.0046e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1156e-01, 1.0029e-05, 8.7870e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.5568e-02, 1.6524e-05, 1.0813e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.8186e-02, 5.9815e-06, 5.0354e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1144e-04, 3.3203e-03, 1.0907e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1076e-03, 7.5293e-04, 1.7252e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.2483e-01, 2.2352e-07, 2.1154e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1819e-01, 1.1773e-07, 1.4315e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.0731e-01, 2.1415e-07, 1.9841e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[8.9774e-03, 2.6941e-04, 3.7395e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.7763e-04, 4.7386e-03, 2.2903e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2463e-03, 2.2257e-03, 1.9620e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.1848e-02, 8.8093e-05, 1.7620e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3670e-02, 8.8600e-05, 1.7867e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.2773e-02, 1.0601e-04, 2.1455e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.7395e-02, 2.0985e-04, 9.8349e-04, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3505e-04, 3.5468e-03, 1.4784e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9124e-03, 1.9014e-03, 2.2491e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.7175e-02, 1.4688e-05, 9.3223e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.9626e-02, 9.9576e-06, 9.4240e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.6579e-02, 1.2380e-05, 9.0343e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.5388e-03, 4.6041e-05, 9.6215e-05, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.3881e-05, 1.0036e-02, 3.3659e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [9.2542e-04, 6.3507e-03, 5.8799e-03, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.5240e-04, 1.0778e-08, 7.3738e-08, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5682e-04, 1.5447e-08, 1.0267e-07, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [2.6646e-04, 3.5608e-08, 2.2134e-07, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]]]], device='cuda:0', grad_fn=)))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output = predict_patient_outcomes(patient, model)\n", - "output" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "e1aa8309b2650820", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T15:52:33.578704Z", - "start_time": "2024-03-15T15:52:33.571290Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "tokens = tokenizer.decode(patient[\"concept_ids\"].squeeze(0).cpu().numpy()).split(\" \")\n", - "truncate_at = patient[\"attention_mask\"].sum().numpy()\n", - "attention_matrix = output[\"attentions\"]\n", - "last_attention_matrix = attention_matrix[-1].detach()\n", - "# batch_size x num_heads x max_len x max_len x num_layers" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "e9203f59", - "metadata": {}, - "outputs": [], - "source": [ - "truncated_attention_matrix = []\n", - "\n", - "for i in range(len(attention_matrix)):\n", - " truncated_attention_matrix.append(\n", - " attention_matrix[i][:, :, :truncate_at, :truncate_at],\n", - " )\n", - "\n", - "truncated_attention_matrix = tuple(truncated_attention_matrix)\n", - "truncated_tokens = tokens[:truncate_at]" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "7fb27b941602401d91542211134fc71a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token [CLS] with Token 50893_0: Attention Value 0.154\n", - "Token [CLS] with Token 51251_1: Attention Value 0.026\n", - "Token [CLS] with Token 51279_3: Attention Value 0.018\n", - "Token [CLS] with Token 50983_0: Attention Value 0.014\n", - "Token [CLS] with Token 51257_0: Attention Value 0.013\n", - "Token [CLS] with Token 52069_0: Attention Value 0.012\n", - "Token [CLS] with Token 50902_0: Attention Value 0.011\n", - "Token [CLS] with Token 51277_1: Attention Value 0.010\n", - "Token [CLS] with Token 50878_4: Attention Value 0.009\n", - "Token [CLS] with Token 00641607825: Attention Value 0.009\n", - "Token [CLS] with Token 00641607825: Attention Value 0.009\n", - "Token [CLS] with Token 51133_2: Attention Value 0.009\n", - "Token [CLS] with Token [CLS]: Attention Value 0.008\n", - "Token [CLS] with Token 51255_1: Attention Value 0.008\n", - "Token [CLS] with Token 61958040101: Attention Value 0.008\n" - ] - }, - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "colorscale": [ - [ - 0, - "rgb(255,255,217)" - ], - [ - 0.125, - "rgb(237,248,177)" - ], - [ - 0.25, - "rgb(199,233,180)" - ], - [ - 0.375, - "rgb(127,205,187)" - ], - [ - 0.5, - "rgb(65,182,196)" - ], - [ - 0.625, - "rgb(29,145,192)" - ], - [ - 0.75, - "rgb(34,94,168)" - ], - [ - 0.875, - "rgb(37,52,148)" - ], - [ - 1, - "rgb(8,29,88)" - ] - ], - "hoverinfo": "text", - "hoverongaps": false, - "text": [ - [ - "Token [CLS] with Token [CLS]:Attention Value 0.008", - "Token [CLS] with Token [VS]:Attention Value 0.000", - "Token [CLS] with Token 00121048910:Attention Value 0.001", - "Token [CLS] with Token 00338004904:Attention Value 0.003", - "Token [CLS] with Token 00121048910:Attention Value 0.001", - "Token [CLS] with Token 00006494300:Attention Value 0.000", - "Token [CLS] with Token 00904272561:Attention Value 0.001", - "Token [CLS] with Token 61958040101:Attention Value 0.004", - "Token [CLS] with Token 00487950101:Attention Value 0.002", - "Token [CLS] with Token 00904198261:Attention Value 0.001", - "Token [CLS] with Token 00173071920:Attention Value 0.007", - "Token [CLS] with Token 00121048910:Attention Value 0.001", - "Token [CLS] with Token 60505011300:Attention Value 0.001", - "Token [CLS] with Token 00338004904:Attention Value 0.004", - "Token [CLS] with Token 61958040101:Attention Value 0.008", - "Token [CLS] with Token 00597008717:Attention Value 0.004", - "Token [CLS] with Token 00338004904:Attention Value 0.004", - "Token [CLS] with Token 42292000120:Attention Value 0.002", - "Token [CLS] with Token 00904639161:Attention Value 0.001", - "Token [CLS] with Token 00338004904:Attention Value 0.005", - "Token [CLS] with Token 00472030680:Attention Value 0.003", - "Token [CLS] with Token 00143989701:Attention Value 0.001", - "Token [CLS] with Token 00641607825:Attention Value 0.009", - "Token [CLS] with Token 00338004904:Attention Value 0.004", - "Token [CLS] with Token 51079096620:Attention Value 0.001", - "Token [CLS] with Token 00409128331:Attention Value 0.002", - "Token [CLS] with Token 63323031461:Attention Value 0.000", - "Token [CLS] with Token 00009513503:Attention Value 0.001", - "Token [CLS] with Token 00056051030:Attention Value 0.004", - "Token [CLS] with Token 66758016013:Attention Value 0.004", - "Token [CLS] with Token 00338004904:Attention Value 0.003", - "Token [CLS] with Token 00338004904:Attention Value 0.003", - "Token [CLS] with Token 00071015892:Attention Value 0.004", - "Token [CLS] with Token 00338004904:Attention Value 0.006", - "Token [CLS] with Token 11523726808:Attention Value 0.001", - "Token [CLS] with Token 00904198861:Attention Value 0.000", - "Token [CLS] with Token 00071015892:Attention Value 0.003", - "Token [CLS] with Token 00487020101:Attention Value 0.002", - "Token [CLS] with Token 63739035810:Attention Value 0.001", - "Token [CLS] with Token 00904635361:Attention Value 0.003", - "Token [CLS] with Token 76439034310:Attention Value 0.005", - "Token [CLS] with Token 57896042101:Attention Value 0.003", - "Token [CLS] with Token 00056051030:Attention Value 0.005", - "Token [CLS] with Token 42292000120:Attention Value 0.002", - "Token [CLS] with Token 51079054720:Attention Value 0.004", - "Token [CLS] with Token 00904224461:Attention Value 0.000", - "Token [CLS] with Token 51079098320:Attention Value 0.002", - "Token [CLS] with Token 61958060101:Attention Value 0.003", - "Token [CLS] with Token 51079098320:Attention Value 0.001", - "Token [CLS] with Token 00641607825:Attention Value 0.009", - "Token [CLS] with Token 49281041688:Attention Value 0.005", - "Token [CLS] with Token 00904516561:Attention Value 0.002", - "Token [CLS] with Token 00121048910:Attention Value 0.001", - "Token [CLS] with Token 00173068224:Attention Value 0.005", - "Token [CLS] with Token 00121048910:Attention Value 0.001", - "Token [CLS] with Token 63323026201:Attention Value 0.000", - "Token [CLS] with Token 57896045208:Attention Value 0.004", - "Token [CLS] with Token 63323031461:Attention Value 0.001", - "Token [CLS] with Token 00487020101:Attention Value 0.002", - "Token [CLS] with Token 52074_3:Attention Value 0.000", - "Token [CLS] with Token 52073_2:Attention Value 0.000", - "Token [CLS] with Token 52069_2:Attention Value 0.000", - "Token [CLS] with Token 51301_1:Attention Value 0.000", - "Token [CLS] with Token 51279_3:Attention Value 0.001", - "Token [CLS] with Token 51277_1:Attention Value 0.001", - "Token [CLS] with Token 51265_2:Attention Value 0.000", - "Token [CLS] with Token 51256_1:Attention Value 0.001", - "Token [CLS] with Token 51254_4:Attention Value 0.001", - "Token [CLS] with Token 51250_2:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51244_3:Attention Value 0.000", - "Token [CLS] with Token 51222_4:Attention Value 0.001", - "Token [CLS] with Token 51221_3:Attention Value 0.002", - "Token [CLS] with Token 51200_2:Attention Value 0.001", - "Token [CLS] with Token 51146_2:Attention Value 0.001", - "Token [CLS] with Token 51133_3:Attention Value 0.001", - "Token [CLS] with Token 52075_2:Attention Value 0.001", - "Token [CLS] with Token 52074_2:Attention Value 0.001", - "Token [CLS] with Token 52073_2:Attention Value 0.001", - "Token [CLS] with Token 52069_1:Attention Value 0.001", - "Token [CLS] with Token 51301_1:Attention Value 0.001", - "Token [CLS] with Token 51279_3:Attention Value 0.002", - "Token [CLS] with Token 51277_1:Attention Value 0.001", - "Token [CLS] with Token 51265_1:Attention Value 0.000", - "Token [CLS] with Token 51256_3:Attention Value 0.001", - "Token [CLS] with Token 51254_3:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51244_1:Attention Value 0.000", - "Token [CLS] with Token 51222_3:Attention Value 0.001", - "Token [CLS] with Token 51221_3:Attention Value 0.001", - "Token [CLS] with Token 51200_2:Attention Value 0.001", - "Token [CLS] with Token 51146_1:Attention Value 0.001", - "Token [CLS] with Token 51133_1:Attention Value 0.001", - "Token [CLS] with Token 51006_0:Attention Value 0.001", - "Token [CLS] with Token 50983_0:Attention Value 0.002", - "Token [CLS] with Token 50971_0:Attention Value 0.001", - "Token [CLS] with Token 50970_1:Attention Value 0.001", - "Token [CLS] with Token 50960_2:Attention Value 0.001", - "Token [CLS] with Token 50931_2:Attention Value 0.001", - "Token [CLS] with Token 50912_3:Attention Value 0.002", - "Token [CLS] with Token 50902_1:Attention Value 0.001", - "Token [CLS] with Token 50893_0:Attention Value 0.003", - "Token [CLS] with Token 50885_1:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.001", - "Token [CLS] with Token 50878_3:Attention Value 0.001", - "Token [CLS] with Token 50868_4:Attention Value 0.001", - "Token [CLS] with Token 50863_1:Attention Value 0.001", - "Token [CLS] with Token 50861_3:Attention Value 0.001", - "Token [CLS] with Token 51516_1:Attention Value 0.001", - "Token [CLS] with Token 51498_4:Attention Value 0.001", - "Token [CLS] with Token 51493_2:Attention Value 0.000", - "Token [CLS] with Token 51492_1:Attention Value 0.001", - "Token [CLS] with Token 51491_2:Attention Value 0.001", - "Token [CLS] with Token 51009_0:Attention Value 0.001", - "Token [CLS] with Token 51006_1:Attention Value 0.002", - "Token [CLS] with Token 50983_0:Attention Value 0.004", - "Token [CLS] with Token 50971_0:Attention Value 0.002", - "Token [CLS] with Token 50970_1:Attention Value 0.001", - "Token [CLS] with Token 50960_0:Attention Value 0.001", - "Token [CLS] with Token 50931_2:Attention Value 0.001", - "Token [CLS] with Token 50912_3:Attention Value 0.003", - "Token [CLS] with Token 50902_2:Attention Value 0.002", - "Token [CLS] with Token 50893_0:Attention Value 0.003", - "Token [CLS] with Token 50882_0:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.001", - "Token [CLS] with Token 51301_2:Attention Value 0.002", - "Token [CLS] with Token 51279_3:Attention Value 0.002", - "Token [CLS] with Token 51277_1:Attention Value 0.002", - "Token [CLS] with Token 51265_1:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.002", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51222_3:Attention Value 0.002", - "Token [CLS] with Token 51221_3:Attention Value 0.001", - "Token [CLS] with Token 50813_3:Attention Value 0.002", - "Token [CLS] with Token 51006_1:Attention Value 0.002", - "Token [CLS] with Token 50983_1:Attention Value 0.002", - "Token [CLS] with Token 50971_1:Attention Value 0.002", - "Token [CLS] with Token 50970_3:Attention Value 0.001", - "Token [CLS] with Token 50960_0:Attention Value 0.001", - "Token [CLS] with Token 50931_2:Attention Value 0.001", - "Token [CLS] with Token 50912_3:Attention Value 0.003", - "Token [CLS] with Token 50911_0:Attention Value 0.001", - "Token [CLS] with Token 50902_1:Attention Value 0.001", - "Token [CLS] with Token 50893_0:Attention Value 0.002", - "Token [CLS] with Token 50882_0:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.002", - "Token [CLS] with Token 51301_3:Attention Value 0.002", - "Token [CLS] with Token 51279_4:Attention Value 0.002", - "Token [CLS] with Token 51277_1:Attention Value 0.002", - "Token [CLS] with Token 51265_2:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.002", - "Token [CLS] with Token 51248_4:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.002", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51516_3:Attention Value 0.001", - "Token [CLS] with Token 51498_4:Attention Value 0.001", - "Token [CLS] with Token 51493_2:Attention Value 0.001", - "Token [CLS] with Token 51491_2:Attention Value 0.001", - "Token [CLS] with Token 51482_2:Attention Value 0.001", - "Token [CLS] with Token 50813_4:Attention Value 0.003", - "Token [CLS] with Token 51301_4:Attention Value 0.002", - "Token [CLS] with Token 51279_4:Attention Value 0.002", - "Token [CLS] with Token 51277_1:Attention Value 0.002", - "Token [CLS] with Token 51265_2:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.002", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.002", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51006_1:Attention Value 0.002", - "Token [CLS] with Token 50983_1:Attention Value 0.003", - "Token [CLS] with Token 50971_1:Attention Value 0.003", - "Token [CLS] with Token 50970_1:Attention Value 0.002", - "Token [CLS] with Token 50960_0:Attention Value 0.001", - "Token [CLS] with Token 50931_4:Attention Value 0.001", - "Token [CLS] with Token 50912_3:Attention Value 0.005", - "Token [CLS] with Token 50902_2:Attention Value 0.003", - "Token [CLS] with Token 50893_2:Attention Value 0.002", - "Token [CLS] with Token 50882_0:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.002", - "Token [CLS] with Token 52135_4:Attention Value 0.003", - "Token [CLS] with Token 52135_2:Attention Value 0.002", - "Token [CLS] with Token 52075_4:Attention Value 0.002", - "Token [CLS] with Token 52074_4:Attention Value 0.002", - "Token [CLS] with Token 52073_3:Attention Value 0.001", - "Token [CLS] with Token 52069_3:Attention Value 0.002", - "Token [CLS] with Token 51301_4:Attention Value 0.002", - "Token [CLS] with Token 51279_4:Attention Value 0.003", - "Token [CLS] with Token 51277_0:Attention Value 0.004", - "Token [CLS] with Token 51265_4:Attention Value 0.001", - "Token [CLS] with Token 51256_3:Attention Value 0.003", - "Token [CLS] with Token 51254_3:Attention Value 0.002", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51244_1:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.001", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51200_2:Attention Value 0.001", - "Token [CLS] with Token 51146_2:Attention Value 0.001", - "Token [CLS] with Token 51133_2:Attention Value 0.001", - "Token [CLS] with Token 51006_1:Attention Value 0.002", - "Token [CLS] with Token 50983_1:Attention Value 0.002", - "Token [CLS] with Token 50971_3:Attention Value 0.002", - "Token [CLS] with Token 50970_3:Attention Value 0.000", - "Token [CLS] with Token 50960_3:Attention Value 0.001", - "Token [CLS] with Token 50931_2:Attention Value 0.001", - "Token [CLS] with Token 50912_2:Attention Value 0.003", - "Token [CLS] with Token 50902_2:Attention Value 0.002", - "Token [CLS] with Token 50893_3:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.001", - "Token [CLS] with Token 50868_4:Attention Value 0.001", - "Token [CLS] with Token 51301_2:Attention Value 0.001", - "Token [CLS] with Token 51279_4:Attention Value 0.001", - "Token [CLS] with Token 51277_0:Attention Value 0.001", - "Token [CLS] with Token 51265_2:Attention Value 0.000", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.001", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51006_1:Attention Value 0.001", - "Token [CLS] with Token 50983_1:Attention Value 0.001", - "Token [CLS] with Token 50971_1:Attention Value 0.001", - "Token [CLS] with Token 50970_2:Attention Value 0.000", - "Token [CLS] with Token 50960_3:Attention Value 0.001", - "Token [CLS] with Token 50931_2:Attention Value 0.000", - "Token [CLS] with Token 50912_2:Attention Value 0.001", - "Token [CLS] with Token 50902_2:Attention Value 0.001", - "Token [CLS] with Token 50893_1:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.001", - "Token [CLS] with Token 50868_4:Attention Value 0.001", - "Token [CLS] with Token 51301_2:Attention Value 0.001", - "Token [CLS] with Token 51279_4:Attention Value 0.001", - "Token [CLS] with Token 51277_0:Attention Value 0.001", - "Token [CLS] with Token 51265_2:Attention Value 0.001", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 51265_2:Attention Value 0.001", - "Token [CLS] with Token 51277_0:Attention Value 0.001", - "Token [CLS] with Token 51279_4:Attention Value 0.001", - "Token [CLS] with Token 51301_2:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 50861_1:Attention Value 0.001", - "Token [CLS] with Token 50863_1:Attention Value 0.001", - "Token [CLS] with Token 50868_4:Attention Value 0.001", - "Token [CLS] with Token 50878_0:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.001", - "Token [CLS] with Token 50885_1:Attention Value 0.001", - "Token [CLS] with Token 50902_3:Attention Value 0.003", - "Token [CLS] with Token 50912_2:Attention Value 0.002", - "Token [CLS] with Token 50931_1:Attention Value 0.001", - "Token [CLS] with Token 50971_2:Attention Value 0.002", - "Token [CLS] with Token 50983_3:Attention Value 0.002", - "Token [CLS] with Token 51006_1:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.002", - "Token [CLS] with Token 50882_0:Attention Value 0.003", - "Token [CLS] with Token 50893_1:Attention Value 0.003", - "Token [CLS] with Token 50902_2:Attention Value 0.005", - "Token [CLS] with Token 50912_2:Attention Value 0.004", - "Token [CLS] with Token 50931_2:Attention Value 0.002", - "Token [CLS] with Token 50960_2:Attention Value 0.002", - "Token [CLS] with Token 50970_2:Attention Value 0.002", - "Token [CLS] with Token 50971_1:Attention Value 0.003", - "Token [CLS] with Token 50983_1:Attention Value 0.004", - "Token [CLS] with Token 51006_1:Attention Value 0.004", - "Token [CLS] with Token 51009_2:Attention Value 0.002", - "Token [CLS] with Token 51221_4:Attention Value 0.003", - "Token [CLS] with Token 51222_4:Attention Value 0.002", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51476_0:Attention Value 0.002", - "Token [CLS] with Token 52135_4:Attention Value 0.003", - "Token [CLS] with Token 52135_4:Attention Value 0.003", - "Token [CLS] with Token 51492_1:Attention Value 0.002", - "Token [CLS] with Token 51476_0:Attention Value 0.002", - "Token [CLS] with Token 51275_0:Attention Value 0.002", - "Token [CLS] with Token 51274_0:Attention Value 0.002", - "Token [CLS] with Token 51237_1:Attention Value 0.001", - "Token [CLS] with Token 51006_0:Attention Value 0.003", - "Token [CLS] with Token 50983_0:Attention Value 0.005", - "Token [CLS] with Token 50971_1:Attention Value 0.003", - "Token [CLS] with Token 50970_0:Attention Value 0.002", - "Token [CLS] with Token 50960_4:Attention Value 0.002", - "Token [CLS] with Token 50931_2:Attention Value 0.001", - "Token [CLS] with Token 50912_3:Attention Value 0.005", - "Token [CLS] with Token 50902_0:Attention Value 0.003", - "Token [CLS] with Token 50893_2:Attention Value 0.002", - "Token [CLS] with Token 50885_1:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.004", - "Token [CLS] with Token 50878_4:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.002", - "Token [CLS] with Token 50863_2:Attention Value 0.001", - "Token [CLS] with Token 50861_4:Attention Value 0.002", - "Token [CLS] with Token 52075_3:Attention Value 0.002", - "Token [CLS] with Token 52074_3:Attention Value 0.002", - "Token [CLS] with Token 52073_2:Attention Value 0.001", - "Token [CLS] with Token 52069_3:Attention Value 0.002", - "Token [CLS] with Token 51301_4:Attention Value 0.002", - "Token [CLS] with Token 51279_3:Attention Value 0.002", - "Token [CLS] with Token 51277_1:Attention Value 0.002", - "Token [CLS] with Token 51265_3:Attention Value 0.001", - "Token [CLS] with Token 51256_3:Attention Value 0.003", - "Token [CLS] with Token 51254_2:Attention Value 0.001", - "Token [CLS] with Token 51250_3:Attention Value 0.001", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51244_1:Attention Value 0.001", - "Token [CLS] with Token 51222_4:Attention Value 0.002", - "Token [CLS] with Token 51221_4:Attention Value 0.001", - "Token [CLS] with Token 51200_2:Attention Value 0.002", - "Token [CLS] with Token 51146_2:Attention Value 0.002", - "Token [CLS] with Token 51133_3:Attention Value 0.001", - "Token [CLS] with Token 51275_0:Attention Value 0.002", - "Token [CLS] with Token 51274_1:Attention Value 0.003", - "Token [CLS] with Token 51237_2:Attention Value 0.002", - "Token [CLS] with Token 51006_0:Attention Value 0.003", - "Token [CLS] with Token 50983_0:Attention Value 0.004", - "Token [CLS] with Token 50976_2:Attention Value 0.002", - "Token [CLS] with Token 50971_0:Attention Value 0.002", - "Token [CLS] with Token 50970_2:Attention Value 0.001", - "Token [CLS] with Token 50960_2:Attention Value 0.001", - "Token [CLS] with Token 50951_1:Attention Value 0.001", - "Token [CLS] with Token 50950_2:Attention Value 0.002", - "Token [CLS] with Token 50949_1:Attention Value 0.004", - "Token [CLS] with Token 50931_3:Attention Value 0.001", - "Token [CLS] with Token 50930_3:Attention Value 0.002", - "Token [CLS] with Token 50912_3:Attention Value 0.003", - "Token [CLS] with Token 50902_0:Attention Value 0.002", - "Token [CLS] with Token 50893_0:Attention Value 0.004", - "Token [CLS] with Token 50885_1:Attention Value 0.001", - "Token [CLS] with Token 50882_0:Attention Value 0.002", - "Token [CLS] with Token 50878_4:Attention Value 0.002", - "Token [CLS] with Token 50868_4:Attention Value 0.002", - "Token [CLS] with Token 50863_2:Attention Value 0.001", - "Token [CLS] with Token 50862_3:Attention Value 0.001", - "Token [CLS] with Token 50861_3:Attention Value 0.001", - "Token [CLS] with Token 50853_1:Attention Value 0.006", - "Token [CLS] with Token 52075_2:Attention Value 0.004", - "Token [CLS] with Token 52074_3:Attention Value 0.002", - "Token [CLS] with Token 52073_3:Attention Value 0.001", - "Token [CLS] with Token 52069_0:Attention Value 0.012", - "Token [CLS] with Token 51301_1:Attention Value 0.006", - "Token [CLS] with Token 51279_3:Attention Value 0.018", - "Token [CLS] with Token 51277_1:Attention Value 0.010", - "Token [CLS] with Token 51265_2:Attention Value 0.001", - "Token [CLS] with Token 51257_0:Attention Value 0.013", - "Token [CLS] with Token 51256_1:Attention Value 0.003", - "Token [CLS] with Token 51255_1:Attention Value 0.008", - "Token [CLS] with Token 51254_4:Attention Value 0.003", - "Token [CLS] with Token 51251_1:Attention Value 0.026", - "Token [CLS] with Token 51250_2:Attention Value 0.003", - "Token [CLS] with Token 51248_3:Attention Value 0.001", - "Token [CLS] with Token 51244_2:Attention Value 0.002", - "Token [CLS] with Token 51222_4:Attention Value 0.003", - "Token [CLS] with Token 51221_3:Attention Value 0.002", - "Token [CLS] with Token 51200_3:Attention Value 0.003", - "Token [CLS] with Token 51146_0:Attention Value 0.003", - "Token [CLS] with Token 51144_2:Attention Value 0.002", - "Token [CLS] with Token 51143_0:Attention Value 0.002", - "Token [CLS] with Token 51133_2:Attention Value 0.009", - "Token [CLS] with Token 51006_0:Attention Value 0.007", - "Token [CLS] with Token 50983_0:Attention Value 0.014", - "Token [CLS] with Token 50971_0:Attention Value 0.006", - "Token [CLS] with Token 50970_1:Attention Value 0.002", - "Token [CLS] with Token 50960_2:Attention Value 0.004", - "Token [CLS] with Token 50950_2:Attention Value 0.005", - "Token [CLS] with Token 50931_2:Attention Value 0.003", - "Token [CLS] with Token 50912_3:Attention Value 0.008", - "Token [CLS] with Token 50902_0:Attention Value 0.011", - "Token [CLS] with Token 50893_0:Attention Value 0.154", - "Token [CLS] with Token 50885_1:Attention Value 0.002", - "Token [CLS] with Token 50882_0:Attention Value 0.004", - "Token [CLS] with Token 50878_4:Attention Value 0.009", - "Token [CLS] with Token 50868_4:Attention Value 0.007", - "Token [CLS] with Token 50863_2:Attention Value 0.005", - "Token [CLS] with Token 50861_3:Attention Value 0.002", - "Token [CLS] with Token 52075_2:Attention Value 0.004", - "Token [CLS] with Token [VE]:Attention Value 0.006", - "Token [CLS] with Token [REG]:Attention Value 0.003" - ] - ], - "type": "heatmap", - "x": [ - "[CLS]", - "[VS]", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "10", - "11", - "12", - "13", - "14", - "15", - "16", - "17", - "18", - "19", - "20", - "21", - "22", - "23", - "24", - "25", - "26", - "27", - "28", - "29", - "30", - "31", - "32", - "33", - "34", - "35", - "36", - "37", - "38", - "39", - "40", - "41", - "42", - "43", - "44", - "45", - "46", - "47", - "48", - "49", - "50", - "51", - "52", - "53", - "54", - "55", - "56", - "57", - "58", - "59", - "60", - "61", - "62", - "63", - "64", - "65", - "66", - "67", - "68", - "69", - "70", - "71", - "72", - "73", - "74", - "75", - "76", - "77", - "78", - "79", - "80", - "81", - "82", - "83", - "84", - "85", - "86", - "87", - "88", - "89", - "90", - "91", - "92", - "93", - "94", - "95", - "96", - "97", - "98", - "99", - "100", - "101", - "102", - "103", - "104", - "105", - "106", - "107", - "108", - "109", - "110", - "111", - "112", - "113", - "114", - "115", - "116", - "117", - "118", - "119", - "120", - "121", - "122", - "123", - "124", - "125", - "126", - "127", - "128", - "129", - "130", - "131", - "132", - "133", - "134", - "135", - "136", - "137", - "138", - "139", - "140", - "141", - "142", - "143", - "144", - "145", - "146", - "147", - "148", - "149", - "150", - "151", - "152", - "153", - "154", - "155", - "156", - "157", - "158", - "159", - "160", - "161", - "162", - "163", - "164", - "165", - "166", - "167", - "168", - "169", - "170", - "171", - "172", - "173", - "174", - "175", - "176", - "177", - "178", - "179", - "180", - "181", - "182", - "183", - "184", - "185", - "186", - "187", - "188", - "189", - "190", - "191", - "192", - "193", - "194", - "195", - "196", - "197", - "198", - "199", - "200", - "201", - "202", - "203", - "204", - "205", - "206", - "207", - "208", - "209", - "210", - "211", - "212", - "213", - "214", - "215", - "216", - "217", - "218", - "219", - "220", - "221", - "222", - "223", - "224", - "225", - "226", - "227", - "228", - "229", - "230", - "231", - "232", - "233", - "234", - "235", - "236", - "237", - "238", - "239", - "240", - "241", - "242", - "243", - "244", - "245", - "246", - "247", - "248", - "249", - "250", - "251", - "252", - "253", - "254", - "255", - "256", - "257", - "258", - "259", - "260", - "261", - "262", - "263", - "264", - "265", - "266", - "267", - "268", - "269", - "270", - "271", - "272", - "273", - "274", - "275", - "276", - "277", - "278", - "279", - "280", - "281", - "282", - "283", - "284", - "285", - "286", - "287", - "288", - "289", - "290", - "291", - "292", - "293", - "294", - "295", - "296", - "297", - "298", - "299", - "300", - "301", - "302", - "303", - "304", - "305", - "306", - "307", - "308", - "309", - "310", - "311", - "312", - "313", - "314", - "315", - "316", - "317", - "318", - "319", - "320", - "321", - "322", - "323", - "324", - "325", - "326", - "327", - "328", - "329", - "330", - "331", - "332", - "333", - "334", - "335", - "336", - "337", - "338", - "339", - "340", - "341", - "342", - "343", - "344", - "345", - "346", - "347", - "348", - "349", - "350", - "351", - "352", - "353", - "354", - "355", - "356", - "357", - "358", - "359", - "360", - "361", - "362", - "363", - "364", - "365", - "366", - "367", - "368", - "369", - "370", - "371", - "372", - "373", - "374", - "375", - "[VE]", - "[REG]" - ], - "y": [ - "[CLS]" - ], - "z": [ - [ - 0.008451767265796661, - 0.00047753742546774447, - 0.001319802482612431, - 0.002615319797769189, - 0.0005728917894884944, - 0.0004854540165979415, - 0.0007460298365913332, - 0.004369142930954695, - 0.0021512836683541536, - 0.0010651483898982406, - 0.006857313681393862, - 0.0009411114733666182, - 0.001146065187640488, - 0.003502822481095791, - 0.008377399295568466, - 0.00369816436432302, - 0.0036342681851238017, - 0.001818674267269671, - 0.0011760754277929664, - 0.005098414141684771, - 0.003414525417611003, - 0.0005102005670778453, - 0.009316006675362589, - 0.0036000509280711412, - 0.0010369595838710666, - 0.0019885266665369272, - 0.0004615136131178588, - 0.0008582020527683198, - 0.004428019281476736, - 0.0036287240218371153, - 0.003203341970220208, - 0.003022358985617757, - 0.004367131274193525, - 0.005930684972554445, - 0.0006912555545568466, - 0.0003829135384876281, - 0.002646421082317829, - 0.001659494242630899, - 0.0008906491566449404, - 0.0034986960235983133, - 0.004725936334580183, - 0.003188258036971092, - 0.004636996425688267, - 0.0017536324448883531, - 0.003574026981368661, - 0.0004167505830992013, - 0.001605302095413208, - 0.0034259206149727106, - 0.0010352873941883445, - 0.008988478220999241, - 0.004783918149769306, - 0.0020494854543358088, - 0.0008181784651242197, - 0.005045545753091574, - 0.0008356395992450416, - 0.0003578341274987906, - 0.004442144185304642, - 0.0007386466022580862, - 0.0015157737070694566, - 0.0003436319238971919, - 0.00029586421442218125, - 0.0004052982258144766, - 0.0004121463280171156, - 0.0006444680620916188, - 0.0006272507016547024, - 0.0003953258565161377, - 0.0005427289288491011, - 0.000674175564199686, - 0.0005934844375588, - 0.0005067584570497274, - 0.00048668019007891417, - 0.0006641843356192112, - 0.001736009493470192, - 0.0006450659711845219, - 0.0006109003443270922, - 0.0014823578530922532, - 0.0007514643366448581, - 0.0010008058743551371, - 0.0005770500865764916, - 0.0009993098210543394, - 0.0013827277580276132, - 0.0017192121595144272, - 0.0011489095631986856, - 0.0004682957660406828, - 0.0008927697781473398, - 0.000683685124386102, - 0.0006812589126639068, - 0.0005206714267842472, - 0.0004224051663186401, - 0.0008084981818683445, - 0.001108424854464829, - 0.0009732283069752156, - 0.0007592075853608549, - 0.0005552212824113667, - 0.00127204111777246, - 0.002062402432784438, - 0.00112050655297935, - 0.0008525322191417217, - 0.0008413223549723625, - 0.000753440021071583, - 0.0019989104475826025, - 0.0010396834695711732, - 0.00325021636672318, - 0.0006715958588756621, - 0.0012460827128961682, - 0.000999152078293264, - 0.0009655823814682662, - 0.0005084593431092799, - 0.0005497524398379028, - 0.0006103277555666864, - 0.001034797285683453, - 0.0004598199157044291, - 0.0008794588502496481, - 0.000696350762154907, - 0.001486221794039011, - 0.0017987970495596528, - 0.003589141182601452, - 0.001990983495488763, - 0.0010251256171613932, - 0.001364864525385201, - 0.0008277010056190193, - 0.0025306493043899536, - 0.0021720887161791325, - 0.003414546838030219, - 0.0019456190057098863, - 0.0013443160569295287, - 0.001797817298211157, - 0.0022308507468551397, - 0.0018264207756146789, - 0.0006515602581202984, - 0.0015126297948881984, - 0.0009584605577401816, - 0.0015974263660609722, - 0.0012428780319169164, - 0.0018666594987735152, - 0.0018526790663599968, - 0.001926859957166016, - 0.001790812355466187, - 0.0006914885598234832, - 0.0009147145319730044, - 0.0006848545162938535, - 0.003264268860220909, - 0.001066222321242094, - 0.0014806516701355577, - 0.0019578184001147747, - 0.001972360536456108, - 0.001908112782984972, - 0.0017941169207915664, - 0.002000534441322088, - 0.00232343259267509, - 0.0006864610477350652, - 0.0015883835731074214, - 0.001452704076655209, - 0.0015260959044098854, - 0.001217573182657361, - 0.0007630472537130117, - 0.0009102648473344744, - 0.0010583768598735332, - 0.0010495075257495046, - 0.0013496936298906803, - 0.002790031023323536, - 0.0016460781916975975, - 0.001800714642740786, - 0.001588907209224999, - 0.0006608786643482745, - 0.0015428924234583974, - 0.00130733463447541, - 0.0019498993642628193, - 0.0011935298098251224, - 0.002081812592223286, - 0.0027750423178076744, - 0.002682143822312355, - 0.0017559939296916127, - 0.0014839059440419078, - 0.0011782203800976276, - 0.005421904847025871, - 0.0033844260033220053, - 0.0019421830074861648, - 0.002427612664178014, - 0.0021077985875308514, - 0.002597493352368474, - 0.0019147455459460616, - 0.0022990747820585966, - 0.00196700356900692, - 0.0013677203096449375, - 0.0017343414947390556, - 0.0022338901180773973, - 0.002618071623146534, - 0.003806268097832799, - 0.0009631214197725058, - 0.0029177824035286903, - 0.0015531842363998294, - 0.00118330551777035, - 0.0009476374834775924, - 0.0007008272805251181, - 0.0012339174281805754, - 0.0011241122847422955, - 0.0009876210242509842, - 0.0008716397569514811, - 0.001008612453006208, - 0.0015318739460781217, - 0.0016837789444252849, - 0.0015550237149000168, - 0.00042292437865398824, - 0.0012427183100953698, - 0.0006267700809985399, - 0.0025690405163913965, - 0.002414754591882229, - 0.0008665476925671101, - 0.0007410055841319263, - 0.0011087892344221473, - 0.0014391514705494046, - 0.001120170229114592, - 0.0013088533887639642, - 0.0004123076796531677, - 0.0007661065901629627, - 0.0005386698176153004, - 0.0007872642017900944, - 0.0005083107971586287, - 0.0009654506575316192, - 0.0009281352977268398, - 0.0006618788465857506, - 0.0004174182831775397, - 0.0006613184814341366, - 0.0003955806605517864, - 0.0011703699128702285, - 0.0007652677595615387, - 0.0009616908500902356, - 0.000610732939094305, - 0.0005696555017493665, - 0.000935644842684269, - 0.0005050115869380534, - 0.0008104036678560078, - 0.0005484815337695181, - 0.0005282653728500009, - 0.0006690899026580155, - 0.0006202083895914257, - 0.0011858937796205282, - 0.0006323098205029964, - 0.0012637152103707194, - 0.0010147326393052936, - 0.001434296485967934, - 0.0010879833716899157, - 0.0007398917223326862, - 0.0006289978045970201, - 0.0010291151702404022, - 0.0007981405942700803, - 0.0012479289434850216, - 0.000667006301227957, - 0.0027725808322429657, - 0.002342016203328967, - 0.0006513544940389693, - 0.0018907733028754592, - 0.002159562660381198, - 0.0017471732571721077, - 0.0015442147850990295, - 0.003038904396817088, - 0.003148901043459773, - 0.005466174334287643, - 0.004158675670623779, - 0.001521154772490263, - 0.002499047899618745, - 0.0023495368659496307, - 0.0034855001140385866, - 0.004092938732355833, - 0.003764112712815404, - 0.0024812589399516582, - 0.0029548786114901304, - 0.00236957217566669, - 0.0012237809132784605, - 0.001522903912700713, - 0.0026057965587824583, - 0.0028801767621189356, - 0.0020971631165593863, - 0.0017077381489798429, - 0.002488095546141267, - 0.001942558796145022, - 0.0012060959124937654, - 0.0032189686316996813, - 0.004897533915936947, - 0.002546509262174368, - 0.002436422742903233, - 0.0016161665553227067, - 0.0008353215525858104, - 0.004716587718576193, - 0.003007955150678754, - 0.001906886580400169, - 0.0010407179361209271, - 0.0036604618653655057, - 0.0019616391509771347, - 0.001880012801848352, - 0.0010364892659708858, - 0.0020254242699593306, - 0.001567527768202126, - 0.0018129348754882812, - 0.0011344858212396502, - 0.001787062268704176, - 0.002030221512541175, - 0.002426884835585952, - 0.0018259655917063355, - 0.0006956890574656427, - 0.0026017874479293823, - 0.0008315623854286969, - 0.0014645435148850083, - 0.0009511765674687922, - 0.0009218405466526748, - 0.0015906983753666282, - 0.0014132362557575109, - 0.0017343055224046111, - 0.0019996073096990585, - 0.0013999902876093984, - 0.0021138533484190702, - 0.002869958057999611, - 0.0016324784373864532, - 0.003336323192343116, - 0.004306184593588114, - 0.0022455891594290733, - 0.0018202642677351832, - 0.0014412502059713006, - 0.0014982676366344094, - 0.0010419118916615844, - 0.0016191820614039898, - 0.003920895978808403, - 0.0013921178178861735, - 0.0016874541761353612, - 0.003467828966677189, - 0.0024901803117245436, - 0.003774721873924136, - 0.0008989623747766018, - 0.0018224921077489853, - 0.001697454135864973, - 0.002475267043337226, - 0.001093029393814504, - 0.0014057605294510722, - 0.0009937405120581388, - 0.006045807618647814, - 0.003723071655258536, - 0.0024961901362985373, - 0.0010668792529031634, - 0.012105301953852177, - 0.006213821936398745, - 0.018133636564016346, - 0.010110520757734776, - 0.001374133862555027, - 0.013326204381883144, - 0.002680495148524642, - 0.008382327854633331, - 0.0030930601060390472, - 0.025627313181757927, - 0.0031410318333655596, - 0.001220939448103309, - 0.0015257663326337934, - 0.0025776990223675966, - 0.0020443866960704327, - 0.003117623971775174, - 0.0028798922430723906, - 0.001909662154503167, - 0.0018645658856257796, - 0.00894069205969572, - 0.006883753929287195, - 0.014036208391189575, - 0.0055046118795871735, - 0.0024932369124144316, - 0.0035899977665394545, - 0.004715373273938894, - 0.002608843147754669, - 0.007800393272191286, - 0.011004121042788029, - 0.15411949157714844, - 0.0017875274643301964, - 0.0042082518339157104, - 0.009378299117088318, - 0.007050686050206423, - 0.005348358768969774, - 0.0016888940008357167, - 0.004475145135074854, - 0.006452843081206083, - 0.0025523456279188395 - ] - ] - } - ], - "layout": { - "annotations": [ - { - "bgcolor": "red", - "font": { - "color": "black", - "size": 10 - }, - "opacity": 0.8, - "showarrow": false, - "text": "[CLS]", - "textangle": -90, - "x": 0, - "xref": "x", - "y": 0.5, - "yref": "paper" - }, - { - "bgcolor": "red", - "font": { - "color": "black", - "size": 10 - }, - "opacity": 0.8, - "showarrow": false, - "text": "[VS]", - "textangle": -90, - "x": 1, - "xref": "x", - "y": 0.5, - "yref": "paper" - }, - { - "bgcolor": "red", - "font": { - "color": "black", - "size": 10 - }, - "opacity": 0.8, - "showarrow": false, - "text": "[VE]", - "textangle": -90, - "x": 376, - "xref": "x", - "y": 0.5, - "yref": "paper" - }, - { - "bgcolor": "red", - "font": { - "color": "black", - "size": 10 - }, - "opacity": 0.8, - "showarrow": false, - "text": "[REG]", - "textangle": -90, - "x": 377, - "xref": "x", - "y": 0.5, - "yref": "paper" - } - ], - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - }, - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "pattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "fillpattern": { - "fillmode": "overlay", - "size": 10, - "solidity": 0.2 - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "autotypenumbers": "strict", - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - }, - "title": { - "text": "Attention Visualization" - }, - "xaxis": { - "nticks": 378, - "tickangle": -90, - "title": { - "text": "Token in Input Sequence" - } - }, - "yaxis": { - "nticks": 1, - "title": { - "text": "Token in Input Sequence" - } - } - } - }, - "text/html": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def visualize_attention(\n", - " attention_weights,\n", - " patient,\n", - " special_tokens,\n", - " tokenizer,\n", - " truncate=False,\n", - " only_cls=False,\n", - " top_k=10,\n", - "):\n", - " # Convert attention tensor to numpy array and squeeze the batch dimension\n", - " concept_ids = patient[\"concept_ids\"].squeeze(0).cpu().numpy()\n", - " attention_weights = attention_weights.squeeze(0).cpu().numpy()\n", - "\n", - " # Truncate attention weights if specified\n", - " if truncate:\n", - " truncate_at = patient[\"attention_mask\"].sum().numpy()\n", - " attention_weights = attention_weights[:, :truncate_at, :truncate_at]\n", - " concept_ids = concept_ids[:truncate_at]\n", - "\n", - " if only_cls:\n", - " attention_weights = attention_weights[:, :1, :]\n", - "\n", - " # Average attention weights across heads\n", - " attention_weights = attention_weights.mean(axis=0)\n", - "\n", - " # Generate token labels, marking special tokens with a special symbol\n", - " x_token_labels = [\n", - " f\"{tokenizer.id_to_token(token)}\"\n", - " if tokenizer.id_to_token(token) in special_tokens\n", - " else str(i)\n", - " for i, token in enumerate(concept_ids)\n", - " ]\n", - " y_token_labels = [\"[CLS]\"]\n", - "\n", - " # Generate hover text\n", - " hover_text = [\n", - " [\n", - " f\"Token {tokenizer.id_to_token(concept_ids[row])} with Token {tokenizer.id_to_token(concept_ids[col])}:\"\n", - " f\"Attention Value {attention_weights[row, col]:.3f}\"\n", - " for col in range(attention_weights.shape[1])\n", - " ]\n", - " for row in range(attention_weights.shape[0])\n", - " ]\n", - "\n", - " # Generate annotations for special tokens\n", - " annotations = []\n", - " for i, token in enumerate(concept_ids):\n", - " if tokenizer.id_to_token(token) in special_tokens:\n", - " annotations.append(\n", - " dict(\n", - " x=i,\n", - " y=0.5,\n", - " xref=\"x\",\n", - " yref=\"paper\", # Use 'paper' coordinates for y\n", - " text=tokenizer.id_to_token(token),\n", - " showarrow=False,\n", - " font=dict(color=\"black\", size=10),\n", - " textangle=-90,\n", - " bgcolor=\"red\",\n", - " opacity=0.8,\n", - " ),\n", - " )\n", - "\n", - " # Plot the attention matrix as a heatmap\n", - " fig = go.Figure(\n", - " data=go.Heatmap(\n", - " z=attention_weights,\n", - " x=x_token_labels,\n", - " y=y_token_labels,\n", - " hoverongaps=False,\n", - " hoverinfo=\"text\",\n", - " text=hover_text,\n", - " colorscale=\"YlGnBu\",\n", - " ),\n", - " )\n", - "\n", - " fig.update_layout(\n", - " title=\"Attention Visualization\",\n", - " xaxis_nticks=len(concept_ids),\n", - " yaxis_nticks=len(y_token_labels),\n", - " xaxis_title=\"Token in Input Sequence\",\n", - " yaxis_title=\"Token in Input Sequence\",\n", - " annotations=annotations,\n", - " xaxis_tickangle=-90,\n", - " )\n", - "\n", - " # Print top k tokens with their attention values\n", - " top_k_indices = np.argsort(-attention_weights, axis=None)[:top_k]\n", - " top_k_values = attention_weights.flatten()[top_k_indices]\n", - " top_k_indices = np.unravel_index(top_k_indices, attention_weights.shape)\n", - "\n", - " for idx in range(len(top_k_indices[0])):\n", - " token1 = top_k_indices[0][idx]\n", - " token2 = top_k_indices[1][idx]\n", - " attention_value = top_k_values[idx]\n", - " print(\n", - " f\"Token {tokenizer.id_to_token(concept_ids[token1])} \"\n", - " f\"with Token {tokenizer.id_to_token(concept_ids[token2])}: \"\n", - " f\"Attention Value {attention_value:.3f}\",\n", - " )\n", - "\n", - " fig.show()\n", - "\n", - "\n", - "# Visualize the attention matrix with special tokens\n", - "special_tokens = [\"[CLS]\", \"[VS]\", \"[VE]\", \"[REG]\"]\n", - "visualize_attention(\n", - " last_attention_matrix,\n", - " patient=patient,\n", - " special_tokens=special_tokens,\n", - " tokenizer=tokenizer,\n", - " truncate=True,\n", - " only_cls=True,\n", - " top_k=15,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92da688f", - "metadata": {}, - "outputs": [], - "source": [ - "# def visualize_attention(attention_weights, patient, special_tokens, tokenizer, truncate=False, only_cls=False, top_k=10):\n", - "# # Convert attention tensor to numpy array and squeeze the batch dimension\n", - "# concept_ids = patient['concept_ids'].squeeze(0).cpu().numpy()\n", - "# attention_weights = attention_weights.squeeze(0).cpu().numpy()\n", - "\n", - "# # Truncate attention weights if specified\n", - "# if truncate:\n", - "# truncate_at = patient['attention_mask'].sum().numpy()\n", - "# attention_weights = attention_weights[:, :truncate_at, :truncate_at]\n", - "# concept_ids = concept_ids[:truncate_at]\n", - "\n", - "# if only_cls:\n", - "# attention_weights = attention_weights[:, :1, :]\n", - "\n", - "# # Average attention weights across heads\n", - "# attention_weights = attention_weights.mean(axis=0)\n", - "\n", - "# # Generate token labels, replacing special tokens with their names\n", - "# token_labels = [tokenizer.id_to_token(token) if tokenizer.id_to_token(token) in special_tokens else '' for token in concept_ids]\n", - "\n", - "# # Plot the attention matrix as a heatmap\n", - "# sns.set_theme(font_scale=1.2)\n", - "# plt.figure(figsize=(15, 12))\n", - "# ax = sns.heatmap(attention_weights, cmap=\"YlGnBu\", linewidths=.5, annot=False, cbar=True)\n", - "# ax.set_title('Attention Visualization')\n", - "\n", - "# # Set custom tick labels\n", - "# # ax.set_xticks(np.arange(len(token_labels)) + 0.5)\n", - "# # ax.set_xticklabels(token_labels, rotation=45, ha='right', fontsize=10)\n", - "# # ax.set_yticks(np.arange(len(token_labels)) + 0.5)\n", - "# # ax.set_yticklabels(token_labels, rotation=0, ha='right', fontsize=10)\n", - "\n", - "# ax.set_xlabel('Token in Input Sequence')\n", - "# ax.set_ylabel('Token in Input Sequence')\n", - "\n", - "# # Print top k tokens with their attention values\n", - "# top_k_indices = np.argsort(-attention_weights, axis=None)[:top_k]\n", - "# top_k_values = attention_weights.flatten()[top_k_indices]\n", - "# top_k_indices = np.unravel_index(top_k_indices, attention_weights.shape)\n", - "\n", - "# for idx in range(len(top_k_indices[0])):\n", - "# token1 = top_k_indices[0][idx]\n", - "# token2 = top_k_indices[1][idx]\n", - "# attention_value = top_k_values[idx]\n", - "# print(f\"Token {tokenizer.id_to_token(concept_ids[token1])} \"\n", - "# f\"with Token {tokenizer.id_to_token(concept_ids[token2])}: \"\n", - "# f\"Attention Value {attention_value}\")\n", - "\n", - "# plt.show()\n", - "\n", - "\n", - "# # Visualize the attention matrix with special tokens\n", - "# special_tokens = ['[CLS]', '[VS]', '[VE]', '[REG]'] # Update this list with your actual special tokens\n", - "# visualize_attention(last_attention_matrix, patient=patient, special_tokens=special_tokens, tokenizer=tokenizer, truncate=True, only_cls=True, top_k=25)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf327908", - "metadata": {}, - "outputs": [], - "source": [ - "# Model view\n", - "html_model_view = model_view(\n", - " truncated_attention_matrix,\n", - " truncated_tokens,\n", - " include_layers=[5],\n", - " include_heads=[0, 1, 2, 3, 4, 5],\n", - " display_mode=\"light\",\n", - " html_action=\"return\",\n", - ")\n", - "\n", - "with open(\"model_view.html\", \"w\") as file:\n", - " file.write(html_model_view.data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "acae54e37e7d407bbb7b55eff062a284", - "metadata": {}, - "outputs": [], - "source": [ - "# Head View\n", - "html_head_view = head_view(\n", - " truncated_attention_matrix,\n", - " truncated_tokens,\n", - " # include_layers=[5],\n", - " html_action=\"return\",\n", - ")\n", - "\n", - "with open(\"head_view.html\", \"w\") as file:\n", - " file.write(html_head_view.data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c13d4bf", - "metadata": {}, - "outputs": [], - "source": [ - "# Neuron View\n", - "model_type = \"bert\"\n", - "\n", - "show(model, model_type, tokenizer, display_mode=\"dark\", layer=5, head=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e4ee069c3f083f", - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-15T13:59:33.204805Z", - "start_time": "2024-03-15T13:59:33.193223Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "# Visualize REG token -> Tricky?\n", - "# DONE Why the row vs column attention differs? -> What the matrix actually represents\n", - "# Include one example patient and visualize the attention matrix -> Include the exact concept token\n", - "# Some sort of markers to separate visits and special tokens\n", - "# Libraries used for attention visualization -> Amrit suggestion\n", - "# Visualize the gradients" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/evaluation/CompareAUROC-Poster.ipynb b/evaluation/CompareAUROC-Poster.ipynb deleted file mode 100644 index 9e92592..0000000 --- a/evaluation/CompareAUROC-Poster.ipynb +++ /dev/null @@ -1,158 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "\"\"\"\n", - "File: CompareAUROC-Poster.ipynb\n", - "---------------------------------\n", - "Compare performance of XGBoost to BigBird & Bi-LSTM using the AUROC curve for all three models on the same test set\n", - "Used to generate the AUROC curves on the poster showcased in Vector Institute's Research Symposium, on February 9\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Import dependencies and define useful constants\n", - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from scipy.special import expit\n", - "from sklearn.metrics import (\n", - " roc_auc_score,\n", - " roc_curve,\n", - ")\n", - "\n", - "\n", - "plt.style.use(\"seaborn-v0_8\")\n", - "%matplotlib inline\n", - "\n", - "TEST_SIZE = \"512\"\n", - "TEST_GROUP = \"two_weeks\"\n", - "TRANSFORMER_TEST_GROUP = \"week\" if TEST_GROUP == \"two_weeks\" else \"month\"\n", - "\n", - "ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n", - "DATA_ROOT = f\"{ROOT}/data/slurm_data/{TEST_SIZE}/{TEST_GROUP}\"\n", - "os.chdir(ROOT)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Load predictions, labels, and probabilities of different models\n", - "y_xgboost_pred = np.load(f\"{ROOT}/xgboost_y_test_pred_{TEST_GROUP}.npy\")\n", - "y_xgboost_labels = np.load(f\"{ROOT}/xgboost_y_test_pred_{TEST_GROUP}_labels.npy\")\n", - "y_xgboost_prob = np.load(f\"{ROOT}/xgboost_y_test_pred_{TEST_GROUP}_prob.npy\")\n", - "y_xgboost_prob = y_xgboost_prob[:, 1]\n", - "\n", - "y_lstm_pred = np.load(f\"{ROOT}/lstm_y_test_pred_{TEST_GROUP}.npy\")\n", - "y_lstm_labels = np.load(f\"{ROOT}/lstm_y_test_pred_{TEST_GROUP}_labels.npy\")\n", - "y_lstm_prob = np.load(f\"{ROOT}/lstm_y_test_pred_{TEST_GROUP}_prob.npy\")\n", - "\n", - "y_transformer_pred = np.load(\n", - " f\"/ssd003/projects/aieng/public/odyssey/results/test_preds_{TRANSFORMER_TEST_GROUP}.npy\",\n", - ")\n", - "y_transformer_labels = np.load(\n", - " f\"/ssd003/projects/aieng/public/odyssey/results/test_labels_{TRANSFORMER_TEST_GROUP}.npy\",\n", - ")\n", - "y_transformer_prob = np.load(\n", - " f\"/ssd003/projects/aieng/public/odyssey/results/test_prob_{TRANSFORMER_TEST_GROUP}.npy\",\n", - ")\n", - "y_transformer_prob = expit(y_transformer_prob[:, 1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Plot ROC Curve for XGBoost, Bi-LSTM, and Transformer\n", - "fpr_xgboost, tpr_xgboost, _ = roc_curve(y_xgboost_labels, y_xgboost_prob)\n", - "fpr_lstm, tpr_lstm, _ = roc_curve(y_lstm_labels, y_lstm_prob)\n", - "fpr_transformer, tpr_transformer, _ = roc_curve(\n", - " y_transformer_labels,\n", - " y_transformer_prob,\n", - ")\n", - "\n", - "# AUROC\n", - "y_xgboost_auroc = roc_auc_score(y_xgboost_labels, y_xgboost_prob)\n", - "y_lstm_auroc = roc_auc_score(y_lstm_labels, y_lstm_prob)\n", - "transformer_auroc = roc_auc_score(y_transformer_labels, y_transformer_prob)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Plot Information\n", - "plt.figure(figsize=(8, 10))\n", - "\n", - "plt.plot(\n", - " fpr_transformer,\n", - " tpr_transformer,\n", - " label=f\"BigBird = {transformer_auroc:.2f}\",\n", - " color=\"red\",\n", - ")\n", - "plt.plot(\n", - " fpr_xgboost,\n", - " tpr_xgboost,\n", - " label=f\"XGBoost = {y_xgboost_auroc:.2f}\",\n", - " color=\"green\",\n", - ")\n", - "plt.plot(fpr_lstm, tpr_lstm, label=f\"Bi-LSTM = {y_lstm_auroc:.2f}\", color=\"blue\")\n", - "plt.plot([0, 1], [0, 1], linestyle=\"--\", color=\"gray\", label=\"Random\")\n", - "\n", - "plt.xlabel(\"False Positive Rate\")\n", - "plt.ylabel(\"True Positive Rate\")\n", - "plt.title(\"ROC Curve - Two-Weeks Mortality Prediction\")\n", - "plt.legend(loc=\"lower right\", fontsize=\"large\", facecolor=\"white\", frameon=True)\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/evaluation/TestAnalysis.ipynb b/evaluation/TestAnalysis.ipynb deleted file mode 100644 index f39ce8d..0000000 --- a/evaluation/TestAnalysis.ipynb +++ /dev/null @@ -1,266 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2024-04-10 12:13:14,754] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" - ] - } - ], - "source": [ - "import os\n", - "\n", - "import torch\n", - "from sklearn.metrics import (\n", - " auc,\n", - " average_precision_score,\n", - " balanced_accuracy_score,\n", - " f1_score,\n", - " precision_recall_curve,\n", - " precision_score,\n", - " recall_score,\n", - " roc_auc_score,\n", - ")\n", - "from transformers import utils\n", - "\n", - "\n", - "utils.logging.set_verbosity_error() # Suppress standard warnings\n", - "\n", - "\n", - "ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n", - "os.chdir(ROOT)\n", - "\n", - "from odyssey.data.tokenizer import ConceptTokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "class config:\n", - " \"\"\"Save the configuration arguments.\"\"\"\n", - "\n", - " model_path = \"test_epoch_end.ckpt\"\n", - " vocab_dir = \"data/vocab\"\n", - " data_dir = \"data/bigbird_data\"\n", - " sequence_file = \"patient_sequences/patient_sequences_2048_mortality.parquet\"\n", - " id_file = \"patient_id_dict/dataset_2048_mortality_1month.pkl\"\n", - " valid_scheme = \"few_shot\"\n", - " num_finetune_patients = \"20000\"\n", - " label_name = \"label_mortality_1month\"\n", - "\n", - " max_len = 2048\n", - " batch_size = 1\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = ConceptTokenizer(data_dir=config.vocab_dir)\n", - "tokenizer.fit_on_vocab()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = torch.load(config.model_path, map_location=config.device)\n", - "model.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'loss': tensor(0.1638, dtype=torch.float64),\n", - " 'preds': tensor([6, 7, 0, ..., 7, 7, 7]),\n", - " 'labels': tensor([[1., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [1., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.]], dtype=torch.float64),\n", - " 'logits': tensor([[ 2.3418, -2.0781, 0.2194, ..., -5.3945, -7.1797, -3.4180],\n", - " [-1.6533, -3.4277, -6.8086, ..., -6.8359, -5.2266, -5.6484],\n", - " [ 1.0947, -3.7930, -6.1094, ..., -6.3867, -6.6836, -5.5508],\n", - " ...,\n", - " [-2.8223, -3.7285, -4.6797, ..., -7.9922, -5.6992, -6.7812],\n", - " [-3.7148, -5.6328, -6.7188, ..., -9.4062, -7.6445, -7.6016],\n", - " [-2.5840, -2.2871, -4.6484, ..., -7.8633, -4.7539, -6.4648]],\n", - " dtype=torch.float16)}" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_outputs = torch.load(\"test_outputs.pt\")\n", - "test_outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'Balanced Accuracy': 0.5, 'F1 Score': 0.0, 'Precision': 0.0, 'Recall': 0.0, 'AUROC': 0.8100258785715974, 'Average Precision Score': 0.001364147006900979, 'AUC-PR': 0.5006820735034505}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/fs01/home/afallah/light/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", - " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" - ] - } - ], - "source": [ - "def calculate_metrics(y_true, y_pred, y_prob):\n", - " \"\"\"\n", - " Calculate and return performance metrics.\n", - " \"\"\"\n", - " metrics = {\n", - " \"Balanced Accuracy\": balanced_accuracy_score(y_true, y_pred),\n", - " \"F1 Score\": f1_score(y_true, y_pred),\n", - " \"Precision\": precision_score(y_true, y_pred),\n", - " \"Recall\": recall_score(y_true, y_pred),\n", - " \"AUROC\": roc_auc_score(y_true, y_prob),\n", - " \"Average Precision Score\": average_precision_score(y_true, y_pred),\n", - " }\n", - "\n", - " precision, recall, _ = precision_recall_curve(y_true, y_pred)\n", - " metrics[\"AUC-PR\"] = auc(recall, precision)\n", - "\n", - " return metrics\n", - "\n", - "\n", - "targets = [10]\n", - "\n", - "for i in targets:\n", - " labels = test_outputs[\"labels\"][:, i]\n", - " logits = torch.sigmoid(test_outputs[\"logits\"][:, i])\n", - " preds = (logits >= 0.5).int()\n", - "\n", - " print(calculate_metrics(labels, preds, logits))" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0)" - ] - }, - "execution_count": 68, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preds.sum()" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(34., dtype=torch.float64)" - ] - }, - "execution_count": 69, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "labels.sum()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([0, 0, 0, ..., 0, 0, 0], dtype=torch.int32)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "preds" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/finetune.py b/finetune.py index 628c809..ab4e1a3 100644 --- a/finetune.py +++ b/finetune.py @@ -19,15 +19,15 @@ from skmultilearn.model_selection import iterative_train_test_split from torch.utils.data import DataLoader +from odyssey.utils.utils import seed_everything from odyssey.data.dataset import FinetuneDataset, FinetuneMultiDataset from odyssey.data.tokenizer import ConceptTokenizer from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain -from odyssey.models.utils import ( +from odyssey.models.model_utils import ( get_run_id, load_config, load_finetune_data, - seed_everything, ) diff --git a/odyssey/data/bigbird_data/DataChecker.ipynb b/odyssey/data/bigbird_data/DataChecker.ipynb deleted file mode 100644 index 4944995..0000000 --- a/odyssey/data/bigbird_data/DataChecker.ipynb +++ /dev/null @@ -1,1136 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:14:45.546088300Z", - "start_time": "2024-03-13T16:14:43.587090300Z" - }, - "collapsed": true - }, - "outputs": [], - "source": [ - "import os\n", - "import pickle\n", - "import random\n", - "import sys\n", - "from typing import Any, Dict, List, Optional\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.model_selection import train_test_split\n", - "from skmultilearn.model_selection import iterative_train_test_split\n", - "\n", - "\n", - "sys.path.append(\"/h/afallah/odyssey/odyssey/lib\")\n", - "from utils import save_object_to_disk\n", - "\n", - "\n", - "DATA_ROOT = \"/h/afallah/odyssey/odyssey/data/bigbird_data\"\n", - "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n", - "MAX_LEN = 2048\n", - "\n", - "SEED = 23\n", - "os.chdir(DATA_ROOT)\n", - "random.seed(SEED)\n", - "np.random.seed(SEED)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:12.321718600Z", - "start_time": "2024-03-13T16:14:45.553089800Z" - }, - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current columns: Index(['patient_id', 'num_visits', 'deceased', 'death_after_start',\n", - " 'death_after_end', 'length', 'token_length', 'event_tokens_2048',\n", - " 'type_tokens_2048', 'age_tokens_2048', 'time_tokens_2048',\n", - " 'visit_tokens_2048', 'position_tokens_2048', 'elapsed_tokens_2048',\n", - " 'common_conditions', 'rare_conditions'],\n", - " dtype='object')\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
patient_idnum_visitsdeceaseddeath_after_startdeath_after_endlengthtoken_lengthevent_tokens_2048type_tokens_2048age_tokens_2048time_tokens_2048visit_tokens_2048position_tokens_2048elapsed_tokens_2048common_conditionsrare_conditions
035581927-9c95-5ae9-af76-7d74870a349c10NaNNaN5054[[CLS], [VS], 00006473900, 00904516561, 510790...[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...[1, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1f5bba8dd-25c0-5336-8d3d-37424c18502620NaNNaN148156[[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073...[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...[0, 0, 0, 0, 0, 0, 0, 1, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2f4938f91-cadb-5133-8541-a52fb0916cea20NaNNaN7886[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
36fe2371b-a6f0-5436-aade-7795005b0c6620NaNNaN8694[[CLS], [VS], 63739057310, 49281041688, 005970...[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...[1, 0, 0, 0, 0, 0, 0, 1, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
46f7590ae-f3b9-50e5-9e41-d4bb1000887a10NaNNaN7276[[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074...[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...[1, 0, 0, 0, 0, 0, 0, 0, 0, 1][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
\n", - "
" - ], - "text/plain": [ - " patient_id num_visits deceased \\\n", - "0 35581927-9c95-5ae9-af76-7d74870a349c 1 0 \n", - "1 f5bba8dd-25c0-5336-8d3d-37424c185026 2 0 \n", - "2 f4938f91-cadb-5133-8541-a52fb0916cea 2 0 \n", - "3 6fe2371b-a6f0-5436-aade-7795005b0c66 2 0 \n", - "4 6f7590ae-f3b9-50e5-9e41-d4bb1000887a 1 0 \n", - "\n", - " death_after_start death_after_end length token_length \\\n", - "0 NaN NaN 50 54 \n", - "1 NaN NaN 148 156 \n", - "2 NaN NaN 78 86 \n", - "3 NaN NaN 86 94 \n", - "4 NaN NaN 72 76 \n", - "\n", - " event_tokens_2048 \\\n", - "0 [[CLS], [VS], 00006473900, 00904516561, 510790... \n", - "1 [[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073... \n", - "2 [[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0... \n", - "3 [[CLS], [VS], 63739057310, 49281041688, 005970... \n", - "4 [[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074... \n", - "\n", - " type_tokens_2048 \\\n", - "0 [1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ... \n", - "1 [1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ... \n", - "2 [1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... \n", - "3 [1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... \n", - "4 [1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ... \n", - "\n", - " age_tokens_2048 \\\n", - "0 [0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85... \n", - "1 [0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83... \n", - "2 [0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44... \n", - "3 [0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72... \n", - "4 [0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47... \n", - "\n", - " time_tokens_2048 \\\n", - "0 [0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ... \n", - "1 [0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ... \n", - "2 [0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ... \n", - "3 [0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ... \n", - "4 [0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ... \n", - "\n", - " visit_tokens_2048 \\\n", - "0 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "1 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "2 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "3 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "4 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "\n", - " position_tokens_2048 \\\n", - "0 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "1 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "2 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "3 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "4 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "\n", - " elapsed_tokens_2048 \\\n", - "0 [-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0... \n", - "1 [-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... \n", - "2 [-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8... \n", - "3 [-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7... \n", - "4 [-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... \n", - "\n", - " common_conditions rare_conditions \n", - "0 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "1 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "2 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "3 [1, 0, 0, 0, 0, 0, 0, 1, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "4 [1, 0, 0, 0, 0, 0, 0, 0, 0, 1] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Load complete dataset\n", - "dataset_2048 = pd.read_parquet(DATASET)\n", - "\n", - "print(f\"Current columns: {dataset_2048.columns}\")\n", - "dataset_2048.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame:\n", - " \"\"\"Filter the patients based on num_visits threshold.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " minimum_num_visits (int): The threshold num_visits\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The filtered dataset.\n", - " \"\"\"\n", - " filtered_dataset = dataset.loc[dataset[\"num_visits\"] >= minimum_num_visits]\n", - " filtered_dataset.reset_index(drop=True, inplace=True)\n", - " return filtered_dataset\n", - "\n", - "\n", - "def filter_by_length_of_stay(dataset: pd.DataFrame, threshold: int = 1) -> pd.DataFrame:\n", - " \"\"\"Filter the patients based on length of stay threshold.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " minimum_num_visits (int): The threshold length of stay\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The filtered dataset.\n", - " \"\"\"\n", - " filtered_dataset = dataset.loc[dataset[\"length_of_stay\"] >= threshold]\n", - "\n", - " # Only keep the patients that their first event happens within threshold\n", - " # TODO: Check how many patients get removed here?\n", - " filtered_dataset = filtered_dataset[\n", - " filtered_dataset.apply(\n", - " lambda row: row[\"elapsed_tokens_2048\"][row[\"last_VS_index\"] + 1]\n", - " < threshold * 24,\n", - " axis=1,\n", - " )\n", - " ]\n", - "\n", - " filtered_dataset.reset_index(drop=True, inplace=True)\n", - " return filtered_dataset\n", - "\n", - "\n", - "def get_last_occurence_index(seq: List[str], target: str) -> int:\n", - " \"\"\"Return the index of the last occurrence of target in seq.\n", - "\n", - " Args:\n", - " seq (List[str]): The input sequence.\n", - " target (str): The target string to find.\n", - "\n", - " Returns\n", - " -------\n", - " int: The index of the last occurrence of target in seq.\n", - " \"\"\"\n", - " return len(seq) - (seq[::-1].index(target) + 1)\n", - "\n", - "\n", - "def check_readmission_label(row: pd.Series) -> int:\n", - " \"\"\"Check if the label indicates readmission within one month.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - "\n", - " Returns\n", - " -------\n", - " bool: True if readmission label is present, False otherwise.\n", - " \"\"\"\n", - " last_vs_index = row[\"last_VS_index\"]\n", - " return int(\n", - " row[\"event_tokens_2048\"][last_vs_index - 1]\n", - " in (\"[W_0]\", \"[W_1]\", \"[W_2]\", \"[W_3]\", \"[M_1]\"),\n", - " )\n", - "\n", - "\n", - "def get_length_of_stay(row: pd.Series) -> pd.Series:\n", - " \"\"\"Determine the length of a given visit.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - "\n", - " Returns\n", - " -------\n", - " pd.Series: The preprocessed row.\n", - " \"\"\"\n", - " admission_time = row[\"last_VS_index\"] + 1\n", - " discharge_time = row[\"last_VE_index\"] - 1\n", - " return (discharge_time - admission_time) / 24\n", - "\n", - "\n", - "def get_visit_cutoff_at_threshold(row: pd.Series, threshold: int = 24) -> int:\n", - " \"\"\"Get the index of the first event token of last visit that occurs after threshold hours.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - " threshold (int): The number of hours to consider.\n", - "\n", - " Returns\n", - " -------\n", - " cutoff_index (int): The corrosponding cutoff index.\n", - " \"\"\"\n", - " last_vs_index = row[\"last_VS_index\"]\n", - " last_ve_index = row[\"last_VE_index\"]\n", - "\n", - " for i in range(last_vs_index + 1, last_ve_index):\n", - " if row[\"elapsed_tokens_2048\"][i] > threshold:\n", - " return i\n", - "\n", - " return len(row[\"event_tokens_2048\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def process_length_of_stay_dataset(\n", - " dataset: pd.DataFrame,\n", - " threshold: int = 7,\n", - ") -> pd.DataFrame:\n", - " \"\"\"Process the length of stay dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " threshold (int): The threshold length of stay.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed dataset.\n", - " \"\"\"\n", - " dataset[\"last_VS_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VS]\"),\n", - " )\n", - " dataset[\"last_VE_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VE]\"),\n", - " )\n", - " dataset[\"length_of_stay\"] = dataset.apply(get_length_of_stay, axis=1)\n", - "\n", - " dataset = filter_by_length_of_stay(dataset, threshold=1)\n", - " dataset[\"label_los_1week\"] = (dataset[\"length_of_stay\"] >= threshold).astype(int)\n", - "\n", - " dataset[\"cutoff_los\"] = dataset.apply(\n", - " lambda row: get_visit_cutoff_at_threshold(row, threshold=24),\n", - " axis=1,\n", - " )\n", - " dataset[\"token_length\"] = dataset[\"event_tokens_2048\"].apply(len)\n", - "\n", - " return dataset\n", - "\n", - "\n", - "# Process the dataset for length of stay prediction above a threshold\n", - "dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def process_condition_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the condition dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input condition dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed condition dataset.\n", - " \"\"\"\n", - " dataset[\"all_conditions\"] = dataset.apply(\n", - " lambda row: np.concatenate(\n", - " [row[\"common_conditions\"], row[\"rare_conditions\"]],\n", - " dtype=np.int64,\n", - " ),\n", - " axis=1,\n", - " )\n", - "\n", - " return dataset\n", - "\n", - "\n", - "# Process the dataset for conditions including rare and common\n", - "dataset_2048_condition = process_condition_dataset(dataset_2048.copy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:16.075719400Z", - "start_time": "2024-03-13T16:15:12.335721100Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the mortality dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input mortality dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed mortality dataset.\n", - " \"\"\"\n", - " dataset[\"label_mortality_2weeks\"] = (\n", - " (dataset[\"death_after_start\"] >= 0) & (dataset[\"death_after_end\"] <= 15)\n", - " ).astype(int)\n", - " dataset[\"label_mortality_1month\"] = (\n", - " (dataset[\"death_after_start\"] >= 0) & (dataset[\"death_after_end\"] <= 32)\n", - " ).astype(int)\n", - "\n", - " return dataset\n", - "\n", - "\n", - "# Process the dataset for mortality in two weeks or one month task\n", - "dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:47.326996100Z", - "start_time": "2024-03-13T16:15:16.094719300Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "def process_readmission_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the readmission dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed dataset.\n", - " \"\"\"\n", - " dataset[\"last_VS_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VS]\"),\n", - " )\n", - " dataset[\"cutoff_readmission\"] = dataset[\"last_VS_index\"] - 1\n", - " dataset[\"label_readmission_1month\"] = dataset.apply(check_readmission_label, axis=1)\n", - "\n", - " dataset[\"num_visits\"] -= 1\n", - " dataset[\"token_length\"] = dataset[\"event_tokens_2048\"].apply(len)\n", - "\n", - " return dataset\n", - "\n", - "\n", - "# Process the dataset for hospital readmission in one month task\n", - "dataset_2048_readmission = filter_by_num_visit(\n", - " dataset_2048.copy(),\n", - " minimum_num_visits=2,\n", - ")\n", - "dataset_2048_readmission = process_readmission_dataset(dataset_2048_readmission)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def process_multi_dataset(datasets: Dict[str, pd.DataFrame]):\n", - " \"\"\"\n", - " Process the multi-task dataset by merging the original dataset with the other datasets.\n", - "\n", - " Args:\n", - " datasets (Dict): Dictionary mapping each task to its respective dataframe\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed multi-task dataset\n", - " \"\"\"\n", - " # Merging datasets on 'patient_id'\n", - " multi_dataset = datasets[\"original\"].merge(\n", - " datasets[\"condition\"][[\"patient_id\", \"all_conditions\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"mortality\"][[\"patient_id\", \"label_mortality_1month\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"readmission\"][\n", - " [\"patient_id\", \"cutoff_readmission\", \"label_readmission_1month\"]\n", - " ],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"los\"][[\"patient_id\", \"cutoff_los\", \"label_los_1week\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - "\n", - " # Selecting the required columns\n", - " multi_dataset = multi_dataset[\n", - " [\n", - " \"patient_id\",\n", - " \"num_visits\",\n", - " \"event_tokens_2048\",\n", - " \"type_tokens_2048\",\n", - " \"age_tokens_2048\",\n", - " \"time_tokens_2048\",\n", - " \"visit_tokens_2048\",\n", - " \"position_tokens_2048\",\n", - " \"elapsed_tokens_2048\",\n", - " \"cutoff_los\",\n", - " \"cutoff_readmission\",\n", - " \"all_conditions\",\n", - " \"label_mortality_1month\",\n", - " \"label_readmission_1month\",\n", - " \"label_los_1week\",\n", - " ]\n", - " ]\n", - "\n", - " # Transform conditions from a vector of numbers to binary classes\n", - " conditions_expanded = multi_dataset[\"all_conditions\"].apply(pd.Series)\n", - " conditions_expanded.columns = [f\"condition{i}\" for i in range(20)]\n", - " multi_dataset = multi_dataset.drop(\"all_conditions\", axis=1)\n", - " multi_dataset = pd.concat([multi_dataset, conditions_expanded], axis=1)\n", - "\n", - " # Standardize important column names\n", - " multi_dataset.rename(\n", - " columns={\n", - " \"cutoff_los\": \"cutoff_los_1week\",\n", - " \"cutoff_readmission\": \"cutoff_readmission_1month\",\n", - " },\n", - " inplace=True,\n", - " )\n", - " condition_columns = {f\"condition{i}\": f\"label_c{i}\" for i in range(20)}\n", - " multi_dataset.rename(columns=condition_columns, inplace=True)\n", - "\n", - " numerical_columns = [\n", - " \"cutoff_los_1week\",\n", - " \"cutoff_readmission_1month\",\n", - " \"label_mortality_1month\",\n", - " \"label_readmission_1month\",\n", - " \"label_los_1week\",\n", - " ] + [f\"label_c{i}\" for i in range(20)]\n", - "\n", - " # Fill NaN values and convert to integers\n", - " for column in numerical_columns:\n", - " multi_dataset[column] = multi_dataset[column].fillna(-1).astype(int)\n", - "\n", - " # Reset dataset index\n", - " multi_dataset.reset_index(drop=True, inplace=True)\n", - "\n", - " return multi_dataset\n", - "\n", - "\n", - "multi_dataset = process_multi_dataset(\n", - " datasets={\n", - " \"original\": dataset_2048,\n", - " \"mortality\": dataset_2048_mortality,\n", - " \"condition\": dataset_2048_condition,\n", - " \"readmission\": dataset_2048_readmission,\n", - " \"los\": dataset_2048_los,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def stratified_train_test_split(\n", - " dataset: pd.DataFrame,\n", - " target: str,\n", - " test_size: float,\n", - " return_test: Optional[bool] = False,\n", - "):\n", - " \"\"\"\n", - " Split the given dataset into training and testing sets using iterative stratification on given multi-label target.\n", - " \"\"\"\n", - " # Convert all_conditions into a format suitable for multi-label stratification\n", - " Y = np.array(dataset[target].values.tolist())\n", - " X = dataset[\"patient_id\"].to_numpy().reshape(-1, 1)\n", - " is_single_label = type(dataset.iloc[0][target]) == np.int64\n", - "\n", - " # Perform stratified split\n", - " if is_single_label:\n", - " X_train, X_test, y_train, y_test = train_test_split(\n", - " X,\n", - " Y,\n", - " stratify=Y,\n", - " test_size=test_size,\n", - " random_state=SEED,\n", - " )\n", - "\n", - " else:\n", - " X_train, y_train, X_test, y_test = iterative_train_test_split(\n", - " X,\n", - " Y,\n", - " test_size=test_size,\n", - " )\n", - "\n", - " X_train = X_train.flatten().tolist()\n", - " X_test = X_test.flatten().tolist()\n", - "\n", - " if return_test:\n", - " return X_test\n", - " else:\n", - " return X_train, X_test\n", - "\n", - "\n", - "def sample_balanced_subset(dataset: pd.DataFrame, target: str, sample_size: int):\n", - " \"\"\"\n", - " Sample a subset of dataset with balanced target labels.\n", - " \"\"\"\n", - " # Sampling positive and negative patients\n", - " pos_patients = dataset[dataset[target] == True].sample(\n", - " n=sample_size // 2,\n", - " random_state=SEED,\n", - " )\n", - " neg_patients = dataset[dataset[target] == False].sample(\n", - " n=sample_size // 2,\n", - " random_state=SEED,\n", - " )\n", - "\n", - " # Combining and shuffling patient IDs\n", - " sample_patients = (\n", - " pos_patients[\"patient_id\"].tolist() + neg_patients[\"patient_id\"].tolist()\n", - " )\n", - " random.shuffle(sample_patients)\n", - "\n", - " return sample_patients\n", - "\n", - "\n", - "def get_pretrain_test_split(\n", - " dataset: pd.DataFrame,\n", - " stratify_target: Optional[str] = None,\n", - " test_size: float = 0.15,\n", - "):\n", - " \"\"\"Split dataset into pretrain and test set. Stratify on a given target column if needed.\"\"\"\n", - " if stratify_target:\n", - " pretrain_ids, test_ids = stratified_train_test_split(\n", - " dataset,\n", - " target=stratify_target,\n", - " test_size=test_size,\n", - " )\n", - "\n", - " else:\n", - " test_patients = dataset.sample(n=test_size, random_state=SEED)\n", - " test_ids = test_patients[\"patient_id\"].tolist()\n", - " pretrain_ids = dataset[~dataset[\"patient_id\"].isin(test_patients)][\n", - " \"patient_id\"\n", - " ].tolist()\n", - "\n", - " random.shuffle(pretrain_ids)\n", - "\n", - " return pretrain_ids, test_ids" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Split data\n", - "patient_ids_dict = {\n", - " \"pretrain\": [],\n", - " \"finetune\": {\"few_shot\": {}, \"kfold\": {}},\n", - " \"test\": [],\n", - "}\n", - "\n", - "# Get train-test split\n", - "# pretrain_ids, test_ids = get_pretrain_test_split(dataset_2048_readmission, stratify_target='label_readmission_1month', test_size=0.2)\n", - "# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)\n", - "# patient_ids_dict['pretrain'] = pretrain_ids\n", - "# patient_ids_dict['test'] = test_ids\n", - "\n", - "# Load pretrain and test patient IDs\n", - "pid = pickle.load(open(\"patient_id_dict/dataset_2048_multi.pkl\", \"rb\"))\n", - "patient_ids_dict[\"pretrain\"] = pid[\"pretrain\"]\n", - "patient_ids_dict[\"test\"] = pid[\"test\"]\n", - "set(pid[\"pretrain\"] + pid[\"test\"]) == set(dataset_2048[\"patient_id\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class config:\n", - " task_splits = {\n", - " \"mortality\": {\n", - " \"dataset\": dataset_2048_mortality,\n", - " \"label_col\": \"label_mortality_1month\",\n", - " \"finetune_size\": [250, 500, 1000, 5000, 20000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"readmission\": {\n", - " \"dataset\": dataset_2048_readmission,\n", - " \"label_col\": \"label_readmission_1month\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n", - " \"split_mode\": \"single_label_stratified\",\n", - " },\n", - " \"length_of_stay\": {\n", - " \"dataset\": dataset_2048_los,\n", - " \"label_col\": \"label_los_1week\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"condition\": {\n", - " \"dataset\": dataset_2048_condition,\n", - " \"label_col\": \"all_conditions\",\n", - " \"finetune_size\": [50000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n", - " \"split_mode\": \"multi_label_stratified\",\n", - " },\n", - " }\n", - "\n", - " all_tasks = list(task_splits.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:51.800996200Z", - "start_time": "2024-03-13T16:15:50.494996100Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "def get_finetune_split(\n", - " config: config,\n", - " patient_ids_dict: Dict[str, Any],\n", - ") -> Dict[str, Dict[str, List[str]]]:\n", - " \"\"\"\n", - " Splits the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation\n", - " while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk.\n", - " \"\"\"\n", - " # Extract task-specific configuration\n", - " task_config = config.task_splits[task]\n", - " dataset = task_config[\"dataset\"]\n", - " label_col = task_config[\"label_col\"]\n", - " finetune_sizes = task_config[\"finetune_size\"]\n", - " save_path = task_config[\"save_path\"]\n", - " split_mode = task_config[\"split_mode\"]\n", - "\n", - " # Get pretrain dataset\n", - " pretrain_ids = patient_ids_dict[\"pretrain\"]\n", - " dataset = dataset[dataset[\"patient_id\"].isin(pretrain_ids)]\n", - "\n", - " # Few-shot finetune patient ids\n", - " for finetune_num in finetune_sizes:\n", - " if split_mode == \"single_label_balanced\":\n", - " finetune_ids = sample_balanced_subset(\n", - " dataset,\n", - " target=label_col,\n", - " sample_size=finetune_num,\n", - " )\n", - "\n", - " elif (\n", - " split_mode == \"single_label_stratified\"\n", - " or split_mode == \"multi_label_stratified\"\n", - " ):\n", - " finetune_ids = stratified_train_test_split(\n", - " dataset,\n", - " target=label_col,\n", - " test_size=finetune_num / len(dataset),\n", - " return_test=True,\n", - " )\n", - "\n", - " patient_ids_dict[\"finetune\"][\"few_shot\"][f\"{finetune_num}\"] = finetune_ids\n", - "\n", - " # Save the dictionary to disk\n", - " save_object_to_disk(patient_ids_dict, save_path)\n", - "\n", - " return patient_ids_dict\n", - "\n", - "\n", - "for task in config.all_tasks:\n", - " patient_ids_dict = get_finetune_split(\n", - " config=config,\n", - " patient_ids_dict=patient_ids_dict,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T14:14:10.181184300Z", - "start_time": "2024-03-13T14:13:39.154567400Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "dataset_2048_mortality.to_parquet(\n", - " \"patient_sequences/patient_sequences_2048_mortality.parquet\",\n", - ")\n", - "dataset_2048_readmission.to_parquet(\n", - " \"patient_sequences/patient_sequences_2048_readmission.parquet\",\n", - ")\n", - "dataset_2048_los.to_parquet(\"patient_sequences/patient_sequences_2048_los.parquet\")\n", - "dataset_2048_condition.to_parquet(\n", - " \"patient_sequences/patient_sequences_2048_condition.parquet\",\n", - ")\n", - "multi_dataset.to_parquet(\"patient_sequences/patient_sequences_2048_multi.parquet\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load data\n", - "# multi_dataset = pd.read_parquet('patient_sequences/patient_sequences_2048_multi.parquet')\n", - "# pid = pickle.load(open('patient_id_dict/dataset_2048_multi.pkl', 'rb'))\n", - "# multi_dataset = multi_dataset[multi_dataset['patient_id'].isin(pid['finetune']['few_shot']['all'])]\n", - "\n", - "# # Train Tokenizer\n", - "# tokenizer = ConceptTokenizer(data_dir='/h/afallah/odyssey/odyssey/data/vocab')\n", - "# tokenizer.fit_on_vocab()\n", - "\n", - "# # Load datasets\n", - "# tasks = ['mortality_1month', 'los_1week'] + [f'c{i}' for i in range(5)]\n", - "\n", - "# train_dataset = FinetuneMultiDataset(\n", - "# data=multi_dataset,\n", - "# tokenizer=tokenizer,\n", - "# tasks=tasks,\n", - "# balance_guide={'mortality_1month': 0.5, 'los_1week': 0.5},\n", - "# max_len=2048,\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# dataset_2048_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')\n", - "# pid = pickle.load(open('patient_id_dict/dataset_2048_condition.pkl', 'rb'))\n", - "# condition_finetune = dataset_2048_condition.loc[dataset_2048_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]\n", - "# condition_finetune" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)\n", - "# weights = np.clip(0, 50, sum(freq) / freq)\n", - "# np.max(np.sqrt(freq)) / np.sqrt(freq)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# merged_df = pd.merge(dataset_2048_mortality, dataset_2048_readmission, how='outer', on='patient_id')\n", - "# final_merged_df = pd.merge(merged_df, dataset_2048_condition, how='outer', on='patient_id')\n", - "# final_merged_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Performing stratified k-fold split\n", - "# skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)\n", - "\n", - "# for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):\n", - "\n", - "# dataset_cv = dataset.iloc[cv_index]\n", - "# dataset_finetune = dataset.iloc[train_index]\n", - "\n", - "# # Separate positive and negative labeled patients\n", - "# pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()\n", - "# neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()\n", - "\n", - "# # Calculate the number of positive and negative patients needed for balanced CV set\n", - "# num_pos_needed = cv_size // 2\n", - "# num_neg_needed = cv_size // 2\n", - "\n", - "# # Select positive and negative patients for CV set ensuring balanced distribution\n", - "# cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]\n", - "# remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]\n", - "\n", - "# # Extract patient IDs for training set\n", - "# finetune_patients = dataset_finetune['patient_id'].tolist()\n", - "# finetune_patients += remaining_finetune_patients\n", - "\n", - "# # Shuffle each list of patients\n", - "# random.shuffle(cv_patients)\n", - "# random.shuffle(finetune_patients)\n", - "\n", - "# patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# Assuming dataset.event_tokens is your DataFrame column\n", - "# dataset.event_tokens.transform(len).plot(kind='hist', bins=100)\n", - "# plt.xlim(1000, 8000) # Limit x-axis to 5000\n", - "# plt.ylim(0, 6000)\n", - "# plt.xlabel('Length of Event Tokens')\n", - "# plt.ylabel('Frequency')\n", - "# plt.title('Histogram of Event Tokens Length')\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# len(patient_ids_dict['group3']['cv'])\n", - "\n", - "# dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']\n", - "\n", - "# s = set()\n", - "# for i in range(1, 6):\n", - "# s = s.union(set(patient_ids_dict[f'group{i}']['cv']))\n", - "#\n", - "# len(s)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "##### DEAD ZONE | DO NOT ENTER #####\n", - "\n", - "# patient_ids = pickle.load(open(join(\"/h/afallah/odyssey/odyssey/data/bigbird_data\", 'dataset_2048_mortality_1month.pkl'), 'rb'))\n", - "# patient_ids['finetune']['few_shot'].keys()\n", - "\n", - "# patient_ids2 = pickle.load(open(join(\"/h/afallah/odyssey/odyssey/data/bigbird_data\", 'dataset_2048_mortality_2weeks.pkl'), 'rb'))['pretrain']\n", - "#\n", - "# patient_ids1.sort()\n", - "# patient_ids2.sort()\n", - "#\n", - "# patient_ids1 == patient_ids2\n", - "# # dataset_2048.loc[dataset_2048['patient_id'].isin(patient_ids['pretrain'])]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# dataset_2048_readmission = dataset_2048.loc[dataset_2048['num_visits'] > 1]\n", - "# dataset_2048_readmission.reset_index(drop=True, inplace=True)\n", - "#\n", - "# dataset_2048_readmission['last_VS_index'] = dataset_2048_readmission['event_tokens_2048'].transform(lambda seq: get_last_occurence_index(list(seq), '[VS]'))\n", - "#\n", - "# dataset_2048_readmission['label_readmission_1month'] = dataset_2048_readmission.apply(\n", - "# lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1\n", - "# )\n", - "# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission.apply(\n", - "# lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1\n", - "# )\n", - "# dataset_2048_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)\n", - "# dataset_2048_readmission['num_visits'] -= 1\n", - "# dataset_2048_readmission['token_length'] = dataset_2048_readmission['event_tokens_2048'].apply(len)\n", - "# dataset_2048_readmission = dataset_2048_readmission.apply(lambda row: truncate_and_pad(row), axis=1)\n", - "# dataset_2048_readmission['event_tokens_2048'] = dataset_2048_readmission['event_tokens_2048'].transform(\n", - "# lambda token_list: ' '.join(token_list)\n", - "# )\n", - "#\n", - "# dataset_2048_readmission" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/odyssey/data/dataset.py b/odyssey/data/dataset.py index 59c4d44..b3663f5 100644 --- a/odyssey/data/dataset.py +++ b/odyssey/data/dataset.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd + import torch from torch.utils.data import Dataset diff --git a/odyssey/models/baseline/Bi-LSTM.ipynb b/odyssey/models/baseline/Bi-LSTM.ipynb index d67b155..99dea78 100644 --- a/odyssey/models/baseline/Bi-LSTM.ipynb +++ b/odyssey/models/baseline/Bi-LSTM.ipynb @@ -61,19 +61,26 @@ "import os\n", "import sys\n", "\n", - "\n", "ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n", "os.chdir(ROOT)\n", "\n", "from typing import Any, Dict, Tuple\n", + "from tqdm import tqdm\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", + "\n", "import torch\n", - "from models.big_bird_cehr.data import FinetuneDataset\n", - "from models.big_bird_cehr.embeddings import Embeddings\n", - "from models.big_bird_cehr.tokenizer import ConceptTokenizer\n", + "from torch import nn, optim\n", + "from torch.nn.functional import sigmoid\n", + "from torch.nn.utils.rnn import pack_padded_sequence\n", + "from torch.optim.lr_scheduler import ExponentialLR\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "from odyssey.data.dataset import FinetuneDataset\n", + "from odyssey.models.cehr_big_bird.embeddings import Embeddings\n", + "from odyssey.data.tokenizer import ConceptTokenizer\n", "from sklearn.metrics import (\n", " auc,\n", " average_precision_score,\n", @@ -85,13 +92,6 @@ " roc_auc_score,\n", " roc_curve,\n", ")\n", - "from torch import nn, optim\n", - "from torch.nn.functional import sigmoid\n", - "from torch.nn.utils.rnn import pack_padded_sequence\n", - "from torch.optim.lr_scheduler import ExponentialLR\n", - "from torch.utils.data import DataLoader, Dataset\n", - "from tqdm import tqdm\n", - "\n", "\n", "DATA_ROOT = f\"{ROOT}/data/slurm_data/512/one_month\"\n", "DATA_PATH = f\"{DATA_ROOT}/pretrain.parquet\"\n", diff --git a/odyssey/models/baseline/Bi-LSTM.py b/odyssey/models/baseline/Bi-LSTM.py index 193a2aa..a231083 100644 --- a/odyssey/models/baseline/Bi-LSTM.py +++ b/odyssey/models/baseline/Bi-LSTM.py @@ -19,7 +19,7 @@ from odyssey.data.dataset import FinetuneDataset from odyssey.models.cehr_big_bird.embeddings import Embeddings -from odyssey.models.cehr_big_bird.tokenizer import HuggingFaceConceptTokenizer +from odyssey.data.tokenizer import HuggingFaceConceptTokenizer ROOT = "/fs01/home/afallah/odyssey/odyssey" diff --git a/odyssey/models/cehr_bert/embeddings.py b/odyssey/models/cehr_bert/embeddings.py deleted file mode 100644 index ecd096b..0000000 --- a/odyssey/models/cehr_bert/embeddings.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Embedding modules.""" - -import math - -import torch -from torch import nn - - -class TimeEmbeddingLayer(nn.Module): - """Embedding layer for time features.""" - - def __init__(self, embedding_size: int, is_time_delta: bool = False): - super().__init__() - self.embedding_size = embedding_size - self.is_time_delta = is_time_delta - - self.w = nn.Parameter(torch.empty(1, self.embedding_size)) - self.phi = nn.Parameter(torch.empty(1, self.embedding_size)) - - nn.init.xavier_uniform_(self.w) - nn.init.xavier_uniform_(self.phi) - - def forward(self, time_stamps: torch.Tensor) -> torch.Tensor: - """Apply time embedding to the input time stamps.""" - if self.is_time_delta: - # If the time_stamps represent time deltas, we calculate the deltas. - # This is equivalent to the difference between consecutive elements. - time_stamps = torch.cat( - (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]), - dim=-1, - ) - time_stamps = time_stamps.float() - time_stamps_expanded = time_stamps.unsqueeze(-1) - next_input = time_stamps_expanded * self.w + self.phi - - return torch.sin(next_input) - - -class VisitEmbedding(nn.Module): - """Embedding layer for visit segments.""" - - def __init__( - self, - visit_order_size: int, - embedding_size: int, - ): - super().__init__() - self.visit_order_size = visit_order_size - self.embedding_size = embedding_size - self.embedding = nn.Embedding(self.visit_order_size, self.embedding_size) - - def forward(self, visit_segments: torch.Tensor) -> torch.Tensor: - """Apply visit embedding to the input visit segments.""" - return self.embedding(visit_segments) - - -class ConceptEmbedding(nn.Module): - """Embedding layer for event concepts.""" - - def __init__( - self, - num_embeddings: int, - embedding_size: int, - padding_idx: int = None, - ): - super(ConceptEmbedding, self).__init__() - self.embedding = nn.Embedding( - num_embeddings, - embedding_size, - padding_idx=padding_idx, - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """Apply concept embedding to the input concepts.""" - return self.embedding(inputs) - - -class PositionalEmbedding(nn.Module): - """Positional embedding layer.""" - - def __init__(self, embedding_size, max_len=512): - super().__init__() - - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, embedding_size).float() - pe.require_grad = False - - position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = ( - torch.arange(0, embedding_size, 2).float() - * -(math.log(10000.0) / embedding_size) - ).exp() - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - - self.register_buffer("pe", pe) - - def forward(self, visit_orders: torch.Tensor) -> torch.Tensor: - """Apply positional embedding to the input visit orders.""" - first_visit_concept_orders = visit_orders[:, 0:1] - normalized_visit_orders = torch.clamp( - visit_orders - first_visit_concept_orders, - 0, - self.pe.size(0) - 1, - ) - return self.pe[normalized_visit_orders] - - -class Embeddings(nn.Module): - """Embeddings for CEHR-BERT.""" - - def __init__( - self, - vocab_size: int, - embedding_size: int = 128, - time_embedding_size: int = 16, - type_vocab_size: int = 8, - visit_order_size: int = 3, - max_len: int = 512, - layer_norm_eps: float = 1e-12, - dropout_prob: float = 0.1, - padding_idx: int = 1, - ): - super().__init__() - self.concept_embedding = ConceptEmbedding( - num_embeddings=vocab_size, - embedding_size=embedding_size, - padding_idx=padding_idx, - ) - self.token_type_embeddings = nn.Embedding(type_vocab_size, embedding_size) - self.time_embedding = TimeEmbeddingLayer( - embedding_size=time_embedding_size, - is_time_delta=True, - ) - self.age_embedding = TimeEmbeddingLayer(embedding_size=time_embedding_size) - self.positional_embedding = PositionalEmbedding( - embedding_size=embedding_size, - max_len=max_len, - ) - self.visit_embedding = VisitEmbedding( - visit_order_size=visit_order_size, - embedding_size=embedding_size, - ) - self.scale_back_concat_layer = nn.Linear( - embedding_size + 2 * time_embedding_size, - embedding_size, - ) # Assuming 4 input features are concatenated - self.tanh = nn.Tanh() - self.LayerNorm = nn.LayerNorm(embedding_size, eps=layer_norm_eps) - self.dropout = nn.Dropout(dropout_prob) - - def forward( - self, - concept_ids: torch.Tensor, - type_ids: torch.Tensor, - time_stamps: torch.Tensor, - ages: torch.Tensor, - visit_orders: torch.Tensor, - visit_segments: torch.Tensor, - ) -> torch.Tensor: - """Apply embeddings to the input features.""" - concept_embed = self.concept_embedding(concept_ids) - type_embed = self.token_type_embeddings(type_ids) - time_embed = self.time_embedding(time_stamps) - age_embed = self.age_embedding(ages) - positional_embed = self.positional_embedding(visit_orders) - visit_segment_embed = self.visit_embedding(visit_segments) - - embeddings = torch.cat((concept_embed, time_embed, age_embed), dim=-1) - embeddings = self.tanh(self.scale_back_concat_layer(embeddings)) - embeddings = embeddings + type_embed + positional_embed + visit_segment_embed - embeddings = self.LayerNorm(embeddings) - - return self.dropout(embeddings) diff --git a/odyssey/models/cehr_bert/model.py b/odyssey/models/cehr_bert/model.py index 9af880d..9e36f59 100644 --- a/odyssey/models/cehr_bert/model.py +++ b/odyssey/models/cehr_bert/model.py @@ -22,7 +22,7 @@ BertPooler, ) -from odyssey.models.cehr_bert.embeddings import Embeddings +from odyssey.models.embeddings import BERTEmbeddingsForCEHR class BertPretrain(pl.LightningModule): @@ -75,7 +75,7 @@ def __init__( ) # BertForMaskedLM ## BertModel - self.embeddings = Embeddings( + self.embeddings = BERTEmbeddingsForCEHR( vocab_size=self.vocab_size, embedding_size=self.embedding_size, time_embedding_size=self.time_embeddings_size, diff --git a/odyssey/models/cehr_big_bird/embeddings.py b/odyssey/models/cehr_big_bird/embeddings.py deleted file mode 100644 index da43b44..0000000 --- a/odyssey/models/cehr_big_bird/embeddings.py +++ /dev/null @@ -1,363 +0,0 @@ -"""Embedding layers for the models.""" - -import math -from typing import Any, Optional - -import torch -from torch import nn -from transformers import BigBirdConfig - - -class TimeEmbeddingLayer(nn.Module): - """Embedding layer for time features.""" - - def __init__(self, embedding_size: int, is_time_delta: bool = False): - super().__init__() - self.embedding_size = embedding_size - self.is_time_delta = is_time_delta - - self.w = nn.Parameter(torch.empty(1, self.embedding_size)) - self.phi = nn.Parameter(torch.empty(1, self.embedding_size)) - - nn.init.xavier_uniform_(self.w) - nn.init.xavier_uniform_(self.phi) - - def forward(self, time_stamps: torch.Tensor) -> Any: - """Apply time embedding to the input time stamps.""" - if self.is_time_delta: - # If the time_stamps represent time deltas, we calculate the deltas. - # This is equivalent to the difference between consecutive elements. - time_stamps = torch.cat( - (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]), - dim=-1, - ) - time_stamps = time_stamps.float() - time_stamps_expanded = time_stamps.unsqueeze(-1) - next_input = time_stamps_expanded * self.w + self.phi - - return torch.sin(next_input) - - -class VisitEmbedding(nn.Module): - """Embedding layer for visit segments.""" - - def __init__( - self, - visit_order_size: int, - embedding_size: int, - ): - super().__init__() - self.visit_order_size = visit_order_size - self.embedding_size = embedding_size - self.embedding = nn.Embedding(self.visit_order_size, self.embedding_size) - - def forward(self, visit_segments: torch.Tensor) -> Any: - """Apply visit embedding to the input visit segments.""" - return self.embedding(visit_segments) - - -class ConceptEmbedding(nn.Module): - """Embedding layer for event concepts.""" - - def __init__( - self, - num_embeddings: int, - embedding_size: int, - padding_idx: Optional[int] = None, - ): - super(ConceptEmbedding, self).__init__() - self.embedding = nn.Embedding( - num_embeddings, - embedding_size, - padding_idx=padding_idx, - ) - - def forward(self, inputs: torch.Tensor) -> Any: - """Apply concept embedding to the input concepts.""" - return self.embedding(inputs) - - -class PositionalEmbedding(nn.Module): - """Positional embedding layer.""" - - def __init__(self, embedding_size: int, max_len: int = 2048): - super().__init__() - - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, embedding_size).float() - pe.require_grad = False - - position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = ( - torch.arange(0, embedding_size, 2).float() - * -(math.log(10000.0) / embedding_size) - ).exp() - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - - self.register_buffer("pe", pe) - - def forward(self, visit_orders: torch.Tensor) -> Any: - """Apply positional embedding to the input visit orders.""" - first_visit_concept_orders = visit_orders[:, 0:1] - normalized_visit_orders = torch.clamp( - visit_orders - first_visit_concept_orders, - 0, - self.pe.size(0) - 1, - ) - return self.pe[normalized_visit_orders] - - -class Embeddings(nn.Module): - """Embeddings for CEHR-BERT.""" - - def __init__( - self, - vocab_size: int, - embedding_size: int = 128, - time_embeddings_size: int = 16, - type_vocab_size: int = 8, - visit_order_size: int = 3, - max_len: int = 2048, - layer_norm_eps: float = 1e-12, - dropout_prob: float = 0.1, - padding_idx: int = 1, - ): - super().__init__() - self.concept_embedding = ConceptEmbedding( - num_embeddings=vocab_size, - embedding_size=embedding_size, - padding_idx=padding_idx, - ) - self.token_type_embeddings = nn.Embedding( - type_vocab_size, - embedding_size, - ) - self.time_embedding = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - is_time_delta=True, - ) - self.age_embedding = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - ) - self.positional_embedding = PositionalEmbedding( - embedding_size=embedding_size, - max_len=max_len, - ) - self.visit_embedding = VisitEmbedding( - visit_order_size=visit_order_size, - embedding_size=embedding_size, - ) - self.scale_back_concat_layer = nn.Linear( - embedding_size + 2 * time_embeddings_size, - embedding_size, - ) # Assuming 4 input features are concatenated - self.tanh = nn.Tanh() - self.LayerNorm = nn.LayerNorm(embedding_size, eps=layer_norm_eps) - self.dropout = nn.Dropout(dropout_prob) - - def forward( - self, - concept_ids: torch.Tensor, - type_ids: torch.Tensor, - time_stamps: torch.Tensor, - ages: torch.Tensor, - visit_orders: torch.Tensor, - visit_segments: torch.Tensor, - ) -> Any: - """Apply embeddings to the input features.""" - concept_embed = self.concept_embedding(concept_ids) - type_embed = self.token_type_embeddings(type_ids) - time_embed = self.time_embedding(time_stamps) - age_embed = self.age_embedding(ages) - positional_embed = self.positional_embedding(visit_orders) - visit_segment_embed = self.visit_embedding(visit_segments) - - embeddings = torch.cat((concept_embed, time_embed, age_embed), dim=-1) - embeddings = self.tanh(self.scale_back_concat_layer(embeddings)) - embeddings = embeddings + type_embed + positional_embed + visit_segment_embed - embeddings = self.LayerNorm(embeddings) - return self.dropout(embeddings) - - -class BigBirdEmbeddingsForCEHR(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" - - # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ - def __init__( - self, - config: BigBirdConfig, - time_embeddings_size: int = 16, - visit_order_size: int = 3, - ) -> None: - """Initiate wrapper class for embeddings used in BigBird CEHR classes.""" - super().__init__() - - self.word_embeddings = nn.Embedding( - config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id, - ) - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, - config.hidden_size, - ) - self.token_type_embeddings = nn.Embedding( - config.type_vocab_size, - config.hidden_size, - ) - self.visit_order_embeddings = nn.Embedding( - config.max_position_embeddings, - config.hidden_size, - ) - self.time_embeddings = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - is_time_delta=True, - ) - self.age_embeddings = TimeEmbeddingLayer( - embedding_size=time_embeddings_size, - ) - self.visit_segment_embeddings = VisitEmbedding( - visit_order_size=visit_order_size, - embedding_size=config.hidden_size, - ) - self.scale_back_concat_layer = nn.Linear( - config.hidden_size + 2 * time_embeddings_size, - config.hidden_size, - ) - - self.time_stamps: Optional[torch.Tensor] = None - self.ages: Optional[torch.Tensor] = None - self.visit_orders: Optional[torch.Tensor] = None - self.visit_segments: Optional[torch.Tensor] = None - - # self.LayerNorm is not snake-cased to stick with TensorFlow model - # variable name and be able to load any TensorFlow checkpoint file. - self.tanh = nn.Tanh() - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - # position_ids (1, len position emb) is contiguous in memory. - self.position_embedding_type = getattr( - config, - "position_embedding_type", - "absolute", - ) - self.register_buffer( - "position_ids", - torch.arange(config.max_position_embeddings).expand((1, -1)), - persistent=False, - ) - self.register_buffer( - "token_type_ids", - torch.zeros(self.position_ids.size(), dtype=torch.long), - persistent=False, - ) - # End copy - - self.rescale_embeddings = config.rescale_embeddings - self.hidden_size = config.hidden_size - - def cache_input( - self, - time_stamps: torch.Tensor, - ages: torch.Tensor, - visit_orders: torch.Tensor, - visit_segments: torch.Tensor, - ) -> None: - """Cache values for time_stamps, ages, visit_orders & visit_segments. - - These values will be used by the forward pass to change the final embedding. - - Parameters - ---------- - time_stamps : torch.Tensor - Time stamps of the input data. - ages : torch.Tensor - Ages of the input data. - visit_orders : torch.Tensor - Visit orders of the input data. - visit_segments : torch.Tensor - Visit segments of the input data. - - """ - self.time_stamps = time_stamps - self.ages = ages - self.visit_orders = visit_orders - self.visit_segments = visit_segments - - def clear_cache(self) -> None: - """Delete the tensors cached by cache_input method.""" - del self.time_stamps, self.ages, self.visit_orders, self.visit_segments - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - past_key_values_length: int = 0, - ) -> Any: - """Return the final embeddings of concept ids using input and cached values.""" - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[ - :, - past_key_values_length : seq_length + past_key_values_length, - ] - - # Setting the token_type_ids to the registered buffer in constructor - if token_type_ids is None: - if hasattr(self, "token_type_ids"): - buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand( - input_shape[0], - seq_length, - ) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros( - input_shape, - dtype=torch.long, - device=self.position_ids.device, - ) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if self.rescale_embeddings: - inputs_embeds = inputs_embeds * (self.hidden_size**0.5) - - # Using cached values from a prior cache_input call - time_stamps_embeds = self.time_embeddings(self.time_stamps) - ages_embeds = self.age_embeddings(self.ages) - visit_segments_embeds = self.visit_segment_embeddings(self.visit_segments) - visit_order_embeds = self.visit_order_embeddings(self.visit_orders) - - position_embeds = self.position_embeddings(position_ids) - token_type_embeds = self.token_type_embeddings(token_type_ids) - - inputs_embeds = torch.cat( - (inputs_embeds, time_stamps_embeds, ages_embeds), - dim=-1, - ) - inputs_embeds = self.tanh(self.scale_back_concat_layer(inputs_embeds)) - embeddings = inputs_embeds + token_type_embeds - embeddings += position_embeds - embeddings += visit_order_embeds - embeddings += visit_segments_embeds - - embeddings = self.dropout(embeddings) - embeddings = self.LayerNorm(embeddings) - - # Clear the cache for next forward call - self.clear_cache() - - return embeddings diff --git a/odyssey/models/cehr_big_bird/model.py b/odyssey/models/cehr_big_bird/model.py index 5acc07e..78ebd9c 100644 --- a/odyssey/models/cehr_big_bird/model.py +++ b/odyssey/models/cehr_big_bird/model.py @@ -1,10 +1,9 @@ -"""Big Bird transformer model.""" +"""BigBird transformer model.""" from typing import Any, Dict, Optional, Tuple, Union import numpy as np import pytorch_lightning as pl -import torch from sklearn.metrics import ( accuracy_score, f1_score, @@ -12,6 +11,8 @@ recall_score, roc_auc_score, ) + +import torch from torch import nn, optim from torch.cuda.amp import autocast from torch.optim import AdamW @@ -23,7 +24,7 @@ BigBirdForSequenceClassification, ) -from odyssey.models.cehr_big_bird.embeddings import BigBirdEmbeddingsForCEHR +from odyssey.models.embeddings import BigBirdEmbeddingsForCEHR class BigBirdPretrain(pl.LightningModule): diff --git a/odyssey/models/prediction.py b/odyssey/models/prediction.py deleted file mode 100644 index 7b8f92c..0000000 --- a/odyssey/models/prediction.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Prediction module for loading and running BigBird models on patient data.""" - -from typing import Any, Dict, Optional - -import torch - -from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain -from odyssey.tokenizer import ConceptTokenizer - - -def load_finetuned_model( - model_path: str, - tokenizer: ConceptTokenizer, - pre_model_config: Optional[Dict[str, Any]] = None, - fine_model_config: Optional[Dict[str, Any]] = None, - device: Optional[torch.device] = None, -) -> torch.nn.Module: - """Load a finetuned model from model_path using tokenizer information. - - Return a loaded finetuned model from model_path, using tokenizer information. - If config arguments are not provided, the default configs built into the - PyTorch classes are used. - - Parameters - ---------- - model_path: str - Path to the finetuned model to load - tokenizer: ConceptTokenizer - Loaded tokenizer object - pre_model_config: Dict[str, Any], optional - Optional config to override default values of a pretrained model - fine_model_config: Dict[str, Any], optional - Optional config to override default values of a finetuned model - device: torch.device, optional - CUDA device. By default, GPU is used - - Returns - ------- - torch.nn.Module - Finetuned model loaded from model_path - - """ - # Load GPU or CPU device - if not device: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Create the skeleton of a pretrained and finetuned model - pretrained_model = BigBirdPretrain( - vocab_size=tokenizer.get_vocab_size(), - padding_idx=tokenizer.get_pad_token_id(), - **(pre_model_config or {}), - ) - - model = BigBirdFinetune( - pretrained_model=pretrained_model, - **(fine_model_config or {}), - ) - - # Load the weights using model_path directory - state_dict = torch.load(model_path, map_location=device)["state_dict"] - model.load_state_dict(state_dict) - model.to(device) - model.eval() - - return model - - -def predict_patient_outcomes( - patient: Dict[str, torch.Tensor], - model: torch.nn.Module, - device: Optional[torch.device] = None, -) -> Any: - """Compute model output predictions on given patient data. - - Parameters - ---------- - patient: Dict[str, torch.Tensor] - Patient data as a dictionary of tensors - model: torch.nn.Module - Model to use for prediction - device: torch.device, optional - CUDA device. By default, GPU is used - - Returns - ------- - Any - Model output predictions on the given patient data - - """ - # Load GPU or CPU device - if not device: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Load patient information as a Tuple - patient_inputs = ( - patient["concept_ids"].to(device), - patient["type_ids"].to(device), - patient["time_stamps"].to(device), - patient["ages"].to(device), - patient["visit_orders"].to(device), - patient["visit_segments"].to(device), - ) - patient_labels = patient["labels"].to(device) - patient_attention_mask = patient["attention_mask"].to(device) - - # Get model output predictions - model.to(device) - - return model( - inputs=patient_inputs, - attention_mask=patient_attention_mask, - labels=patient_labels, - ) diff --git a/odyssey/models/utils.py b/odyssey/models/utils.py deleted file mode 100644 index 72bd126..0000000 --- a/odyssey/models/utils.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Utility functions for the model.""" - -import os -import pickle -import random -import uuid -from os.path import join -from typing import Any - -import numpy as np -import pandas as pd -import pytorch_lightning as pl -import torch -import yaml - - -def load_config(config_dir: str, model_type: str) -> Any: - """Load the model configuration. - - Parameters - ---------- - config_dir: str - Directory containing the model configuration files - - model_type: str - Model type to load configuration for - - Returns - ------- - Any - Model configuration - - """ - config_file = join(config_dir, f"{model_type}.yaml") - with open(config_file, "r") as file: - return yaml.safe_load(file) - - -def seed_everything(seed: int) -> None: - """Seed all components of the model. - - Parameters - ---------- - seed: int - Seed value to use - - """ - random.seed(seed) - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - pl.seed_everything(seed) - - -def load_pretrain_data( - data_dir: str, - sequence_file: str, - id_file: str, -) -> pd.DataFrame: - """Load the pretraining data. - - Parameters - ---------- - data_dir: str - Directory containing the data files - sequence_file: str - Sequence file name - id_file: str - ID file name - - Returns - ------- - pd.DataFrame - Pretraining data - - """ - sequence_path = join(data_dir, sequence_file) - id_path = join(data_dir, id_file) - - if not os.path.exists(sequence_path): - raise FileNotFoundError(f"Sequence file not found: {sequence_path}") - - if not os.path.exists(id_path): - raise FileNotFoundError(f"ID file not found: {id_path}") - - data = pd.read_parquet(sequence_path) - with open(id_path, "rb") as file: - patient_ids = pickle.load(file) - - return data.loc[data["patient_id"].isin(patient_ids["pretrain"])] - - -def load_finetune_data( - data_dir: str, - sequence_file: str, - id_file: str, - valid_scheme: str, - num_finetune_patients: str, -) -> pd.DataFrame: - """Load the finetuning data. - - Parameters - ---------- - data_dir: str - Directory containing the data files - sequence_file: str - Sequence file name - id_file: str - ID file name - valid_scheme: str - Validation scheme - num_finetune_patients: str - Number of finetune patients - - Returns - ------- - pd.DataFrame - Finetuning data - - """ - sequence_path = join(data_dir, sequence_file) - id_path = join(data_dir, id_file) - - if not os.path.exists(sequence_path): - raise FileNotFoundError(f"Sequence file not found: {sequence_path}") - - if not os.path.exists(id_path): - raise FileNotFoundError(f"ID file not found: {id_path}") - - data = pd.read_parquet(sequence_path) - with open(id_path, "rb") as file: - patient_ids = pickle.load(file) - - fine_tune = data.loc[ - data["patient_id"].isin( - patient_ids["finetune"][valid_scheme][num_finetune_patients], - ) - ] - fine_test = data.loc[data["patient_id"].isin(patient_ids["test"])] - return fine_tune, fine_test - - -def get_run_id( - checkpoint_dir: str, - retrieve: bool = False, - run_id_file: str = "wandb_run_id.txt", - length: int = 8, -) -> str: - """Fetch the run ID for the current run. - - If the run ID file exists, retrieve the run ID from the file. - Otherwise, generate a new run ID and save it to the file. - - Parameters - ---------- - checkpoint_dir: str - Directory to store the run ID file - retrieve: bool, optional - Retrieve the run ID from the file, by default False - run_id_file: str, optional - Run ID file name, by default "wandb_run_id.txt" - length: int, optional - String length of the run ID, by default 8 - - Returns - ------- - str - Run ID for the current run - - """ - run_id_path = os.path.join(checkpoint_dir, run_id_file) - if retrieve and os.path.exists(run_id_path): - with open(run_id_path, "r") as file: - run_id = file.read().strip() - else: - run_id = str(uuid.uuid4())[:length] - with open(run_id_path, "w") as file: - file.write(run_id) - return run_id - - -def save_object_to_disk(obj: Any, save_path: str) -> None: - """Save an object to disk using pickle. - - Parameters - ---------- - obj: Any - Object to save - save_path: str - Path to save the object - - """ - with open(save_path, "wb") as f: - pickle.dump(obj, f) - print(f"File saved to disk: {save_path}") diff --git a/pretrain.py b/pretrain.py index 4627b34..143f10c 100644 --- a/pretrain.py +++ b/pretrain.py @@ -13,15 +13,15 @@ from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader +from odyssey.utils.utils import seed_everything from odyssey.data.dataset import PretrainDataset from odyssey.data.tokenizer import ConceptTokenizer from odyssey.models.cehr_bert.model import BertPretrain from odyssey.models.cehr_big_bird.model import BigBirdPretrain -from odyssey.models.utils import ( +from odyssey.models.model_utils import ( get_run_id, load_config, load_pretrain_data, - seed_everything, )