diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.cpp b/lib/Optimizer/Transforms/DecompositionPatterns.cpp index 3849329843..4275473153 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.cpp +++ b/lib/Optimizer/Transforms/DecompositionPatterns.cpp @@ -11,6 +11,7 @@ #include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -40,10 +41,12 @@ inline Value createDivF(Location loc, Value numerator, double denominator, return rewriter.create(loc, numerator, denominatorValue); } -/// @brief Returns true if \p op contains any `ControlType` operands. -inline bool containsControlTypes(quake::OperatorInterface op) { - return llvm::any_of(op.getControls(), [](const Value &v) { - return v.getType().isa(); +/// @brief Returns true if \p op contains any `ControlType` operands that are +/// used outside of the block where \p op resides. +inline bool containsControlsUsedOutsideBlock(quake::OperatorInterface op) { + return llvm::any_of(op.getControls(), [op](Value v) { + return v.getType().isa() && + v.isUsedOutsideOfBlock(op->getBlock()); }); } @@ -82,6 +85,37 @@ class QuakeOperatorCreator { quake::WireType::get(rewriter.getContext())); } + /// @brief Promote all `quake.control` types in \p controls to `quake.wire` + /// types. + void promoteControls(Location loc, MutableArrayRef controls) { + origControls.assign(controls.begin(), controls.end()); + for (auto &c : controls) + if (c.getType().isa()) + c = rewriter.create( + loc, quake::WireType::get(rewriter.getContext()), c); + } + + /// @brief Perform necessary conversion of `quake.wire` values to + /// `quake.control` values (complementing `promoteControls`). This also + /// replaces downstream uses of the original control with the new control in + /// case the decomposition modified the control. + void demoteWiresToControlsAndReplaceUses(Operation *op, + MutableArrayRef controls) { + auto ctrlTy = quake::ControlType::get(rewriter.getContext()); + auto loc = op->getLoc(); + DominanceInfo domInfo(op->getParentOfType()); + for (auto &&[oc, c] : llvm::zip_equal(origControls, controls)) + if (oc.getType().isa()) { + c = rewriter.create(loc, ctrlTy, c); + // This is like oc.replaceAllUsesWith(c) except it checks for + // proper dominance before doing the replacement. This is important + // because the rewriter tends to work from the bottom up. + for (auto &use : llvm::make_early_inc_range(oc.getUses())) + if (domInfo.properlyDominates(c, use.getOwner())) + use.set(c); + } + } + /// Pluck out the values from \p newValues whose type is `WireType` and /// replace all the \p op uses with those values. void selectWiresAndReplaceUses(Operation *op, ValueRange newValues) { @@ -224,6 +258,9 @@ class QuakeOperatorCreator { private: PatternRewriter &rewriter; + /// The original control values before some may have been promoted from + /// quake.control to quake.wire. + SmallVector origControls; }; /// Check whether the operation has the correct number of controls. @@ -661,11 +698,6 @@ struct CXToCZ : public OpRewritePattern { PatternRewriter &rewriter) const override { if (failed(checkNumControls(op, 1))) return failure(); - // This decomposition does not support `quake.control` types because the - // input controls are used as targets during this transformation. - if (containsControlTypes(op)) - return failure(); - // Op info Location loc = op->getLoc(); Value target = op.getTarget(); @@ -675,7 +707,15 @@ struct CXToCZ : public OpRewritePattern { if (negatedControls) negControl = (*negatedControls)[0]; + // TODO - Update this pattern to support threading modified controls + // throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses + // does not currently handle that. + if (negControl && containsControlsUsedOutsideBlock(op)) + return failure(); + QuakeOperatorCreator qRewriter(rewriter); + if (negControl) + qRewriter.promoteControls(loc, controls); qRewriter.create(loc, target); if (negControl) qRewriter.create(loc, controls); @@ -684,6 +724,8 @@ struct CXToCZ : public OpRewritePattern { qRewriter.create(loc, controls); qRewriter.create(loc, target); + if (negControl) + qRewriter.demoteWiresToControlsAndReplaceUses(op, controls); qRewriter.selectWiresAndReplaceUses(op, controls, target); rewriter.eraseOp(op); return success(); @@ -813,11 +855,6 @@ struct CCZToCX : public OpRewritePattern { LogicalResult matchAndRewrite(quake::ZOp op, PatternRewriter &rewriter) const override { - // This decomposition does not support `quake.control` types because the - // input controls are used as targets during this transformation. - if (containsControlTypes(op)) - return failure(); - SmallVector controls(2); if (failed(checkAndExtractControls(op, controls, rewriter))) return failure(); @@ -843,7 +880,14 @@ struct CCZToCX : public OpRewritePattern { } } + // TODO - Update this pattern to support threading modified controls + // throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses + // does not currently handle that. + if (containsControlsUsedOutsideBlock(op)) + return failure(); + QuakeOperatorCreator qRewriter(rewriter); + qRewriter.promoteControls(loc, controls); qRewriter.create(loc, controls[1], target); qRewriter.create(loc, /*isAdj=*/!negC0, target); qRewriter.create(loc, controls[0], target); @@ -860,6 +904,7 @@ struct CCZToCX : public OpRewritePattern { qRewriter.create(loc, /*isAdj=*/negC1, controls[0]); + qRewriter.demoteWiresToControlsAndReplaceUses(op, controls); qRewriter.selectWiresAndReplaceUses(op, controls, target); rewriter.eraseOp(op); return success(); @@ -878,10 +923,6 @@ struct CZToCX : public OpRewritePattern { LogicalResult matchAndRewrite(quake::ZOp op, PatternRewriter &rewriter) const override { - // This decomposition does not support `quake.control` types because the - // input controls are used as targets during this transformation. - if (containsControlTypes(op)) - return failure(); if (failed(checkNumControls(op, 1))) return failure(); @@ -894,7 +935,15 @@ struct CZToCX : public OpRewritePattern { if (negatedControls) negControl = (*negatedControls)[0]; + // TODO - Update this pattern to support threading modified controls + // throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses + // does not currently handle that. + if (negControl && containsControlsUsedOutsideBlock(op)) + return failure(); + QuakeOperatorCreator qRewriter(rewriter); + if (negControl) + qRewriter.promoteControls(loc, controls); qRewriter.create(loc, target); if (negControl) qRewriter.create(loc, controls); @@ -903,6 +952,8 @@ struct CZToCX : public OpRewritePattern { qRewriter.create(loc, controls); qRewriter.create(loc, target); + if (negControl) + qRewriter.demoteWiresToControlsAndReplaceUses(op, controls); qRewriter.selectWiresAndReplaceUses(op, controls, target); rewriter.eraseOp(op); return success(); @@ -969,7 +1020,10 @@ struct CR1ToCX : public OpRewritePattern { LogicalResult matchAndRewrite(quake::R1Op op, PatternRewriter &rewriter) const override { - if (containsControlTypes(op)) + // TODO - Update this pattern to support threading modified controls + // throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses + // does not currently handle that. + if (containsControlsUsedOutsideBlock(op)) return failure(); Value control; @@ -994,6 +1048,7 @@ struct CR1ToCX : public OpRewritePattern { Value negHalfAngle = rewriter.create(loc, halfAngle); QuakeOperatorCreator qRewriter(rewriter); + qRewriter.promoteControls(loc, control); qRewriter.create(loc, /*isAdj*/ negControl, halfAngle, noControls, control); qRewriter.create(loc, control, target); @@ -1002,6 +1057,7 @@ struct CR1ToCX : public OpRewritePattern { qRewriter.create(loc, control, target); qRewriter.create(loc, halfAngle, noControls, target); + qRewriter.demoteWiresToControlsAndReplaceUses(op, control); qRewriter.selectWiresAndReplaceUses(op, ValueRange{control, target}); rewriter.eraseOp(op); return success(); diff --git a/test/Transforms/DecompositionPatterns/CCZToCX.qke b/test/Transforms/DecompositionPatterns/CCZToCX.qke index c560079ba9..6c9ea5bbce 100644 --- a/test/Transforms/DecompositionPatterns/CCZToCX.qke +++ b/test/Transforms/DecompositionPatterns/CCZToCX.qke @@ -9,6 +9,7 @@ // RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CCZToCX})' %s | CircuitCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s +// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s // Test the decomposition pattern with different control types. The FileCheck // part of this test only cares about the sequence of operations. Correcteness diff --git a/test/Transforms/DecompositionPatterns/CR1ToCX.qke b/test/Transforms/DecompositionPatterns/CR1ToCX.qke index c2cff33ee6..58d391c42a 100644 --- a/test/Transforms/DecompositionPatterns/CR1ToCX.qke +++ b/test/Transforms/DecompositionPatterns/CR1ToCX.qke @@ -10,6 +10,7 @@ // RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CR1ToCX})' %s | CircuitCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CR1ToCX})' %s | FileCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CR1ToCX})' %s | CircuitCheck %s +// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CR1ToCX})' %s | FileCheck %s // Test the decomposition pattern with different control types. The FileCheck // part of this test only cares about the sequence of operations. Correcteness diff --git a/test/Transforms/DecompositionPatterns/CXToCZ.qke b/test/Transforms/DecompositionPatterns/CXToCZ.qke index 20b1d2c683..90e96961a4 100644 --- a/test/Transforms/DecompositionPatterns/CXToCZ.qke +++ b/test/Transforms/DecompositionPatterns/CXToCZ.qke @@ -10,6 +10,7 @@ // RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CXToCZ})' %s | CircuitCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CXToCZ})' %s | FileCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CXToCZ})' %s | CircuitCheck %s +// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CXToCZ})' %s | FileCheck %s // Test the decomposition pattern with different control types. The FileCheck // part of this test only cares about the sequence of operations. Correcteness diff --git a/test/Transforms/DecompositionPatterns/CZToCX.qke b/test/Transforms/DecompositionPatterns/CZToCX.qke index 9846c1f798..e0b11259d3 100644 --- a/test/Transforms/DecompositionPatterns/CZToCX.qke +++ b/test/Transforms/DecompositionPatterns/CZToCX.qke @@ -10,6 +10,7 @@ // RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CZToCX})' %s | CircuitCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CZToCX})' %s | FileCheck %s // RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CZToCX})' %s | CircuitCheck %s +// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CZToCX})' %s | FileCheck %s // Test the decomposition pattern with different control types. The FileCheck // part of this test only cares about the sequence of operations. Correcteness