Skip to content

Commit

Permalink
Added some type hinting.
Browse files Browse the repository at this point in the history
Added iterative sampling of "missing" seafloor ages at slabs.

Ignored warnings for some plotting functions.
  • Loading branch information
thomas-schouten committed Nov 20, 2024
1 parent 5a8abaa commit efe7a03
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 300 deletions.
Binary file modified plato/__pycache__/optimisation.cpython-310.pyc
Binary file not shown.
Binary file modified plato/__pycache__/plates.cpython-310.pyc
Binary file not shown.
Binary file modified plato/__pycache__/plot.cpython-310.pyc
Binary file not shown.
Binary file modified plato/__pycache__/points.cpython-310.pyc
Binary file not shown.
Binary file modified plato/__pycache__/slabs.cpython-310.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions plato/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def generate_velocity_grid(
cases: Optional[str],
point_data: Dict[str, _numpy.ndarray],
components: Union[str, List[str]] = None,
PROGRESS_BAR: bool = True,
):
"""
Function to generate a velocity grid.
Expand Down
301 changes: 157 additions & 144 deletions plato/optimisation.py

Large diffs are not rendered by default.

319 changes: 175 additions & 144 deletions plato/plate_torques.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion plato/plates.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
# Check if any DataFrames were loaded
if len(available_cases) > 0:
# Copy all DataFrames from the available case
for entries in entry:
for entry in entries:
if entry not in available_cases:
self.resolved_geometries[_age][entry] = self.resolved_geometries[_age][available_cases[0]].copy()
else:
Expand Down Expand Up @@ -286,6 +286,10 @@ def calculate_rms_velocity(
# Select points belonging to plate
mask = points.data[_age][key].plateID == _plateID

if mask.sum() == 0:
logging.warning(f"No points found for plate {_plateID} for case {key} at {_age} Ma")
continue

# Calculate RMS velocity for plate
rms_velocity = utils_calc.compute_rms_velocity(
points.data[_age][key].segment_length_lat.values[mask],
Expand Down
14 changes: 6 additions & 8 deletions plato/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,14 +1199,12 @@ def plot_vectors(
"""
# Normalise vectors, if necessary
if normalise_vectors and vector_mag is not None:
# Normalise by dividing by the magnitude of the vectors
# Multiply by 10 to make the vectors more visible
vector_lon = vector_lon / vector_mag * 10
vector_lat = vector_lat / vector_mag * 10

print(_numpy.nanmean(_numpy.sqrt(vector_lon**2 + vector_lat**2)))

print("Normalised vectors!")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Normalise by dividing by the magnitude of the vectors
# Multiply by 10 to make the vectors more visible
vector_lon = vector_lon / vector_mag * 10
vector_lat = vector_lat / vector_mag * 10

# Plot vectors
# Ignore annoying warnings
Expand Down
12 changes: 12 additions & 0 deletions plato/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ def calculate_gpe_force(
# Select points
if plateIDs is not None:
_data = _data[_data.plateID.isin(_plateIDs)]

if _data.empty:
logging.info(f"No valid points found for case {key} at age {_age} Ma.")
continue

# Calculate GPE force
computed_data = utils_calc.compute_GPE_force(
Expand Down Expand Up @@ -482,6 +486,10 @@ def calculate_mantle_drag_force(
# Select points
if plateIDs is not None:
_data = _data[_data.plateID.isin(_plateIDs)]

if _data.empty:
logging.info(f"No valid points found for case {key} at age {_age} Ma.")
continue

# Calculate GPE force
_data = utils_calc.compute_mantle_drag_force(
Expand Down Expand Up @@ -554,6 +562,10 @@ def calculate_residual_force(

# Make mask for plate
mask = self.data[_age][_case]["plateID"] == _plateID

if mask.sum() == 0:
logging.info(f"No valid points found for age {_age}, case {_case}, and plateID {_plateID}.")
continue

# Compute velocities
forces = utils_calc.compute_residual_force(
Expand Down
82 changes: 79 additions & 3 deletions plato/slabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def sample_slab_seafloor_ages(
plate = "lower",
vars = ["seafloor_age"],
cols = ["slab_seafloor_age"],
ITERATIONS = True,
PROGRESS_BAR = PROGRESS_BAR,
)

Expand Down Expand Up @@ -457,6 +458,7 @@ def sample_grid(
plate: Optional[str] = "lower",
vars: Optional[Union[str, List[str]]] = ["seafloor_age"],
cols = ["slab_seafloor_age"],
ITERATIONS: bool = False,
PROGRESS_BAR: bool = True,
):
"""
Expand Down Expand Up @@ -496,7 +498,7 @@ def sample_grid(
# Select points
_data = self.data[_age][key]
if plateIDs is not None:
_data = _data[_data.plateID.isin(_plateIDs)]
_data = _data[_data[f"{plate}_plateID"].isin(_plateIDs)]

# Determine the appropriate grid
_grid = None
Expand Down Expand Up @@ -560,6 +562,50 @@ def sample_grid(
)
accumulated_data += sampled_data

# This is to iteratively check if the sampling distance is to be adjusted
# This is especially a problem with on the western active margin of South America in the Earthbyte reconstructions, where the continental mask does not account of motion of the trench
if ITERATIONS:
if plate == "lower":
current_sampling_distance = -30
iter_num = 20
if plate == "upper":
current_sampling_distance = +100
iter_num = 4

for i in range(iter_num):
# Mask data
mask = _numpy.isnan(accumulated_data)

# Set masked data to zero to avoid errors
accumulated_data[mask] = 0

# Calculate new sampling points
sampling_lat, sampling_lon = utils_calc.project_points(
_data.loc[_data.index[mask], f"{type}_sampling_lat"],
_data.loc[_data.index[mask], f"{type}_sampling_lat"],
_data.loc[_data.index[mask], "trench_normal_azimuth"],
current_sampling_distance,
)

# Sample grid at points for each variable
for _var in _vars:
sampled_data[mask] = utils_calc.sample_grid(
sampling_lat,
sampling_lon,
_grid[_var],
)
accumulated_data[mask] += sampled_data[mask]

# Define new sampling distance
if plate == "lower":
if i <= 1:
current_sampling_distance -= 30
elif i % 2 == 0:
current_sampling_distance -= 30 * (2 ** (i // 2))

if plate == "upper":
current_sampling_distance += 100

# Enter sampled data back into the DataFrame
self.data[_age][key].loc[_data.index, _col] = accumulated_data

Expand Down Expand Up @@ -687,6 +733,10 @@ def calculate_slab_pull_force(
# Select points
if plateIDs is not None:
_data = _data[_data.lower_plateID.isin(_plateIDs)]

if _data.empty:
logging.warning(f"No valid points found for case {key} Ma.")
continue

# Calculate slab pull force
computed_data1 = utils_calc.compute_slab_pull_force(
Expand Down Expand Up @@ -762,6 +812,10 @@ def calculate_slab_bend_force(
# Select points
if plateIDs is not None:
_data = _data[_data.lower_plateID.isin(_plateIDs)]

if _data.empty:
logging.warning(f"No valid points found for case {key} Ma.")
continue

# Calculate slab pull force
_data = utils_calc.compute_slab_bend_force(
Expand Down Expand Up @@ -832,6 +886,10 @@ def calculate_residual_force(

# Make mask for plate
mask = self.data[_age][_case]["lower_plateID"] == _plateID

if mask.sum() == 0:
logging.info(f"No valid points found for age {_age}, case {_case}, and plateID {_plateID}.")
continue

# Compute velocities
forces = utils_calc.compute_residual_force(
Expand Down Expand Up @@ -951,8 +1009,17 @@ def save(
):
# Loop through cases
for _case in _cases:
# Select data
_data = self.data[_age][_case]

# Subselect data, if plateIDs are provided
if plateIDs is not None:
_plateIDs = utils_data.select_plateIDs(plateIDs, _data.lower_plateID.unique())
_data = _data[_data.lower_plateID.isin(_plateIDs)]

# Save data
utils_data.DataFrame_to_parquet(
self.data[_age][_case],
_data,
"Slabs",
self.settings.name,
_age,
Expand Down Expand Up @@ -993,8 +1060,17 @@ def export(
):
# Loop through cases
for _case in _cases:
# Select data
_data = self.data[_age][_case]

# Subselect data, if plateIDs are provided
if plateIDs is not None:
_plateIDs = utils_data.select_plateIDs(plateIDs, _data.lower_plateID.unique())
_data = _data[_data.lower_plateID.isin(_plateIDs)]

# Export data
utils_data.DataFrame_to_csv(
self.data[_age][_case],
_data,
"Slabs",
self.settings.name,
_age,
Expand Down

0 comments on commit efe7a03

Please sign in to comment.