Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changing evml imports to mlguess and some formatting updates #37

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ptype/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CSVLogger,
EarlyStopping,
)
from evml.keras.models import calc_prob_uncertainty
from mlguess.keras.models import calc_prob_uncertainty
from tensorflow.python.keras.callbacks import ReduceLROnPlateau
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from hagelslag.evaluation.ProbabilityMetrics import DistributedROC
Expand Down
4 changes: 2 additions & 2 deletions ptype/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

from imblearn.under_sampling import RandomUnderSampler
from imblearn.tensorflow import balanced_batch_generator
from evml.keras.losses import DirichletEvidentialLoss
from evml.keras.callbacks import ReportEpoch
from mlguess.keras.losses import DirichletEvidentialLoss
from mlguess.keras.callbacks import ReportEpoch


logger = logging.getLogger(__name__)
Expand Down
1 change: 0 additions & 1 deletion ptype/qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import xarray as xr
import metpy.calc
from metpy.units import units
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
import cartopy.io.shapereader as shpreader
Expand Down
1 change: 0 additions & 1 deletion ptype/reliability.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import numpy as np
import matplotlib.pyplot as plt

Expand Down
7 changes: 4 additions & 3 deletions ptype/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
import tensorflow as tf
import torch


def seed_everything(seed=1234):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()


def torch_seed_everything(seed=1234):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.deterministic = True
16 changes: 9 additions & 7 deletions ptype/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from cartopy import crs as ccrs
from cartopy import feature as cfeature
import imageio
from PIL import Image
from pathlib import Path
from datetime import datetime, timedelta
from datetime import datetime
import xarray as xr
from scipy.ndimage.filters import gaussian_filter

Expand All @@ -34,14 +33,17 @@
"okla":[39.0, 31.0, -90.0, -106.0]}

# colors = {0:'lime', 1:'darkturquoise', 2:'red', 3:'black'}
colors = {0:'lime', 1:'dodgerblue', 2:'red', 3:'black'}
colors = {0: 'lime', 1: 'dodgerblue', 2: 'red', 3: 'black'}
datapath = "/glade/p/cisl/aiml/ai2es/winter_ptypes/precip_rap/"

def ptype_map(datatype, starttime, endtime, gifname, imgsavepath="gif_images", gifsavepath="gifs", coords="na", duration=0.5):

def ptype_map(datatype, starttime, endtime, gifname,
imgsavepath="gif_images", gifsavepath="gifs", coords="na",
duration=0.5):
"""
Create and save GIF of P-Type data over specific CONUS region and time range.
:param datatype:

:param datatype:
:param starttime:
:param endtime:
:param gifname:
Expand All @@ -55,7 +57,7 @@ def ptype_map(datatype, starttime, endtime, gifname, imgsavepath="gif_images", g
enddate = datetime.strptime(endtime, "%Y%m%d %H:%M:%S")
time_range = pd.date_range(startdate, enddate, freq="h").strftime("%Y%m%d %H:%M:%S")
coords = coord_dict[coords]

# Account for differences between mPING and ASOS.
if datatype == "mping":
if enddate >= datetime.strptime("20180101", "%Y%m%d"):
Expand Down
Loading