Skip to content

Commit

Permalink
De-bug function and tests for grouping H terms
Browse files Browse the repository at this point in the history
  • Loading branch information
chmwzc committed Apr 11, 2024
1 parent 08686f2 commit 1b9f05a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
21 changes: 9 additions & 12 deletions src/qibochem/measurement/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,29 @@ def group_commuting_terms(terms_list, qubitwise):
commute), which this function follows.
Args:
terms_list: List of strings that contain X/Y. The strings should follow the output from
``[factor.name for factor in term.factors]``, where term is a Qibo SymbolicTerm. E.g. ["X0", "Z1"]
terms_list: List of strings. The strings should follow the output from
``" ".join(factor.name for factor in term.factors)``, where term is a Qibo SymbolicTerm. E.g. "X0 Z1".
qubitwise: Determines if the check is for general commutativity, or the stricter qubitwise commutativity
Returns:
list: Containing groups (lists) of Pauli strings that all commute mutually
"""
# Need to convert the lists in _terms_list to strings so that they can be used as graph nodes
terms_str = ["".join(term) for term in terms_list]
# print(terms_str)

G = nx.Graph()
# Complement graph: Add all the terms as nodes first, then add edges between nodes if they DO NOT commute
G.add_nodes_from(terms_str)
G.add_nodes_from(terms_list)
G.add_edges_from(
(term1, term2)
for _i1, term1 in enumerate(terms_str)
for _i2, term2 in enumerate(terms_str)
for _i1, term1 in enumerate(terms_list)
for _i2, term2 in enumerate(terms_list)
if _i2 > _i1 and not check_terms_commutativity(term1, term2, qubitwise)
)

# Solve using Greedy Colouring on NetworkX
term_groups = nx.coloring.greedy_color(G)
group_ids = set(term_groups.values())
sorted_groups = [[group.split() for group, group_id in term_groups.items() if group_id == _id] for _id in group_ids]
print(sorted_groups)
# Sort results so that test results will be replicable
sorted_groups = sorted(
sorted(group for group, group_id in term_groups.items() if group_id == _id) for _id in group_ids
)
return sorted_groups


Expand Down
26 changes: 12 additions & 14 deletions tests/test_measurement_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,15 @@ def test_check_terms_commutativity(term1, term2, qwc_expected, gc_expected):
assert gc_result == gc_expected


# @pytest.mark.parametrize(
# "term_list,qwc_expected,gc_expected",
# [
# ([["X0"], ["Z0"], ["Z1"], ["Z0", "Z1"]], {[["Z0"], ["Z1"], ["Z0", "Z1"]], ["X0"]}, {[["Z0"], ["Z1"], ["Z0", "Z1"]], ["X0"]}),
# # (["X0"], ["Z1"], True, True),
# # (["X0", "X1"], ["Y0", "Y1"], False, True),
# # (["X0", "Y1"], ["Y0", "Y1"], False, False),
# ],
# )
# def test_group_commuting_terms(term_list, qwc_expected, gc_expected):
# qwc_result = group_commuting_terms(term_list, qubitwise=True)
# assert set(qwc_result) == qwc_expected
# gc_result = group_commuting_terms(term_list, qubitwise=False)
# assert set(gc_result) == gc_expected
@pytest.mark.parametrize(
"term_list,qwc_expected,gc_expected",
[
(["X0 Z1", "X0", "Z0", "Z0 Z1"], [["X0", "X0 Z1"], ["Z0", "Z0 Z1"]], [["X0", "X0 Z1"], ["Z0", "Z0 Z1"]]),
(["X0 Y1 Z2", "X1 X2", "Z1 Y2"], [["X0 Y1 Z2"], ["X1 X2"], ["Z1 Y2"]], [["X0 Y1 Z2", "X1 X2", "Z1 Y2"]]),
],
)
def test_group_commuting_terms(term_list, qwc_expected, gc_expected):
qwc_result = group_commuting_terms(term_list, qubitwise=True)
assert qwc_result == qwc_expected
gc_result = group_commuting_terms(term_list, qubitwise=False)
assert gc_result == gc_expected

0 comments on commit 1b9f05a

Please sign in to comment.