From 04a88a27f42cd7c1a69ac7d9f8c8e3abc0e755d3 Mon Sep 17 00:00:00 2001 From: Coull Date: Fri, 11 Oct 2024 09:46:54 -0700 Subject: [PATCH] fix: correct typing for task results methods --- src/braket/aws/aws_quantum_task_batch.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/braket/aws/aws_quantum_task_batch.py b/src/braket/aws/aws_quantum_task_batch.py index 300963a6f..f4f1d328a 100644 --- a/src/braket/aws/aws_quantum_task_batch.py +++ b/src/braket/aws/aws_quantum_task_batch.py @@ -16,7 +16,7 @@ import time from concurrent.futures.thread import ThreadPoolExecutor from itertools import repeat -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing import Problem @@ -30,6 +30,11 @@ from braket.registers.qubit_set import QubitSet from braket.tasks.quantum_task_batch import QuantumTaskBatch +if TYPE_CHECKING: + from braket.tasks.annealing_quantum_task_result import AnnealingQuantumTaskResult + from braket.tasks.gate_model_quantum_task_result import GateModelQuantumTaskResult + from braket.tasks.photonic_model_quantum_task_result import PhotonicModelQuantumTaskResult + class AwsQuantumTaskBatch(QuantumTaskBatch): """Executes a batch of quantum tasks in parallel. @@ -331,7 +336,9 @@ def results( fail_unsuccessful: bool = False, max_retries: int = MAX_RETRIES, use_cached_value: bool = True, - ) -> list[AwsQuantumTask]: + ) -> list[ + GateModelQuantumTaskResult | AnnealingQuantumTaskResult | PhotonicModelQuantumTaskResult + ]: """Retrieves the result of every quantum task in the batch. Polling for results happens in parallel; this method returns when all quantum tasks @@ -348,7 +355,8 @@ def results( even when results have already been cached. Default: `True`. Returns: - list[AwsQuantumTask]: The results of all of the quantum tasks in the batch. + list[GateModelQuantumTaskResult | AnnealingQuantumTaskResult | PhotonicModelQuantumTaskResult]: The # noqa: E501 + results of all of the quantum tasks in the batch. `FAILED`, `CANCELLED`, or timed out quantum tasks will have a result of None """ if not self._results or not use_cached_value: @@ -369,7 +377,11 @@ def results( return self._results @staticmethod - def _retrieve_results(tasks: list[AwsQuantumTask], max_workers: int) -> list[AwsQuantumTask]: + def _retrieve_results( + tasks: list[AwsQuantumTask], max_workers: int + ) -> list[ + GateModelQuantumTaskResult | AnnealingQuantumTaskResult | PhotonicModelQuantumTaskResult + ]: with ThreadPoolExecutor(max_workers=max_workers) as executor: result_futures = [executor.submit(task.result) for task in tasks] return [future.result() for future in result_futures]