Skip to content

Commit

Permalink
fix plotting rot
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Nov 15, 2024
1 parent bbfc13e commit 5cb7a49
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/YaoBlocks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ CacheServers = "0.2"
ChainRulesCore = "1.11"
DocStringExtensions = "0.8, 0.9"
InteractiveUtils = "1"
KrylovKit = "0.5, 0.6, 0.7"
KrylovKit = "0.5, 0.6, 0.7, 0.8"
LegibleLambdas = "0.2, 0.3"
LinearAlgebra = "1"
LuxurySparse = "0.7"
Expand Down
2 changes: 1 addition & 1 deletion lib/YaoBlocks/src/primitive/rotation_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ cache_key(R::RotationGate) = R.theta

iparams_range(::RotationGate{D,T,GT}) where {D,T,GT} = ((zero(T), T(2 * pi)),)

occupied_locs(g::RotationGate) = occupied_locs(g.block)
occupied_locs(g::RotationGate) = (1:nqudits(g)...,)

function unsafe_getindex(::Type{T}, rg::RotationGate{D}, i::Integer, j::Integer) where {D,T}
return (i==j ? cos(T(rg.theta)/2) : zero(T)) - im * sin(T(rg.theta)/2) * unsafe_getindex(T, rg.block, i, j)
Expand Down
5 changes: 4 additions & 1 deletion lib/YaoBlocks/src/treeutils/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using YaoBlocks: NotImplementedError

include("to_basictypes.jl")

export replace_block
export replace_block, flatten_basic, simplify
"""
replace_block(actor, tree::AbstractBlock) -> AbstractBlock
replace_block(pair::Pair{Type{ST}, TT}, tree::AbstractBlock) -> AbstractBlock
Expand Down Expand Up @@ -208,6 +208,9 @@ export simplify

const __default_simplification_rules__ =
Function[merge_pauli, eliminate_nested, merge_scale, combine_similar]
const __flatten_basic__ = Function[eliminate_nested, to_basictypes]

flatten_basic(ex::AbstractBlock) = simplify(ex; rules = __flatten_basic__)

# Inspired by MasonPotter/Symbolics.jl
"""
Expand Down
1 change: 1 addition & 0 deletions lib/YaoBlocks/src/treeutils/to_basictypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ convert gates to basic types
function to_basictypes end

to_basictypes(block::PrimitiveBlock) = block
to_basictypes(block::UnitaryChannel) = block
function to_basictypes(block::AbstractBlock)
throw(NotImplementedError(:to_basictypes, typeof(block)))
end
Expand Down
2 changes: 1 addition & 1 deletion lib/YaoBlocks/test/primitive/rotation_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end

@testset "occupied locs" begin
g = rot(put(5, 2 => X), 0.5)
@test occupied_locs(g) == (2,)
@test occupied_locs(g) == (1,2,3,4,5)
end

@testset "instruct_get_element" begin
Expand Down
4 changes: 4 additions & 0 deletions lib/YaoPlots/src/vizcircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ function draw!(c::CircuitGrid, p::Daggered{<:PrimitiveBlock}, address, controls)
bts = length(controls)>=1 ? get_cbrush_texts(c, content(p)) : get_brush_texts(c, content(p))
_draw!(c, [controls..., (getindex.(Ref(address), occupied_locs(p)), bts[1], bts[2]*"'")])
end
function draw!(c::CircuitGrid, p::UnitaryChannel, address, controls)
bts = (c.gatestyles.g, "*")
_draw!(c, [controls..., (getindex.(Ref(address), occupied_locs(p)), bts[1], bts[2])])
end

function draw!(c::CircuitGrid, p::Scale, address, controls)
fp = YaoBlocks.factor(p)
Expand Down
7 changes: 6 additions & 1 deletion lib/YaoPlots/test/vizcircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ end
)
YaoPlots.CircuitStyles.barrier_for_chain[] = true
@test vizcircuit(circuit) isa Drawing
end
end

@testset "rot igate" begin
@test plot(rot(igate(1), 1.)) isa Drawing
@test plot(rot(put(3, 1=>X), 1.)) isa Drawing
end

0 comments on commit 5cb7a49

Please sign in to comment.