Skip to content

Commit

Permalink
[core] Add quantum reference product type (#2254)
Browse files Browse the repository at this point in the history
* Start on pure quantum struct usage in kernels

Signed-off-by: Alex McCaskey <[email protected]>

* Update the python bindings with new qstruct restrictions

Signed-off-by: Alex McCaskey <[email protected]>

* Enable default parenthesis constructor

Signed-off-by: Alex McCaskey <[email protected]>

* disallow recursive quantum struct

Signed-off-by: Alex McCaskey <[email protected]>

* Implement error handling for various cases in python

Signed-off-by: Alex McCaskey <[email protected]>

* spell fixes

Signed-off-by: Alex McCaskey <[email protected]>

* forgot to filter out __qpu__ methods on structs, those are allowed

Signed-off-by: Alex McCaskey <[email protected]>

* Add new quantum reference type, !quake.struq, and a couple of new
operations: quake.make_struq and quake.get_member.  These add the
utility of having a product quantum reference type (to logically group
together sets of qubits) but keep the classical and quantum dialects
distinct.

Update the tests, python ast bridge, C++ bridge, add codegen patterns,
etc.

* Whackamole games with the CI.

Add roundtrip test on new type and ops.

Update the python tests. Also change test to eliminate deprecation
warnings.

Add invlid IR checks for new operations.

Add sanity checks. We do not want to allow a quantum struct that holds
anything but non-owning references to qubits or qubit collections.

Remove unused folder pattern.

Workaround for overly assertive compiler warning.

Reenable the hash-and-cache of extract_ref ops in the C++ bridge. This
is a dubious optimization that we may actually want to take out at some
point, but that should be part of a distinct/different PR.

Update test to work around that pytest output can be shuffled.

Add case to python for quake.struq type.  Another python fix.

Add explicit checks to utils.py.

Stab in the dark.

---------

Signed-off-by: Alex McCaskey <[email protected]>
Co-authored-by: Alex McCaskey <[email protected]>
  • Loading branch information
schweitzpgi and amccaskey authored Oct 8, 2024
1 parent f3907a3 commit e21a32b
Show file tree
Hide file tree
Showing 29 changed files with 1,242 additions and 138 deletions.
1 change: 1 addition & 0 deletions include/cudaq/Frontend/nvqpp/ASTBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class QuakeBridgeVisitor
DataRecursionQueue *q = nullptr);
bool VisitCXXConstructExpr(clang::CXXConstructExpr *x);
bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x);
bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
bool TraverseDeclRefExpr(clang::DeclRefExpr *x,
DataRecursionQueue *q = nullptr);
Expand Down
48 changes: 48 additions & 0 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,54 @@ def quake_ReturnWireOp : QuakeOp<"return_wire"> {
let assemblyFormat = "$target `:` type(operands) attr-dict";
}

//===----------------------------------------------------------------------===//
// Struq handling
//===----------------------------------------------------------------------===//

def quake_MakeStruqOp : QuakeOp<"make_struq", [Pure]> {
let summary = "create a quantum struct from a set of quantum references";
let description = [{
Given a list of values of quantum reference type, creates a new quantum
product reference type. This is a logical grouping and does not imply any
new quantum references are created.

This operation can be useful for grouping a number of values of type `veq`
into a logical product type, which may be passed to a pure device kernel
as a single unit, for example. These product types may always be erased into
a vector of the quantum references used to compose them via a make_struq op.
}];

let arguments = (ins Variadic<NonStruqRefType>:$veqs);
let results = (outs StruqType);
let hasVerifier = 1;

let assemblyFormat = [{
$veqs `:` functional-type(operands, results) attr-dict
}];
}

def quake_GetMemberOp : QuakeOp<"get_member", [Pure]> {
let summary = "extract quantum references from a quantum struct";
let description = [{
The get_member operation can be used to extract a set of quantum references
from a quantum struct (product) type. The fields in the quantum struct are
indexed from 0 to $n-1$ where $n$ is the number of fields. An index outside
of this range will produce a verification error.
}];

let arguments = (ins
StruqType:$struq,
I32Attr:$index
);
let results = (outs NonStruqRefType);
let hasCanonicalizer = 1;
let hasVerifier = 1;

let assemblyFormat = [{
$struq `[` $index `]` `:` functional-type(operands, results) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// ToControl, FromControl pair
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 8 additions & 2 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace quake {
inline bool isQuantumType(mlir::Type ty) {
// NB: this intentionally excludes MeasureType.
return mlir::isa<quake::RefType, quake::VeqType, quake::WireType,
quake::ControlType>(ty);
quake::ControlType, quake::StruqType>(ty);
}

/// \returns true if \p `ty` is a Quake type.
Expand All @@ -34,10 +34,16 @@ inline bool isQuakeType(mlir::Type ty) {
return isQuantumType(ty) || mlir::isa<quake::MeasureType>(ty);
}

inline bool isQuantumReferenceType(mlir::Type ty) {
/// \returns true if \p ty is a quantum reference type, excluding `struq`.
inline bool isNonStruqReferenceType(mlir::Type ty) {
return mlir::isa<quake::RefType, quake::VeqType>(ty);
}

/// \returns true if \p ty is a quantum reference type.
inline bool isQuantumReferenceType(mlir::Type ty) {
return isNonStruqReferenceType(ty) || mlir::isa<quake::StruqType>(ty);
}

/// A quake wire type is a linear type.
inline bool isLinearType(mlir::Type ty) {
return mlir::isa<quake::WireType>(ty);
Expand Down
44 changes: 42 additions & 2 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,41 @@ def VeqType : QuakeType<"Veq", "veq"> {
}];
}

//===----------------------------------------------------------------------===//
// StruqType: quantum reference type; product of veq and ref types.
//===----------------------------------------------------------------------===//

def StruqType : QuakeType<"Struq", "struq"> {
let summary = "a product type of quantum references";
let description = [{
This type allows one to group veqs of quantum references together in a
single product type.

To support Python, a struq type can be assigned a name. This allows the
python bridge to perform dictionary based lookups on member field names.
}];

let parameters = (ins
"mlir::StringAttr":$name,
ArrayRefParameter<"mlir::Type">:$members
);
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
std::size_t getNumMembers() const { return getMembers().size(); }
}];

let builders = [
TypeBuilder<(ins CArg<"llvm::ArrayRef<mlir::Type>">:$members), [{
return $_get($_ctxt, mlir::StringAttr{}, members);
}]>,
TypeBuilder<(ins CArg<"llvm::StringRef">:$name,
CArg<"llvm::ArrayRef<mlir::Type>">:$members), [{
return $_get($_ctxt, mlir::StringAttr::get($_ctxt, name), members);
}]>
];
}

//===----------------------------------------------------------------------===//
// MeasureType: classical data type
//===----------------------------------------------------------------------===//
Expand All @@ -183,14 +218,19 @@ def MeasureType : QuakeType<"Measure", "measure"> {
}

def AnyQTypeLike : TypeConstraint<Or<[WireType.predicate, VeqType.predicate,
ControlType.predicate, RefType.predicate]>, "quake quantum types">;
ControlType.predicate, RefType.predicate, StruqType.predicate]>,
"quake quantum types">;
def AnyQType : Type<AnyQTypeLike.predicate, "quantum type">;
def AnyQTargetTypeLike : TypeConstraint<Or<[WireType.predicate,
VeqType.predicate, RefType.predicate]>, "quake quantum target types">;
def AnyQTargetType : Type<AnyQTargetTypeLike.predicate, "quantum target type">;
def AnyRefTypeLike : TypeConstraint<Or<[VeqType.predicate,
def AnyRefTypeLike : TypeConstraint<Or<[VeqType.predicate, StruqType.predicate,
RefType.predicate]>, "quake quantum reference types">;
def AnyRefType : Type<AnyRefTypeLike.predicate, "quantum reference type">;
def NonStruqRefTypeLike : TypeConstraint<Or<[VeqType.predicate,
RefType.predicate]>, "non-struct quake quantum reference types">;
def NonStruqRefType : Type<NonStruqRefTypeLike.predicate,
"non-struct quantum reference type">;
def AnyQValueTypeLike : TypeConstraint<Or<[WireType.predicate,
ControlType.predicate]>, "quake quantum value types">;
def AnyQValueType : Type<AnyQValueTypeLike.predicate, "quantum value type">;
Expand Down
22 changes: 17 additions & 5 deletions lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ void QuakeBridgeVisitor::addArgumentSymbols(
auto parmTy = entryBlock->getArgument(index).getType();
if (isa<FunctionType, cc::CallableType, cc::IndirectCallableType,
cc::PointerType, cc::SpanLikeType, LLVM::LLVMStructType,
quake::ControlType, quake::RefType, quake::VeqType,
quake::WireType>(parmTy)) {
quake::ControlType, quake::RefType, quake::StruqType,
quake::VeqType, quake::WireType>(parmTy)) {
symbolTable.insert(name, entryBlock->getArgument(index));
} else {
auto stackSlot = builder.create<cc::AllocaOp>(loc, parmTy);
Expand Down Expand Up @@ -658,9 +658,8 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
if (auto qType = dyn_cast<quake::RefType>(type)) {
// Variable is of !quake.ref type.
if (x->hasInit() && !valueStack.empty()) {
auto val = popValue();
symbolTable.insert(name, val);
return pushValue(val);
symbolTable.insert(name, peekValue());
return true;
}
auto zero = builder.create<mlir::arith::ConstantIntOp>(
loc, 0, builder.getIntegerType(64));
Expand All @@ -672,6 +671,13 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
return pushValue(addressTheQubit);
}

if (isa<quake::StruqType>(type)) {
// A pure quantum struct is just passed along by value. It cannot be stored
// to a variable.
symbolTable.insert(name, peekValue());
return true;
}

// Here we maybe have something like auto var = mz(qreg)
if (auto vecType = dyn_cast<cc::StdvecType>(type)) {
// Variable is of !cc.stdvec type.
Expand Down Expand Up @@ -805,6 +811,12 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
return pushValue(cast);
}

// Don't allocate memory for a quantum or value-semantic struct.
if (auto insertValOp = initValue.getDefiningOp<cc::InsertValueOp>()) {
symbolTable.insert(x->getName(), initValue);
return pushValue(initValue);
}

// Initialization expression resulted in a value. Create a variable and save
// that value to the variable's memory address.
Value alloca = builder.create<cc::AllocaOp>(loc, type);
Expand Down
81 changes: 67 additions & 14 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,14 +1109,23 @@ bool QuakeBridgeVisitor::VisitMemberExpr(clang::MemberExpr *x) {
if (auto *field = dyn_cast<clang::FieldDecl>(x->getMemberDecl())) {
auto loc = toLocation(x->getSourceRange());
auto object = popValue(); // DeclRefExpr
auto ty = popType();
std::int32_t offset = field->getFieldIndex();
if (isa<quake::StruqType>(object.getType())) {
return pushValue(
builder.create<quake::GetMemberOp>(loc, ty, object, offset));
}
if (!isa<cc::PointerType>(object.getType())) {
reportClangError(x, mangler,
"internal error: struct must be an object in memory");
return false;
}
auto eleTy = cast<cc::PointerType>(object.getType()).getElementType();
SmallVector<cc::ComputePtrArg> offsets;
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy))
if (arrTy.isUnknownSize())
offsets.push_back(0);
std::int32_t offset = field->getFieldIndex();
offsets.push_back(offset);
auto ty = popType();
return pushValue(builder.create<cc::ComputePtrOp>(
loc, cc::PointerType::get(ty), object, offsets));
}
Expand Down Expand Up @@ -2199,24 +2208,42 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr(
if (isCudaQType(typeName)) {
auto idx_var = popValue();
auto qreg_var = popValue();

auto *arg0 = x->getArg(0);
if (isa<clang::MemberExpr>(arg0)) {
// This is a subscript operator on a data member and the type is a
// quantum type (likely a `qview`). This can only happen in a quantum
// `struct`, which the spec says must be one-level deep at most and must
// only contain references to qubits explicitly allocated in other
// variables. `qreg_var` will be a `quake.get_member`. Do not add this
// extract `Op` to the symbol table, but always generate a new
// `quake.extract_ref` `Op` to get the exact qubit (reference) value.
auto address_qubit =
builder.create<quake::ExtractRefOp>(loc, qreg_var, idx_var);
return replaceTOSValue(address_qubit);
}
// Get name of the qreg, e.g. qr, and use it to construct a name for the
// element, which is intended to be qr%n when n is the index of the
// accessed qubit.
StringRef qregName = getNamedDecl(x->getArg(0))->getName();
if (!isa<clang::DeclRefExpr>(arg0))
reportClangError(x, mangler,
"internal error: expected a variable name");
StringRef qregName = getNamedDecl(arg0)->getName();
auto name = getQubitSymbolTableName(qregName, idx_var);
char *varName = strdup(name.c_str());

// If the name exists in the symbol table, return its stored value.
if (symbolTable.count(name))
return replaceTOSValue(symbolTable.lookup(name));

// Otherwise create an operation to access the qubit, store that value in
// the symbol table, and return the AddressQubit operation's resulting
// value.
// Otherwise create an operation to access the qubit, store that value
// in the symbol table, and return the AddressQubit operation's
// resulting value.
auto address_qubit =
builder.create<quake::ExtractRefOp>(loc, qreg_var, idx_var);

// NB: varName is built from the variable name *and* the index value. This
// front-end optimization is likely unnecessary as the compiler can always
// canonicalize and merge identical quake.extract_ref operations.
symbolTable.insert(StringRef(varName), address_qubit);
return replaceTOSValue(address_qubit);
}
Expand Down Expand Up @@ -2395,7 +2422,10 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
bool allRef = std::all_of(last.begin(), last.end(), [](auto v) {
return isa<quake::RefType, quake::VeqType>(v.getType());
});
if (allRef) {
if (allRef && isa<quake::StruqType>(initListTy))
return pushValue(builder.create<quake::MakeStruqOp>(loc, initListTy, last));

if (allRef && !isa<cc::StructType>(initListTy)) {
// Initializer list contains all quantum reference types. In this case we
// want to create quake code to concatenate the references into a veq.
if (size > 1) {
Expand Down Expand Up @@ -2466,6 +2496,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
auto globalInit = builder.create<cc::AddressOfOp>(loc, ptrTy, name);
return pushValue(globalInit);
}

// If quantum, use value semantics with cc insert / extract value.
if (isa<quake::StruqType>(eleTy))
return pushValue(builder.create<quake::MakeStruqOp>(loc, eleTy, last));

Value alloca = (numEles > 1)
? builder.create<cc::AllocaOp>(loc, eleTy, arrSize)
: builder.create<cc::AllocaOp>(loc, eleTy);
Expand Down Expand Up @@ -2556,6 +2591,19 @@ static Type getEleTyFromVectorCtor(Type ctorTy) {
return ctorTy;
}

bool QuakeBridgeVisitor::VisitCXXParenListInitExpr(
clang::CXXParenListInitExpr *x) {
auto ty = peekType();
assert(ty && "type must be present");
LLVM_DEBUG(llvm::dbgs() << "paren list type: " << ty << '\n');
auto structTy = dyn_cast<quake::StruqType>(ty);
if (!structTy)
return true;
auto loc = toLocation(x);
auto last = lastValues(structTy.getMembers().size());
return pushValue(builder.create<quake::MakeStruqOp>(loc, structTy, last));
}

bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
auto loc = toLocation(x);
auto *ctor = x->getConstructor();
Expand Down Expand Up @@ -2855,12 +2903,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
return true;
}

if (ctor->isCopyOrMoveConstructor() && parent->isPOD()) {
// Copy or move constructor on a POD struct. The value stack should contain
// the object to load the value from.
auto fromStruct = popValue();
assert(isa<cc::StructType>(ctorTy) && "POD must be a struct type");
return pushValue(builder.create<cc::LoadOp>(loc, fromStruct));
if (ctor->isCopyOrMoveConstructor()) {
// Just walk through copy constructors for quantum struct types.
if (isa<quake::StruqType>(ctorTy))
return true;
if (parent->isPOD()) {
// Copy or move constructor on a POD struct. The value stack should
// contain the object to load the value from.
auto fromStruct = popValue();
assert(isa<cc::StructType>(ctorTy) && "POD must be a struct type");
return pushValue(builder.create<cc::LoadOp>(loc, fromStruct));
}
}

if (ctor->isCopyConstructor() && ctor->isTrivial() &&
Expand Down
Loading

0 comments on commit e21a32b

Please sign in to comment.