From b6377a9474c316f3866a07c8b8a6f58f13647f3c Mon Sep 17 00:00:00 2001 From: vsedov Date: Sat, 7 Dec 2024 16:08:16 +0000 Subject: [PATCH] day 7 DP solution --- src/aoc/aoc2024/day_07.py | 101 +++++++++++++++++++------------------- src/aoc/aoc_helper.py | 6 +-- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/src/aoc/aoc2024/day_07.py b/src/aoc/aoc2024/day_07.py index 8fc5692..dc88bb8 100644 --- a/src/aoc/aoc2024/day_07.py +++ b/src/aoc/aoc2024/day_07.py @@ -1,10 +1,9 @@ -from functools import partial -from itertools import product -from typing import Callable, Set, Tuple +from typing import Tuple import numpy as np from numba import njit, prange -from numba.typed import List +from numba.typed import Dict +from numba.typed import List as NumbaList from src.aoc.aoc2024 import YEAR, get_day from src.aoc.aoc_helper import Aoc @@ -27,70 +26,70 @@ def concat_nums(a: int, b: int) -> int: @njit(cache=True) -def evaluate(nums: np.ndarray, ops: np.ndarray, with_concat: bool = False) -> int: - result = int(nums[0]) - for i in range(len(ops)): - val = int(nums[i + 1]) - if ops[i] == 0: - result += val - elif ops[i] == 1: - result *= val - elif with_concat and ops[i] == 2: - result = concat_nums(result, val) - return result - - -@njit(cache=True, parallel=True) -def check_ops_batch( - nums: np.ndarray, all_ops: np.ndarray, target: int, with_concat: bool -) -> np.ndarray: - results = np.zeros(len(all_ops), dtype=np.bool_) - for i in prange(len(all_ops)): - results[i] = evaluate(nums, all_ops[i], with_concat) == target - return results - - -def solve(txt: str, eval_func: Callable, n_ops: int = 2) -> int: - def parse_line(line: str) -> Tuple[int, np.ndarray]: - target, nums_str = line.split(": ") - return int(target), np.array([int(x) for x in nums_str.split()], dtype=np.int64) - +def dp_line(nums: np.ndarray, target: int, with_concat: bool) -> bool: + prev_reachable = Dict.empty(key_type=np.int64, value_type=np.bool_) + prev_reachable[nums[0]] = True + for i in range(1, len(nums)): + current_num = nums[i] + next_reachable = Dict.empty(key_type=np.int64, value_type=np.bool_) + for val in prev_reachable.keys(): + new_val = val + current_num + if abs(new_val) < 10**12: + next_reachable[new_val] = True + new_val = val * current_num + if abs(new_val) < 10**12: + next_reachable[new_val] = True + if with_concat: + new_val = concat_nums(val, current_num) + if abs(new_val) < 10**12: + next_reachable[new_val] = True + prev_reachable = next_reachable + return target in prev_reachable + + +@njit(cache=True, parallel=True, fastmath=True) +def dp_solve_all( + targets: np.ndarray, all_nums: NumbaList[np.ndarray], with_concat: bool +) -> int: total = 0 - ops_cache = {} + for i in prange(len(targets)): + if dp_line(all_nums[i], targets[i], with_concat): + total += targets[i] + return total - for target, nums in map(parse_line, txt.splitlines()): - n = len(nums) - 1 - if n not in ops_cache: - ops = list(product(range(n_ops), repeat=n)) - ops_cache[n] = np.array(ops, dtype=np.int64) +def solve_with_dp(txt: str, with_concat: bool = False) -> int: + lines = txt.strip().split("\n") + targets_list = [] + nums_list = [] + for line in lines: + target_str, nums_str = line.split(": ") + target = int(target_str) + nums = np.array([int(x) for x in nums_str.split()], dtype=np.int64) + targets_list.append(target) + nums_list.append(nums) - results = check_ops_batch( - nums, ops_cache[n], target, eval_func.keywords["with_concat"] - ) - if np.any(results): - total += target + targets_arr = np.array(targets_list, dtype=np.int64) + typed_nums_list = NumbaList() + for arr in nums_list: + typed_nums_list.append(arr) - return total + return dp_solve_all(targets_arr, typed_nums_list, with_concat) def part_a(txt: str) -> int: - return solve(txt, partial(evaluate, with_concat=False), n_ops=2) + return solve_with_dp(txt, with_concat=False) def part_b(txt: str) -> int: - return solve(txt, partial(evaluate, with_concat=True), n_ops=3) + return solve_with_dp(txt, with_concat=True) def main(txt: str) -> None: - dummy_nums = np.array([1, 2], dtype=np.int64) - dummy_ops = np.array([[0]], dtype=np.int64) - check_ops_batch(dummy_nums, dummy_ops, 3, False) - print("part_a: ", part_a(txt)) print("part_b: ", part_b(txt)) if __name__ == "__main__": aoc = Aoc(day=get_day(), years=YEAR) - aoc.run(main, submit=False, part="both", readme_update=True, profile=False) + aoc.run(main, submit=False, part="both", readme_update=True, profile=True) diff --git a/src/aoc/aoc_helper.py b/src/aoc/aoc_helper.py index e207206..d6368d0 100644 --- a/src/aoc/aoc_helper.py +++ b/src/aoc/aoc_helper.py @@ -177,8 +177,8 @@ def analyze_performance( runs: int = 10, with_profile: bool = False, analyze_complexity: bool = True, - warmups: int = 1, - repeats: int = 5, + warmups: int = 10, + repeats: int = 10, ) -> PerformanceMetrics: """Run comprehensive performance analysis""" times = [] @@ -322,7 +322,7 @@ def run( submit: bool = False, part: Union[None, str] = None, readme_update: bool = False, - profile: bool = True, + profile: bool = False, analyze_complexity: bool = False, warmups: int = 10, repeats: int = 10,