From 22dd6fcf9e4671f1799073d6758ee1b0bdf397e2 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 24 May 2024 10:01:40 +0100 Subject: [PATCH 1/2] Add notebook showing how to run v1 --- docs/clay-v1-wall-to-wall.ipynb | 1648 +++++++++++++++++++++++++++++++ 1 file changed, 1648 insertions(+) create mode 100644 docs/clay-v1-wall-to-wall.ipynb diff --git a/docs/clay-v1-wall-to-wall.ipynb b/docs/clay-v1-wall-to-wall.ipynb new file mode 100644 index 00000000..4dcac9b3 --- /dev/null +++ b/docs/clay-v1-wall-to-wall.ipynb @@ -0,0 +1,1648 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0cc5e729-9116-4ec9-bf1e-8346cbccdf7b", + "metadata": {}, + "source": [ + "## Run Clay v1\n", + "\n", + "This notebook shows how to run Clay v1 wall-to-wall, from downloading imagery\n", + "to training a tiny fine tuning head. This will include the following steps:\n", + "\n", + "1. Set a location and date range of interest\n", + "2. Download Sentinel-2 imagery for this specification\n", + "3. Load the model checkpoint\n", + "4. Prepare data into a format for the model\n", + "5. Run the model on the imagery\n", + "6. Analyise the model embeddings output using PCA\n", + "7. Train a Support Vector Machines fine tuning head" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "add63cd9", + "metadata": {}, + "outputs": [], + "source": [ + "# Add the repo root to the sys path for the model import below\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6a17b8a8-a9c6-4053-833e-de97287fae49", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import geopandas as gpd\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pystac_client\n", + "import stackstac\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from matplotlib import pyplot as plt\n", + "from rasterio.enums import Resampling\n", + "from shapely import Point\n", + "from sklearn import decomposition, svm\n", + "from stacchip.processors.prechip import normalize_timestamp\n", + "from torchvision.transforms import v2\n", + "\n", + "from src.model import ClayMAEModule" + ] + }, + { + "cell_type": "markdown", + "id": "beac6394-9762-422b-9f5d-82d226018c0c", + "metadata": {}, + "source": [ + "### Specify location and date of interest\n", + "In this example we will use a location in Portugal where a forest fire happened. We will run the model over the time period of the fire and analyse the model embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "08d7787d-1506-4de7-89dc-c1054910acf7", + "metadata": {}, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "markdown", + "id": "2bd226c9-003b-4867-a64a-8ae887e7e20a", + "metadata": {}, + "source": [ + "### Get data from STAC catalog\n", + "\n", + "Based on the location and date we can obtain a stack of imagery using stackstac. Let's start with finding the STAC items we want to analyse." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2e80743c-7c77-459b-9984-f6c26cdff549", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tam/apps/miniforge3/envs/claymodel/lib/python3.11/site-packages/pystac_client/item_search.py:850: FutureWarning: get_all_items() is deprecated, use item_collection() instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 12 items\n" + ] + } + ], + "source": [ + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "5b7c68ae-7c8a-446a-8bc7-5afba70183c2", + "metadata": {}, + "source": [ + "### Create a bounding box around the point of interest\n", + "\n", + "This is needed in the projection of the data so that we can generate image chips of the right size." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0f3573b5-5a00-47d9-a648-5c4d7cd2c996", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract coordinate system from first item\n", + "epsg = items[0].properties[\"proj:epsg\"]\n", + "\n", + "# Convert point of interest into the image projection\n", + "# (assumes all images are in the same projection)\n", + "poidf = gpd.GeoDataFrame(\n", + " pd.DataFrame(),\n", + " crs=\"EPSG:4326\",\n", + " geometry=[Point(lon, lat)],\n", + ").to_crs(epsg)\n", + "\n", + "coords = poidf.iloc[0].geometry.coords[0]\n", + "\n", + "# Create bounds in projection\n", + "size = 256\n", + "gsd = 10\n", + "bounds = (\n", + " coords[0] - (size * gsd) // 2,\n", + " coords[1] - (size * gsd) // 2,\n", + " coords[0] + (size * gsd) // 2,\n", + " coords[1] + (size * gsd) // 2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bbbd3f67-5f2c-46dc-9ee1-2ef1f50fa032", + "metadata": {}, + "source": [ + "### Retrieve the imagery data." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8b8d3824-e48c-4f9d-9c7b-181c0800f96f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Working with stack of size (12, 4, 256, 256)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'stackstac-7cbad7c129d678be53c9b6676bee564b' (time: 12,\n",
+       "                                                                band: 4,\n",
+       "                                                                y: 256, x: 256)> Size: 13MB\n",
+       "array([[[[ 9136.,  9232.,  9544., ...,  1258.,  1120.,   930.],\n",
+       "         [ 9616.,  9768.,  9840., ...,  1230.,  1208.,  1030.],\n",
+       "         [ 9992., 10008., 10000., ...,  1418.,  1336.,  1242.],\n",
+       "         ...,\n",
+       "         [  811.,   655.,   688., ...,   385.,   362.,   461.],\n",
+       "         [  798.,   675.,   727., ...,   394.,   415.,   402.],\n",
+       "         [  888.,   673.,   642., ...,   403.,   454.,   393.]],\n",
+       "\n",
+       "        [[ 8656.,  8656.,  8864., ...,  1500.,  1428.,  1220.],\n",
+       "         [ 9016.,  9160.,  9224., ...,  1546.,  1522.,  1360.],\n",
+       "         [ 9248.,  9328.,  9384., ...,  1620.,  1542.,  1482.],\n",
+       "         ...,\n",
+       "         [ 1010.,   831.,   853., ...,   277.,   276.,   336.],\n",
+       "         [ 1016.,   930.,   927., ...,   276.,   317.,   293.],\n",
+       "         [ 1112.,   885.,   827., ...,   299.,   369.,   293.]],\n",
+       "\n",
+       "        [[ 8416.,  8416.,  8640., ...,  1598.,  1466.,  1138.],\n",
+       "         [ 8744.,  8880.,  8928., ...,  1498.,  1522.,  1284.],\n",
+       "         [ 8952.,  8944.,  8960., ...,  1542.,  1478.,  1448.],\n",
+       "         ...,\n",
+       "...\n",
+       "         [  652.,   640.,   638., ...,   590.,   821.,  1008.],\n",
+       "         [  622.,   676.,   630., ...,   606.,  1092.,   726.],\n",
+       "         [  864.,   786.,   569., ...,   766.,  1068.,   630.]],\n",
+       "\n",
+       "        [[  201.,   213.,   195., ...,  1138.,  1058.,   749.],\n",
+       "         [  196.,   198.,   169., ...,   861.,   784.,   768.],\n",
+       "         [  216.,   178.,   191., ...,   870.,   806.,   820.],\n",
+       "         ...,\n",
+       "         [  857.,   838.,   846., ...,   622.,   800.,  1332.],\n",
+       "         [  922.,   848.,   771., ...,   786.,  1046.,   912.],\n",
+       "         [ 1118.,  1010.,   735., ...,   755.,   977.,   686.]],\n",
+       "\n",
+       "        [[ 3264.,  3352.,  3304., ...,  3160.,  3296.,  3376.],\n",
+       "         [ 3356.,  3300.,  3212., ...,  3188.,  3272.,  3064.],\n",
+       "         [ 3288.,  3372.,  3344., ...,  3136.,  3200.,  2932.],\n",
+       "         ...,\n",
+       "         [ 1320.,  1468.,  1298., ...,  2492.,  2556.,  3018.],\n",
+       "         [ 1630.,  1694.,  1250., ...,  2318.,  2684.,  2894.],\n",
+       "         [ 2190.,  2072.,  1288., ...,  2544.,  2942.,  2928.]]]],\n",
+       "      dtype=float32)\n",
+       "Coordinates: (12/53)\n",
+       "  * time                                     (time) datetime64[ns] 96B 2018-0...\n",
+       "    id                                       (time) <U24 1kB 'S2B_29SNB_20180...\n",
+       "  * band                                     (band) <U5 80B 'blue' ... 'nir'\n",
+       "  * x                                        (x) float64 2kB 5.366e+05 ... 5....\n",
+       "  * y                                        (y) float64 2kB 4.131e+06 ... 4....\n",
+       "    platform                                 (time) <U11 528B 'sentinel-2b' ....\n",
+       "    ...                                       ...\n",
+       "    gsd                                      int64 8B 10\n",
+       "    title                                    (band) <U20 320B 'Blue (band 2) ...\n",
+       "    common_name                              (band) <U5 80B 'blue' ... 'nir'\n",
+       "    center_wavelength                        (band) float64 32B 0.49 ... 0.842\n",
+       "    full_width_half_max                      (band) float64 32B 0.098 ... 0.145\n",
+       "    epsg                                     int64 8B 32629\n",
+       "Attributes:\n",
+       "    spec:        RasterSpec(epsg=32629, bounds=(536640.79691545, 4128000.7407...\n",
+       "    crs:         epsg:32629\n",
+       "    transform:   | 10.00, 0.00, 536640.80|\\n| 0.00,-10.00, 4130560.74|\\n| 0.0...\n",
+       "    resolution:  10
" + ], + "text/plain": [ + " Size: 13MB\n", + "array([[[[ 9136., 9232., 9544., ..., 1258., 1120., 930.],\n", + " [ 9616., 9768., 9840., ..., 1230., 1208., 1030.],\n", + " [ 9992., 10008., 10000., ..., 1418., 1336., 1242.],\n", + " ...,\n", + " [ 811., 655., 688., ..., 385., 362., 461.],\n", + " [ 798., 675., 727., ..., 394., 415., 402.],\n", + " [ 888., 673., 642., ..., 403., 454., 393.]],\n", + "\n", + " [[ 8656., 8656., 8864., ..., 1500., 1428., 1220.],\n", + " [ 9016., 9160., 9224., ..., 1546., 1522., 1360.],\n", + " [ 9248., 9328., 9384., ..., 1620., 1542., 1482.],\n", + " ...,\n", + " [ 1010., 831., 853., ..., 277., 276., 336.],\n", + " [ 1016., 930., 927., ..., 276., 317., 293.],\n", + " [ 1112., 885., 827., ..., 299., 369., 293.]],\n", + "\n", + " [[ 8416., 8416., 8640., ..., 1598., 1466., 1138.],\n", + " [ 8744., 8880., 8928., ..., 1498., 1522., 1284.],\n", + " [ 8952., 8944., 8960., ..., 1542., 1478., 1448.],\n", + " ...,\n", + "...\n", + " [ 652., 640., 638., ..., 590., 821., 1008.],\n", + " [ 622., 676., 630., ..., 606., 1092., 726.],\n", + " [ 864., 786., 569., ..., 766., 1068., 630.]],\n", + "\n", + " [[ 201., 213., 195., ..., 1138., 1058., 749.],\n", + " [ 196., 198., 169., ..., 861., 784., 768.],\n", + " [ 216., 178., 191., ..., 870., 806., 820.],\n", + " ...,\n", + " [ 857., 838., 846., ..., 622., 800., 1332.],\n", + " [ 922., 848., 771., ..., 786., 1046., 912.],\n", + " [ 1118., 1010., 735., ..., 755., 977., 686.]],\n", + "\n", + " [[ 3264., 3352., 3304., ..., 3160., 3296., 3376.],\n", + " [ 3356., 3300., 3212., ..., 3188., 3272., 3064.],\n", + " [ 3288., 3372., 3344., ..., 3136., 3200., 2932.],\n", + " ...,\n", + " [ 1320., 1468., 1298., ..., 2492., 2556., 3018.],\n", + " [ 1630., 1694., 1250., ..., 2318., 2684., 2894.],\n", + " [ 2190., 2072., 1288., ..., 2544., 2942., 2928.]]]],\n", + " dtype=float32)\n", + "Coordinates: (12/53)\n", + " * time (time) datetime64[ns] 96B 2018-0...\n", + " id (time) " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Run PCA\n", + "pca = decomposition.PCA(n_components=1)\n", + "pca_result = pca.fit_transform(embeddings)\n", + "\n", + "plt.xticks(rotation=-45)\n", + "\n", + "# Plot all points in blue first\n", + "plt.scatter(stack.time, pca_result, color=\"blue\")\n", + "\n", + "# Re-plot cloudy images in green\n", + "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", + "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", + "\n", + "# Color all images after fire in red\n", + "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" + ] + }, + { + "cell_type": "markdown", + "id": "b38b70a6-2156-41f8-967e-a490cc8e2778", + "metadata": {}, + "source": [ + "### And finally, some finetuning\n", + "\n", + "We are going to train a classifier head on the embeddings and use it to detect fires." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "1da07de0-b8f2-46c9-bd2a-58b15ca2224f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Matched 5 out of 5 correctly\n" + ] + } + ], + "source": [ + "# Label the images we downloaded\n", + "# 0 = Cloud\n", + "# 1 = Forest\n", + "# 2 = Fire\n", + "labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])\n", + "\n", + "# Split into fit and test manually, ensuring we have all 3 classes in both sets\n", + "fit = [0, 1, 3, 4, 7, 8, 9]\n", + "test = [2, 5, 6, 10, 11]\n", + "\n", + "# Train a support vector machine model\n", + "clf = svm.SVC()\n", + "clf.fit(embeddings[fit] + 100, labels[fit])\n", + "\n", + "# Predict classes on test set\n", + "prediction = clf.predict(embeddings[test] + 100)\n", + "\n", + "# Perfect match for SVM\n", + "match = np.sum(labels[test] == prediction)\n", + "print(f\"Matched {match} out of {len(test)} correctly\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "claymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 514215bfc20fa8af3ddcf47a1207489f06641251 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 24 May 2024 10:04:06 +0100 Subject: [PATCH 2/2] Add notebook to toc --- docs/_config.yml | 1 + docs/_toc.yml | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/_config.yml b/docs/_config.yml index 10147fa0..2e3abae7 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -15,6 +15,7 @@ execute: - partial-inputs.ipynb - tutorial_digital_earth_pacific_patch_level.ipynb - patch_level_cloud_cover.ipynb + - clay-v1-wall-to-wall.ipynb # Define the name of the latex output file for PDF builds latex: diff --git a/docs/_toc.yml b/docs/_toc.yml index 1704f2ec..f6fb351a 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -26,6 +26,8 @@ parts: file: data_sampling - caption: Running the model chapters: + - title: Clay v1 wall-to-wall example + file: clay-v1-wall-to-wall - title: Run over a region file: run_region - title: Generating embeddings