From a79221a2f94820a1216037250b37bf1ab9c388e9 Mon Sep 17 00:00:00 2001 From: YannisZa Date: Mon, 22 Apr 2024 15:46:55 +0100 Subject: [PATCH] added metric evaluation --- .../Reading outputs (work in progress).ipynb | 273 +++++++++++------- 1 file changed, 161 insertions(+), 112 deletions(-) diff --git a/notebooks/Reading outputs (work in progress).ipynb b/notebooks/Reading outputs (work in progress).ipynb index f314def..bd1db71 100644 --- a/notebooks/Reading outputs (work in progress).ipynb +++ b/notebooks/Reading outputs (work in progress).ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "c76b8620", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_282017/2303974116.py:12: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", - " from IPython.core.display import display, HTML\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import glob\n", @@ -41,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "b718b815", "metadata": {}, "outputs": [], @@ -218,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "e21669cd", "metadata": {}, "outputs": [], @@ -246,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "41e02907", "metadata": {}, "outputs": [], @@ -273,59 +264,17 @@ " \"group_by\":[],\n", " \"filename_ending\":\"test\",\n", " \"sample\":[\"intensity\",\"table\"],\n", - " \"validation_data\":{\"test_cells\":\"../data/inputs/DC/train_cells.txt\"},\n", + " \"validation_data\":{\"test_cells\":\"../data/inputs/DC/test_cells.txt\"},\n", " \"force_reload\":False\n", "}" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "833a9fad", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "05:32.509 config INFO ----------------------------------------------------------------------------------\n", - "05:32.517 config INFO Parameter space size: \n", - " --- sigma: ['sigma', 'to_learn'] (3)\n", - "05:32.526 config INFO Total = 3.\n", - "05:32.534 config INFO ----------------------------------------------------------------------------------\n", - "05:32.553 outputs INFO //////////////////////////////////////////////////////////////////////////////////\n", - "05:32.561 outputs INFO Slicing coordinates:\n", - "05:32.570 outputs INFO loss_name == str(['dest_attraction_ts_likelihood_loss'])\n", - "05:32.578 outputs INFO //////////////////////////////////////////////////////////////////////////////////\n", - "05:32.587 outputs INFO Reading samples alpha, beta, log_destination_attraction, table.\n", - "05:54.974 outputs INFO Creating Data Collection for each group. \n", - "Grouping/Initialising Data Collection samples sequentially: 100%|██████████| 12/12 [00:00<00:00, 69615.00it/s]\n", - "Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 82782.32it/s]\n", - "Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73584.28it/s]\n", - "Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73584.28it/s]\n", - "Combining Data Collection group elements: 100%|██████████| 3/3 [00:00<00:00, 73156.47it/s]\n", - "Slicing coordinates sequentially: 25%|██▌ | 3/12 [00:00<00:00, 10.59it/s]05:55.284 outputs INFO table: 12 collection ids kept out of 12.\n", - "05:55.292 outputs INFO log_destination_attraction: 12 collection ids kept out of 12.\n", - "05:55.300 outputs INFO beta: 12 collection ids kept out of 12.\n", - "05:55.308 outputs INFO alpha: 12 collection ids kept out of 12.\n", - " " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "3 experiments matched\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r" - ] - } - ], + "outputs": [], "source": [ "# Initialise outputs\n", "current_sweep_outputs = Outputs(\n", @@ -453,24 +402,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "d5eb796d", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "# Sweeps: 3\n", - "ItemsView(Coordinates:\n", - " * id (id) object MultiIndex\n", - " * iter (id) int32 1 2 3 4 5 6 7 ... 99995 99996 99997 99998 99999 100000\n", - " * sweep (sweep) object MultiIndex\n", - " * sigma (sweep) object 'none'\n", - " * to_learn (sweep) object \"['alpha', 'beta', 'sigma']\")\n" - ] - } - ], + "outputs": [], "source": [ "index = 0\n", "current_data = current_sweep_outputs.get(index)\n", @@ -480,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "fa8fb841", "metadata": {}, "outputs": [], @@ -489,60 +424,174 @@ " config = current_data.config\n", ")\n", "ins.cast_to_xarray()\n", - "test_cells = current_data.get_sample('test_cells')" + "test_cells = current_data.get_sample('test_cells')\n", + "train_cells = current_data.get_sample('train_cells')" ] }, { "cell_type": "code", - "execution_count": 91, + "execution_count": null, "id": "1a1e911f", "metadata": {}, "outputs": [], "source": [ - "test_error = srmse(\n", - " prediction = current_data.data.table.mean('id'),\n", + "all_table_error = srmse(\n", + " prediction = current_data.data.table.mean('id',dtype='float64'),\n", + " ground_truth = ins.data.ground_truth_table\n", + ")\n", + "train_table_error = srmse(\n", + " prediction = current_data.data.table.mean('id',dtype='float64'),\n", " ground_truth = ins.data.ground_truth_table,\n", - " test_cells = test_cells\n", + " cells = train_cells\n", ")\n", - "all_error = srmse(\n", - " prediction = current_data.data.table.mean('id'),\n", + "test_table_error = srmse(\n", + " prediction = current_data.data.table.mean('id',dtype='float64'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " cells = test_cells\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc905be7", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " all_table_error.values.squeeze().item(),\n", + " train_table_error.values.squeeze().item(),\n", + " test_table_error.values.squeeze().item()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7090ffd", + "metadata": {}, + "outputs": [], + "source": [ + "all_intensity_error = srmse(\n", + " prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n", " ground_truth = ins.data.ground_truth_table\n", + ")\n", + "train_intensity_error = srmse(\n", + " prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " cells = train_cells\n", + ")\n", + "test_intensity_error = srmse(\n", + " prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " cells = test_cells\n", ")" ] }, { "cell_type": "code", - "execution_count": 90, + "execution_count": null, + "id": "dcf96b7f", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " all_intensity_error.values.squeeze().item(),\n", + " train_intensity_error.values.squeeze().item(),\n", + " test_intensity_error.values.squeeze().item()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "ca0eecab", "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [90]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_cp \u001b[38;5;241m=\u001b[39m \u001b[43mcoverage_probability\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mprediction\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcurrent_data\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mground_truth\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mins\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mground_truth_table\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mregion_mass\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.95\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_cells\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtest_cells\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m all_cp \u001b[38;5;241m=\u001b[39m coverage_probability(\n\u001b[1;32m 8\u001b[0m prediction \u001b[38;5;241m=\u001b[39m current_data\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mtable,\n\u001b[1;32m 9\u001b[0m ground_truth \u001b[38;5;241m=\u001b[39m ins\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mground_truth_table,\n\u001b[1;32m 10\u001b[0m region_mass \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.95\u001b[39m\n\u001b[1;32m 11\u001b[0m )\n", - "File \u001b[0;32m~/GeNSIT/gensit/utils/math_utils.py:323\u001b[0m, in \u001b[0;36mcoverage_probability\u001b[0;34m(prediction, ground_truth, **kwargs)\u001b[0m\n\u001b[1;32m 320\u001b[0m stacked_dims \u001b[38;5;241m=\u001b[39m deepcopy(prediction\u001b[38;5;241m.\u001b[39mdims)\n\u001b[1;32m 322\u001b[0m \u001b[38;5;66;03m# Sort all samples by iteration-seed\u001b[39;00m\n\u001b[0;32m--> 323\u001b[0m prediction[:] \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msort\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprediction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mstacked_dims\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mid\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;66;03m# Get lower and upper bound high posterior density regions\u001b[39;00m\n\u001b[1;32m 326\u001b[0m lower_bound_hpdr,upper_bound_hpdr \u001b[38;5;241m=\u001b[39m calculate_min_interval(\n\u001b[1;32m 327\u001b[0m prediction,\n\u001b[1;32m 328\u001b[0m alpha\n\u001b[1;32m 329\u001b[0m )\n", - "File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36msort\u001b[0;34m(*args, **kwargs)\u001b[0m\n", - "File \u001b[0;32m~/miniconda3/envs/gensit/lib/python3.10/site-packages/numpy/core/fromnumeric.py:1003\u001b[0m, in \u001b[0;36msort\u001b[0;34m(a, axis, kind, order)\u001b[0m\n\u001b[1;32m 1001\u001b[0m axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1002\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1003\u001b[0m a \u001b[38;5;241m=\u001b[39m \u001b[43masanyarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mK\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1004\u001b[0m a\u001b[38;5;241m.\u001b[39msort(axis\u001b[38;5;241m=\u001b[39maxis, kind\u001b[38;5;241m=\u001b[39mkind, order\u001b[38;5;241m=\u001b[39morder)\n\u001b[1;32m 1005\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m a\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "# test_cp = coverage_probability(\n", - "# prediction = current_data.data.table,\n", - "# ground_truth = ins.data.ground_truth_table,\n", - "# region_mass = 0.95,\n", - "# test_cells = test_cells\n", - "# )\n", - "# all_cp = coverage_probability(\n", - "# prediction = current_data.data.table,\n", - "# ground_truth = ins.data.ground_truth_table,\n", - "# region_mass = 0.95\n", - "# )" + "outputs": [], + "source": [ + "all_table_cp = coverage_probability(\n", + " prediction = current_data.data.table,\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95\n", + ")\n", + "train_table_cp = coverage_probability(\n", + " prediction = current_data.data.table,\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95,\n", + " cells = train_cells\n", + ")\n", + "test_table_cp = coverage_probability(\n", + " prediction = current_data.data.table,\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95,\n", + " cells = test_cells\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed196756", + "metadata": {}, + "outputs": [], + "source": [ + "all_cp = all_table_cp\n", + "test_cp = train_table_cp\n", + "test_cp = test_table_cp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6404451", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " all_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n", + " train_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n", + " test_table_cp.mean(['origin','destination'],skipna=True).values.item()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c60450c", + "metadata": {}, + "outputs": [], + "source": [ + "all_intensity_cp = coverage_probability(\n", + " prediction = current_data.get_sample('intensity'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95\n", + ")\n", + "train_intensity_cp = coverage_probability(\n", + " prediction = current_data.get_sample('intensity'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95,\n", + " cells = train_cells\n", + ")\n", + "test_intensity_cp = coverage_probability(\n", + " prediction = current_data.get_sample('intensity'),\n", + " ground_truth = ins.data.ground_truth_table,\n", + " region_mass = 0.95,\n", + " cells = test_cells\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30e55924", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " all_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n", + " train_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n", + " test_intensity_cp.mean(['origin','destination'],skipna=True).values.item()\n", + ")" ] }, {