diff --git a/src/nomad_tools/entry_githubrunner/entry_githubrunner.py b/src/nomad_tools/entry_githubrunner/entry_githubrunner.py index 648f5b7..abe4086 100644 --- a/src/nomad_tools/entry_githubrunner/entry_githubrunner.py +++ b/src/nomad_tools/entry_githubrunner/entry_githubrunner.py @@ -17,6 +17,7 @@ import time from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field +from http.client import RemoteDisconnected from pathlib import Path from typing import ( Any, @@ -491,7 +492,8 @@ def save(self): f"Github cache saved {len(self.data)} entires to {github_cachefile()}" ) - def prepare(self, url: str, headers: Dict[str, str]): + def prepare(self, url: str) -> Dict[str, str]: + headers = {} with self.lock: cached = self.data.get(url) if cached: @@ -499,6 +501,7 @@ def prepare(self, url: str, headers: Dict[str, str]): headers["if-none-match"] = cached.etag_or_last_modified else: headers["if-modified-since"] = cached.etag_or_last_modified + return headers def handle(self, response: requests.Response) -> Optional[Value]: with self.lock: @@ -523,51 +526,60 @@ def handle(self, response: requests.Response) -> Optional[Value]: return None -def gh_get(url: str, key: str = "") -> Any: - """Execute query to github +class GithubConnection: + def __init__(self, config: Config): + self.s = requests.Session() + self.s.headers.update({"Accept": "application/vnd.github+json"}) + if config.github.token: + self.s.headers.update({"Authorization": "Bearer " + config.github.token}) - @param key if set, means the output is paginated - """ - headers = { - "Accept": "application/vnd.github+json", - **( - {"Authorization": "Bearer " + CONFIG.github.token} - if CONFIG.github.token - else {} - ), - } - ret: list[dict] = [] - while True: - # Prepare the request adding headers from Github cache. - GITHUB_CACHE.prepare(url, headers) - response = requests.get(url, headers=headers) - headerslog = {**headers, "Authorization": "***"} - log.debug(f"{url} {headerslog} {response}") - try: - response.raise_for_status() - except Exception: - raise Exception(f"{response.url}\n{response.text}") - # If the result is found in github cache, use it, otherwise extract it. - cached: Optional[GithubCache.Value] = GITHUB_CACHE.handle(response) - data: Any = cached.json if cached else response.json() - links: dict = cached.links if cached else response.links - # if there are no links, this is not a paginated url. - if not links: - return data[key] if key else data - # Append the data to ret. Use key if given. - add = data if isinstance(data, list) else data[key] - try: - assert isinstance(add, list) - ret.extend(add) - except KeyError: - log.exception(f"key={key} data={data}") - raise - # if no next, this is the end. - next = links.get("next") - if not next: - break - url = next["url"] - return ret + def get(self, url: str, key: str = ""): + """Execute query to github + + @param key if set, means the output is paginated + """ + ret: list[dict] = [] + trynum = 0 + while True: + # Prepare the request adding headers from Github cache. + headers = GITHUB_CACHE.prepare(url) + response = self.s.get(url, headers=headers) + log.debug(f"{url} {response}") + try: + try: + response.raise_for_status() + except RemoteDisconnected: + trynum += 1 + if trynum >= 3: + raise + continue + except Exception: + raise Exception(f"{response.url}\n{response.text}") + # If the result is found in github cache, use it, otherwise extract it. + cached: Optional[GithubCache.Value] = GITHUB_CACHE.handle(response) + data: Any = cached.json if cached else response.json() + links: dict = cached.links if cached else response.links + # if there are no links, this is not a paginated url. + if not links: + return data[key] if key else data + # Append the data to ret. Use key if given. + add = data if isinstance(data, list) else data[key] + try: + assert isinstance(add, list) + ret.extend(add) + except KeyError: + log.exception(f"key={key} data={data}") + raise + # if no next, this is the end. + next = links.get("next") + if not next: + break + url = next["url"] + return ret + + +def gh_get(url: str, key: str = "") -> Any: + return GH.get(url, key) @functools.lru_cache() @@ -1118,6 +1130,8 @@ def cli(args: Args): CONFIG = Config(tmp) global PARSEDCONFIG PARSEDCONFIG = ParsedConfig(CONFIG) + global GH + GH = GithubConnection(CONFIG) global TEMPLATE TEMPLATE = jinja2.Environment( loader=jinja2.BaseLoader(), lstrip_blocks=True, trim_blocks=True