Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hv pv detection #44

Merged
merged 14 commits into from
Oct 29, 2024
2 changes: 1 addition & 1 deletion algorithm_catalog/eurac_pv_farm_detection.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"rel": "openeo-process",
"type": "application/json",
"title": "openEO Process Definition",
"href": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/refs/heads/hv_pv_detection/openeo_udp/eurac_pv_farm_detection/eurac_pv_farm_detection.json"
"href": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/4003046e3b79ec3ab8dace888a231655db389d66/openeo_udp/eurac_pv_farm_detection/udf_eurac_pvfarm_onnx.py"
HansVRP marked this conversation as resolved.
Show resolved Hide resolved
},
{
"rel": "git",
Expand Down
2 changes: 1 addition & 1 deletion benchmark_scenarios/eurac_pv_farm_detection.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"process_graph": {
"maxndvi1": {
"process_id": "eurac_pv_farm_detection",
"namespace": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/refs/heads/hv_pv_detection/openeo_udp/eurac_pv_farm_detection/eurac_pv_farm_detection.json",
"namespace": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/4003046e3b79ec3ab8dace888a231655db389d66/openeo_udp/eurac_pv_farm_detection/udf_eurac_pvfarm_onnx.py",
HansVRP marked this conversation as resolved.
Show resolved Hide resolved
"arguments": {
"bbox": {
"east": 16.414,
Expand Down
22 changes: 18 additions & 4 deletions openeo_udp/eurac_pv_farm_detection/udf_eurac_pvfarm_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
import os
import sys
import zipfile
Expand All @@ -7,7 +8,7 @@
import numpy as np
import requests
import xarray as xr
import logging


def _setup_logging():
logging.basicConfig(level=logging.INFO)
HansVRP marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -21,6 +22,7 @@ def _setup_logging():
DEPENDENCIES_DIR = "onnx_dependencies"
MODEL_DIR = "model_files"


def download_file(url, path):
"""
Downloads a file from the given URL to the specified path.
Expand All @@ -29,6 +31,7 @@ def download_file(url, path):
with open(path, "wb") as file:
file.write(response.content)


def extract_zip(zip_path, extract_to):
"""
Extracts a zip file from zip_path to the specified extract_to directory.
Expand All @@ -37,13 +40,15 @@ def extract_zip(zip_path, extract_to):
zip_ref.extractall(extract_to)
os.remove(zip_path) # Clean up the zip file after extraction


def add_directory_to_sys_path(directory):
"""
Adds a directory to the Python sys.path if it's not already present.
"""
if directory not in sys.path:
sys.path.append(directory)


def setup_model_and_dependencies(model_url, dependencies_url):
"""
Main function to set up the model and dependencies by downloading, extracting,
Expand Down Expand Up @@ -73,12 +78,16 @@ def setup_model_and_dependencies(model_url, dependencies_url):
download_file(model_url, zip_path)
extract_zip(zip_path, MODEL_DIR)

setup_model_and_dependencies(model_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/EURAC_pvfarm_rf_1_median_depth_15.zip",
dependencies_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_dependencies_1.16.3.zip")

setup_model_and_dependencies(
model_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/EURAC_pvfarm_rf_1_median_depth_15.zip",
dependencies_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_dependencies_1.16.3.zip",
)

# Add dependencies to the Python path
import onnxruntime as ort # Import after downloading dependencies


@functools.lru_cache(maxsize=5)
def load_onnx_model(model_name: str) -> ort.InferenceSession:
"""
Expand All @@ -90,6 +99,7 @@ def load_onnx_model(model_name: str) -> ort.InferenceSession:
f"{MODEL_DIR}/{model_name}", providers=["CPUExecutionProvider"]
)


def preprocess_input(
input_xr: xr.DataArray, ort_session: ort.InferenceSession
) -> tuple:
Expand All @@ -103,6 +113,7 @@ def preprocess_input(
input_np = input_np.astype(np.float32)
return input_np, input_shape


def run_inference(input_np: np.ndarray, ort_session: ort.InferenceSession) -> tuple:
"""
Run inference using the ONNX runtime session and return predicted labels and probabilities.
Expand All @@ -112,6 +123,7 @@ def run_inference(input_np: np.ndarray, ort_session: ort.InferenceSession) -> tu
predicted_labels = ort_outputs[0]
return predicted_labels


def postprocess_output(predicted_labels: np.ndarray, input_shape: tuple) -> tuple:
"""
Postprocess the output by reshaping the predicted labels and probabilities into the original spatial structure.
Expand All @@ -120,6 +132,7 @@ def postprocess_output(predicted_labels: np.ndarray, input_shape: tuple) -> tupl

return predicted_labels


def create_output_xarray(
predicted_labels: np.ndarray, input_xr: xr.DataArray
) -> xr.DataArray:
Expand All @@ -133,6 +146,7 @@ def create_output_xarray(
coords={"y": input_xr.coords["y"], "x": input_xr.coords["x"]},
)


def apply_model(input_xr: xr.DataArray) -> xr.DataArray:
"""
Run inference on the given input data using the provided ONNX runtime session.
Expand All @@ -143,7 +157,7 @@ def apply_model(input_xr: xr.DataArray) -> xr.DataArray:
# Step 1: Load the ONNX model
logger.info("load onnx model")
ort_session = load_onnx_model("EURAC_pvfarm_rf_1_median_depth_15.onnx")

# Step 2: Preprocess the input
logger.info("preprocess input")
input_np, input_shape = preprocess_input(input_xr, ort_session)
Expand Down
Loading