diff --git a/algobattle/problem.py b/algobattle/problem.py index 4b340e03..0fd3bd9b 100644 --- a/algobattle/problem.py +++ b/algobattle/problem.py @@ -209,7 +209,9 @@ def default_score( try: return max(0, min(1, sol_score / gen_score)) except ZeroDivisionError: - return float(sol_score < 0) + # if generator scored 0 then the solver will have achieved an equal or better score + # i.e. the Fight's score is simply 1 regardless of its solution score. + return 1 else: return max(0, min(1, solution.score(instance, Role.solver))) diff --git a/tests/test_util.py b/tests/test_util.py index 45b47c00..d6f42692 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,23 @@ """Tests for all util functions.""" +from math import inf import unittest from algobattle.battle import Battle, Iterated, Averaged +from algobattle.problem import InstanceModel, SolutionModel, default_score +from algobattle.util import Role + + +class DummyInstance(InstanceModel): # noqa: D101 + @property + def size(self) -> int: + return 1 + + +class DummySolution(SolutionModel[DummyInstance]): # noqa: D101 + val: float + + def score(self, instance: DummyInstance, role: Role) -> float: + return self.val class Utiltests(unittest.TestCase): @@ -12,6 +28,35 @@ def test_default_battle_types(self): self.assertEqual(Battle.all()["Iterated"], Iterated) self.assertEqual(Battle.all()["Averaged"], Averaged) + def test_default_fight_score(self): + """Tests the default fight scoring function.""" + instance = DummyInstance() + scores = [ + (0, 0, 1), + (0, 2, 1), + (0, 4, 1), + (0, inf, 1), + (2, 0, 0), + (2, 2, 1), + (2, 4, 1), + (2, inf, 1), + (4, 0, 0), + (4, 2, 0.5), + (4, 4, 1), + (4, inf, 1), + (inf, 0, 0), + (inf, 2, 0), + (inf, 4, 0), + (inf, inf, 1), + ] + for gen, sol, score in scores: + self.assertEqual( + default_score( + instance, generator_solution=DummySolution(val=gen), solver_solution=DummySolution(val=sol) + ), + score, + ) + if __name__ == "__main__": unittest.main()