diff --git a/sdb_gui.py b/sdb_gui.py index 36fcf6c..8b57b8b 100644 --- a/sdb_gui.py +++ b/sdb_gui.py @@ -39,6 +39,7 @@ import numpy as np import geopandas as gpd import rasterio as rio +import matplotlib.pyplot as plt from pathlib import Path import sys, os import datetime @@ -66,7 +67,7 @@ ############################################################################### ############################################################################### -SDB_GUI_VERSION = '3.3.2' +SDB_GUI_VERSION = '3.4.1' def resource_path(relative_path): '''Get the absolute path to the resource, works for dev and for PyInstaller''' @@ -943,6 +944,18 @@ def timeCounting(self, time_text): self.completeDialog() + def scatter_plotter(self, x, y, plot_color='royalblue', line_color='r', title='Scatter Plot'): + fig, ax = plt.subplots(figsize=(5, 5)) + ax.scatter(x, y, marker='.', color=plot_color, facecolors='none') + min_val, max_val = round(np.nanmin(x)), round(np.nanmax(x)) + ax.plot([min_val, max_val], [min_val, max_val], color=line_color) + ax.set_xlabel('True Depth') + ax.set_ylabel('Predicted Depth') + ax.set_title(title) + + return fig, ax + + def results(self, result_dict): ''' Recieve processing results and filter the predicted value to depth @@ -1132,6 +1145,9 @@ def saveOptionWindow(self): locLabel = QLabel('Location:') self.savelocList = QTextBrowser() + self.scatterPlotCheckBox = QCheckBox('Save Scatter Plot') + self.scatterPlotCheckBox.setChecked(False) + self.trainTestDataCheckBox = QCheckBox('Save Training and Testing Data in') self.trainTestDataCheckBox.setChecked(False) @@ -1165,14 +1181,16 @@ def saveOptionWindow(self): grid.addWidget(locLabel, 4, 1, 1, 4) grid.addWidget(self.savelocList, 5, 1, 1, 4) - grid.addWidget(self.trainTestDataCheckBox, 6, 1, 1, 2) - grid.addWidget(self.trainTestFormatCB, 6, 3, 1, 1) - grid.addWidget(trainTestLabel, 6, 4, 1, 1) + grid.addWidget(self.scatterPlotCheckBox, 6, 1, 1, 2) + + grid.addWidget(self.trainTestDataCheckBox, 7, 1, 1, 2) + grid.addWidget(self.trainTestFormatCB, 7, 3, 1, 1) + grid.addWidget(trainTestLabel, 7, 4, 1, 1) - grid.addWidget(self.saveDEMCheckBox, 7, 1, 1, 1) - grid.addWidget(self.reportCheckBox, 7, 2, 1, 1) - grid.addWidget(saveButton, 7, 3, 1, 1) - grid.addWidget(cancelButton, 7, 4, 1, 1) + grid.addWidget(self.saveDEMCheckBox, 8, 1, 1, 1) + grid.addWidget(self.reportCheckBox, 8, 2, 1, 1) + grid.addWidget(saveButton, 8, 3, 1, 1) + grid.addWidget(cancelButton, 8, 4, 1, 1) self.saveOptionDialog.setLayout(grid) @@ -1276,8 +1294,29 @@ def saveAction(self): 'Test Data output:\tNot Saved\n' ) + if self.scatterPlotCheckBox.isChecked() == True: + scatter_plot_loc = ( + os.path.splitext(self.savelocList.toPlainText())[0] + + '_scatter_plot.png' + ) + scatter_plot_fig, scatter_plot_ax = self.scatter_plotter( + x=test_data_df.z, + y=test_data_df.z_validate + ) + scatter_plot_fig.savefig(scatter_plot_loc) + + scatter_plot_size = os.path.getsize(scatter_plot_loc) + + print_scatter_plot_info = ( + 'Scatter Plot:\t' + scatter_plot_loc + ' (' + + str(round(scatter_plot_size / 2**10, 2)) + ' KB)\n' + ) + elif self.scatterPlotCheckBox.isChecked() == False: + print_scatter_plot_info = 'Scatter Plot:\tNotSaved\n' + self.resultText.append(print_dem_info) self.resultText.append(print_train_test_info) + self.resultText.append(print_scatter_plot_info) if self.reportCheckBox.isChecked() == True: report_save_loc = ( @@ -1289,7 +1328,8 @@ def saveAction(self): report.write( print_result_info + print_dem_info + - print_train_test_info + print_train_test_info + + print_scatter_plot_info ) except: self.saveOptionDialog.close()