diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..2196085 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,40 @@ +name: Build and release + +on: + push: + tags: + - v* + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Store version number + run: | + VERSION=${GITHUB_REF_NAME#v} + echo "VERSION=$VERSION" >> $GITHUB_ENV + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Build and run checks + run: | + ./build.sh build + - name: Upload build artifacts + uses: actions/upload-artifact@v3 + with: + name: latest-release + path: dist/ + - name: Upload New Release. + uses: softprops/action-gh-release@v1 + with: + name: v${{ env.VERSION }} + tag_name: v${{ env.VERSION }} + body: pyra2yr wheel and source distribution. + files: | + dist/pyra2yr-${{ env.VERSION }}-py3-none-any.whl + dist/pyra2yr-${{ env.VERSION }}.tar.gz + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e422201 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +** +!.github +!.github/** +!.github/workflows/build.yml +!.gitignore +!.pylintrc +!/pyra2yr +!/pyra2yr/** +!/ra2yrproto/__init__.py +!LICENSE +!Makefile +!README.md +!pyproject.toml +!requirements.txt +!setup.py +!build.sh +*__pycache__* diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..36b641d --- /dev/null +++ b/.pylintrc @@ -0,0 +1,638 @@ +[MASTER] +load-plugins=pylint_protobuf + +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=.+_pb2\.py(i)? + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=ra2yrproto + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +# load-plugins=pylint_protobuf + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.10 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=88 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=attribute-defined-outside-init, + broad-exception-caught, + fixme, + raw-checker-failed, + too-many-boolean-expressions, + too-many-instance-attributes, + too-many-public-methods, + unbalanced-tuple-unpacking, + bad-inline-option, + deprecated-pragma, + file-ignored, + invalid-name, + locally-disabled, + missing-class-docstring, + missing-function-docstring, + missing-module-docstring, + suppressed-message, + use-symbolic-message-instead, + useless-suppression, + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/README.md b/README.md new file mode 100644 index 0000000..aef2f6d --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# pyra2yr + +Python interface for ra2yrcpp. + + +## Setup + +```bash +poetry install +``` + +## Usage + +See `pyra2yr/test_*.py` files for example usage. diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..9232262 --- /dev/null +++ b/build.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -o nounset + +if [ ! -d .venv ]; then + set -e + python3 -m venv .venv + .venv/bin/pip install -U pip setuptools + .venv/bin/pip install poetry poetry-dynamic-versioning + .venv/bin/poetry install + # TODO(shmocz): Workaround for poetry adding empty lines to this + git checkout pyproject.toml + set +e +fi + +function lint() { + .venv/bin/poetry run pylint pyra2yr +} + +function format() { + .venv/bin/poetry run black pyra2yr +} + +function check-format() { + format + d="$(git diff)" + [ ! -z "$d" ] && { echo "$d"; exit 1; } +} + +function build() { + check-format + set -e + lint + .venv/bin/poetry build + set +e +} + +$1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a7784e3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "pyra2yr" +version = "0.0.0" +description = "Python interface for ra2yrproto" +authors = ["shmocz <112764837+shmocz@users.noreply.github.com>"] +license = "GPL3" +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.11" +numpy = "^1.26.2" +aiohttp = "^3.9.1" +protobuf = "^4.25.1" + +[tool.poetry.group.release.dependencies] +ra2yrproto = { url = "https://github.com/shmocz/ra2yrproto/releases/download/v1/ra2yrproto-1-py3-none-any.whl" } + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" + +[tool.poetry.group.dev.dependencies] +black = "^23.12.0" +pylint = "^3.0.3" +pylint-protobuf = "^0.22.0" +isort = "^5.13.2" +docformatter = "^1.7.5" + +[tool.poetry-dynamic-versioning] +enable = true + +[build-system] +requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] +build-backend = "poetry_dynamic_versioning.backend" diff --git a/pyra2yr/__init__.py b/pyra2yr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyra2yr/async_container.py b/pyra2yr/async_container.py new file mode 100644 index 0000000..8907de0 --- /dev/null +++ b/pyra2yr/async_container.py @@ -0,0 +1,24 @@ +import asyncio + + +class AsyncDict: + def __init__(self): + self._data = {} + self._cond = asyncio.Condition() + + async def get_item(self, key, timeout: float = None, remove: bool = False): + async with self._cond: + await asyncio.wait_for( + self._cond.wait_for(lambda: key in self._data), timeout + ) + item = self._data[key] + if remove: + self._data.pop(key) + return item + + async def put_item(self, key, value): + async with self._cond: + if key in self._data: + raise RuntimeError(f"key {key} exists") + self._data[key] = value + self._cond.notify_all() diff --git a/pyra2yr/main.py b/pyra2yr/main.py new file mode 100644 index 0000000..abc7964 --- /dev/null +++ b/pyra2yr/main.py @@ -0,0 +1,34 @@ +import argparse +import gzip + +from google.protobuf.json_format import MessageToJson + +from pyra2yr.util import read_protobuf_messages + + +def parse_args(): + a = argparse.ArgumentParser( + description="pyra2yr main utility", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + a.add_argument("-d", "--dump-replay", action="store_true") + a.add_argument("-i", "--input-path", help="input path if applicable", type=str) + return a.parse_args() + + +def dump_replay(path: str): + with gzip.open(path, "rb") as f: + m = read_protobuf_messages(f) + for _, m0 in enumerate(m): + print(MessageToJson(m0)) + + +def main(): + # pylint: disable=unused-variable + args = parse_args() + if args.dump_replay: + dump_replay(args.input_path) + + +if __name__ == "__main__": + main() diff --git a/pyra2yr/manager.py b/pyra2yr/manager.py new file mode 100644 index 0000000..6fe070e --- /dev/null +++ b/pyra2yr/manager.py @@ -0,0 +1,453 @@ +import asyncio +import logging as lg +import traceback +from collections.abc import Iterable +from datetime import datetime as dt +from enum import Enum +from typing import Any + +from ra2yrproto import commands_game, commands_yr, core, ra2yr + +from pyra2yr.network import DualClient, logged_task +from pyra2yr.state_manager import StateManager +from pyra2yr.util import Clock + + +class PlaceStrategy(Enum): + RANDOM = 0 + FARTHEST = 1 + ABSOLUTE = 2 + + +class Manager: + """Manages connections and state updates for an active game process.""" + + def __init__( + self, + address: str = "0.0.0.0", + port: int = 14521, + poll_frequency=20, + fetch_state_timeout=5.0, + ): + """ + Parameters + ---------- + address : str, optional + WebSocket API endpoint, by default "0.0.0.0" + port : int, optional + Destination server port, by default 14525 + poll_frequency : int, optional + Frequency for polling the game state in Hz, by default 20 + fetch_state_timeout : float, optional + Timeout (seconds) for state fetching (default: 5.0) + """ + self.address = address + self.port = port + self.poll_frequency = min(max(1, poll_frequency), 60) + self.fetch_state_timeout = fetch_state_timeout + self.state = StateManager() + self.client: DualClient = DualClient(self.address, self.port) + self.t = Clock() + self.iters = 0 + self.show_stats_every = 30 + self.delta = 0 + self.M = ManagerUtil(self) + self._stop = asyncio.Event() + self._main_task = None + + def start(self): + self._main_task = logged_task(self.mainloop()) + self.client.connect() + + async def stop(self): + self._stop.set() + await self._main_task + await self.client.stop() + + async def step(self, s: ra2yr.GameState): + pass + + async def update_initials(self): + res_istate = await self.M.read_value(initial_game_state=ra2yr.GameState()) + state = res_istate.data.initial_game_state + self.state.sc.set_initials(state.object_types, state.prerequisite_groups) + + async def on_state_update(self, s: ra2yr.GameState): + if self.iters % self.show_stats_every == 0: + delta = self.t.toc() + lg.debug( + "step=%d interval=%d avg_duration=%f avg_fps=%f", + self.iters, + self.show_stats_every, + delta / self.show_stats_every, + self.show_stats_every / delta, + ) + self.t.tic() + if s.current_frame > 0: + if not self.state.sc.has_initials(): + await self.update_initials() + try: + fn = await self.step(s) + if fn: + # await asyncio.create_task(fn) + await fn() + except AssertionError: + raise + except Exception: + lg.error("exception on step: %s", traceback.format_exc()) + self.iters += 1 + + async def get_state(self) -> ra2yr.GameState: + cmd = commands_yr.GetGameState() + state = await self.client.exec_command(cmd, timeout=5.0) + if not state.result.Unpack(cmd): + raise RuntimeError(f"failed to unpack state: {state}") + return cmd.state + + async def run_command(self, c: Any) -> core.CommandResult: + return await self.client.exec_command(c) + + async def run(self, c: Any = None, **kwargs) -> Any: + for k, v in kwargs.items(): + if isinstance(v, list): + getattr(c, k).extend(v) + else: + try: + setattr(c, k, v) + except Exception: # FIXME: more explicit check + getattr(c, k).CopyFrom(v) + cmd_name = c.__class__.__name__ + res = await self.run_command(c) + if res.result_code == core.ResponseCode.ERROR: + lg.error("Failed to run command %s: %s", cmd_name, res.error_message) + res_o = type(c)() + res.result.Unpack(res_o) + return res_o + + # TODO: dont run async code in same thread as Manager due to performance reasons + async def mainloop(self): + d = 1 / self.poll_frequency + deadline = dt.now().timestamp() + while not self._stop.is_set(): + try: + await asyncio.sleep(min(d, max(deadline - dt.now().timestamp(), 0.0))) + deadline = dt.now().timestamp() + d + s = await self.get_state() + if not self.state.should_update(s): + continue + self.state.sc.set_state(s) + await self.on_state_update(s) + await self.state.state_updated() + except asyncio.exceptions.TimeoutError: + lg.error("Couldn't fetch result") + + async def wait_state(self, cond, timeout=30, err=None): + await self.state.wait_state(lambda x: cond(), timeout=timeout, err=err) + + +class CommandBuilder: + @classmethod + def make_command(cls, c: Any, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, list): + getattr(c, k).extend(v) + else: + try: + setattr(c, k, v) + except Exception: # FIXME: more explicit check + getattr(c, k).CopyFrom(v) + return c + + @classmethod + def add_event( + cls, + event_type: ra2yr.NetworkEvent = None, + house_index: int = 0, + frame_delay=0, + spoof=False, + **kwargs, + ): + return cls.make_command( + commands_yr.AddEvent(), + event=ra2yr.Event(event_type=event_type, house_index=house_index, **kwargs), + frame_delay=frame_delay, + spoof=spoof, + ) + + @classmethod + def make_produce(cls, rtti_id: int = 0, heap_id: int = 0, is_naval: bool = False): + return cls.add_event( + event_type=ra2yr.NETWORK_EVENT_Produce, + production=ra2yr.Event.Production( + rtti_id=rtti_id, heap_id=heap_id, is_naval=is_naval + ), + ) + + +class ManagerUtil: + def __init__(self, manager: Manager): + self.manager = manager + self.C = CommandBuilder + + # TODO(shmocz): low level stuff put elsewhere + async def unit_command( + self, + object_addresses: list[int] = None, + action: ra2yr.UnitAction = None, + ): + return await self.manager.run( + self.C.make_command( + commands_yr.UnitCommand(), + object_addresses=object_addresses, + action=action, + ) + ) + + async def unit_order( + self, + objects: list[ra2yr.Object] | ra2yr.Object = None, + action: ra2yr.UnitAction = None, + target_object: ra2yr.Object = None, + coordinates: ra2yr.Coordinates = None, + ): + """Perform UnitOrder. Depending on action type, target_object or coordinates + may be optional and will be ignored. + + Parameters + ---------- + objects : list[ra2yr.Object] | ra2yr.Object + Source object/objects. + action : ra2yr.UnitAction + Action to perform + target_object : ra2yr.Object, optional + Target object, if applicable + coordinates : ra2yr.Coordinates, optional + Target coordinates, if applicable + + Returns + ------- + + Raises + ------ + RuntimeError + If command execution failed. + + """ + p_target = None + if target_object: + p_target = target_object.pointer_self + if not isinstance(objects, list): + if isinstance(objects, Iterable): + objects = list(objects) + elif not objects is None: + objects = [objects] + else: + objects = [] + r = await self.manager.run_command( + commands_game.UnitOrder( + object_addresses=[o.pointer_self for o in objects], + action=action, + target_object=p_target, + coordinates=coordinates, + ) + ) + if r.result_code == core.ERROR: + raise RuntimeError(f"UnitOrder failed: {r.error_message}") + return r + + async def select( + self, + objects: list[ra2yr.Object] | ra2yr.Object, + ): + return await self.unit_order(objects=objects, action=ra2yr.UNIT_ACTION_SELECT) + + async def attack( + self, objects: list[ra2yr.Object] | ra2yr.Object, target_object=None + ): + return await self.unit_order( + objects=objects, + action=ra2yr.UNIT_ACTION_ATTACK, + target_object=target_object, + ) + + async def attack_move( + self, objects: list[ra2yr.Object] | ra2yr.Object, coordinates=None + ): + return await self.unit_order( + objects=objects, + action=ra2yr.UNIT_ACTION_ATTACK_MOVE, + coordinates=coordinates, + ) + + async def move(self, objects: list[ra2yr.Object] | ra2yr.Object, coordinates=None): + return await self.unit_order( + objects=objects, + action=ra2yr.UNIT_ACTION_MOVE, + coordinates=coordinates, + ) + + async def capture( + self, + objects: list[ra2yr.Object] | ra2yr.Object = None, + target: ra2yr.Object = None, + ): + return await self.unit_order( + objects=objects, + target_object=target, + action=ra2yr.UNIT_ACTION_CAPTURE, + ) + + # TODO(shmocz): ambiguous wrt. building/unit + async def repair( + self, + obj: ra2yr.Object, + target: ra2yr.Object, + ): + return await self.unit_order( + objects=obj, + target_object=target, + action=ra2yr.UNIT_ACTION_REPAIR, + ) + + # TODO(shmocz): wait for proper value + # FIXME: rename var + async def deploy(self, obj: ra2yr.Object): + return await self.unit_order( + objects=[obj], + action=ra2yr.UNIT_ACTION_DEPLOY, + ) + + async def place_query( + self, type_class: int = None, house_class: int = None, coordinates=None + ) -> commands_yr.PlaceQuery: + return await self.manager.run( + self.C.make_command( + commands_yr.PlaceQuery(), + type_class=type_class, + house_class=house_class, + coordinates=coordinates, + ) + ) + + async def place_building( + self, + building: ra2yr.Object = None, + coordinates: ra2yr.Coordinates = None, + ) -> core.CommandResult: + return await self.run_command( + commands_game.PlaceBuilding(building=building, coordinates=coordinates), + ) + + async def click_event( + self, object_addresses=None, event: ra2yr.NetworkEvent = None + ): + return await self.manager.run( + self.C.make_command( + commands_yr.ClickEvent(), + object_addresses=object_addresses, + event=event, + ) + ) + + async def sell_buildings(self, objects: list[ra2yr.Object] | ra2yr.Object): + return await self.unit_order(objects=objects, action=ra2yr.UNIT_ACTION_SELL) + + async def stop(self, objects: list[ra2yr.Object] | ra2yr.Object): + return await self.unit_order(objects=objects, action=ra2yr.UNIT_ACTION_STOP) + + async def produce_order( + self, + object_type: ra2yr.ObjectTypeClass, + action: ra2yr.ProduceAction = None, + ) -> core.CommandResult: + return await self.run_command( + commands_game.ProduceOrder(object_type=object_type, action=action) + ) + + async def sell(self, objects: list[ra2yr.Object]): + return await self.unit_order(objects=objects, action=ra2yr.UNIT_ACTION_SELL) + + async def sell_walls(self, coordinates: ra2yr.Coordinates): + return await self.unit_order( + action=ra2yr.UNIT_ACTION_SELL_CELL, coordinates=coordinates + ) + + async def produce( + self, + rtti_id: int = 0, + heap_id: int = 0, + is_naval: bool = False, + ): + return await self.manager.run( + self.C.make_produce( + rtti_id=rtti_id, + heap_id=heap_id, + is_naval=is_naval, + ) + ) + + # TODO(shmocz): autodetect is_naval in the library + async def produce_building(self, heap_id: int = 0, is_naval: bool = False): + return await self.manager.run( + self.C.make_produce( + rtti_id=ra2yr.ABSTRACT_TYPE_BUILDINGTYPE, + heap_id=heap_id, + is_naval=is_naval, + ) + ) + + async def run_command(self, c: Any): + return await self.manager.run_command(c) + + async def start_production( + self, object_type: ra2yr.ObjectTypeClass + ) -> core.CommandResult: + return await self.produce_order( + object_type=object_type, + action=ra2yr.PRODUCE_ACTION_BEGIN, + ) + + async def add_message( + self, + message: str = None, + duration_frames: int = None, + color: ra2yr.ColorScheme = None, + ): + return await self.manager.run( + self.C.make_command( + commands_yr.AddMessage(), + message=message, + duration_frames=duration_frames, + color=color, + ) + ) + + async def read_value(self, **kwargs): + return await self.manager.run( + commands_yr.ReadValue(), + data=ra2yr.StorageValue(**kwargs), + ) + + async def map_data(self) -> ra2yr.MapData: + res = await self.read_value(map_data=ra2yr.MapData()) + return res.data.map_data + + async def inspect_configuration(self, config: commands_yr.Configuration = None): + return await self.manager.run( + self.C.make_command(commands_yr.InspectConfiguration(), config=config) + ) + + async def wait_game_to_begin(self, timeout=60): + await self.manager.wait_state( + lambda: self.manager.state.s.stage == ra2yr.STAGE_INGAME + and self.manager.state.s.current_frame > 1, + timeout=timeout, + ) + + async def wait_game_to_exit(self, timeout=60): + await self.manager.wait_state( + lambda: self.manager.state.s.stage == ra2yr.STAGE_EXIT_GAME, + timeout=timeout, + ) diff --git a/pyra2yr/network.py b/pyra2yr/network.py new file mode 100644 index 0000000..1f04a55 --- /dev/null +++ b/pyra2yr/network.py @@ -0,0 +1,224 @@ +import asyncio +import datetime +import logging +import struct +import traceback +from typing import Any, Dict + +import aiohttp +from ra2yrproto import core + +from .async_container import AsyncDict + +debug = logging.debug + + +async def async_log_exceptions(coro): + try: + return await coro + except Exception: + logging.error("%s", traceback.format_exc()) + + +def logged_task(coro): + return asyncio.create_task(async_log_exceptions(coro)) + + +class TCPClient: + def __init__(self, host: str, port: int): + self.host = host + self.port = port + self.reader = None + self.writer = None + self.__fmt = " bytes: + # read message length + data = await self.reader.read(struct.calcsize(self.__fmt)) + message_length = struct.unpack(self.__fmt, data)[0] + + # read the actual message + r = message_length + res = bytearray() + while r > 0: + chunk_bytes = await self.reader.read(message_length) + res.extend(chunk_bytes) + r -= len(chunk_bytes) + return bytes(res) + + async def send_message(self, m: str | bytes): + # write length + data = m + if not isinstance(m, bytes): + data = m.encode() + self.writer.write(self.pack_length(len(data))) + # write actual message + self.writer.write(data) + await self.writer.drain() + + +class WebSocketClient: + def __init__(self, uri: str, timeout=5.0): + self.uri = uri + self.in_queue = asyncio.Queue() + self.out_queue = asyncio.Queue() + self.timeout = timeout + self.task = None + self._tries = 15 + self._connect_delay = 1.0 + self._lock = asyncio.Lock() + + def open(self): + self.task = asyncio.create_task(async_log_exceptions(self.main())) + + async def close(self): + await self.in_queue.put(None) + await self.task + + async def send_message(self, m: str) -> aiohttp.WSMessage: + async with self._lock: + await self.in_queue.put(m) + return await self.out_queue.get() + + async def main(self): + # send the initial message + msg = await self.in_queue.get() + for i in range(self._tries): + try: + debug("connect, try %d %d", i, self._tries) + await self._main_session(msg) + break + except asyncio.exceptions.CancelledError: + break + except Exception: + logging.warning("connect failed (try %d/%d)", i + 1, self._tries) + if i + 1 == self._tries: + raise + await asyncio.sleep(self._connect_delay) + + async def _main_session(self, msg): + async with aiohttp.ClientSession() as session: + debug("connecting to %s %s msg %s", self.uri, session, msg) + async with session.ws_connect(self.uri, autoclose=False) as ws: + debug("connected to %s", self.uri) + await ws.send_bytes(msg) + + async for msg in ws: + await self.out_queue.put(msg) + in_msg = await self.in_queue.get() + if in_msg is None: + await ws.close() + break + await ws.send_bytes(in_msg) + self.in_queue = None + self.out_queue = None + debug("close _main_session") + + +class DualClient: + def __init__(self, host: str, port: int, timeout: float = 5.0): + self.host = host + self.port = port + self.conns: Dict[str, WebSocketClient] = {} + self.uri = f"http://{host}:{port}" + self.queue_id = -1 + self.timeout = timeout + self.results = AsyncDict() + self.in_queue = asyncio.Queue() + self._poll_task = None + self._stop = asyncio.Event() + # FIXME: ugly + self._queue_set = asyncio.Event() + + def connect(self): + for k in ["command", "poll"]: + self.conns[k] = WebSocketClient(self.uri, self.timeout) + self.conns[k].open() + debug("opened %s", k) + self._poll_task = asyncio.create_task(async_log_exceptions(self._poll_loop())) + + def make_command(self, msg=None, command_type=None) -> core.Command: + c = core.Command() + c.command_type = command_type + c.command.Pack(msg) + return c + + def make_poll_blocking(self, queue_id, timeout) -> core.Command: + c = core.PollResults() + c.args.queue_id = queue_id + c.args.timeout = timeout + return self.make_command(c, core.POLL_BLOCKING) + + def parse_response(self, msg: str) -> core.Response: + res = core.Response() + res.ParseFromString(msg) + return res + + async def run_client_command(self, c: Any) -> core.RunCommandAck: + msg = await self.conns["command"].send_message( + self.make_command(c, core.CLIENT_COMMAND).SerializeToString() + ) + + res = self.parse_response(msg.data) + ack = core.RunCommandAck() + if not res.body.Unpack(ack): + raise RuntimeError(f"failed to unpack ack: {res}") + return ack + + # TODO: could wrap this into task and cancel at exit + async def exec_command(self, c: Any, timeout=None): + msg = await self.run_client_command(c) + if self.queue_id < 0: + self.queue_id = msg.queue_id + self._queue_set.set() + # wait until results polled + return await self.results.get_item(msg.id, timeout=timeout, remove=True) + + async def _poll_loop(self): + await self._queue_set.wait() + while not self._stop.is_set(): + msg = await self.conns["poll"].send_message( + self.make_poll_blocking( + self.queue_id, int(self.timeout * 1000) + ).SerializeToString() + ) + res = self.parse_response(msg.data) + cc = core.PollResults() + if not res.body.Unpack(cc): + raise RuntimeError(f"failed to unpack poll results {cc}") + for x in cc.result.results: + await self.results.put_item(x.command_id, x) + + async def stop(self): + self._stop.set() + await self._poll_task + await self.conns["command"].close() + await self.conns["poll"].close() diff --git a/pyra2yr/state_container.py b/pyra2yr/state_container.py new file mode 100644 index 0000000..4c025e3 --- /dev/null +++ b/pyra2yr/state_container.py @@ -0,0 +1,92 @@ +from functools import cached_property + +from ra2yrproto import ra2yr + + +class StateContainer: + def __init__(self, s: ra2yr.GameState = None): + if not s: + s = ra2yr.GameState() + self.s = s + self._types: list[ra2yr.ObjectTypeClass] = [] + self._prerequisite_groups: ra2yr.PrerequisiteGroups = [] + + def get_factory(self, o: ra2yr.Factory) -> ra2yr.Factory: + """Get factory that's producing a particular object. + + Parameters + ---------- + o : ra2yr.Factory + Factory to query with + + Returns + ------- + ra2yr.Factory + Corresponding factory in the state. + + Raises + ------ + StopIteration if factory wasn't found. + """ + return next(x for x in self.s.factories if x.object == o.object) + + def get_object(self, o: ra2yr.Object | int) -> ra2yr.Object: + """Get object from current state by pointer value. + + Parameters + ---------- + o : ra2yr.Object | int | None + Object or address to be queried. If None, return all objects + + Returns + ------- + ra2yr.Object + _description_ + + Raises + ------ + StopIteration if object wasn't found + """ + if isinstance(o, ra2yr.Object): + o = o.pointer_self + return next(x for x in self.s.objects if x.pointer_self == o) + + def set_initials(self, t: list[ra2yr.ObjectTypeClass], p: ra2yr.PrerequisiteGroups): + self._types = t + self._prerequisite_groups = p + + def has_initials(self) -> bool: + return self._prerequisite_groups and self._types + + def set_state(self, s: ra2yr.GameState): + self.s.CopyFrom(s) + if any(o.pointer_technotypeclass == 0 for o in self.s.objects): + raise RuntimeError( + f"zero TC, frame={self.s.current_frame}, objs={self.s.objects}" + ) + + def types(self) -> list[ra2yr.ObjectTypeClass]: + return self._types + + @cached_property + def ttc_map(self) -> dict[int, ra2yr.ObjectTypeClass]: + """Map pointer to type class for fast look ups. + + Returns + ------- + dict[int, ra2yr.ObjectTypeClass] + the mapping + """ + return {x.pointer_self: x for x in self._types} + + @cached_property + def prerequisite_map(self) -> dict[int, set[int]]: + items = { + "proc": -6, + "tech": -5, + "radar": -4, + "barracks": -3, + "factory": -2, + "power": -1, + } + return {v: set(getattr(self._prerequisite_groups, k)) for k, v in items.items()} diff --git a/pyra2yr/state_manager.py b/pyra2yr/state_manager.py new file mode 100644 index 0000000..305c8b0 --- /dev/null +++ b/pyra2yr/state_manager.py @@ -0,0 +1,95 @@ +import asyncio +import logging as lg +import re +from typing import Iterator + +from google.protobuf import message as _message +from ra2yrproto import ra2yr + +from pyra2yr.state_container import StateContainer +from pyra2yr.state_objects import FactoryEntry, ObjectEntry + + +class StateManager: + def __init__(self, s: ra2yr.GameState = None): + self.sc = StateContainer(s) + self._cond_state_update = asyncio.Condition() + + @property + def s(self) -> ra2yr.GameState: + return self.sc.s + + def should_update(self, s: ra2yr.GameState) -> bool: + return s.current_frame != self.s.current_frame or s.stage != self.s.stage + + async def wait_state(self, cond, timeout=30, err=None): + async with self._cond_state_update: + try: + await asyncio.wait_for( + self._cond_state_update.wait_for(lambda: cond(self)), + timeout, + ) + except TimeoutError: + if err: + lg.error("wait failed: %s", err) + raise + + async def state_updated(self): + async with self._cond_state_update: + self._cond_state_update.notify_all() + + def query_type_class( + self, p: str, abstract_type=None + ) -> Iterator[ra2yr.ObjectTypeClass]: + """Query type classes by pattern and abstract type + + Parameters + ---------- + p : str + Regex to be searched from type class name + abstract_type : _type_, optional + Abstract type of the type class + + Yields + ------ + Iterator[ra2yr.ObjectTypeClass] + Matching type classes + """ + for x in self.sc.types(): + if (not re.search(p, x.name)) or ( + abstract_type and x.type != abstract_type + ): + continue + yield x + + def query_objects( + self, + t: ra2yr.ObjectTypeClass = None, + h: ra2yr.House = None, + a: ra2yr.AbstractType = None, + p: str = None, + ) -> Iterator[ObjectEntry]: + for _, x in enumerate(self.s.objects): + o = ObjectEntry(self.sc, x) + if ( + (h and o.get().pointer_house != h.self) + or (t and t.pointer_self != o.get().pointer_technotypeclass) + or (a and o.get().object_type != a) + or (p and not re.search(p, o.tc().name)) + ): + continue + yield o + + def query_factories( + self, t: ra2yr.ObjectTypeClass = None, h: ra2yr.House = None + ) -> Iterator[FactoryEntry]: + for x in self.s.factories: + f = FactoryEntry(self.sc, x) + if (h and f.o.owner != h.self) or ( + t and t.pointer_self != f.object.get().pointer_technotypeclass + ): + continue + yield f + + def current_player(self) -> ra2yr.House: + return next(p for p in self.s.houses if p.current_player) diff --git a/pyra2yr/state_objects.py b/pyra2yr/state_objects.py new file mode 100644 index 0000000..5aef088 --- /dev/null +++ b/pyra2yr/state_objects.py @@ -0,0 +1,122 @@ +import numpy as np +from ra2yrproto import ra2yr + +from pyra2yr.state_container import StateContainer +from pyra2yr.util import coord2array + + +class ViewObject: + """Provides an up to date view to an object of the underlying game state. + State change is directly reflected in the view object. + """ + + def __init__(self, m: StateContainer): + self.m = m + self._invalid = False + self.latest_frame = -1 + + def invalid(self): + self.update() + return self._invalid + + def fetch_next(self): + raise NotImplementedError() + + def update(self): + if self._invalid: + return + if self.latest_frame != self.m.s.current_frame: + try: + self.fetch_next() + self.latest_frame = self.m.s.current_frame + # TODO: general exception + except StopIteration: + self._invalid = True + + +class ObjectEntry(ViewObject): + def __init__(self, m: StateContainer, o: ra2yr.Object): + super().__init__(m) + self.o: ra2yr.Object = o + + def fetch_next(self): + self.o = self.m.get_object(self.o) + + def get(self) -> ra2yr.Object: + """Get reference to most recent Object entry. + + Returns + ------- + ra2yr.Object + """ + self.update() + return self.o + + def tc(self) -> ra2yr.ObjectTypeClass: + """The type class of the object""" + return self.m.ttc_map[self.o.pointer_technotypeclass] + + @property + def coordinates(self): + return coord2array(self.get().coordinates) + + @property + def health(self) -> float: + """Health in percentage + + Returns + ------- + float + Health in percentage + """ + return self.get().health / self.tc().strength + + def __repr__(self): + return str(self.o) + + +class FactoryEntry(ViewObject): + """State object entry. + + If the object is no longer in state, it's marked as invalid. + """ + + def __init__(self, m: StateContainer, o: ra2yr.Factory): + super().__init__(m) + self.o: ra2yr.Factory = o + self.object = ObjectEntry(m, m.get_object(o.object)) + + def fetch_next(self): + self.o = self.m.get_factory(self.o) + + def get(self) -> ra2yr.Factory: + """Get reference to most recent Factory entry. + + Returns + ------- + ra2yr.Factory + """ + self.update() + return self.o + + +class MapData: + def __init__(self, m: ra2yr.MapData): + self.m = m + + def ind2sub(self, I): + yy, xx = np.unravel_index(I, (self.m.width, self.m.height)) + return np.c_[xx, yy] + + @classmethod + def bbox(cls, x): + m_min = np.min(x, axis=0) + m_max = np.max(x, axis=0) + return np.array( + [ + [m_max[0] + 1, m_max[1] + 1], # BL + [m_max[0] + 1, m_min[1] - 1], # BR + [m_min[0] - 1, m_max[1] + 1], # TL + [m_min[0] - 1, m_min[1] - 1], # TR + ] + ) diff --git a/pyra2yr/test_multi_client.py b/pyra2yr/test_multi_client.py new file mode 100644 index 0000000..d92d467 --- /dev/null +++ b/pyra2yr/test_multi_client.py @@ -0,0 +1,214 @@ +import asyncio +import logging as lg +import random +import unittest + +import numpy as np +from ra2yrproto import ra2yr + +from pyra2yr.manager import PlaceStrategy +from pyra2yr.state_manager import ObjectEntry +from pyra2yr.test_util import MyManager +from pyra2yr.util import array2coord, pdist, setup_logging + +setup_logging(level=lg.DEBUG) + + +PT_CONNIE = r"Conscript" +PT_CONYARD = r"Yard" +PT_MCV = r"Construction\s+Vehicle" +PT_SENTRY = r"Sentry\s+Gun" +PT_SOV_BARRACKS = r"Soviet\s+Barracks" +PT_TESLA_REACTOR = r"Tesla\s+Reactor" +PT_TESLA_TROOPER = r"Shock\s+Trooper" +PT_ENGI = r"Soviet\s+Engineer" +PT_OIL = r"Oil\s+Derrick" +PT_SOV_WALL = r"Soviet\s+Wall" + + +# TODO(shmocz): get rid of task groups +class MultiTest(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.poll_frequency = 30 + self.fetch_state_timeout = 10.0 + self.managers: list[MyManager] = [] + for i in range(2): + M = MyManager( + port=14521 + i, + poll_frequency=self.poll_frequency, + fetch_state_timeout=self.fetch_state_timeout, + ) + M.start() + self.managers.append(M) + + async def deploy_mcvs(self): + for M in self.managers: + o = next(M.state.query_objects(p=PT_MCV, h=M.state.current_player())) + await M.M.deploy(o.o) + # wait until deployed + await self.managers[0].wait_state( + lambda: len(list(self.managers[0].state.query_objects(p=PT_CONYARD))) + == len(self.managers) + ) + await self.managers[0].wait_state( + lambda: all( + o.get().current_mission == ra2yr.Mission_Guard + for o in self.managers[0].state.query_objects(p=PT_CONYARD) + ) + ) + + async def check_defense_structure_attack(self, M: MyManager): + o_mcv = next(M.state.query_objects(p=PT_CONYARD, h=M.state.current_player())) + # Build conscript and sentry + async with asyncio.TaskGroup() as tg: + t = tg.create_task(M.produce_unit(PT_CONNIE)) + o_sentry = await M.produce_and_place( + M.get_unique_tc(PT_SENTRY), o_mcv.coordinates + ) + o_con = await t + await M.M.attack(objects=o_sentry.o, target_object=o_con.o) + + async def check_engi_capture( + self, M: MyManager, engi: ObjectEntry, target: ObjectEntry + ): + self.assertNotEqual(target.get().pointer_house, engi.get().pointer_house) + # Capture + await M.M.capture(objects=engi.get(), target=target.get()) + # Wait until captured + await M.wait_state( + lambda: target.get().pointer_house == engi.get().pointer_house + ) + self.assertEqual(target.get().pointer_house, engi.get().pointer_house) + + async def check_capture(self, M: MyManager): + # Build engineer + o_engi = await M.produce_unit(PT_ENGI) + + # Get nearest oil + oils = list(M.state.query_objects(p=PT_OIL)) + dists = pdist( + o_engi.coordinates, np.array([o.coordinates for o in oils]), axis=1 + ) + ix = np.argsort(dists)[0] + await self.check_engi_capture(M, o_engi, oils[ix]) + + async def check_repair_building( + self, M: MyManager, src: ObjectEntry, dst: ObjectEntry + ): + self.assertLess(dst.health, 1.0) + await M.M.repair(obj=src.get(), target=dst.get()) + await M.wait_state(lambda: dst.health == 1.0) + self.assertEqual(dst.health, 1.0) + + async def do_check_repair_building(self, M: MyManager, num_attackers=5): + attackers = [] + th = 0.7 + for _ in range(num_attackers): + obj = await M.produce_unit(PT_CONNIE) + attackers.append(obj.get()) + engi = await M.produce_unit(PT_ENGI) + + o_tesla = next( + M.state.query_objects(p=PT_TESLA_REACTOR, h=M.state.current_player()) + ) + await M.M.attack(objects=attackers, target_object=o_tesla.get()) + await M.wait_state(lambda: o_tesla.health < th) + + self.assertLess(o_tesla.health, th) + await M.M.stop(objects=attackers) + await self.check_repair_building(M, engi, o_tesla) + + async def sell_all_buildings(self, M: MyManager): + for o in M.state.query_objects( + h=M.state.current_player(), a=ra2yr.ABSTRACT_TYPE_BUILDING + ): + await M.M.sell(objects=o.get()) + + async def do_build_stuff(self): + for bname in [PT_TESLA_REACTOR, PT_SOV_BARRACKS]: + async with asyncio.TaskGroup() as tg: + for m in self.managers: + o_mcv = next( + m.state.query_objects(p=PT_CONYARD, h=m.state.current_player()) + ) + tg.create_task( + m.produce_and_place(m.get_unique_tc(bname), o_mcv.coordinates) + ) + + async def do_build_sell_walls(self, M: MyManager): + o_mcv = next(M.state.query_objects(p=PT_CONYARD, h=M.state.current_player())) + # Get corner coordinates for walls + D = await M.get_map_data() + I = D.ind2sub( + np.array( + [ + c.index + for c in D.m.cells + if o_mcv.get().pointer_self in [q.pointer_self for q in c.objects] + ] + ) + ) + # Get corners + B = D.bbox(I) * 256 + + tc_wall = next( + M.state.query_type_class( + p=PT_SOV_WALL, abstract_type=ra2yr.ABSTRACT_TYPE_BUILDINGTYPE + ) + ) + # Build walls to corners + for c in B: + await M.produce_and_place(tc_wall, c, strategy=PlaceStrategy.ABSOLUTE) + + # Get wall cells + tc_wall_ol = next( + M.state.query_type_class( + p=PT_SOV_WALL, abstract_type=ra2yr.ABSTRACT_TYPE_OVERLAYTYPE + ) + ) + D = await M.get_map_data() + I = ( + D.ind2sub( + np.array( + [ + c.index + for c in D.m.cells + if c.overlay_type_index == tc_wall_ol.array_index + ] + ) + ) + * 256 + ) + I = np.c_[I, np.ones((I.shape[0], 1)) * o_mcv.coordinates[2]] + + # Sell walls + # FIXME: check if z-coordintate fucks this up + for c in I: + await M.M.sell_walls(coordinates=array2coord(c)) + await asyncio.sleep(0.1) + lg.info("OK") + + # Check that no walls exist + + async def test_multi_client(self): + random.seed(1234) + + for M in self.managers: + await M.M.wait_game_to_begin() + + await self.deploy_mcvs() + await self.do_build_stuff() + + M = self.managers[0] + await self.do_build_sell_walls(M) + await self.do_check_repair_building(M) + await self.check_defense_structure_attack(M) + await self.check_capture(M) + await self.sell_all_buildings(M) + await M.M.wait_game_to_exit() + + # TODO(shmocz): check record file + + +if __name__ == "__main__": + unittest.main() diff --git a/pyra2yr/test_util.py b/pyra2yr/test_util.py new file mode 100644 index 0000000..1f07e24 --- /dev/null +++ b/pyra2yr/test_util.py @@ -0,0 +1,174 @@ +import unittest +from typing import Type + +import numpy as np +from ra2yrproto import commands_yr, core, ra2yr + +from pyra2yr.manager import Manager, ManagerUtil, PlaceStrategy +from pyra2yr.state_objects import FactoryEntry, MapData, ObjectEntry +from pyra2yr.util import array2coord, coord2array + + +async def check_config(U: ManagerUtil = None): + # Get config + cmd_1 = await U.inspect_configuration() + cfg1 = cmd_1.config + cfg_ex = commands_yr.Configuration() + cfg_ex.CopyFrom(cfg1) + cfg_ex.debug_log = True + cfg_ex.parse_map_data_interval = 1 + assert cfg_ex == cfg1 + + # Try changing some settings + cfg_diff = commands_yr.Configuration(parse_map_data_interval=4) + cfg2_ex = commands_yr.Configuration() + cfg2_ex.CopyFrom(cfg1) + cfg2_ex.MergeFrom(cfg_diff) + cmd_2 = await U.inspect_configuration(config=cfg_diff) + cfg2 = cmd_2.config + assert cfg2 == cfg2_ex + + +class MyManager(Manager): + async def get_place_locations( + self, coords: np.array, o: ra2yr.Object, rx: int, ry: int + ) -> np.array: + """Return coordinates where a building can be placed. + + Parameters + ---------- + coords : np.array + Center point + o : ra2yr.Object + The ready building + rx : int + Query x radius + ry : int + Query y radius + + Returns + ------- + np.array + Result coordinates. + """ + xx = np.arange(rx) - int(rx / 2) + yy = np.arange(ry) - int(ry / 2) + if coords.size < 3: + coords = np.append(coords, 0) + grid = np.transpose([np.tile(xx, yy.shape), np.repeat(yy, xx.shape)]) * 256 + grid = np.c_[grid, np.zeros((grid.shape[0], 1))] + coords + res = await self.M.place_query( + type_class=o.pointer_technotypeclass, + house_class=o.pointer_house, + coordinates=[array2coord(x) for x in grid], + ) + return np.array([coord2array(x) for x in res.coordinates]) + + def get_unique_tc(self, pattern) -> ra2yr.ObjectTypeClass: + tc = list(self.state.query_type_class(p=pattern)) + if len(tc) != 1: + raise RuntimeError(f"Non unique TypeClass: {pattern}: {tc}") + return tc[0] + + async def begin_production(self, t: ra2yr.ObjectTypeClass) -> FactoryEntry: + # TODO: Check for error + await self.M.start_production(t) + frame = self.state.s.current_frame + + # Wait until corresponding type is being produced + await self.wait_state( + lambda: any( + f + for f in self.state.query_factories(h=self.state.current_player(), t=t) + ) + or self.state.s.current_frame - frame >= 300 + ) + try: + # Get the object + return next(self.state.query_factories(h=self.state.current_player(), t=t)) + except StopIteration as e: + raise RuntimeError(f"Failed to start production of {t}") from e + + async def produce_unit(self, t: ra2yr.ObjectTypeClass | str) -> ObjectEntry: + if isinstance(t, str): + t = self.get_unique_tc(t) + + fac = await self.begin_production(t) + + # Wait until done + # TODO: check for cancellation + await self.wait_state( + lambda: fac.object.get().current_mission == ra2yr.Mission_Guard + ) + return fac.object + + async def produce_and_place( + self, + t: ra2yr.ObjectTypeClass, + coords, + strategy: PlaceStrategy = PlaceStrategy.FARTHEST, + ) -> ObjectEntry: + U = self.M + fac = await self.begin_production(t) + obj = fac.object + + # wait until done + await self.wait_state(lambda: fac.get().completed) + + if strategy == PlaceStrategy.FARTHEST: + place_locations = await self.get_place_locations( + coords, + obj.get(), + 15, + 15, + ) + + # get cell closest away and place + dists = np.sqrt(np.sum((place_locations - coords) ** 2, axis=1)) + coords = place_locations[np.argsort(dists)[-1]] + r = await U.place_building(building=obj.get(), coordinates=array2coord(coords)) + if r.result_code != core.ResponseCode.OK: + raise RuntimeError(f"place failed: {r.error_message}") + # wait until building has been placed + await self.wait_state( + lambda: obj.invalid() + or fac.invalid() + and obj.get().current_mission + not in (ra2yr.Mission_None, ra2yr.Mission_Construction) + ) + return obj + + async def get_map_data(self) -> MapData: + W = await self.M.map_data() + return MapData(W) + + +class BaseGameTest(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.poll_frequency = 30 + self.fetch_state_timeout = 10.0 + self.M = self.get_manager() + self.sm = self.M.state + self.M.start() + + async def asyncTearDown(self): + await self.M.stop() + + async def check_record_output_defined(self): + # Get config + cmd_1 = await self.M.M.inspect_configuration() + cfg1 = cmd_1.config + self.assertNotEqual( + cfg1.record_filename, + "", + "No record output. Is RA2YRCPP_RECORD_PATH environment variable is set.", + ) + + def get_manager_class(self) -> Type[Manager]: + return Manager + + def get_manager(self) -> Manager: + return self.get_manager_class()( + poll_frequency=self.poll_frequency, + fetch_state_timeout=self.fetch_state_timeout, + ) diff --git a/pyra2yr/util.py b/pyra2yr/util.py new file mode 100644 index 0000000..f3e0090 --- /dev/null +++ b/pyra2yr/util.py @@ -0,0 +1,201 @@ +import datetime +import io +import logging +import pstats +import re +from enum import Enum, IntFlag, auto +from typing import Iterator, List, Tuple + +from google.protobuf import text_format +from google.protobuf.internal.decoder import _DecodeVarint32 +import numpy as np +from ra2yrproto import ra2yr + + +def tuple2coord(x) -> ra2yr.Coordinates: + return ra2yr.Coordinates(**{k: x[i] for i, k in enumerate("xyz")}) + + +def unixpath(p: str): + return re.sub(r"^\w:", "", re.sub(r"\\+", "/", p)) + + +def coord2cell(x): + if isinstance(x, tuple): + return tuple(coord2cell(v) for v in x) + return int(x / 256) + + +def cell2coord(x) -> Tuple[int]: + if isinstance(x, tuple): + return tuple(cell2coord(v) for v in x) + return int(x * 256) + + +def coord2tuple(x: ra2yr.Coordinates) -> Tuple[int]: + return tuple(getattr(x, k) for k in "xyz") + + +def array2coord(x: np.array) -> ra2yr.Coordinates: + return ra2yr.Coordinates(**{k: int(v) for k, v in zip("xyz", x)}) + + +def coord2array(x: ra2yr.Coordinates) -> np.array: + return np.array(coord2tuple(x)) + + +def pdist(x1, x2, axis=0): + return np.sqrt(np.sum((x1 - x2) ** 2, axis=axis)) + + +def read_protobuf_messages(f) -> Iterator[ra2yr.GameState]: + buf = f.read(10) # Maximum length of length prefix + while buf: + msg_len, new_pos = _DecodeVarint32(buf, 0) + buf = buf[new_pos:] + buf += f.read(msg_len - len(buf)) + s = ra2yr.GameState() + s.ParseFromString(buf) + yield s + buf = buf[msg_len:] + buf += f.read(10 - len(buf)) + + +def msg_oneline(m): + return text_format.MessageToString(m, as_one_line=True) + + +class EventListID(Enum): + out_list = 1 + do_list = 2 + megamission_list = 3 + + +class EventHistory: + def __init__(self, lists: ra2yr.EventListsSnapshot): + self.lists = lists + + def unique_events(self): + """Get unique events across all of the event queues for the entire timespan""" + all_events = [] + for l in self.lists: + for k in ["out_list", "do_list", "megamission_list"]: + all_events.extend((k, x.frame, x.timing, x) for x in getattr(l, k)) + keys = [(x[0], x[1], x[2]) for x in all_events] + res = [] + for u in set(keys): + res.append(all_events[keys.index(u)]) + return res + + def get_indexed(self) -> List[List[Tuple[EventListID, ra2yr.Event]]]: + """For each frame, flattens events from each event list into a single list. + + This is useful for determining the point when an event was consumed. + """ + res = [] + for l in self.lists.lists: + flattened = [] + for k in EventListID: + events = getattr(l, k.name) + flattened.extend((k, x) for x in events) + res.append(flattened) + return res + + +def setup_logging(level=logging.INFO): + FORMAT = ( + "[%(levelname)s] %(asctime)s %(module)s.%(filename)s:%(lineno)d: %(message)s" + ) + logging.basicConfig(level=level, format=FORMAT) + start_time = datetime.datetime.now().isoformat() + level_name = logging.getLevelName(level) + logging.info("Logging started at: %s, level=%s", start_time, level_name) + + +class Clock: + def __init__(self): + self.t = self.tic() + + def tic(self): + self.t = datetime.datetime.now() + return self.t + + def toc(self): + return (datetime.datetime.now() - self.t).total_seconds() + + +def finish_profiler(p): + p.disable() + s = io.StringIO() + ps = pstats.Stats(p, stream=s).sort_stats(pstats.SortKey.CUMULATIVE) + ps.print_stats() + return s.getvalue() + + +class PrerequisiteCheck(IntFlag): + OK = auto() + EMPTY_PREREQUISITES = auto() + NOT_IN_MYIDS = auto() + INVALID_PREQ_GROUP = auto() + NOT_PREQ_GROUP_IN_MYIDS = auto() + BAD_TECH_LEVEL = auto() + NO_STOLEN_TECH = auto() + FORBIDDEN_HOUSE = auto() + BUILD_LIMIT_REACHED = auto() + + +def check_stolen_tech(ttc: ra2yr.ObjectTypeClass, h: ra2yr.House): + return ( + (ttc.requires_stolen_allied_tech and not h.allied_infiltrated) + or (ttc.requires_stolen_soviet_tech and not h.soviet_infiltrated) + or (ttc.requires_stolen_third_tech and not h.third_infiltrated) + ) + + +def check_preq_list(ttc: ra2yr.ObjectTypeClass, myids, prerequisite_map): + res = PrerequisiteCheck.OK + if not ttc.prerequisites: + res |= PrerequisiteCheck.EMPTY_PREREQUISITES + for p in ttc.prerequisites: + if p >= 0 and p not in myids: + res |= PrerequisiteCheck.NOT_IN_MYIDS + if p < 0: + if p not in prerequisite_map: + res |= PrerequisiteCheck.INVALID_PREQ_GROUP + if not myids.intersection(prerequisite_map[p]): + res |= PrerequisiteCheck.NOT_PREQ_GROUP_IN_MYIDS + return res + + +def cell_grid(coords, rx: int, ry: int) -> List[ra2yr.Coordinates]: + """Get a rectangular grid centered on given coordinates. + + Parameters + ---------- + coords : _type_ + The center coordinates for the grid. + rx : int + x-axis radius + ry : int + y-axis radius + + Returns + ------- + List[ra2yr.Coordinates] + List of cell coordinates lying on the given bounds. + """ + # TODO(shmocz): generalize + if coords.size < 3: + coords = np.append(coords, 0) + place_query_grid = [] + for i in range(0, rx): + for j in range(0, ry): + place_query_grid.append( + tuple2coord( + tuple( + x + y * 256 + for x, y in zip(coords, (i - int(rx / 2), j - int(ry / 2), 0)) + ) + ) + ) + return place_query_grid