Skip to content

Commit

Permalink
githubrunner: make persistent connection to github and reconnect when…
Browse files Browse the repository at this point in the history
… it closes
  • Loading branch information
Kamilcuk committed Oct 19, 2024
1 parent 79a1db6 commit 27f4634
Showing 1 changed file with 59 additions and 45 deletions.
104 changes: 59 additions & 45 deletions src/nomad_tools/entry_githubrunner/entry_githubrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from http.client import RemoteDisconnected
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -491,14 +492,16 @@ def save(self):
f"Github cache saved {len(self.data)} entires to {github_cachefile()}"
)

def prepare(self, url: str, headers: Dict[str, str]):
def prepare(self, url: str) -> Dict[str, str]:
headers = {}
with self.lock:
cached = self.data.get(url)
if cached:
if cached.is_etag:
headers["if-none-match"] = cached.etag_or_last_modified
else:
headers["if-modified-since"] = cached.etag_or_last_modified
return headers

def handle(self, response: requests.Response) -> Optional[Value]:
with self.lock:
Expand All @@ -523,51 +526,60 @@ def handle(self, response: requests.Response) -> Optional[Value]:
return None


def gh_get(url: str, key: str = "") -> Any:
"""Execute query to github
class GithubConnection:
def __init__(self, config: Config):
self.s = requests.Session()
self.s.headers.update({"Accept": "application/vnd.github+json"})
if config.github.token:
self.s.headers.update({"Authorization": "Bearer " + config.github.token})

@param key if set, means the output is paginated
"""
headers = {
"Accept": "application/vnd.github+json",
**(
{"Authorization": "Bearer " + CONFIG.github.token}
if CONFIG.github.token
else {}
),
}
ret: list[dict] = []
while True:
# Prepare the request adding headers from Github cache.
GITHUB_CACHE.prepare(url, headers)
response = requests.get(url, headers=headers)
headerslog = {**headers, "Authorization": "***"}
log.debug(f"{url} {headerslog} {response}")
try:
response.raise_for_status()
except Exception:
raise Exception(f"{response.url}\n{response.text}")
# If the result is found in github cache, use it, otherwise extract it.
cached: Optional[GithubCache.Value] = GITHUB_CACHE.handle(response)
data: Any = cached.json if cached else response.json()
links: dict = cached.links if cached else response.links
# if there are no links, this is not a paginated url.
if not links:
return data[key] if key else data
# Append the data to ret. Use key if given.
add = data if isinstance(data, list) else data[key]
try:
assert isinstance(add, list)
ret.extend(add)
except KeyError:
log.exception(f"key={key} data={data}")
raise
# if no next, this is the end.
next = links.get("next")
if not next:
break
url = next["url"]
return ret
def get(self, url: str, key: str = ""):
"""Execute query to github
@param key if set, means the output is paginated
"""
ret: list[dict] = []
trynum = 0
while True:
# Prepare the request adding headers from Github cache.
headers = GITHUB_CACHE.prepare(url)
response = self.s.get(url, headers=headers)
log.debug(f"{url} {response}")
try:
try:
response.raise_for_status()
except RemoteDisconnected:
trynum += 1
if trynum >= 3:
raise
continue
except Exception:
raise Exception(f"{response.url}\n{response.text}")
# If the result is found in github cache, use it, otherwise extract it.
cached: Optional[GithubCache.Value] = GITHUB_CACHE.handle(response)
data: Any = cached.json if cached else response.json()
links: dict = cached.links if cached else response.links
# if there are no links, this is not a paginated url.
if not links:
return data[key] if key else data
# Append the data to ret. Use key if given.
add = data if isinstance(data, list) else data[key]
try:
assert isinstance(add, list)
ret.extend(add)
except KeyError:
log.exception(f"key={key} data={data}")
raise
# if no next, this is the end.
next = links.get("next")
if not next:
break
url = next["url"]
return ret


def gh_get(url: str, key: str = "") -> Any:
return GH.get(url, key)


@functools.lru_cache()
Expand Down Expand Up @@ -1118,6 +1130,8 @@ def cli(args: Args):
CONFIG = Config(tmp)
global PARSEDCONFIG
PARSEDCONFIG = ParsedConfig(CONFIG)
global GH
GH = GithubConnection(CONFIG)
global TEMPLATE
TEMPLATE = jinja2.Environment(
loader=jinja2.BaseLoader(), lstrip_blocks=True, trim_blocks=True
Expand Down

0 comments on commit 27f4634

Please sign in to comment.