Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue 432 #435

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/YaoAPI/src/blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ chain


julia> expect(op, r)
0.7071067811865474 + 0.0im
0.7071067811865474
```
"""
@interface expect
Expand Down
52 changes: 28 additions & 24 deletions lib/YaoArrayRegister/src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,40 +158,44 @@ for op in [:*, :/]
end
end

function Base.:*(bra::AdjointRegister{D,<:ArrayReg}, ket::ArrayReg{D}) where D
if nremain(bra) == nremain(ket)
return dot(relaxedvec(parent(bra)), relaxedvec(ket))
elseif nremain(bra) == 0 # <s|active> |remain>
return ArrayReg{D}(state(bra) * state(ket))
else
error(
"partially contract ⟨bra|ket⟩ is not supported, expect ⟨bra| to be fully actived. nactive(bra)/nqudits(bra)=$(nactive(bra))/$(nqudits(bra))",
"""
$TYPEDSIGNATURES

The overlap between `ket` and `bra`, which is only defined for two fully activated equal sized registers.
It is only slightly different from [`inner_product`](@ref) in that it always returns a complex number.

### Examples
```jldoctest; setup=:(using YaoArrayRegister)
julia> reg1 = ghz_state(3);

julia> reg2 = uniform_state(3);

julia> reg1' * reg2
0.5 + 0.0im
```
"""
function Base.:*(bra::AdjointRegister{D,<:ArrayReg}, ket::ArrayReg{D})::Number where D
# check the register sizes
nqudits(bra) == nqudits(ket) && nremain(bra) == nremain(ket) || error(
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved
"partially contract ⟨bra|ket⟩ is not supported, expect ⟨bra| and |ket⟩ to have the same size. Got nactive(bra)/nqudits(bra)=$(nactive(bra))/$(nqudits(bra)), nactive(ket)/nqudits(ket)=$(nactive(ket))/$(nqudits(ket))",
)
end
return dot(relaxedvec(parent(bra)), relaxedvec(ket))
end

Base.:*(bra::AdjointRegister{D,<:BatchedArrayReg{D}}, ket::BatchedArrayReg{D}) where D = bra .* ket
function Base.:*(
bra::AdjointRegister{D,<:BatchedArrayReg{D, T1, <:Transpose}},
ket::BatchedArrayReg{D,T2,<:Transpose},
) where {D,T1,T2}
if nremain(bra) == nremain(ket) == 0 # all active
A, C = parent(state(parent(bra))), parent(state(ket))
res = zeros(eltype(promote_type(T1, T2)), nbatch(ket))
#return mapreduce((x, y) -> conj(x) * y, +, ; dims=2)
for j = 1:size(A, 2)
for i = 1:size(A, 1)
@inbounds res[i] += conj(A[i, j]) * C[i, j]
end
end
res
elseif nremain(bra) == 0 # <s|active> |remain>
bra .* ket
else
error(
"partially contract ⟨bra|ket⟩ is not supported, expect ⟨bra| to be fully actived. nactive(bra)/nqudits(bra)=$(nactive(bra))/$(nqudits(bra))",
nqudits(bra) == nqudits(ket) && nremain(bra) == nremain(ket) && nbatch(bra) == nbatch(ket) || error(
"partially contract ⟨bra|ket⟩ is not supported, expect ⟨bra| and |ket⟩ to have the same size. Got nactive(bra)/nqudits(bra)/nbatch(bra)=$(nactive(bra))/$(nqudits(bra))/$(nbatch(bra)), nactive(ket)/nqudits(ket)=$(nactive(ket))/$(nqudits(ket))/$(nbatch(ket))",
)
A, C = parent(state(parent(bra))), parent(state(ket))
res = zeros(eltype(promote_type(T1, T2)), nbatch(ket))
for j in 1:size(A, 2), i in 1:size(A, 1)
@inbounds res[i] += conj(A[i, j]) * C[i, j]
end
return res
end

# broadcast
Expand Down
6 changes: 2 additions & 4 deletions lib/YaoArrayRegister/test/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ end
ket = arrayreg(bit"100") + 2 * arrayreg(bit"110") + 3 * arrayreg(bit"111")

focus!(ket, 2:3)
t = bra' * ket
relax!(t, 1)
@test state(t) ≈ [1, 0]
@test_throws ErrorException bra' * ket

relax!(ket, 2:3)
focus!(ket, 1)
Expand All @@ -66,7 +64,7 @@ end
reg1 = rand_state(2; nbatch = 10)
reg2 = rand_state(5; nbatch = 10)
focus!(reg2, 2:3)
@test all(reg1' * reg2 .≈ reg1' .* reg2)
@test_throws ErrorException reg1' * reg2
end

@testset "inplace funcs" begin
Expand Down
1 change: 1 addition & 0 deletions lib/YaoBlocks/src/YaoBlocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export AbstractBlock,
content,
dispatch!,
dispatch,
sandwich,
expect,
getiparams,
iparams_eltype,
Expand Down
53 changes: 31 additions & 22 deletions lib/YaoBlocks/src/blocktools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ collect_blocks(::Type{T}, x::AbstractBlock) where {T<:AbstractBlock} =
#expect(op::AbstractBlock, dm::DensityMatrix) = mapslices(x->sum(mat(op).*x)[], dm.state, dims=[1,2]) |> vec

"""
expect(op::AbstractBlock, reg) -> Vector
expect(op::AbstractBlock, reg => circuit) -> Vector
expect(op::AbstractBlock, density_matrix) -> Vector
expect(op::AbstractBlock, reg) -> Real
expect(op::AbstractBlock, reg => circuit) -> Real
expect(op::AbstractBlock, density_matrix) -> Real

Get the expectation value of an operator, the second parameter can be a register `reg` or a pair of input register and circuit `reg => circuit`.

Expand All @@ -92,13 +92,21 @@ For register input, the return value is a register.

For batched register, `expect(op, reg=>circuit)` returns a vector of size number of batch as output. However, one can not differentiate over a vector loss, so `expect'(op, reg=>circuit)` accumulates the gradient over batch, rather than returning a batched gradient of parameters.
"""
function expect(op::AbstractBlock, reg::AbstractRegister)
# NOTE: broadcast because the input register can be a batched one
return safe_real.(sandwich(reg, op, reg))
end
function expect(op, plan::Pair{<:AbstractRegister,<:AbstractBlock})
expect(op, copy(plan.first) |> plan.second)
end

function expect(op::AbstractBlock, dm::DensityMatrix)
# NOTE: we use matrix form here because the matrix size is known to be small,
# while applying a circuit on a reduced density matrix might take much more than constructing the matrix.
mop = mat(op)
# TODO: switch to `IterNz`
# sum(x->dm.state[x[2],x[1]]*x[3], IterNz(mop))
return sum(transpose(dm.state) .* mop)
return safe_real(sum(transpose(dm.state) .* mop))
end
function expect(op::AbstractAdd, reg::DensityMatrix)
# NOTE: this is faster in e.g. when the op is Heisenberg
Expand All @@ -108,23 +116,28 @@ function expect(op::Scale, reg::DensityMatrix)
factor(op) * expect(content(op), reg)
end

# NOTE: assume an register has a bra. Can we define it for density matrix?
expect(op::AbstractBlock, reg::AbstractRegister) = reg' * apply!(copy(reg), op)
"""
sandwich(bra::AbstractRegister, op::AbstractBlock, ket::AbstracRegister) -> Complex

Compute the sandwich function ⟨bra|op|ket⟩.
"""
sandwich(bra::AbstractArrayReg, op::AbstractBlock, reg::AbstractArrayReg) = bra' * apply!(copy(reg), op)

function expect(op::AbstractBlock, reg::BatchedArrayReg)
function sandwich(bra::BatchedArrayReg, op::AbstractBlock, reg::BatchedArrayReg)
@assert nbatch(bra) == nbatch(reg)
B = YaoArrayRegister._asint(nbatch(reg))
ket = apply!(copy(reg), op)
if !(reg.state isa Transpose) # not-transposed storage
if !(bra.state isa Transpose) # not-transposed storage
C = reshape(ket.state, :, B)
A = reshape(reg.state, :, B)
A = reshape(bra.state, :, B)
# reduce over the 1st dimension
conjsumprod1(A, C)
elseif size(reg.state, 2) == B # transposed storage, no environment qubits
elseif size(bra.state, 2) == B # transposed storage, no environment qubits
# reduce over the second dimension
conjsumprod2(reg.state.parent, ket.state.parent)
conjsumprod2(bra.state.parent, ket.state.parent)
else
C = reshape(ket.state.parent, :, B, size(reg.state, 1))
A = reshape(reg.state.parent, :, B, size(reg.state, 1))
C = reshape(ket.state.parent, :, B, size(bra.state, 1))
A = reshape(bra.state.parent, :, B, size(bra.state, 1))
# reduce over the 1st and 3rd dimension
conjsumprod13(A, C)
end
Expand Down Expand Up @@ -161,19 +174,15 @@ function conjsumprod13(A::AbstractArray, C::AbstractArray)
res
end

for REG in [:AbstractRegister, :BatchedArrayReg]
@eval function expect(op::AbstractAdd, reg::$REG)
sum(opi -> expect(opi, reg), op)
for REG in [:AbstractArrayReg, :BatchedArrayReg]
@eval function sandwich(bra::$REG, op::AbstractAdd, reg::$REG)
sum(opi -> sandwich(bra, opi, reg), op)
end
@eval function expect(op::Scale, reg::$REG)
factor(op) * expect(content(op), reg)
@eval function sandwich(bra::$REG, op::Scale, reg::$REG)
factor(op) * sandwich(bra, content(op), reg)
end
end

function expect(op, plan::Pair{<:AbstractRegister,<:AbstractBlock})
expect(op, copy(plan.first) |> plan.second)
end

# obtaining Dense Matrix of a block
LinearAlgebra.Matrix(blk::AbstractBlock) = Matrix(mat(blk))

Expand Down
9 changes: 9 additions & 0 deletions lib/YaoBlocks/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,12 @@ end

SparseArrays.sparse(et::EntryTable) = SparseVector(et)
Base.vec(et::EntryTable) = Vector(et)

# convert a (maybe complex) number x to real number.
function safe_real(x)
img = imag(x)
if !(iszero(img) || isapprox(x - im*img, x))
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved
error("Can not convert number $x to real due to its large imaginary part.")
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved
end
return real(x)
end
6 changes: 3 additions & 3 deletions lib/YaoBlocks/test/measure_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
@show op
@test isapprox(
sum(measure(op, reg2; nshots = 100000)) / 100000,
expect(op, reg),
sandwich(reg, op, reg),
rtol = 0.1,
)
@test reg ≈ reg2
Expand All @@ -60,7 +60,7 @@ end
reg2 = copy(reg)
@test isapprox(
dropdims(sum(measure(op, reg2; nshots = 100000), dims = 1), dims = 1) / 100000,
expect(op, reg),
sandwich(reg, op, reg),
rtol = 0.1,
)
@test reg ≈ reg2
Expand Down Expand Up @@ -91,7 +91,7 @@ end
@show op
@test isapprox(
sum(measure(op, reg2, locs; nshots = 100000)) / 100000,
expect(put(Nbit, locs => op), reg),
sandwich(reg, put(Nbit, locs => op), reg),
rtol = 0.2,
)
@test reg ≈ reg2
Expand Down
4 changes: 2 additions & 2 deletions test/easybuild/hadamardtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ end
U = put(nbit, 2=>Rx(0.2))
reg = rand_state(nbit)

@test hadamard_test(U, reg, 0.0) ≈ real(expect(U, reg))
@test hadamard_test(U, reg, -π/2) ≈ imag(expect(U, reg))
@test hadamard_test(U, reg, 0.0) ≈ real(sandwich(reg, U, reg))
@test hadamard_test(U, reg, -π/2) ≈ imag(sandwich(reg, U, reg))

reg = zero_state(2) |> EasyBuild.singlet_block()
@test single_swap_test(reg, 0) ≈ -1
Expand Down
Loading