diff --git a/crates/accelerate/src/commutation_checker.rs b/crates/accelerate/src/commutation_checker.rs index fe242c73422..52d4900efa5 100644 --- a/crates/accelerate/src/commutation_checker.rs +++ b/crates/accelerate/src/commutation_checker.rs @@ -28,14 +28,21 @@ use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationF use qiskit_circuit::dag_node::DAGOpNode; use qiskit_circuit::imports::QI_OPERATOR; use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType}; -use qiskit_circuit::operations::{Operation, OperationRef, Param, StandardGate}; +use qiskit_circuit::operations::{ + get_standard_gate_names, Operation, OperationRef, Param, StandardGate, +}; use qiskit_circuit::{BitType, Clbit, Qubit}; use crate::unitary_compose; use crate::QiskitError; +const TWOPI: f64 = 2.0 * std::f64::consts::PI; + +// These gates do not commute with other gates, we do not check them. static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"]; -static NO_CACHE_NAMES: [&str; 2] = ["annotated", "linear_function"]; + +// We keep a hash-set of operations eligible for commutation checking. This is because checking +// eligibility is not for free. static SUPPORTED_OP: Lazy> = Lazy::new(|| { HashSet::from([ "rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx", @@ -43,9 +50,7 @@ static SUPPORTED_OP: Lazy> = Lazy::new(|| { ]) }); -const TWOPI: f64 = 2.0 * std::f64::consts::PI; - -// map rotation gates to their generators, or to ``None`` if we cannot currently efficiently +// Map rotation gates to their generators, or to ``None`` if we cannot currently efficiently // represent the generator in Rust and store the commutation relation in the commutation dictionary static SUPPORTED_ROTATIONS: Lazy>> = Lazy::new(|| { HashMap::from([ @@ -322,15 +327,17 @@ impl CommutationChecker { (qargs1, qargs2) }; - let skip_cache: bool = NO_CACHE_NAMES.contains(&first_op.name()) || - NO_CACHE_NAMES.contains(&second_op.name()) || - // Skip params that do not evaluate to floats for caching and commutation library - first_params.iter().any(|p| !matches!(p, Param::Float(_))) || - second_params.iter().any(|p| !matches!(p, Param::Float(_))) - && !SUPPORTED_OP.contains(op1.name()) - && !SUPPORTED_OP.contains(op2.name()); - - if skip_cache { + // For our cache to work correctly, we require the gate's definition to only depend on the + // ``params`` attribute. This cannot be guaranteed for custom gates, so we only check + // the cache for our standard gates, which we know are defined by the ``params`` AND + // that the ``params`` are float-only at this point. + let whitelist = get_standard_gate_names(); + let check_cache = whitelist.contains(&first_op.name()) + && whitelist.contains(&second_op.name()) + && first_params.iter().all(|p| matches!(p, Param::Float(_))) + && second_params.iter().all(|p| matches!(p, Param::Float(_))); + + if !check_cache { return self.commute_matmul( py, first_op, @@ -630,21 +637,24 @@ fn map_rotation<'a>( ) -> (&'a OperationRef<'a>, &'a [Param], bool) { let name = op.name(); if let Some(generator) = SUPPORTED_ROTATIONS.get(name) { - // if the rotation angle is below the tolerance, the gate is assumed to + // If the rotation angle is below the tolerance, the gate is assumed to // commute with everything, and we simply return the operation with the flag that - // it commutes trivially + // it commutes trivially. if let Param::Float(angle) = params[0] { if (angle % TWOPI).abs() < tol { return (op, params, true); }; }; - // otherwise, we check if a generator is given -- if not, we'll just return the operation - // itself (e.g. RXX does not have a generator and is just stored in the commutations - // dictionary) + // Otherwise we need to cover two cases -- either a generator is given, in which case + // we return it, or we don't have a generator yet, but we know we have the operation + // stored in the commutation library. For example, RXX does not have a generator in Rust + // yet (PauliGate is not in Rust currently), but it is stored in the library, so we + // can strip the parameters and just return the gate. if let Some(gate) = generator { return (gate, &[], false); }; + return (op, &[], false); } (op, params, false) } diff --git a/crates/circuit/src/operations.rs b/crates/circuit/src/operations.rs index 59adfd9e0e8..444d178c686 100644 --- a/crates/circuit/src/operations.rs +++ b/crates/circuit/src/operations.rs @@ -431,6 +431,11 @@ static STANDARD_GATE_NAME: [&str; STANDARD_GATE_SIZE] = [ "rcccx", // 51 ("rc3x") ]; +/// Get a slice of all standard gate names. +pub fn get_standard_gate_names() -> &'static [&'static str] { + &STANDARD_GATE_NAME +} + impl StandardGate { pub fn create_py_op( &self, diff --git a/test/benchmarks/releasenotes/notes/conservative-commutation-checking-e1ca8501416119a1.yaml b/test/benchmarks/releasenotes/notes/conservative-commutation-checking-e1ca8501416119a1.yaml new file mode 100644 index 00000000000..58a47c1f859 --- /dev/null +++ b/test/benchmarks/releasenotes/notes/conservative-commutation-checking-e1ca8501416119a1.yaml @@ -0,0 +1,17 @@ +--- +fixes: + - | + Commutation relations of :class:`~.circuit.Instruction`\ s with float-only ``params`` + were eagerly cached by the :class:`.CommutationChecker`, using the ``params`` as key to + query the relation. This could lead to faulty results, if the instruction's definition + depended on additional information that just the :attr:`~.circuit.Instruction.params` + attribute, such as e.g. the case for :class:`.PauliEvolutionGate`. + This behavior is now fixed, and the commutation checker only conservatively caches + commutations for Qiskit-native standard gates. This can incur a performance cost if you were + relying on your custom gates being cached, however, we cannot guarantee safe caching for + custom gates, as they might rely on information beyond :attr:`~.circuit.Instruction.params`. + - | + Fixed a bug in the :class:`.CommmutationChecker`, where checking commutation of instruction + with non-numeric values in the :attr:`~.circuit.Instruction.params` attribute (such as the + :class:`.PauliGate`) could raise an error. + Fixed `#13570 `__. diff --git a/test/python/circuit/test_commutation_checker.py b/test/python/circuit/test_commutation_checker.py index 9759b5bffd1..62f7054bef8 100644 --- a/test/python/circuit/test_commutation_checker.py +++ b/test/python/circuit/test_commutation_checker.py @@ -27,6 +27,7 @@ Parameter, QuantumRegister, Qubit, + QuantumCircuit, ) from qiskit.circuit.commutation_library import SessionCommutationChecker as scc from qiskit.circuit.library import ( @@ -37,9 +38,11 @@ CRYGate, CRZGate, CXGate, + CUGate, LinearFunction, MCXGate, Measure, + PauliGate, PhaseGate, Reset, RXGate, @@ -82,6 +85,22 @@ def to_matrix(self): return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex) +class MyEvilRXGate(Gate): + """A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless).""" + + def __init__(self, evil_input_not_in_param: float): + """ + Args: + evil_input_not_in_param: The RX rotation angle. + """ + self.value = evil_input_not_in_param + super().__init__("", 1, []) + + def _define(self): + self.definition = QuantumCircuit(1) + self.definition.rx(self.value, 0) + + @ddt class TestCommutationChecker(QiskitTestCase): """Test CommutationChecker class.""" @@ -137,7 +156,7 @@ def test_standard_gates_commutations(self): def test_caching_positive_results(self): """Check that hashing positive results in commutativity checker works as expected.""" scc.clear_cached_commutations() - self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], [])) + self.assertTrue(scc.commute(ZGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], [])) self.assertGreater(scc.num_cached_entries(), 0) def test_caching_lookup_with_non_overlapping_qubits(self): @@ -150,16 +169,17 @@ def test_caching_lookup_with_non_overlapping_qubits(self): def test_caching_store_and_lookup_with_non_overlapping_qubits(self): """Check that commutations storing and lookup with non-overlapping qubits works as expected.""" scc_lenm = scc.num_cached_entries() - self.assertTrue(scc.commute(NewGateCX(), [0, 2], [], CXGate(), [0, 1], [])) - self.assertFalse(scc.commute(NewGateCX(), [0, 1], [], CXGate(), [1, 2], [])) - self.assertTrue(scc.commute(NewGateCX(), [1, 4], [], CXGate(), [1, 6], [])) - self.assertFalse(scc.commute(NewGateCX(), [5, 3], [], CXGate(), [3, 1], [])) + cx_like = CUGate(np.pi, 0, np.pi, 0) + self.assertTrue(scc.commute(cx_like, [0, 2], [], CXGate(), [0, 1], [])) + self.assertFalse(scc.commute(cx_like, [0, 1], [], CXGate(), [1, 2], [])) + self.assertTrue(scc.commute(cx_like, [1, 4], [], CXGate(), [1, 6], [])) + self.assertFalse(scc.commute(cx_like, [5, 3], [], CXGate(), [3, 1], [])) self.assertEqual(scc.num_cached_entries(), scc_lenm + 2) def test_caching_negative_results(self): """Check that hashing negative results in commutativity checker works as expected.""" scc.clear_cached_commutations() - self.assertFalse(scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], [])) + self.assertFalse(scc.commute(XGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], [])) self.assertGreater(scc.num_cached_entries(), 0) def test_caching_different_qubit_sets(self): @@ -167,10 +187,11 @@ def test_caching_different_qubit_sets(self): scc.clear_cached_commutations() # All the following should be cached in the same way # though each relation gets cached twice: (A, B) and (B, A) - scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], []) - scc.commute(XGate(), [10], [], NewGateCX(), [10, 20], []) - scc.commute(XGate(), [10], [], NewGateCX(), [10, 5], []) - scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], []) + cx_like = CUGate(np.pi, 0, np.pi, 0) + scc.commute(XGate(), [0], [], cx_like, [0, 1], []) + scc.commute(XGate(), [10], [], cx_like, [10, 20], []) + scc.commute(XGate(), [10], [], cx_like, [10, 5], []) + scc.commute(XGate(), [5], [], cx_like, [5, 7], []) self.assertEqual(scc.num_cached_entries(), 1) def test_zero_rotations(self): @@ -377,12 +398,14 @@ def test_serialization(self): """Test that the commutation checker is correctly serialized""" import pickle + cx_like = CUGate(np.pi, 0, np.pi, 0) + scc.clear_cached_commutations() - self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], [])) + self.assertTrue(scc.commute(ZGate(), [0], [], cx_like, [0, 1], [])) cc2 = pickle.loads(pickle.dumps(scc)) self.assertEqual(cc2.num_cached_entries(), 1) dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[]) - dop2 = DAGOpNode(NewGateCX(), qargs=[0, 1], cargs=[]) + dop2 = DAGOpNode(cx_like, qargs=[0, 1], cargs=[]) cc2.commute_nodes(dop1, dop2) dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[]) dop2 = DAGOpNode(CXGate(), qargs=[0, 1], cargs=[]) @@ -430,6 +453,25 @@ def test_rotation_mod_2pi(self, gate_cls): scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), []) ) + def test_custom_gate_caching(self): + """Test a custom gate is correctly handled on consecutive runs.""" + + all_commuter = MyEvilRXGate(0) # this will with anything + some_rx = MyEvilRXGate(1.6192) # this should not commute with H + + # the order here is important: we're testing whether the gate that commutes with + # everything is used after the first commutation check, regardless of the internal + # gate parameters + self.assertTrue(scc.commute(all_commuter, [0], [], HGate(), [0], [])) + self.assertFalse(scc.commute(some_rx, [0], [], HGate(), [0], [])) + + def test_nonfloat_param(self): + """Test commutation-checking on a gate that has non-float ``params``.""" + pauli_gate = PauliGate("XX") + rx_gate_theta = RXGate(Parameter("Theta")) + self.assertTrue(scc.commute(pauli_gate, [0, 1], [], rx_gate_theta, [0], [])) + self.assertTrue(scc.commute(rx_gate_theta, [0], [], pauli_gate, [0, 1], [])) + if __name__ == "__main__": unittest.main()