diff --git a/src/system/Asense.jl b/src/system/Asense.jl index ca15569..39642c2 100644 --- a/src/system/Asense.jl +++ b/src/system/Asense.jl @@ -19,7 +19,7 @@ and coil sensitivity maps `smaps`. The input `smaps` can either be a `D+1` dimensional array of size `(size(samp)..., ncoil)`, -or a Vector of `ncoil` arrays of size `size(samp)`. +or a Vector (or `Slices`) of `ncoil` arrays of size `size(samp)`. # Input - `samp::AbstractArray{<:Bool}` `D`-dimensional sampling pattern. @@ -37,7 +37,7 @@ Returns a `LinearMapsAA.LinearMapAO` object. """ function Asense( samp::AbstractArray{<:Bool, D}, - smaps::Vector{<:AbstractArray{<:Number}}, + smaps::AbstractVector{<:AbstractArray{<:Number}}, ; dims = 1:D, T::Type{<:Complex{<:AbstractFloat}} = ComplexF32, @@ -48,9 +48,9 @@ function Asense( kwargs... ) where {D, Tw <: Number} - all(in(1:D), dims) || error("dims $dims") - promote_type(Tw, T) == Tw || error("type Tw=$Tw cannot hold T=$T") - axes(work1) == axes(work2) == axes(samp) || error("axes mismatch: samp work") + all(in(1:D), dims) || throw(DimensionMismath("dims $dims")) + promote_type(Tw, T) == Tw || throw("type Tw=$Tw cannot hold T=$T") + axes(work1) == axes(work2) == axes(samp) || throw("axes mismatch: samp work") all(==(axes(samp)), axes.(smaps)) || throw("axes mismatch: samp smaps") sdim = size(samp) @@ -66,12 +66,12 @@ function Asense( function forw!(y, x) for ic in 1:ncoil - @. work1 = x * smaps[ic] + @. work1 = x * smaps[ic] # apply sensitivity map ifftshift!(work2, work1) - mul!(work2, pf, work2) + mul!(work2, pf, work2) # FFT fftshift!(work1, work2) if factor == 1 - y[:,ic] .= work1[samp] + y[:,ic] .= work1[samp] # sampling else @. y[:,ic] = work1[samp] * factor end @@ -111,13 +111,14 @@ function Asense( end +# handle typical array version of `smaps` efficiently function Asense( samp::AbstractArray{<:Bool, D}, smaps::AbstractArray{<:Number}, ; kwargs... ) where D - ndims(smaps) == D+1 || error("dimension mismatch") - smapv = [eachslice(smaps, dims = D+1)...] + ndims(smaps) == D+1 || throw(DimensionMismatch("$(ndims(smaps)) ≠ $(D+1)")) + smapv = eachslice(smaps, dims = D+1) # Slices return Asense(samp, smapv; kwargs...) end