Skip to content

Commit

Permalink
[Compilation] Adding compiled_with Program attribute (#447)
Browse files Browse the repository at this point in the history
* Adding compiled_with attribute

* CHANGELOG

* Prog changes

* Testing

* Further tests

* Formatting with black

* f string

* Update strawberryfields/engine.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Update strawberryfields/engine.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Adding recompile logic

* Update strawberryfields/engine.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Update strawberryfields/engine.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Making recompile kwarg

* Adding further tests

* Reorganizing if statements; adjusting messages and tests

* Rewording docstring

* Linting formatting

* Updating comment

* Removing if statement for error; tidying tests

* Remove comment

* Update strawberryfields/engine.py

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Adding recompile to the signature of run_async

* Formatting with black

Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
antalszava and josh146 authored Sep 10, 2020
1 parent e46bd12 commit 656866b
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 6 deletions.
6 changes: 5 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

<h3>Improvements</h3>

* Adds the ability to bypass recompilation of programs if they have been
compiled already to the target device.
[(#447)](https://github.com/XanaduAI/strawberryfields/pull/447)

<h3>Breaking Changes</h3>

* Changes the default compiler for devices that don't specify a default from `"Xcov"` to `"Xunitary"`.
Expand All @@ -23,7 +27,7 @@

This release contains contributions from (in alphabetical order):

Josh Izaac, Nicolás Quesada
Josh Izaac, Nicolás Quesada, Antal Száva

# Release 0.15.0 (current release)

Expand Down
1 change: 1 addition & 0 deletions strawberryfields/api/devicespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def create_program(self, **parameters):
# evaluate the blackbird template
bb = bb(**parameters)
prog = sf.io.to_program(bb)
prog._compile_info = (self, self.default_compiler)
return prog

def refresh(self):
Expand Down
53 changes: 49 additions & 4 deletions strawberryfields/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def run(self, program: Program, *, compile_options=None, **kwargs) -> Optional[R
Args:
program (strawberryfields.Program): the quantum circuit
compile_options (None, Dict[str, Any]): keyword arguments for :meth:`.Program.compile`
Keyword Args:
shots (Optional[int]): The number of shots for which to run the job. If this
Expand Down Expand Up @@ -590,7 +591,9 @@ def run(self, program: Program, *, compile_options=None, **kwargs) -> Optional[R
self._connection.cancel_job(job.id)
raise KeyboardInterrupt("The job has been cancelled.")

def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job:
def run_async(
self, program: Program, *, compile_options=None, recompile=False, **kwargs
) -> Job:
"""Runs a non-blocking remote job.
In the non-blocking mode, a ``Job`` object is returned immediately, and the user can
Expand All @@ -599,6 +602,8 @@ def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job:
Args:
program (strawberryfields.Program): the quantum circuit
compile_options (None, Dict[str, Any]): keyword arguments for :meth:`.Program.compile`
recompile (bool): Specifies if ``program`` should be recompiled
using ``compile_options``, or if not provided, the default compilation options.
Keyword Args:
shots (Optional[int]): The number of shots for which to run the job. If this
Expand All @@ -614,10 +619,50 @@ def run_async(self, program: Program, *, compile_options=None, **kwargs) -> Job:
device = self.device_spec

compiler_name = compile_options.get("compiler", device.default_compiler)
msg = f"Compiling program for device {device.target} using compiler {compiler_name}."
self.log.info(msg)

program = program.compile(device=device, **compile_options)
program_is_compiled = program.compile_info is not None

if program_is_compiled and not recompile:
# error handling for program compilation:
# program was compiled but recompilation was not allowed by the
# user

if (
program.compile_info[0].target != device.target
or program.compile_info[0]._spec != device._spec
):
# program was compiled for a different device
raise ValueError(
"Cannot use program compiled with "
f"{program._compile_info[0].target} for target {self.target}. "
'Pass the "recompile=True" keyword argument '
f"to compile with {compiler_name}."
)

if not program_is_compiled:
# program is not compiled
msg = f"Compiling program for device {device.target} using compiler {compiler_name}."
self.log.info(msg)
program = program.compile(device=device, **compile_options)

elif recompile:
# recompiling program
if compile_options:
msg = f"Recompiling program for device {device.target} using the specified compiler options: {compile_options}."
else:
msg = f"Recompiling program for device {device.target} using compiler {compiler_name}."

self.log.info(msg)
program = program.compile(device=device, **compile_options)

else:
# validating program
msg = (
f"Program previously compiled for {device.target} using {program.compile_info[1]}. "
f"Validating program against the Xstrict compiler."
)
self.log.info(msg)
program = program.compile(device=device, compiler="Xstrict")

# update the run options if provided
run_options = {}
Expand Down
17 changes: 17 additions & 0 deletions strawberryfields/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __init__(self, num_subsystems, name=None):
self.locked = False
#: str, None: for compiled Programs, the short name of the target Compiler template, otherwise None
self._target = None
#: tuple, None: for compiled Programs, the device spec and the short
# name of Compiler that was used, otherwise None
self._compile_info = None
#: Program, None: for compiled Programs, this is the original, otherwise None
self.source = None
#: dict[str, Parameter]: free circuit parameters owned by this Program
Expand Down Expand Up @@ -559,6 +562,7 @@ def _get_compiler(compiler_or_name):
compiled = self._linked_copy()
compiled.circuit = seq
compiled._target = target
compiled._compile_info = (device, compiler.short_name)

# Get run options of compiled program.
run_options = {k: kwargs[k] for k in ALLOWED_RUN_OPTIONS if k in kwargs}
Expand Down Expand Up @@ -655,6 +659,19 @@ def target(self):
"""
return self._target

@property
def compile_info(self):
"""The device specification and the compiler that was used during
compilation.
If the program has not been compiled, this will return ``None``.
Returns:
tuple or None: device specification and the short name of the
Compiler that was used if compiled, otherwise None
"""
return self._compile_info

def params(self, *args):
"""Create and access free circuit parameters.
Expand Down
134 changes: 133 additions & 1 deletion tests/api/test_remote_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import pytest

import strawberryfields as sf
from strawberryfields.api import Connection, Job, JobStatus, Result
from strawberryfields.api import Connection, DeviceSpec, Job, JobStatus, Result
from strawberryfields.engine import RemoteEngine
from strawberryfields.program import Program

from .conftest import mock_return

Expand Down Expand Up @@ -207,3 +208,134 @@ def test_default_compiler(self, prog, monkeypatch, caplog):

assert engine.device_spec.default_compiler == "Xunitary"
assert caplog.records[-1].message == "Compiling program for device X8_01 using compiler Xunitary."

class MockProgram:
"""A mock program for testing"""
def __init__(self):
self.run_options = {}

def test_compile_device_invalid_device_error(self, prog, monkeypatch, caplog):
"""Tests that an error is raised if the program was compiled for
another device and recompilation was not requested."""
caplog.set_level(logging.INFO)
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

# Setting compile_info with a dummy devicespec and compiler name
X8_spec = DeviceSpec(target="DummyDevice", connection=None, spec=None)
prog._compile_info = (X8_spec, "dummy_compiler")

engine = sf.RemoteEngine("X8")
with pytest.raises(
ValueError, match="Cannot use program compiled"
):
program = engine.run_async(prog, shots=10)

def test_compile(self, prog, monkeypatch, caplog):
"""Tests that compilation happens by default if no compile_info was
specified when call run_async."""
caplog.set_level(logging.INFO)
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

# Leaving compile_info as None
assert prog.compile_info == None

engine = RemoteEngine("X8")
program = engine.run_async(prog, shots=10)

assert isinstance(program, self.MockProgram)
assert caplog.records[-1].message == "Compiling program for device X8_01 using compiler Xunitary."

def test_recompilation(self, prog, monkeypatch, caplog):
"""Test that recompilation happens when the recompile keyword argument
was set to True."""
compiler = "Xunitary"

caplog.set_level(logging.INFO)
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = compiler

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)

compile_options = {"compiler": compiler}

engine = sf.RemoteEngine("X8")

device = engine.device_spec

# Setting compile_info
prog._compile_info = (device, device.compiler)

program = engine.run_async(prog, shots=10, compile_options=compile_options, recompile=True)

# No recompilation, original Program
assert caplog.records[-1].message == (f"Recompiling program for device "
f"{device.target} using the specified compiler options: "
f"{compile_options}.")

def test_recompilation_precompiled(self, prog, monkeypatch, caplog):
"""Test that recompilation happens when:
1. the program was precompiled
2. but the recompile keyword argument was set to True.
The program is considered to be precompiled if program.compile_info was
set (setting it in the test case).
"""
caplog.set_level(logging.INFO)
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

# Setting compile_info
prog._compile_info = (None, "dummy_compiler")

# Setting compile_info with a dummy devicespec and compiler name
X8_spec = DeviceSpec(target="DummyDevice", connection=None, spec=None)
prog._compile_info = (X8_spec, "dummy_compiler")

engine = sf.RemoteEngine("X8")

compile_options = None

# Setting recompile in keyword arguments
program = engine.run_async(prog, shots=10, compile_options=compile_options, recompile=True)
assert isinstance(program, self.MockProgram)
assert caplog.records[-1].message == "Recompiling program for device X8_01 using compiler Xunitary."

def test_validation(self, prog, monkeypatch, caplog):
"""Test that validation happens (no recompilation) when the target
device and device spec match."""
compiler = "Xunitary"

caplog.set_level(logging.INFO)
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = compiler

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)

engine = sf.RemoteEngine("X8_01")

device = engine.device_spec

# Setting compile_info
prog._compile_info = (device, device.compiler)

program = engine.run_async(prog, shots=10)

# No recompilation, original Program
assert caplog.records[-1].message == (f"Program previously compiled for {device.target} using {prog.compile_info[1]}. "
f"Validating program against the Xstrict compiler.")

0 comments on commit 656866b

Please sign in to comment.