Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qml.sample() without specifying the observable #266

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def get_adjoint_gradient_result_type(
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type(
def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Expand All @@ -547,6 +547,7 @@ def translate_result_type(
then this will return a result type for each term.
"""
return_type = measurement.return_type
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand All @@ -558,14 +559,19 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(measurement.obs)
if observable is None:
if return_type is ObservableReturnTypes.Sample:
return tuple(Sample(observables.Z(), target) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
elif return_type is ObservableReturnTypes.Variance:
Expand Down Expand Up @@ -698,6 +704,14 @@ def translate_result(
ag_result.value["gradient"][f"p_{i}"]
for i in sorted(key_indices)
]

if observable is None:
if measurement.return_type is ObservableReturnTypes.Sample:
if targets:
return [m[targets] for m in braket_result.measurements]
return braket_result.measurements
raise NotImplementedError(f"Unsupported measurement type: {type(measurement)}")

translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
Expand Down
80 changes: 80 additions & 0 deletions test/integ_tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,86 @@
class TestSample:
"""Tests for the sample return type"""

def test_sample_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
return qml.sample()

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(3)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2])

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_batch_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified and
the batch dimension is returned
"""
dev = device(3)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample()

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_batch_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(4)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2, 3])

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_values(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values
Expand Down
Loading