Skip to content

Commit

Permalink
refactor: add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jungs1 committed Aug 7, 2024
1 parent 351b345 commit 021a063
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 63 deletions.
17 changes: 6 additions & 11 deletions src/mutahunter/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@
from mutahunter.core.db import MutationDatabase
from mutahunter.core.entities.config import MutationTestControllerConfig
from mutahunter.core.error_parser import extract_error_message
from mutahunter.core.exceptions import (
CoverageAnalysisError,
MutantKilledError,
MutantSurvivedError,
MutationTestingError,
ReportGenerationError,
UnexpectedTestResultError,
)
from mutahunter.core.exceptions import (CoverageAnalysisError,
MutantKilledError, MutantSurvivedError,
MutationTestingError,
ReportGenerationError,
UnexpectedTestResultError)
from mutahunter.core.git_handler import GitHandler
from mutahunter.core.io import FileOperationHandler
from mutahunter.core.llm_mutation_engine import LLMMutationEngine
from mutahunter.core.logger import logger
from mutahunter.core.prompts.mutant_generator import (
SYSTEM_PROMPT_MUTANT_ANALYSUS,
USER_PROMPT_MUTANT_ANALYSIS,
)
SYSTEM_PROMPT_MUTANT_ANALYSUS, USER_PROMPT_MUTANT_ANALYSIS)
from mutahunter.core.report import MutantReport
from mutahunter.core.router import LLMRouter
from mutahunter.core.runner import MutantTestRunner
Expand Down
2 changes: 1 addition & 1 deletion src/mutahunter/core/coverage_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class CoverageProcessor:
def __init__(self, coverage_type, code_coverage_report_path) -> None:
def __init__(self, coverage_type: str, code_coverage_report_path: str) -> None:
"""
Initializes the CoverageProcessor with the given configuration.
Expand Down
12 changes: 6 additions & 6 deletions src/mutahunter/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import sqlite3
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple


class DatabaseError(Exception):
pass


class MutationDatabase:
def __init__(self, db_path: str = "mutahunter.db"):
def __init__(self, db_path: str = "mutahunter.db") -> None:
self.db_path = db_path
self.conn = None
if os.path.exists(self.db_path):
Expand All @@ -23,7 +23,7 @@ def __init__(self, db_path: str = "mutahunter.db"):
self.create_tables()

@contextmanager
def get_connection(self):
def get_connection(self) -> Iterator[sqlite3.Connection]:
conn = sqlite3.connect(self.db_path)
try:
yield conn
Expand Down Expand Up @@ -75,7 +75,7 @@ def check_schema(self):

return True

def create_tables(self):
def create_tables(self) -> None:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.executescript(
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_file_version(self, file_path: str) -> Tuple[int, int, bool]:
conn.rollback()
raise DatabaseError(f"Error processing file version: {str(e)}")

def add_mutant(self, run_id: int, file_version_id: int, mutant_data: dict):
def add_mutant(self, run_id: int, file_version_id: int, mutant_data: dict) -> None:
with self.get_connection() as conn:
cursor = conn.cursor()
try:
Expand Down Expand Up @@ -454,7 +454,7 @@ def get_file_mutations(self, file_name: str) -> List[Dict[str, Any]]:
except sqlite3.Error as e:
raise DatabaseError(f"Error fetching file mutations: {str(e)}")

def get_mutant_summary(self, run_id) -> Dict[str, int]:
def get_mutant_summary(self, run_id: int) -> Dict[str, int]:
with self.get_connection() as conn:
cursor = conn.cursor()
try:
Expand Down
9 changes: 5 additions & 4 deletions src/mutahunter/core/llm_mutation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mutahunter.core.logger import logger
from mutahunter.core.prompts.factory import PromptFactory
from mutahunter.core.repomap import RepoMap
from mutahunter.core.router import LLMRouter

SYSTEM_YAML_FIX = """
Based on the error message, the YAML content provided is not in the correct format. Please ensure the YAML content is in the correct format and try again.
Expand All @@ -34,9 +35,9 @@ class LLMMutationEngine:

def __init__(
self,
model,
router,
):
model: str,
router: LLMRouter,
) -> None:
self.model = model
self.router = router
self.repo_map = RepoMap(model=self.model)
Expand Down Expand Up @@ -135,7 +136,7 @@ def fix_format(self, error: Exception, content: str) -> str:
)
return model_response

def _get_repo_map(self, cov_files) -> Optional[Dict[str, Any]]:
def _get_repo_map(self, cov_files: List[str]) -> Optional[Dict[str, Any]]:
return self.repo_map.get_repo_map(chat_files=[], other_files=cov_files)

def _add_line_numbers(self, src_code: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/mutahunter/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
warnings.filterwarnings("ignore", category=FutureWarning, module="tree_sitter")


def setup_logger(name):
def setup_logger(name: str) -> logging.Logger:
os.makedirs("logs/_latest/mutants", exist_ok=True)
# Create a custom format for your logs
log_format = "%(asctime)s %(levelname)s: %(message)s"
Expand Down
53 changes: 31 additions & 22 deletions src/mutahunter/core/repomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import Counter, defaultdict, namedtuple
from importlib import resources
from pathlib import Path
from typing import Any, Iterator, List, Optional, Set

import networkx as nx
from grep_ast import TreeContext, filename_to_lang
Expand All @@ -29,11 +30,11 @@ class RepoMap:

def __init__(
self,
model=None,
root=None,
map_tokens=1024,
max_context_window=None,
):
model: Optional[str] = None,
root: None = None,
map_tokens: int = 1024,
max_context_window: None = None,
) -> None:
if not root:
root = os.getcwd()
self.model = model
Expand All @@ -42,16 +43,20 @@ def __init__(
self.max_map_tokens = map_tokens
self.max_context_window = max_context_window

def token_count(self, string):
def token_count(self, string: str) -> int:
if not string:
return 0
return token_counter(
model=self.model, messages=[{"user": "role", "content": "hihi"}]
)

def get_repo_map(
self, chat_files, other_files, mentioned_fnames=None, mentioned_idents=None
):
self,
chat_files: List[Any],
other_files: List[str],
mentioned_fnames: None = None,
mentioned_idents: None = None,
) -> str:
if self.max_map_tokens <= 0:
return
if not other_files:
Expand Down Expand Up @@ -101,20 +106,20 @@ def get_repo_map(

return repo_content

def get_rel_fname(self, fname):
def get_rel_fname(self, fname: str) -> str:
return os.path.relpath(fname, self.root)

def split_path(self, path):
path = os.path.relpath(path, self.root)
return [path + ":"]

def get_mtime(self, fname):
def get_mtime(self, fname: str) -> float:
try:
return os.path.getmtime(fname)
except FileNotFoundError:
pass

def get_tags(self, fname, rel_fname):
def get_tags(self, fname: str, rel_fname: str) -> List[Tag]:
# Check if the file is in the cache and if the modification time has not changed
file_mtime = self.get_mtime(fname)
if file_mtime is None:
Expand All @@ -123,7 +128,7 @@ def get_tags(self, fname, rel_fname):
data = list(self.get_tags_raw(fname, rel_fname))
return data

def get_tags_raw(self, fname, rel_fname):
def get_tags_raw(self, fname: str, rel_fname: str) -> Iterator[Tag]:
lang = filename_to_lang(fname)
if not lang:
return
Expand Down Expand Up @@ -203,8 +208,12 @@ def get_tags_raw(self, fname, rel_fname):
)

def get_ranked_tags(
self, chat_fnames, other_fnames, mentioned_fnames, mentioned_idents
):
self,
chat_fnames: List[Any],
other_fnames: List[str],
mentioned_fnames: Set[Any],
mentioned_idents: Set[Any],
) -> List[Tag]:
defines = defaultdict(set)
references = defaultdict(list)
definitions = defaultdict(set)
Expand Down Expand Up @@ -325,12 +334,12 @@ def get_ranked_tags(

def get_ranked_tags_map(
self,
chat_fnames,
other_fnames=None,
max_map_tokens=None,
mentioned_fnames=None,
mentioned_idents=None,
):
chat_fnames: List[Any],
other_fnames: Optional[List[str]] = None,
max_map_tokens: Optional[int] = None,
mentioned_fnames: Optional[Set[Any]] = None,
mentioned_idents: Optional[Set[Any]] = None,
) -> str:
if not other_fnames:
other_fnames = list()
if not max_map_tokens:
Expand Down Expand Up @@ -372,7 +381,7 @@ def get_ranked_tags_map(

return best_tree

def render_tree(self, abs_fname, rel_fname, lois):
def render_tree(self, abs_fname: str, rel_fname: str, lois: List[int]) -> str:
key = (rel_fname, tuple(sorted(lois)))

with open(abs_fname, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -400,7 +409,7 @@ def render_tree(self, abs_fname, rel_fname, lois):
res = context.format()
return res

def to_tree(self, tags, chat_rel_fnames):
def to_tree(self, tags: List[Tag], chat_rel_fnames: List[Any]) -> str:
if not tags:
return ""

Expand Down
6 changes: 3 additions & 3 deletions src/mutahunter/core/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from importlib import resources
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from jinja2 import (Environment, FileSystemLoader, PackageLoader,
select_autoescape)
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(self, db: MutationDatabase) -> None:
assert self.template_env.get_template("report_template.html")
assert self.template_env.get_template("file_detail_template.html")

def generate_report(self, total_cost: float, line_rate: float, run_id) -> None:
def generate_report(self, total_cost: float, line_rate: float, run_id: int) -> None:
"""
Generates a comprehensive mutation testing report.
Expand Down Expand Up @@ -115,7 +115,7 @@ def _write_html_report(self, html_content: str, filename: str) -> None:
logger.info(f"HTML report generated: {filename}")

def _generate_summary_report(
self, data, total_cost: float, line_rate: float
self, data: Dict[str, Union[int, float]], total_cost: float, line_rate: float
) -> None:
"""
Generates a summary mutation testing report.
Expand Down
4 changes: 2 additions & 2 deletions src/mutahunter/core/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class LLMRouter:
def __init__(self, model: str, api_base: str = ""):
def __init__(self, model: str, api_base: str = "") -> None:
"""
Initialize the LLMRouter with a model and optional API base URL.
"""
Expand Down Expand Up @@ -56,7 +56,7 @@ def generate_response(
print(f"Error during response generation: {e}")
return "", 0, 0

def _validate_prompt(self, prompt: dict):
def _validate_prompt(self, prompt: dict) -> None:
"""
Validate that the prompt contains the required keys.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/mutahunter/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class MutantTestRunner:
def __init__(self, test_command):
def __init__(self, test_command: str) -> None:
self.test_command = test_command

def dry_run(self) -> None:
Expand Down Expand Up @@ -56,13 +56,13 @@ def run_test(self, params: dict) -> subprocess.CompletedProcess:
self.revert_file(module_path, backup_path)
return result

def replace_file(self, original, replacement, backup):
def replace_file(self, original: str, replacement: str, backup: str) -> None:
"""Backup original file and replace it with the replacement file."""
if not os.path.exists(backup):
shutil.copy2(original, backup)
shutil.copy2(replacement, original)

def revert_file(self, original, backup):
def revert_file(self, original: str, backup: str) -> None:
"""Revert the file to the original using the backup."""
if os.path.exists(backup):
shutil.copy2(backup, original)
Expand Down
15 changes: 5 additions & 10 deletions src/mutahunter/core/unittest_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,15 @@
from mutahunter.core.controller import MutationTestController
from mutahunter.core.coverage_processor import CoverageProcessor
from mutahunter.core.db import MutationDatabase
from mutahunter.core.entities.config import (
MutationTestControllerConfig,
UnittestGeneratorConfig,
)
from mutahunter.core.entities.config import (MutationTestControllerConfig,
UnittestGeneratorConfig)
from mutahunter.core.error_parser import extract_error_message
from mutahunter.core.logger import logger
from mutahunter.core.prompts.unittest_generator import (
FAILED_TESTS_TEXT,
LINE_COV_UNITTEST_GENERATOR_USER_PROMPT,
MUTATION_COV_UNITTEST_GENERATOR_USER_PROMPT,
MUTATION_WEAK_TESTS_TEXT,
)
FAILED_TESTS_TEXT, LINE_COV_UNITTEST_GENERATOR_USER_PROMPT,
MUTATION_COV_UNITTEST_GENERATOR_USER_PROMPT, MUTATION_WEAK_TESTS_TEXT)
from mutahunter.core.router import LLMRouter
from mutahunter.core.runner import MutantTestRunner
from mutahunter.core.logger import logger

SYSTEM_YAML_FIX = """
Based on the error message, the YAML content provided is not in the correct format. Please ensure the YAML content is in the correct format and try again.
Expand Down

0 comments on commit 021a063

Please sign in to comment.