Skip to content

Commit

Permalink
Fixes for speedier Python-Julia interaction (#32)
Browse files Browse the repository at this point in the history
* change: Use multiprocessing Pool rather than ProcessPoolExecutor

* fix: No sigterm screaming

* fix: Force close and join at exit

* fix: Faster

* fix: handle julia error in serializable way

* change: Add benchmark workflow to CI

* change: Point to latest unblocking branch

* fix: linting

* fix: Don't turn off juliapkg in tox

* Update src/braket/simulator_v2/julia_workers.py

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>

* Update src/braket/simulator_v2/julia_workers.py

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>

* fix: one-line pip in benchmark

* fix: typo

* fix: Use benchmark-json option correctly

* change: Add initial output.json

* fix: Actually include benchmark script

* fix: Don't deploy benchmark results to gh-pages

* fix: restore GH token

* fix: Remove benchmarks for now

* change: Point to new BraketSimulator 0.0.4

---------

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>
  • Loading branch information
kshyatt-aws and rmshaffer authored Aug 27, 2024
1 parent 2e8fb1e commit a6e9e80
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 195 deletions.
2 changes: 1 addition & 1 deletion src/braket/juliapkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"packages": {
"BraketSimulator": {
"uuid": "76d27892-9a0b-406c-98e4-7c178e9b3dff",
"version": "0.0.3"
"version": "0.0.4"
},
"JSON3": {
"uuid": "0f8b85d8-7281-11e9-16c2-39a750bddbf1",
Expand Down
1 change: 0 additions & 1 deletion src/braket/simulator_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from braket.simulator_v2.density_matrix_simulator_v2 import ( # noqa: F401
DensityMatrixSimulatorV2,
)
from braket.simulator_v2.julia_import import setup_julia # noqa: F401
from braket.simulator_v2.state_vector_simulator_v2 import ( # noqa: F401
StateVectorSimulatorV2,
)
Expand Down
209 changes: 90 additions & 119 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,97 @@
import sys
import atexit
import json
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor, wait
from typing import List, Optional, Union
from multiprocessing.pool import Pool
from typing import Optional, Union

import numpy as np
from braket.default_simulator.simulator import BaseLocalSimulator
from braket.ir.jaqcd import DensityMatrix, Probability, StateVector
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import GateModelTaskResult

from braket.simulator_v2.julia_import import setup_julia
from braket.simulator_v2.julia_workers import (
_handle_julia_error,
translate_and_run,
translate_and_run_multiple,
)

__JULIA_POOL__ = None

def _handle_julia_error(error):
# we don't import `JuliaError` explicitly here to avoid
# having to import juliacall on the main thread. we need
# to call *this* function on that thread in case getting
# the result from the submitted Future raises an exception
if type(error).__name__ == "JuliaError":
python_exception = getattr(error.exception, "alternate_type", None)
if python_exception is None:
py_error = error
else:
class_val = getattr(sys.modules["builtins"], str(python_exception))
py_error = class_val(str(error.exception.message))
raise py_error

def setup_julia():
import os
import sys

# don't reimport if we don't have to
if "juliacall" in sys.modules:
os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
return sys.modules["juliacall"].Main
else:
raise error


def translate_and_run(
device_id: str, openqasm_ir: OpenQASMProgram, shots: int = 0
) -> str:
jl = setup_julia()
jl_shots = shots
jl_inputs = (
jl.Dict[jl.String, jl.Any](
jl.Pair(jl.convert(jl.String, k), jl.convert(jl.Any, v))
for (k, v) in openqasm_ir.inputs.items()
)
if openqasm_ir.inputs
else jl.Dict[jl.String, jl.Any]()
)
if device_id == "braket_sv_v2":
device = jl.BraketSimulator.StateVectorSimulator(0, 0)
elif device_id == "braket_dm_v2":
device = jl.BraketSimulator.DensityMatrixSimulator(0, 0)

try:
result = jl.BraketSimulator.simulate._jl_call_nogil(
device,
openqasm_ir.source,
jl_inputs,
jl_shots,
for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)
# install Julia and any packages as needed
os.environ["PYTHON_JULIAPKG_OFFLINE"] = "yes"
import juliacall

jl = juliacall.Main
jl.seval("using JSON3, BraketSimulator")
sv_stock_oq3 = """
OPENQASM 3.0;
input float theta;
qubit[2] q;
h q[0];
cnot q;
x q[0];
xx(theta) q;
yy(theta) q;
zz(theta) q;
#pragma braket result expectation z(q[0])
"""
dm_stock_oq3 = """
OPENQASM 3.0;
input float theta;
qubit[2] q;
h q[0];
x q[0];
cnot q;
xx(theta) q;
yy(theta) q;
zz(theta) q;
#pragma braket result probability
"""
r = jl.BraketSimulator.simulate(
"braket_sv_v2", sv_stock_oq3, '{"theta": 0.1}', 0
)
py_result = str(result)
return py_result
except Exception as e:
_handle_julia_error(e)


def translate_and_run_multiple(
device_id: str,
programs: Sequence[OpenQASMProgram],
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = {},
) -> List[str]:
jl = setup_julia()
irs = jl.Vector[jl.String]()
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
jl_inputs = jl.Vector[jl.Dict[jl.String, jl.Any]]()
for p_ix, program in enumerate(programs):
irs.append(program.source)
if program.inputs:
jl_inputs.append(program.inputs | py_inputs[p_ix])
else:
jl_inputs.append(py_inputs[p_ix])

if device_id == "braket_sv_v2":
device = jl.BraketSimulator.StateVectorSimulator(0, 0)
elif device_id == "braket_dm_v2":
device = jl.BraketSimulator.DensityMatrixSimulator(0, 0)

try:
results = jl.BraketSimulator.simulate._jl_call_nogil(
device,
irs,
jl_inputs,
shots,
jl.JSON3.write(r)
r = jl.BraketSimulator.simulate(
"braket_dm_v2", dm_stock_oq3, '{"theta": 0.1}', 0
)
py_results = [str(result) for result in results]
except Exception as e:
_handle_julia_error(e)
return py_results
jl.JSON3.write(r)
return jl


def setup_pool():
global __JULIA_POOL__
__JULIA_POOL__ = Pool(processes=1)
__JULIA_POOL__.apply(setup_julia)
atexit.register(__JULIA_POOL__.join)
atexit.register(__JULIA_POOL__.close)
return


class BaseLocalSimulatorV2(BaseLocalSimulator):
def __init__(self, device: str):
global __JULIA_POOL__
if __JULIA_POOL__ is None:
setup_pool()
self._device = device
executor = ProcessPoolExecutor(max_workers=1, initializer=setup_julia)

def no_op():
pass

# trigger worker creation here, because workers are created
# on an as-needed basis, *not* when the executor is created
f = executor.submit(no_op)
wait([f])
self._executor = executor

def __del__(self):
self._executor.shutdown(wait=False)

def initialize_simulation(self, **kwargs):
return
Expand All @@ -143,18 +117,17 @@ def run_openqasm(
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
f = self._executor.submit(
translate_and_run,
self._device,
openqasm_ir,
shots,
)
global __JULIA_POOL__
try:
jl_result = f.result()
jl_result = __JULIA_POOL__.apply(
translate_and_run,
[self._device, openqasm_ir, shots],
)
except Exception as e:
_handle_julia_error(e)

result = GateModelTaskResult.parse_raw_schema(jl_result)
result = GateModelTaskResult(**json.loads(jl_result))
jl_result = None
result.additionalMetadata.action = openqasm_ir

# attach the result types
Expand Down Expand Up @@ -183,21 +156,19 @@ def run_multiple(
list[GateModelTaskResult]: A list of result objects, with the ith object being
the result of the ith program.
"""
f = self._executor.submit(
translate_and_run_multiple,
self._device,
programs,
shots,
inputs,
)
global __JULIA_POOL__
try:
jl_results = f.result()
jl_results = __JULIA_POOL__.apply(
translate_and_run_multiple,
[self._device, programs, shots, inputs],
)
except Exception as e:
_handle_julia_error(e)

results = [
GateModelTaskResult.parse_raw_schema(jl_result) for jl_result in jl_results
GateModelTaskResult(**json.loads(jl_result)) for jl_result in jl_results
]
jl_results = None
for p_ix, program in enumerate(programs):
results[p_ix].additionalMetadata.action = program

Expand Down
72 changes: 0 additions & 72 deletions src/braket/simulator_v2/julia_import.py

This file was deleted.

Loading

0 comments on commit a6e9e80

Please sign in to comment.