Skip to content

Commit

Permalink
Merge latest from novelty
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Apr 16, 2024
1 parent 772a05c commit 9686a38
Show file tree
Hide file tree
Showing 17 changed files with 2,679 additions and 1,083 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.3.1'
rev: 'v0.3.7'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
45 changes: 10 additions & 35 deletions evaluation/AttentionVisualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,15 @@
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from typing import Any, Dict\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"import plotly.figure_factory as ff\n",
"import plotly.graph_objects as go\n",
"import seaborn as sns\n",
"\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torch.utils.data import Subset\n",
"from lightning.pytorch.loggers import WandbLogger\n",
"from pytorch_lightning.callbacks import (\n",
" EarlyStopping,\n",
" LearningRateMonitor,\n",
" ModelCheckpoint,\n",
")\n",
"from pytorch_lightning.strategies.ddp import DDPStrategy\n",
"from pytorch_lightning.strategies import DeepSpeedStrategy\n",
"from sklearn.model_selection import train_test_split\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data import Subset\n",
"\n",
"from bertviz.transformers_neuron_view import BertModel, BertTokenizer\n",
"from bertviz import head_view, model_view\n",
"from bertviz.neuron_view import show\n",
"from transformers import AutoTokenizer, AutoModel, utils\n",
"from bertviz import model_view, head_view\n",
"from torch.utils.data import DataLoader, Subset\n",
"from transformers import utils\n",
"\n",
"\n",
"utils.logging.set_verbosity_error() # Suppress standard warnings\n",
"\n",
Expand All @@ -82,16 +62,11 @@
"os.chdir(ROOT)\n",
"\n",
"from lib.data import FinetuneDataset\n",
"from lib.prediction import load_finetuned_model, predict_patient_outcomes\n",
"from lib.tokenizer import ConceptTokenizer\n",
"from lib.utils import (\n",
" get_run_id,\n",
" load_config,\n",
" load_finetune_data,\n",
" seed_everything,\n",
")\n",
"from lib.prediction import load_finetuned_model, predict_patient_outcomes\n",
"from models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain\n",
"from models.cehr_bert.model import BertFinetune, BertPretrain"
")"
]
},
{
Expand Down Expand Up @@ -996,7 +971,7 @@
"\n",
"for i in range(len(attention_matrix)):\n",
" truncated_attention_matrix.append(\n",
" attention_matrix[i][:, :, :truncate_at, :truncate_at]\n",
" attention_matrix[i][:, :, :truncate_at, :truncate_at],\n",
" )\n",
"\n",
"truncated_attention_matrix = tuple(truncated_attention_matrix)\n",
Expand Down Expand Up @@ -3218,7 +3193,7 @@
" textangle=-90,\n",
" bgcolor=\"red\",\n",
" opacity=0.8,\n",
" )\n",
" ),\n",
" )\n",
"\n",
" # Plot the attention matrix as a heatmap\n",
Expand All @@ -3231,7 +3206,7 @@
" hoverinfo=\"text\",\n",
" text=hover_text,\n",
" colorscale=\"YlGnBu\",\n",
" )\n",
" ),\n",
" )\n",
"\n",
" fig.update_layout(\n",
Expand All @@ -3256,7 +3231,7 @@
" print(\n",
" f\"Token {tokenizer.id_to_token(concept_ids[token1])} \"\n",
" f\"with Token {tokenizer.id_to_token(concept_ids[token2])}: \"\n",
" f\"Attention Value {attention_value:.3f}\"\n",
" f\"Attention Value {attention_value:.3f}\",\n",
" )\n",
"\n",
" fig.show()\n",
Expand Down
Loading

0 comments on commit 9686a38

Please sign in to comment.