Skip to content

Commit

Permalink
day 7 DP solution
Browse files Browse the repository at this point in the history
  • Loading branch information
vsedov committed Dec 7, 2024
1 parent c4d74d6 commit b6377a9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 54 deletions.
101 changes: 50 additions & 51 deletions src/aoc/aoc2024/day_07.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from functools import partial
from itertools import product
from typing import Callable, Set, Tuple
from typing import Tuple

import numpy as np
from numba import njit, prange
from numba.typed import List
from numba.typed import Dict
from numba.typed import List as NumbaList

from src.aoc.aoc2024 import YEAR, get_day
from src.aoc.aoc_helper import Aoc
Expand All @@ -27,70 +26,70 @@ def concat_nums(a: int, b: int) -> int:


@njit(cache=True)
def evaluate(nums: np.ndarray, ops: np.ndarray, with_concat: bool = False) -> int:
result = int(nums[0])
for i in range(len(ops)):
val = int(nums[i + 1])
if ops[i] == 0:
result += val
elif ops[i] == 1:
result *= val
elif with_concat and ops[i] == 2:
result = concat_nums(result, val)
return result


@njit(cache=True, parallel=True)
def check_ops_batch(
nums: np.ndarray, all_ops: np.ndarray, target: int, with_concat: bool
) -> np.ndarray:
results = np.zeros(len(all_ops), dtype=np.bool_)
for i in prange(len(all_ops)):
results[i] = evaluate(nums, all_ops[i], with_concat) == target
return results


def solve(txt: str, eval_func: Callable, n_ops: int = 2) -> int:
def parse_line(line: str) -> Tuple[int, np.ndarray]:
target, nums_str = line.split(": ")
return int(target), np.array([int(x) for x in nums_str.split()], dtype=np.int64)

def dp_line(nums: np.ndarray, target: int, with_concat: bool) -> bool:
prev_reachable = Dict.empty(key_type=np.int64, value_type=np.bool_)
prev_reachable[nums[0]] = True
for i in range(1, len(nums)):
current_num = nums[i]
next_reachable = Dict.empty(key_type=np.int64, value_type=np.bool_)
for val in prev_reachable.keys():
new_val = val + current_num
if abs(new_val) < 10**12:
next_reachable[new_val] = True
new_val = val * current_num
if abs(new_val) < 10**12:
next_reachable[new_val] = True
if with_concat:
new_val = concat_nums(val, current_num)
if abs(new_val) < 10**12:
next_reachable[new_val] = True
prev_reachable = next_reachable
return target in prev_reachable


@njit(cache=True, parallel=True, fastmath=True)
def dp_solve_all(
targets: np.ndarray, all_nums: NumbaList[np.ndarray], with_concat: bool
) -> int:
total = 0
ops_cache = {}
for i in prange(len(targets)):
if dp_line(all_nums[i], targets[i], with_concat):
total += targets[i]
return total

for target, nums in map(parse_line, txt.splitlines()):
n = len(nums) - 1

if n not in ops_cache:
ops = list(product(range(n_ops), repeat=n))
ops_cache[n] = np.array(ops, dtype=np.int64)
def solve_with_dp(txt: str, with_concat: bool = False) -> int:
lines = txt.strip().split("\n")
targets_list = []
nums_list = []
for line in lines:
target_str, nums_str = line.split(": ")
target = int(target_str)
nums = np.array([int(x) for x in nums_str.split()], dtype=np.int64)
targets_list.append(target)
nums_list.append(nums)

results = check_ops_batch(
nums, ops_cache[n], target, eval_func.keywords["with_concat"]
)
if np.any(results):
total += target
targets_arr = np.array(targets_list, dtype=np.int64)
typed_nums_list = NumbaList()
for arr in nums_list:
typed_nums_list.append(arr)

return total
return dp_solve_all(targets_arr, typed_nums_list, with_concat)


def part_a(txt: str) -> int:
return solve(txt, partial(evaluate, with_concat=False), n_ops=2)
return solve_with_dp(txt, with_concat=False)


def part_b(txt: str) -> int:
return solve(txt, partial(evaluate, with_concat=True), n_ops=3)
return solve_with_dp(txt, with_concat=True)


def main(txt: str) -> None:
dummy_nums = np.array([1, 2], dtype=np.int64)
dummy_ops = np.array([[0]], dtype=np.int64)
check_ops_batch(dummy_nums, dummy_ops, 3, False)

print("part_a: ", part_a(txt))
print("part_b: ", part_b(txt))


if __name__ == "__main__":
aoc = Aoc(day=get_day(), years=YEAR)
aoc.run(main, submit=False, part="both", readme_update=True, profile=False)
aoc.run(main, submit=False, part="both", readme_update=True, profile=True)
6 changes: 3 additions & 3 deletions src/aoc/aoc_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def analyze_performance(
runs: int = 10,
with_profile: bool = False,
analyze_complexity: bool = True,
warmups: int = 1,
repeats: int = 5,
warmups: int = 10,
repeats: int = 10,
) -> PerformanceMetrics:
"""Run comprehensive performance analysis"""
times = []
Expand Down Expand Up @@ -322,7 +322,7 @@ def run(
submit: bool = False,
part: Union[None, str] = None,
readme_update: bool = False,
profile: bool = True,
profile: bool = False,
analyze_complexity: bool = False,
warmups: int = 10,
repeats: int = 10,
Expand Down

0 comments on commit b6377a9

Please sign in to comment.