Skip to content

Commit

Permalink
Support printing DISCRETE sample sets
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Sep 25, 2020
1 parent ded8458 commit 33d4b7f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
21 changes: 17 additions & 4 deletions dimod/serialization/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,20 @@ def f(datum):

self.append(header, f)

def append_sample(self, v, vartype, _left=False):
def append_sample(self, v, vartype, _left=False, width=2):
"""Add a sample column"""
vstr = str(v).rjust(2) # the variable will be len 0, or 1
vstr = str(v).rjust(width)
length = len(vstr)

if vartype is dimod.SPIN:
def f(datum):
return _spinstr(datum.sample[v], rjust=length)
else:
elif vartype is dimod.BINARY:
def f(datum):
return _binarystr(datum.sample[v], rjust=length)
else:
def f(datum):
return str(datum.sample[v]).rjust(length)

self.append(vstr, f, _left=_left)

Expand Down Expand Up @@ -315,7 +318,17 @@ def _print_samples(self, sampleset, stream, width, depth, sorted_by):
table.rotate(-1) # move the index to the end
num_added = 0
for v in sampleset.variables:
table.append_sample(v, sampleset.vartype)
if sampleset.vartype is dimod.DISCRETE:
# in the discrete case we also need the width. There are faster
# ways to get it
valwidth = max(
(len(str(c)) for c in sampleset.samples()[:depth, v]),
default=0)
valwidth = max(len(str(v)), valwidth)

table.append_sample(v, sampleset.vartype, width=valwidth)
else:
table.append_sample(v, sampleset.vartype)
num_added += 1

if table.width > width:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_serialization_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,29 @@ def test_additional_fields(self):
"['BINARY', 2 rows, 2 samples, 5 variables]"])

self.assertEqual(target, s)

def test_discrete(self):
ss = dimod.SampleSet.from_samples(([[0, 17, 236], [3, 321, 1]], 'abc'),
'DISCRETE', energy=[1, 2])
s = Formatter(width=79, depth=None).format(ss)
target = '\n'.join([" a b c energy num_oc.",
"0 0 17 236 1 1",
"1 3 321 1 2 1",
"['DISCRETE', 2 rows, 2 samples, 3 variables]"])

self.assertEqual(target, s)

def test_depth(self):
ss = dimod.SampleSet.from_samples(([[0, 17, 236],
[3, 321, 1],
[4444444444, 312, 1],
[4, 3, 3]], 'abc'),
'DISCRETE', energy=[1, 2, 3, 4])
s = Formatter(width=79, depth=2).format(ss)
target = '\n'.join([" a b c energy num_oc.",
"0 0 17 236 1 1",
"...",
"3 4 3 3 4 1",
"['DISCRETE', 4 rows, 4 samples, 3 variables]"])

self.assertEqual(target, s)

0 comments on commit 33d4b7f

Please sign in to comment.