diff --git a/finetune/regression/biomasters_inference.ipynb b/finetune/regression/biomasters_inference.ipynb new file mode 100644 index 00000000..e921059f --- /dev/null +++ b/finetune/regression/biomasters_inference.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4e9c5b75-623e-447a-a62f-1bec5c2da0e7", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import warnings\n", + "\n", + "sys.path.append(\"../../\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6fc54a3-7123-4a29-ada8-0344665fd9d0", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from einops import rearrange\n", + "from finetune.regression.biomasters_datamodule import BioMastersDataModule\n", + "from finetune.regression.biomasters_model import BioMastersClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "cd1be1f5-2a55-47b3-8d55-5a87683eb4ba", + "metadata": {}, + "source": [ + "### Define paths and parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6c65362-dc3d-44c3-a992-fe15ad26d519", + "metadata": {}, + "outputs": [], + "source": [ + "BIOMASTERS_CHECKPOINT_PATH = \"../../checkpoints/regression/biomasters_epoch-33_val-score-36.606.ckpt\"\n", + "CLAY_CHECKPOINT_PATH = \"../../checkpoints/clay-v1-base.ckpt\"\n", + "METADATA_PATH = \"../../configs/metadata.yaml\"\n", + "\n", + "TRAIN_CHIP_DIR = \"../../data/biomasters/train_cube/\"\n", + "TRAIN_LABEL_DIR = \"../../data/biomasters/train_agbm/\"\n", + "VAL_CHIP_DIR = \"../../data/biomasters/test_cube/\"\n", + "VAL_LABEL_DIR = \"../../data/biomasters/test_agbm/\"\n", + "\n", + "BATCH_SIZE = 32\n", + "NUM_WORKERS = 1" + ] + }, + { + "cell_type": "markdown", + "id": "18665299-f505-4ae4-8c12-09d6dbce9d9c", + "metadata": {}, + "source": [ + "### Model Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d39eda04-771d-4a51-914a-0aa9eb5f54a8", + "metadata": {}, + "outputs": [], + "source": [ + "def get_model(biomasters_checkpoint_path, clay_checkpoint_path, metadata_path):\n", + " model = BioMastersClassifier.load_from_checkpoint(\n", + " checkpoint_path=biomasters_checkpoint_path,\n", + " metadata_path=metadata_path,\n", + " ckpt_path=clay_checkpoint_path,\n", + " )\n", + " model.eval()\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "4cfb9f4f-1765-480d-95dc-1def32459f95", + "metadata": {}, + "source": [ + "### Data Preparation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5738b71-ac76-4b3b-ac8f-92fdcfa49d20", + "metadata": {}, + "outputs": [], + "source": [ + "def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers):\n", + " dm = BioMastersDataModule(\n", + " train_chip_dir=train_chip_dir,\n", + " train_label_dir=train_label_dir,\n", + " val_chip_dir=val_chip_dir,\n", + " val_label_dir=val_label_dir,\n", + " metadata_path=metadata_path,\n", + " batch_size=batch_size,\n", + " num_workers=num_workers\n", + " )\n", + " dm.setup(stage=\"fit\")\n", + " val_dl = iter(dm.val_dataloader())\n", + " batch = next(val_dl)\n", + " metadata = dm.metadata\n", + " return batch, metadata" + ] + }, + { + "cell_type": "markdown", + "id": "adb781e5-077d-401b-9d99-e63cfc92ea1b", + "metadata": {}, + "source": [ + "### Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daff7782-c808-4639-89fe-4d18e28f9ec8", + "metadata": {}, + "outputs": [], + "source": [ + "def run_prediction(model, batch):\n", + " with torch.no_grad():\n", + " outputs = model(batch)\n", + " outputs = F.interpolate(outputs, size=(256, 256), mode='bilinear', align_corners=False)\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "id": "a7bf3e2c-9de1-487a-bc2e-438bf06de482", + "metadata": {}, + "source": [ + "### Post-Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77a30bde-c300-4d5d-a588-3d4dcd40e2e2", + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize_images(normalized_images, means, stds):\n", + " means = np.array(means).reshape(1, -1, 1, 1)\n", + " stds = np.array(stds).reshape(1, -1, 1, 1)\n", + " denormalized_images = normalized_images * stds + means\n", + " return denormalized_images\n", + "\n", + "def post_process(batch, outputs, metadata):\n", + " labels = batch[\"label\"].detach().cpu().numpy()\n", + " pixels = batch[\"pixels\"].detach().cpu().numpy()\n", + " outputs = outputs.detach().cpu().numpy()\n", + "\n", + " means = list(metadata[\"sentinel-2-l2a\"].bands.mean.values())\n", + " stds = list(metadata[\"sentinel-2-l2a\"].bands.std.values())\n", + " norm_pixels = denormalize_images(pixels, means, stds)\n", + " \n", + " images = rearrange(norm_pixels[:, :3, :, :], \"b c h w -> b h w c\")\n", + " \n", + " labels = np.clip(labels.squeeze(axis=1), 0, 400)\n", + " outputs = np.clip(outputs.squeeze(axis=1), 0, 400)\n", + " images = np.clip(images / 2000, 0, 1)\n", + "\n", + " return images, labels, outputs" + ] + }, + { + "cell_type": "markdown", + "id": "fac21c4c-88c6-43aa-8f1e-f8260ceb213c", + "metadata": {}, + "source": [ + "### Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "140783b8-84ad-4ad3-9a10-1fce8ff0db80", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_predictions(images, labels, outputs):\n", + " fig, axes = plt.subplots(12, 8, figsize=(12, 18))\n", + "\n", + " # Plot the images\n", + " plot_data(axes, images, row_offset=0, title=\"Image\")\n", + "\n", + " # Plot the actual segmentation maps\n", + " plot_data(axes, labels, row_offset=1, title=\"Actual\")\n", + "\n", + " # Plot the predicted segmentation maps\n", + " plot_data(axes, outputs, row_offset=2, title=\"Pred\")\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "def plot_data(ax, data, row_offset, title=None):\n", + " for i, item in enumerate(data):\n", + " row = row_offset + (i // 8) * 3\n", + " col = i % 8\n", + " ax[row, col].imshow(item, cmap='cividis')\n", + " ax[row, col].axis('off')\n", + " if title and col == 0:\n", + " ax[row, col].set_title(title, rotation=0, fontsize=12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f06754b6-080e-4be3-b433-cd1cf8b760c1", + "metadata": {}, + "outputs": [], + "source": [ + "# Load model\n", + "model = get_model(BIOMASTERS_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc82cbf6-45cd-46fb-b58e-0959029732be", + "metadata": {}, + "outputs": [], + "source": [ + "# Get data\n", + "batch, metadata = get_data(\n", + " TRAIN_CHIP_DIR, TRAIN_LABEL_DIR, VAL_CHIP_DIR, VAL_LABEL_DIR, METADATA_PATH, BATCH_SIZE, NUM_WORKERS\n", + ")\n", + "# Move batch to GPU\n", + "batch = {k: v.to(\"cuda\") for k, v in batch.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "611a2b9e-8181-4033-b63c-80a6018ccbb9", + "metadata": {}, + "outputs": [], + "source": [ + "# Run prediction\n", + "outputs = run_prediction(model, batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da3e1034-3cf7-45ff-afd7-a3f1e748d829", + "metadata": {}, + "outputs": [], + "source": [ + "# Post-process the results\n", + "images, labels, outputs = post_process(batch, outputs, metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab0a9497-8769-40c7-b5c6-de3f2e5aac14", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the predictions\n", + "plot_predictions(images, labels, outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffa83c77-9e2e-4e8c-aa30-0966c38ccef9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/finetune/segment/chesapeake_inference.ipynb b/finetune/segment/chesapeake_inference.ipynb new file mode 100644 index 00000000..05ffc1b2 --- /dev/null +++ b/finetune/segment/chesapeake_inference.ipynb @@ -0,0 +1,318 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "25d69a7d-5f0e-453a-8a7d-8ef4b100e72b", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import warnings\n", + "\n", + "sys.path.append(\"../../\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34608fe0-9c89-4b39-b0b7-59d74efafdbe", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from einops import rearrange\n", + "from matplotlib.colors import ListedColormap\n", + "from finetune.segment.chesapeake_model import ChesapeakeSegmentor\n", + "from finetune.segment.chesapeake_datamodule import ChesapeakeDataModule" + ] + }, + { + "cell_type": "markdown", + "id": "8873272f-89e7-48de-9115-7c9d21b62c1f", + "metadata": {}, + "source": [ + "### Define paths and parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1ea85c6-5086-42b2-b032-489890554d84", + "metadata": {}, + "outputs": [], + "source": [ + "CHESAPEAKE_CHECKPOINT_PATH = \"../../checkpoints/segment/chesapeake-7class-segment_epoch-09_val-iou-0.8751.ckpt\"\n", + "CLAY_CHECKPOINT_PATH = \"../../checkpoints/clay-v1-base.ckpt\"\n", + "METADATA_PATH = \"../../configs/metadata.yaml\"\n", + "\n", + "TRAIN_CHIP_DIR = \"../../data/cvpr/ny/train/chips/\"\n", + "TRAIN_LABEL_DIR = \"../../data/cvpr/ny/train/labels/\"\n", + "VAL_CHIP_DIR = \"../../data/cvpr/ny/val/chips/\"\n", + "VAL_LABEL_DIR = \"../../data/cvpr/ny/val/labels/\"\n", + "\n", + "BATCH_SIZE = 32\n", + "NUM_WORKERS = 1\n", + "PLATFORM = \"naip\"" + ] + }, + { + "cell_type": "markdown", + "id": "cc278db5-e241-4763-8f33-bdeb5b0f81fc", + "metadata": {}, + "source": [ + "### Model Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b0da577-f3e5-485a-bbc5-a3ff7367e670", + "metadata": {}, + "outputs": [], + "source": [ + "def get_model(chesapeake_checkpoint_path, clay_checkpoint_path, metadata_path):\n", + " model = ChesapeakeSegmentor.load_from_checkpoint(\n", + " checkpoint_path=chesapeake_checkpoint_path,\n", + " metadata_path=metadata_path,\n", + " ckpt_path=clay_checkpoint_path\n", + " )\n", + " model.eval()\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "2d9ba7fc-f1ca-465c-be66-15edca8e8419", + "metadata": {}, + "source": [ + "### Data Preparation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3402cf0a-cb9b-47c4-a12a-bb704912edfd", + "metadata": {}, + "outputs": [], + "source": [ + "def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, metadata_path, batch_size, num_workers, platform):\n", + " dm = ChesapeakeDataModule(\n", + " train_chip_dir=train_chip_dir,\n", + " train_label_dir=train_label_dir,\n", + " val_chip_dir=val_chip_dir,\n", + " val_label_dir=val_label_dir,\n", + " metadata_path=metadata_path,\n", + " batch_size=batch_size,\n", + " num_workers=num_workers,\n", + " platform=platform\n", + " )\n", + " dm.setup(stage=\"fit\")\n", + " val_dl = iter(dm.val_dataloader())\n", + " batch = next(val_dl)\n", + " metadata = dm.metadata\n", + " return batch, metadata" + ] + }, + { + "cell_type": "markdown", + "id": "ea94afc8-c507-41b8-a3be-dd130ff90c72", + "metadata": {}, + "source": [ + "### Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7d71514-47b0-447b-899b-5aef44c38bc2", + "metadata": {}, + "outputs": [], + "source": [ + "def run_prediction(model, batch):\n", + " with torch.no_grad():\n", + " outputs = model(batch)\n", + " outputs = F.interpolate(outputs, size=(224, 224), mode='bilinear', align_corners=False)\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "id": "2a64735f-70b1-4d05-acd9-2a0812545cfa", + "metadata": {}, + "source": [ + "### Post-Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d69561e-b7ab-4f4d-b426-2d0cccc949f3", + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize_images(normalized_images, means, stds):\n", + " means = np.array(means).reshape(1, -1, 1, 1)\n", + " stds = np.array(stds).reshape(1, -1, 1, 1)\n", + " denormalized_images = normalized_images * stds + means\n", + " return denormalized_images.astype(np.uint8) # Do for NAIP/LINZ\n", + " \n", + "def post_process(batch, outputs, metadata):\n", + " preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()\n", + " labels = batch[\"label\"].detach().cpu().numpy()\n", + " pixels = batch[\"pixels\"].detach().cpu().numpy()\n", + "\n", + " means = list(metadata[\"naip\"].bands.mean.values())\n", + " stds = list(metadata[\"naip\"].bands.std.values())\n", + " norm_pixels = denormalize_images(pixels, means, stds)\n", + " \n", + " images = rearrange(norm_pixels[:, :3, :, :], \"b c h w -> b h w c\")\n", + " \n", + " return images, labels, preds" + ] + }, + { + "cell_type": "markdown", + "id": "ef86d23c-eca7-458a-99ef-fff4534b927e", + "metadata": {}, + "source": [ + "### Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "368b1925-be0f-47a5-bbb9-c642c3f04afa", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_predictions(images, labels, preds):\n", + " colors = [\n", + " (0 / 255, 0 / 255, 255 / 255, 1), # Deep Blue for water\n", + " (34 / 255, 139 / 255, 34 / 255, 1), # Forest Green for tree canopy / forest\n", + " (154 / 255, 205 / 255, 50 / 255, 1), # Yellow Green for low vegetation / field\n", + " (210 / 255, 180 / 255, 140 / 255, 1), # Tan for barren land\n", + " (169 / 255, 169 / 255, 169 / 255, 1), # Dark Gray for impervious (other)\n", + " (105 / 255, 105 / 255, 105 / 255, 1), # Dim Gray for impervious (road)\n", + " (255 / 255, 255 / 255, 255 / 255, 1) # White for no data\n", + " ]\n", + " cmap = ListedColormap(colors)\n", + "\n", + " fig, axes = plt.subplots(12, 8, figsize=(12, 18))\n", + "\n", + " # Plot the images\n", + " plot_data(axes, images, row_offset=0, title=\"Image\")\n", + "\n", + " # Plot the actual segmentation maps\n", + " plot_data(axes, labels, row_offset=1, title=\"Actual\", cmap=cmap, vmin=0, vmax=6)\n", + "\n", + " # Plot the predicted segmentation maps\n", + " plot_data(axes, preds, row_offset=2, title=\"Pred\", cmap=cmap, vmin=0, vmax=6)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "def plot_data(ax, data, row_offset, title=None, cmap=None, vmin=None, vmax=None):\n", + " for i, item in enumerate(data):\n", + " row = row_offset + (i // 8) * 3\n", + " col = i % 8\n", + " ax[row, col].imshow(item, cmap=cmap, vmin=vmin, vmax=vmax)\n", + " ax[row, col].axis('off')\n", + " if title and col == 0:\n", + " ax[row, col].set_title(title, rotation=0, fontsize=12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30d9b66b-ea25-4697-83be-776abb40db9b", + "metadata": {}, + "outputs": [], + "source": [ + "# Load model\n", + "model = get_model(CHESAPEAKE_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac07a050-c55d-4392-9461-a16afdb65f8f", + "metadata": {}, + "outputs": [], + "source": [ + "# Get data\n", + "batch, metadata = get_data(\n", + " TRAIN_CHIP_DIR, TRAIN_LABEL_DIR, VAL_CHIP_DIR, VAL_LABEL_DIR, METADATA_PATH, BATCH_SIZE, NUM_WORKERS, PLATFORM\n", + ")\n", + "# Move batch to GPU\n", + "batch = {k: v.to(\"cuda\") for k, v in batch.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f41e9f1-9387-4b97-ab6b-9273dd80f317", + "metadata": {}, + "outputs": [], + "source": [ + "# Run prediction\n", + "outputs = run_prediction(model, batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "648f7d5b-5ec1-4f43-8cf9-32d9b79efe00", + "metadata": {}, + "outputs": [], + "source": [ + "# Post-process the results\n", + "images, labels, preds = post_process(batch, outputs, metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b110995b-46d9-416a-a42c-d53d2955671e", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the predictions\n", + "plot_predictions(images, labels, preds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cfade62-f5ef-4b09-9b8d-f084e6dab075", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}