Skip to content

Commit

Permalink
organize java files
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Jul 31, 2024
1 parent 9869fe3 commit 0d1acdd
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 117 deletions.
50 changes: 30 additions & 20 deletions pyterrier/java/core.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import os
from pyterrier.java import required_raise, required, before_init, started, mavenresolver, JavaClasses, JavaInitializer, register_config
from typing import Dict, Optional
from typing import Optional
import pyterrier as pt


stdout_ref = None
stderr_ref = None



configure = register_config('pyterrier.java', {
'jars': [],
'options': [],
'mem': None,
'log_level': 'WARN',
'redirect_io': True,
})
_stdout_ref = None
_stderr_ref = None


# ----------------------------------------------------------
# Java Initialization
# ----------------------------------------------------------

class CoreInit(JavaInitializer):
def priority(self) -> int:
Expand Down Expand Up @@ -111,22 +104,22 @@ def writeChar(self, chara):
return self.pystream.write(bytes([chara]))
return self.pystream.write(chr(chara))

# we need to hold lifetime references to stdout_ref/stderr_ref, to ensure
# we need to hold lifetime references to _stdout_ref/_stderr_ref, to ensure
# they arent GCd. This prevents a crash when Java callsback to GCd py obj

global stdout_ref
global stderr_ref
global _stdout_ref
global _stderr_ref
import sys
stdout_ref = MyOut(sys.stdout)
stderr_ref = MyOut(sys.stderr)
_stdout_ref = MyOut(sys.stdout)
_stderr_ref = MyOut(sys.stderr)
jls = autoclass("java.lang.System")
jls.setOut(
autoclass('java.io.PrintStream')(
autoclass('org.terrier.python.ProxyableOutputStream')(stdout_ref),
autoclass('org.terrier.python.ProxyableOutputStream')(_stdout_ref),
signature="(Ljava/io/OutputStream;)V"))
jls.setErr(
autoclass('java.io.PrintStream')(
autoclass('org.terrier.python.ProxyableOutputStream')(stderr_ref),
autoclass('org.terrier.python.ProxyableOutputStream')(_stderr_ref),
signature="(Ljava/io/OutputStream;)V"))


Expand All @@ -137,6 +130,19 @@ def unsign(signed):
return bytearray([ unsign(buffer.get(offset)) for offset in range(buffer.capacity()) ])


# ----------------------------------------------------------
# Configuration
# ----------------------------------------------------------

configure = register_config('pyterrier.java', {
'jars': [],
'options': [],
'mem': None,
'log_level': 'WARN',
'redirect_io': True,
})


@before_init
def add_jar(jar_path):
configure.append('jars', jar_path)
Expand Down Expand Up @@ -183,6 +189,10 @@ def set_log_level(level):
J.PTUtils.setLogLevel(level, None) # noqa: PT100 handled by started() check above


# ----------------------------------------------------------
# Common classes (accessible via pt.java.J.[ClassName])
# ----------------------------------------------------------

J = JavaClasses({
'ArrayList': 'java.util.ArrayList',
'Properties': 'java.util.Properties',
Expand Down
246 changes: 149 additions & 97 deletions pyterrier/java/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,19 @@
from copy import deepcopy
import pyterrier as pt


_started = False
_configs = {}

class JavaInitializer:
"""
A `JavaInitializer` manages the initilization of a module that uses java components. The two main methods are
`pre_init` and `post_init`, which perform configuration before and after the JVM has started, respectively.
"""

def priority(self) -> int:
"""
Returns the priority of this initializer. A lower priority is executed first.
"""
return 0

def condition(self) -> bool:
"""
Returns True if the initializer should be run. Otherwise False.
"""
return True

def pre_init(self, jnius_config) -> None:
"""
Called before the JVM is started. `jnius_config` is the `jnius_config` module, whic can be used to configure
java, such as by adding jars to the classpath.
"""
pass

def post_init(self, jnius) -> None:
"""
Called after the JVM has started. `jnius` is the `jnius` module, which can be used to interact with java.
"""
pass

def message(self) -> Optional[str]:
"""
Returns a message to be displayed after the JVM has started alongside the name of the entry point. If None,
only the entry point name will be displayed.
"""
return None


def started() -> bool:
"""
Returns True if pt.java.init() has been called. Otherwise False.
"""
return _started

# ----------------------------------------------------------
# Decorators
# ----------------------------------------------------------
# These functions wrap functions/classes to enforce certain
# behavior regarding Java. For instance @pt.java.required
# automatically starts java before it's invoked (if it's not
# already started).
# ----------------------------------------------------------

@pt.utils.pre_invocation_decorator
def required(fn: Optional[Callable] = None) -> Union[Callable, bool]:
Expand All @@ -62,7 +27,7 @@ def required(fn: Optional[Callable] = None) -> Union[Callable, bool]:
Can be used as either a standalone function or a function/class @decorator. When used as a class decorator, it
is applied to all methods defined by the class.
"""
if not started():
if not _started:
init()


Expand Down Expand Up @@ -97,27 +62,40 @@ def _wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return _wrapper

class JavaClasses:
def __init__(self, mapping: Dict[str, str]):
self._mapping = mapping
self._cache = {}

def __dir__(self):
return list(self._mapping.keys())
# ----------------------------------------------------------
# Jnius Wrappers
# ----------------------------------------------------------
# These functions wrap jnius to make sure that java is
# running before they're called. Doing it this way allows
# functions to import them before java is loaded.
# ----------------------------------------------------------

@required_raise
def autoclass(*args, **kwargs):
"""
Wraps jnius.autoclass once java has started. Raises an error if called before pt.java.init() is called.
"""
import jnius # noqa: PT100
return jnius.autoclass(*args, **kwargs) # noqa: PT100

@required_raise
def __getattr__(self, key):
if key not in self._mapping:
return AttributeError(f'{self} has no attribute {key!r}')
if key not in self._cache:
clz = self._mapping[key]
if callable(clz):
clz = clz()
self._cache[key] = pt.java.autoclass(clz)
return self._cache[key]

@required_raise
def cast(*args, **kwargs):
"""
Wraps jnius.cast once java has started. Raises an error if called before pt.java.init() is called.
"""
import jnius # noqa: PT100
return jnius.cast(*args, **kwargs) # noqa: PT100


# ----------------------------------------------------------
# Init
# ----------------------------------------------------------
# This function (along with legacy_init) loads all modules
# registered via pyterrier.java.init entry points and starts
# the JVM.
# ----------------------------------------------------------

@pt.utils.once()
def init() -> None:
Expand Down Expand Up @@ -162,42 +140,6 @@ def init() -> None:
sys.stderr.write('Java started and loaded:\n' + ''.join(message))


def parallel_init(started: bool, configs: Dict[str, Dict[str, Any]]) -> None:
global _configs
if started:
if not pt.java.started():
warn(f'Starting java parallel with configs {configs}')
_configs = configs
init()
else:
warn("Avoiding reinit of PyTerrier")


def parallel_init_args() -> Tuple[bool, Dict[str, Dict[str, Any]]]:
return (
started(),
deepcopy(_configs),
)


@required_raise
def autoclass(*args, **kwargs):
"""
Wraps jnius.autoclass once java has started. Raises an error if called before pt.java.init() is called.
"""
import jnius # noqa: PT100
return jnius.autoclass(*args, **kwargs) # noqa: PT100


@required_raise
def cast(*args, **kwargs):
"""
Wraps jnius.cast once java has started. Raises an error if called before pt.java.init() is called.
"""
import jnius # noqa: PT100
return jnius.cast(*args, **kwargs) # noqa: PT100


@before_init
def legacy_init(version=None, mem=None, packages=[], jvm_opts=[], redirect_io=True, logging='WARN', home_dir=None, boot_packages=[], tqdm=None, no_download=False,helper_version = None):
"""
Expand Down Expand Up @@ -261,6 +203,88 @@ def legacy_init(version=None, mem=None, packages=[], jvm_opts=[], redirect_io=Tr
pt.terrier.set_property("terrier.mvn.coords", pkgs_string)


def started() -> bool:
"""
Returns True if pt.java.init() has been called. Otherwise False.
"""
return _started


class JavaInitializer:
"""
A `JavaInitializer` manages the initilization of a module that uses java components. The two main methods are
`pre_init` and `post_init`, which perform configuration before and after the JVM has started, respectively.
"""

def priority(self) -> int:
"""
Returns the priority of this initializer. A lower priority is executed first.
"""
return 0

def condition(self) -> bool:
"""
Returns True if the initializer should be run. Otherwise False.
"""
return True

def pre_init(self, jnius_config) -> None:
"""
Called before the JVM is started. `jnius_config` is the `jnius_config` module, whic can be used to configure
java, such as by adding jars to the classpath.
"""
pass

def post_init(self, jnius) -> None:
"""
Called after the JVM has started. `jnius` is the `jnius` module, which can be used to interact with java.
"""
pass

def message(self) -> Optional[str]:
"""
Returns a message to be displayed after the JVM has started alongside the name of the entry point. If None,
only the entry point name will be displayed.
"""
return None


# ----------------------------------------------------------
# Parallel
# ----------------------------------------------------------
# These functions are for working in parallel mode, e.g.,
# with multiprocessing. They help restarting and configure
# the JVM the same way it was when it was started in the
# parent process
# ----------------------------------------------------------

def parallel_init(started: bool, configs: Dict[str, Dict[str, Any]]) -> None:
global _configs
if started:
if not pt.java.started():
warn(f'Starting java parallel with configs {configs}')
_configs = configs
init()
else:
warn("Avoiding reinit of PyTerrier")


def parallel_init_args() -> Tuple[bool, Dict[str, Dict[str, Any]]]:
return (
started(),
deepcopy(_configs),
)


# ----------------------------------------------------------
# Configuration Utils
# ----------------------------------------------------------
# We need a global store of all java-related configurations
# so that when running in paralle, we can set everying back
# up the same way it started. These utils help manage this
# gloabal configuration.
# ----------------------------------------------------------

class Configuration:
def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -298,3 +322,31 @@ def register_config(name, config: Dict[str, Any]):
assert name not in _configs
_configs[name] = deepcopy(config)
return Configuration(name)


# ----------------------------------------------------------
# Java Classes
# ----------------------------------------------------------
# This class enables the lazy loading of java classes. It
# helps avod needing a ton of autclass() statements to
# pre-load Java classes.
# ----------------------------------------------------------

class JavaClasses:
def __init__(self, mapping: Dict[str, str]):
self._mapping = mapping
self._cache = {}

def __dir__(self):
return list(self._mapping.keys())

@required_raise
def __getattr__(self, key):
if key not in self._mapping:
return AttributeError(f'{self} has no attribute {key!r}')
if key not in self._cache:
clz = self._mapping[key]
if callable(clz):
clz = clz()
self._cache[key] = pt.java.autoclass(clz)
return self._cache[key]

0 comments on commit 0d1acdd

Please sign in to comment.