Skip to content

Commit

Permalink
fix: explicit 'import awkward' needed to write NumPy strings (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored Aug 15, 2024
1 parent ba3269b commit aa8b94f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/uproot/writing/_cascadetree.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,20 @@ def extend(self, file, sink, data):

if datum["counter"] is None:
if datum["dtype"] == ">U0":
lengths = numpy.asarray(awkward.num(branch_array.layout))
awkward = uproot.extras.awkward()

layout = awkward.to_layout(branch_array)
if isinstance(
layout,
(awkward.contents.ListArray, awkward.contents.RegularArray),
):
layout = layout.to_ListOffsetArray64()

lengths = numpy.asarray(awkward.num(layout))
which_big = lengths >= 255

lengths_extension_offsets = numpy.empty(
len(branch_array.layout) + 1, numpy.int64
len(layout) + 1, numpy.int64
)
lengths_extension_offsets[0] = 0
numpy.cumsum(which_big * 4, out=lengths_extension_offsets[1:])
Expand All @@ -685,16 +694,16 @@ def extend(self, file, sink, data):
[
lengths.reshape(-1, 1).astype("u1"),
lengths_extension,
awkward.without_parameters(branch_array.layout),
awkward.without_parameters(layout),
],
axis=1,
)

big_endian = numpy.asarray(awkward.flatten(leafc_data_awkward))
big_endian_offsets = (
lengths_extension_offsets
+ numpy.asarray(branch_array.layout.offsets)
+ numpy.arange(len(branch_array.layout.offsets))
+ numpy.asarray(layout.offsets)
+ numpy.arange(len(layout.offsets))
).astype(">i4", copy=True)
tofill.append(
(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_1264_write_NumPy_array_of_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import pytest
import uproot
import os
import numpy as np


def test(tmp_path):
newfile = os.path.join(tmp_path, "example.root")

with uproot.recreate(newfile) as f:
f["t"] = {"x": np.array(["A", "B"]), "y": np.array([1, 2])}
f["t"].extend({"x": np.array(["A", "B"]), "y": np.array([1, 2])})

with uproot.open(newfile) as f:
assert f["t"]["x"].array().tolist() == ["A", "B", "A", "B"]
assert f["t"]["y"].array().tolist() == [1, 2, 1, 2]

0 comments on commit aa8b94f

Please sign in to comment.