Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 9, 2025
2 parents 1e93045 + 8e42b5e commit 7a74a00
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ jobs:
python3 -mpip install wheel
TENSORDICT_BUILD_VERSION=0.6.2 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: tensordict-win-${{ matrix.python_version[0] }}.whl
path: dist/tensordict-*.whl
- name: Upload wheel for download
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: tensordict-batch.whl
path: dist/*.whl
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
run: |
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml
- name: Download built wheels
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: tensordict-win-${{ matrix.python_version }}.whl
path: wheels
Expand Down
2 changes: 1 addition & 1 deletion tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def pad_sequence(
try:
item0 = list_of_dicts[0][key]
if is_non_tensor(item0):
out.set(key, torch.stack([d[key] for d in list_of_dicts]))
out.set(key, TensorDict.lazy_stack([d[key] for d in list_of_dicts]))
continue
tensor_shape = item0.shape
pos_pad_dim = (
Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,13 @@ def test_pad_sequence_nontensor(self):
assert (d["a"] == torch.tensor([[1, 1], [2, 0]])).all()
assert d["b"] == ["asd", "efg"]

def test_pad_sequence_single_nontensor(self):
d1 = TensorDict({"a": torch.tensor([1, 1]), "b": "asd"})
d = pad_sequence([d1])
assert (d["a"] == torch.tensor([[1, 1]])).all()
assert d["b"] == ["asd"]
assert isinstance(d.get("b"), NonTensorStack)

def test_pad_sequence_tensorclass_nontensor(self):
@tensorclass
class Sample:
Expand Down

0 comments on commit 7a74a00

Please sign in to comment.