Skip to content

Commit

Permalink
feat(dataset): add KGSIterableDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jaxvanyang committed Jun 24, 2024
1 parent 59e7d85 commit 4d4f17b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/mygo/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .mcts import MCTSDataset
from .sgf import KGSDataset
from .sgf import KGSDataset, KGSIterableDataset

__all__ = ["MCTSDataset", "KGSDataset"]
__all__ = ["MCTSDataset", "KGSDataset", "KGSIterableDataset"]
128 changes: 98 additions & 30 deletions src/mygo/datasets/sgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import tarfile
from pathlib import Path
from shutil import copyfileobj
from typing import Any
from typing import Any, Generator
from urllib.request import urlopen

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset

from mygo.encoder.base import Encoder
from mygo.encoder.oneplane import OnePlaneEncoder
Expand All @@ -15,8 +15,8 @@
from mygo.pysgf import SGF


class KGSDataset(Dataset):
"""KGS SGF game records."""
class KGSMixin:
"""Metadata and helper functions of KGS archives."""

url_prefix = "https://dl.u-go.net/gamerecords"
train_archives = (
Expand Down Expand Up @@ -79,6 +79,46 @@ class KGSDataset(Dataset):
"KGS-2002-19-3646-.tar.bz2",
)

@classmethod
def download_and_extract_archives(cls, root: Path, train: bool = True) -> None:
"""Download and extract SGF archives."""

if not root.is_dir():
root.mkdir(parents=True)

archives = cls.train_archives if train else cls.test_archives
for archive in archives:
url = f"{cls.url_prefix}/{archive}"
path = root / archive

if path.is_file():
continue

print(f"Downloading {url} -> {path}")
with urlopen(url) as response, open(path, "wb") as f:
copyfileobj(response, f)

# only extract after downloading because checking if they are extracted is
# too expensive
with tarfile.open(path) as tar:
print(f"Extracting {path}")
tar.extractall(path=root)

@classmethod
def get_sgf_paths(
cls, root: Path, train: bool = True
) -> Generator[str, None, None]:
"""Return a generator of extracted SGF files' relative paths to root."""

for archive in cls.train_archives if train else cls.test_archives:
with tarfile.open(root / archive) as tar:
# ignore 1st name, because it's the parent directory
yield from tar.getnames()[1:]


class KGSDataset(KGSMixin, Dataset):
"""Dataset of KGS game records."""

def __init__(
self,
root: str | Path,
Expand All @@ -102,16 +142,11 @@ def __init__(
self.features, self.labels = [], []

if download:
self._download_and_extract_archives()
self.download_and_extract_archives(self.root, train=train)

def get_tarfile_names(archive):
with tarfile.open(self.root / archive) as tar:
return tar.getnames()[1:] # ignore 1st name, because it's a directory

names = itertools.chain(*(get_tarfile_names(a) for a in self.archives))
names = itertools.islice(names, game_count)
for name in names:
sgf_root = SGF.parse_file(self.root / name)
paths = itertools.islice(self.get_sgf_paths(self.root, train=train), game_count)
for path in paths:
sgf_root = SGF.parse_file(self.root / path)
game = Game.from_pysgf(sgf_root)
node = sgf_root

Expand Down Expand Up @@ -141,25 +176,58 @@ def __getitem__(self, key: int) -> tuple[Any, Any]:

return x, y

def _download_and_extract_archives(self) -> None:
"""Download and extract SGF archives."""

if not self.root.is_dir():
self.root.mkdir(parents=True)
class KGSIterableDataset(KGSMixin, IterableDataset):
"""Iterable dataset of KGS game records."""

for archive in self.archives:
url = f"{self.url_prefix}/{archive}"
path = self.root / archive
def __init__(
self,
root: str | Path,
train: bool = True,
download: bool = True,
game_count: int = 100,
encoder: Encoder = OnePlaneEncoder(),
transform=None,
target_transform=None,
) -> None:
self.archives = self.train_archives if train else self.test_archives
assert game_count <= sum(int(a.split("-")[-2]) for a in self.archives)

if path.is_file():
continue
self.root = root if isinstance(root, Path) else Path(root)
self.game_count = game_count
self.encoder = encoder
self.train = train
self.download = download
self.transform = transform
self.target_transform = target_transform

print(f"Downloading {url} -> {path}")
with urlopen(url) as response, open(path, "wb") as f:
copyfileobj(response, f)
if download:
self.download_and_extract_archives(self.root, train=train)

# only extract after downloading
# because checking if it's extracted is too expensive
with tarfile.open(path) as tar:
print(f"Extracting {path}")
tar.extractall(path=self.root)
self.sgf_paths = list(
itertools.islice(self.get_sgf_paths(self.root, train=train), game_count)
)

def __iter__(self) -> Generator[tuple[Any, Any], None, None]:
for path in self.sgf_paths:
sgf_root = SGF.parse_file(self.root / path)
game = Game.from_pysgf(sgf_root)
node = sgf_root

while node.children:
# only select the first child for simplicity
node = node.children[0]
move = from_pysgf_move(node.move)

if isinstance(move, PlayMove):
feature = self.encoder.encode(game)
label = move.point.encode(game.board_size)

if self.transform:
feature = self.transform(feature)
if self.target_transform:
label = self.target_transform(label)

yield feature, label

game.apply_move(move)

0 comments on commit 4d4f17b

Please sign in to comment.