Skip to content

Commit

Permalink
replaced LheData.from_path with LheData.from_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Dec 5, 2022
1 parent 65990e1 commit 2c7debf
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 46 deletions.
4 changes: 1 addition & 3 deletions showerpipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
event generation programs for Pythonic usage, and exposing the data via
NumPy arrays.
"""
from . import generator
from . import lhe
from ._version import __version__


__all__ = ["generator", "lhe", "__version__"]
__all__ = ["__version__"]
120 changes: 77 additions & 43 deletions showerpipe/lhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
redistribute and repeat hard events, outputting valid lhe files.
"""

from typing import Union, BinaryIO, Iterator, Optional
from typing import Union, Iterator, Optional, Callable
import io
import re
import gzip
import requests # type: ignore
from contextlib import contextmanager
import shutil
from contextlib import contextmanager, ExitStack
import tempfile
from pathlib import Path
from copy import deepcopy
from itertools import chain
from functools import cached_property
from xml.sax.saxutils import unescape
from urllib.parse import urlparse
from urllib.request import urlopen

from lxml import etree # type: ignore
from lxml.etree import ElementBase
Expand All @@ -35,7 +37,7 @@


@contextmanager
def source_adapter(source: _LHE_STORAGE) -> Iterator[BinaryIO]:
def source_adapter(source: _LHE_STORAGE) -> Iterator[io.BufferedIOBase]:
"""Context manager to provide a consistent adapter interface for LHE
data stored in various formats.
Expand All @@ -49,44 +51,51 @@ def source_adapter(source: _LHE_STORAGE) -> Iterator[BinaryIO]:
Returns
-------
LHE data : io.BytesIO, io.BufferedReader
File-like object containing the Les Houches data. Interface is
io.BytesIO if source is string or bytestring, and
io.BufferedReader if source is filepath.
lhe_file : io.BufferedIOBase
File-like object containing the Les Houches data. If ``source``
is a URL pointing to a non-gzipped file, ``lhe_file`` will not
be seekable.
"""
is_bytes = isinstance(source, bytes)
is_str = isinstance(source, str)
is_path = (is_str or isinstance(source, Path)) and Path(source).exists()
is_path = (is_str or isinstance(source, Path)) and Path(
source
).exists() # type: ignore
is_url = False
out_io: io.BufferedIOBase
if is_str:
is_url = bool(urlparse(source).netloc) # type: ignore
if is_path: # provide a file-object referring to the actual file
if is_path:
path = Path(source) # type: ignore
try:
with open(path, "r") as lhe_filecheck:
with io.open(path, "r") as lhe_filecheck:
lhe_filecheck.read(1)
lhe_file = open(path, "rb")
yield lhe_file
out_io = io.open(path, "rb")
yield out_io
except UnicodeDecodeError:
lhe_file = gzip.open(path, "rb") # type: ignore
lhe_file.read(1)
lhe_file.seek(0)
yield lhe_file # type: ignore
out_io = gzip.open(path, "rb")
out_io.read(1)
out_io.seek(0)
yield out_io
finally:
lhe_file.close() # type: ignore
out_io.close() # type: ignore
elif is_str or is_bytes or is_url: # create a BytesIO file-object
if is_url:
lhe_request = requests.get(source) # type: ignore
lhe_content = lhe_request.content
lhe_response = urlopen(source) # type: ignore
out_io = gzip.GzipFile(fileobj=lhe_response, mode="rb")
try:
xml_bytes = gzip.decompress(lhe_content)
out_io.read(1)
out_io.seek(0)
except gzip.BadGzipFile:
xml_bytes = lhe_content
elif is_str:
xml_bytes = source.encode()
elif is_bytes:
xml_bytes = source
out_io = io.BytesIO(xml_bytes)
lhe_response.close()
out_io.close()
out_io = urlopen(source) # type: ignore
else:
if is_str:
xml_bytes = source.encode() # type: ignore
else:
xml_bytes: bytes = source # type: ignore
out_io = io.BytesIO(xml_bytes)
try:
yield out_io
finally:
Expand Down Expand Up @@ -153,7 +162,7 @@ def count_events(source: _LHE_STORAGE) -> int:
return num_events


def split(source: _LHE_STORAGE, stride: int):
def split(source: _LHE_STORAGE, stride: int) -> Iterator[bytes]:
"""Generator, splitting LHE file content into separate bytestrings
representing LHE files, with maximum number of events per bytestring
equal to stride.
Expand All @@ -176,13 +185,19 @@ def split(source: _LHE_STORAGE, stride: int):
-----
Particularly useful for large LHE files, which cannot fit in memory.
"""
with source_adapter(source) as xml_source:
with ExitStack() as stack:
xml_source = stack.enter_context(source_adapter(source))
if not xml_source.seekable():
temp = stack.enter_context(tempfile.TemporaryFile())
shutil.copyfileobj(xml_source, temp)
temp.seek(0)
xml_source = temp # type: ignore
lhe_root_tagname = "LesHouchesEvents"
lhe_root_parser = etree.iterparse(
source=xml_source,
events=("start",),
tag=(lhe_root_tagname,),
**_parse_kwargs
**_parse_kwargs,
)
_, lhe_root_meta = next(lhe_root_parser)
lhe_root_template = etree.Element(lhe_root_tagname)
Expand Down Expand Up @@ -225,10 +240,10 @@ def split(source: _LHE_STORAGE, stride: int):
yield etree.tostring(lhe_root)


def _root_to_bytes(root):
def _root_to_bytes(root: ElementBase) -> bytes:
content_invalid = etree.tostring(root)

def unescape_bytes(x):
def unescape_bytes(x: bytes) -> bytes:
return unescape(x.decode()).encode()

content = unescape_bytes(content_invalid)
Expand All @@ -255,30 +270,46 @@ class LheData:
Returns bytes content, with additional events by tiling.
"""

def __init__(self, content: bytes):
def __init__(self, content: bytes) -> None:
self._root: ElementBase = etree.fromstring(content)

def __repr__(self) -> str:
return f"LheData(num_events={self.num_events})"

@classmethod
def from_path(cls, path: Union[str, Path]) -> "LheData":
return cls(etree.fromstring(load_lhe(path)))
def from_storage(cls, storage: Union[str, Path]) -> "LheData":
"""Loads the LHE data directly from the given file location.
Parameters
----------
storage : str, Path
File location. Can be string, path, or a URL.
Returns
-------
lhe_data : LheData
Instance of LheData loaded with the data from the given
``storage``.
"""
return cls(load_lhe(storage))

@property
def content(self) -> bytes:
"""The LHE file contents in bytes."""
return _root_to_bytes(self._root)

@content.setter
def content(self, data):
def content(self, data: bytes) -> None:
del self.num_events
self._root = etree.fromstring(data)

@cached_property
def num_events(self):
def num_events(self) -> int:
return len(self._root.findall("event"))

@property
def _event_iter(self):
return self._root.iter("event")
def _event_iter(self) -> Iterator[ElementBase]:
return self._root.iter("event") # type: ignore

def repeat(self, repeats: int, inplace: bool = False) -> Optional[bytes]:
"""Modifies LHE content, repeating each event the number of
Expand Down Expand Up @@ -336,13 +367,13 @@ def tile(self, repeats: int, inplace: bool = False) -> Optional[bytes]:
repeats, inplace, dup_strat=self._tile_order
)

def _tile_order(self, x):
def _tile_order(self, x: Iterator[ElementBase]) -> Iterator[ElementBase]:
return x

def _repeat_order(self, x):
def _repeat_order(self, x: Iterator[ElementBase]) -> Iterator[ElementBase]:
return zip(*x)

def _build_root(self, event_iter):
def _build_root(self, event_iter: Iterator[ElementBase]) -> ElementBase:
root = deepcopy(self._root)
for event in root.findall("event"):
root.remove(event)
Expand All @@ -351,7 +382,10 @@ def _build_root(self, event_iter):
return root

def _event_duplicator(
self, repeats: int, inplace: bool, dup_strat
self,
repeats: int,
inplace: bool,
dup_strat: Callable[[Iterator[ElementBase]], Iterator[ElementBase]],
) -> Optional[bytes]:
tiled_lists = (self._event_iter for _ in range(repeats))
dup_events = chain.from_iterable(dup_strat(tiled_lists))
Expand Down

0 comments on commit 2c7debf

Please sign in to comment.