Skip to content

Commit

Permalink
Add inference notebooks for biomasters & chesapeake
Browse files Browse the repository at this point in the history
  • Loading branch information
srmsoumya committed Jul 1, 2024
1 parent 4defc00 commit b595481
Show file tree
Hide file tree
Showing 2 changed files with 626 additions and 0 deletions.
308 changes: 308 additions & 0 deletions finetune/regression/biomasters_inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4e9c5b75-623e-447a-a62f-1bec5c2da0e7",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import warnings\n",
"\n",
"sys.path.append(\"../../\")\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6fc54a3-7123-4a29-ada8-0344665fd9d0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from einops import rearrange\n",
"from finetune.regression.biomasters_datamodule import BioMastersDataModule\n",
"from finetune.regression.biomasters_model import BioMastersClassifier"
]
},
{
"cell_type": "markdown",
"id": "cd1be1f5-2a55-47b3-8d55-5a87683eb4ba",
"metadata": {},
"source": [
"### Define paths and parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6c65362-dc3d-44c3-a992-fe15ad26d519",
"metadata": {},
"outputs": [],
"source": [
"BIOMASTERS_CHECKPOINT_PATH = \"../../checkpoints/regression/biomasters_epoch-33_val-score-36.606.ckpt\"\n",
"CLAY_CHECKPOINT_PATH = \"../../checkpoints/clay-v1-base.ckpt\"\n",
"METADATA_PATH = \"../../configs/metadata.yaml\"\n",
"\n",
"TRAIN_CHIP_DIR = \"../../data/biomasters/train_cube/\"\n",
"TRAIN_LABEL_DIR = \"../../data/biomasters/train_agbm/\"\n",
"VAL_CHIP_DIR = \"../../data/biomasters/test_cube/\"\n",
"VAL_LABEL_DIR = \"../../data/biomasters/test_agbm/\"\n",
"\n",
"BATCH_SIZE = 32\n",
"NUM_WORKERS = 1"
]
},
{
"cell_type": "markdown",
"id": "18665299-f505-4ae4-8c12-09d6dbce9d9c",
"metadata": {},
"source": [
"### Model Loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d39eda04-771d-4a51-914a-0aa9eb5f54a8",
"metadata": {},
"outputs": [],
"source": [
"def get_model(biomasters_checkpoint_path, clay_checkpoint_path, metadata_path):\n",
" model = BioMastersClassifier.load_from_checkpoint(\n",
" checkpoint_path=biomasters_checkpoint_path,\n",
" metadata_path=metadata_path,\n",
" ckpt_path=clay_checkpoint_path,\n",
" )\n",
" model.eval()\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "4cfb9f4f-1765-480d-95dc-1def32459f95",
"metadata": {},
"source": [
"### Data Preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5738b71-ac76-4b3b-ac8f-92fdcfa49d20",
"metadata": {},
"outputs": [],
"source": [
"def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers):\n",
" dm = BioMastersDataModule(\n",
" train_chip_dir=train_chip_dir,\n",
" train_label_dir=train_label_dir,\n",
" val_chip_dir=val_chip_dir,\n",
" val_label_dir=val_label_dir,\n",
" metadata_path=metadata_path,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers\n",
" )\n",
" dm.setup(stage=\"fit\")\n",
" val_dl = iter(dm.val_dataloader())\n",
" batch = next(val_dl)\n",
" metadata = dm.metadata\n",
" return batch, metadata"
]
},
{
"cell_type": "markdown",
"id": "adb781e5-077d-401b-9d99-e63cfc92ea1b",
"metadata": {},
"source": [
"### Prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "daff7782-c808-4639-89fe-4d18e28f9ec8",
"metadata": {},
"outputs": [],
"source": [
"def run_prediction(model, batch):\n",
" with torch.no_grad():\n",
" outputs = model(batch)\n",
" outputs = F.interpolate(outputs, size=(256, 256), mode='bilinear', align_corners=False)\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"id": "a7bf3e2c-9de1-487a-bc2e-438bf06de482",
"metadata": {},
"source": [
"### Post-Processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77a30bde-c300-4d5d-a588-3d4dcd40e2e2",
"metadata": {},
"outputs": [],
"source": [
"def denormalize_images(normalized_images, means, stds):\n",
" means = np.array(means).reshape(1, -1, 1, 1)\n",
" stds = np.array(stds).reshape(1, -1, 1, 1)\n",
" denormalized_images = normalized_images * stds + means\n",
" return denormalized_images\n",
"\n",
"def post_process(batch, outputs, metadata):\n",
" labels = batch[\"label\"].detach().cpu().numpy()\n",
" pixels = batch[\"pixels\"].detach().cpu().numpy()\n",
" outputs = outputs.detach().cpu().numpy()\n",
"\n",
" means = list(metadata[\"sentinel-2-l2a\"].bands.mean.values())\n",
" stds = list(metadata[\"sentinel-2-l2a\"].bands.std.values())\n",
" norm_pixels = denormalize_images(pixels, means, stds)\n",
" \n",
" images = rearrange(norm_pixels[:, :3, :, :], \"b c h w -> b h w c\")\n",
" \n",
" labels = np.clip(labels.squeeze(axis=1), 0, 400)\n",
" outputs = np.clip(outputs.squeeze(axis=1), 0, 400)\n",
" images = np.clip(images / 2000, 0, 1)\n",
"\n",
" return images, labels, outputs"
]
},
{
"cell_type": "markdown",
"id": "fac21c4c-88c6-43aa-8f1e-f8260ceb213c",
"metadata": {},
"source": [
"### Plotting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "140783b8-84ad-4ad3-9a10-1fce8ff0db80",
"metadata": {},
"outputs": [],
"source": [
"def plot_predictions(images, labels, outputs):\n",
" fig, axes = plt.subplots(12, 8, figsize=(12, 18))\n",
"\n",
" # Plot the images\n",
" plot_data(axes, images, row_offset=0, title=\"Image\")\n",
"\n",
" # Plot the actual segmentation maps\n",
" plot_data(axes, labels, row_offset=1, title=\"Actual\")\n",
"\n",
" # Plot the predicted segmentation maps\n",
" plot_data(axes, outputs, row_offset=2, title=\"Pred\")\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def plot_data(ax, data, row_offset, title=None):\n",
" for i, item in enumerate(data):\n",
" row = row_offset + (i // 8) * 3\n",
" col = i % 8\n",
" ax[row, col].imshow(item, cmap='cividis')\n",
" ax[row, col].axis('off')\n",
" if title and col == 0:\n",
" ax[row, col].set_title(title, rotation=0, fontsize=12)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f06754b6-080e-4be3-b433-cd1cf8b760c1",
"metadata": {},
"outputs": [],
"source": [
"# Load model\n",
"model = get_model(BIOMASTERS_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc82cbf6-45cd-46fb-b58e-0959029732be",
"metadata": {},
"outputs": [],
"source": [
"# Get data\n",
"batch, metadata = get_data(\n",
" TRAIN_CHIP_DIR, TRAIN_LABEL_DIR, VAL_CHIP_DIR, VAL_LABEL_DIR, METADATA_PATH, BATCH_SIZE, NUM_WORKERS\n",
")\n",
"# Move batch to GPU\n",
"batch = {k: v.to(\"cuda\") for k, v in batch.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "611a2b9e-8181-4033-b63c-80a6018ccbb9",
"metadata": {},
"outputs": [],
"source": [
"# Run prediction\n",
"outputs = run_prediction(model, batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da3e1034-3cf7-45ff-afd7-a3f1e748d829",
"metadata": {},
"outputs": [],
"source": [
"# Post-process the results\n",
"images, labels, outputs = post_process(batch, outputs, metadata)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab0a9497-8769-40c7-b5c6-de3f2e5aac14",
"metadata": {},
"outputs": [],
"source": [
"# Plot the predictions\n",
"plot_predictions(images, labels, outputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffa83c77-9e2e-4e8c-aa30-0966c38ccef9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit b595481

Please sign in to comment.