diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f5f3eae..db9ec4a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,13 @@ jobs: - ubuntu-latest arch: - x64 + include: + - os: macos-latest + arch: arm64 + version: '1' + - os: windows-latest + arch: x64 + version: '1' steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v3 diff --git a/src/algos.jl b/src/algos.jl index 1ab65a4..f67cdca 100644 --- a/src/algos.jl +++ b/src/algos.jl @@ -2,16 +2,14 @@ Int(d) end -@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w)) - function fft!( - out::AbstractVector{T}, in::AbstractVector{T}, + out::AbstractVector{T}, in::AbstractVector{<:Number}, start_out::Int, start_in::Int, d::Direction, t::FFTEnum, g::CallGraph{T}, idx::Int - ) where T +) where T if t === COMPOSITE_FFT fft_composite!(out, in, start_out, start_in, d, g, idx) else @@ -19,15 +17,14 @@ function fft!( s_in = root.s_in s_out = root.s_out N = root.sz - w = _conj(root.w, d) if t === DFT - fft_dft!(out, in, N, start_out, s_out, start_in, s_in, w) + fft_dft!(out, in, N, start_out, s_out, start_in, s_in, d) elseif t === POW2RADIX4_FFT - fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, w) + fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, d) elseif t === POW3_FFT _m_120 = cispi(T(2) / 3) m_120 = d === FFT_FORWARD ? _m_120 : conj(_m_120) - fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, w, m_120) + fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, m_120, d) elseif t === BLUESTEIN fft_bluestein!(out, in, d, N, start_out, s_out, start_in, s_in) else @@ -51,7 +48,13 @@ Cooley-Tukey composite FFT, with a pre-computed call graph - `idx`: Index of the current transform in the call graph """ -function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, g::CallGraph{T}, idx::Int) where {T,U} +function fft_composite!( + out::AbstractVector{T}, in::AbstractVector{U}, + start_out::Int, start_in::Int, + d::Direction, + g::CallGraph{T}, + idx::Int +) where {T,U} root = g[idx] left_idx = idx + root.left right_idx = idx + root.right @@ -66,12 +69,8 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out Rt = right.type Lt = left.type - w1 = _conj(root.w, d) Rtype = real(T) - # The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`. - # Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2 - # for fixed j1; (α, β) depend on j1 so we reset them at each outer step. - dir = twiddle_direction(w1) + dir = direction_sign(d) tmp = g.workspace[idx] if Rt === BLUESTEIN @@ -90,11 +89,14 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out end if j1 > 0 - αi, βi = singleton_params(dir * Rtype(2 * j1) / Rtype(N)) - ci, si = one(Rtype), zero(Rtype) + # The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`. + # Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2 + # for fixed j1; (α, β) depend on j1 so we reset them at each outer step. + zj1 = singleton_params(dir * Rtype(j1) / Rtype(N)) + wk2 = one(T) @inbounds for k2 in 1:N2-1 - ci, si = singleton_step(ci, si, αi, βi) - tmp[R_start_out + k2] *= Complex(ci, si) + wk2 = singleton_step(wk2, zj1) + tmp[R_start_out + k2] *= wk2 end end end @@ -127,10 +129,16 @@ Discrete Fourier Transform, O(N^2) algorithm, in place. - `stride_out`: Stride of the output vector - `start_in`: Index of the first element of the input vector - `stride_in`: Stride of the input vector -- `w`: The value `cispi(direction_sign(d) * 2 / N)` +- `d`: Direction of the transform """ -function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T} +function fft_dft!( + out::AbstractVector{T}, in::AbstractVector{T}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Complex} tmp = in[start_in] @inbounds for j in 1:N-1 tmp += in[start_in + j*stride_in] @@ -138,20 +146,26 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o out[start_out] = tmp Rtype = real(T) - dir = twiddle_direction(w) - @inbounds for d in 1:N-1 - t = in[start_in] - αk, βk = singleton_params(dir * Rtype(2 * d) / Rtype(N)) - ck, sk = one(Rtype), zero(Rtype) + dir = direction_sign(d) + @inbounds for j in 1:N-1 + tmp = in[start_in] + zj = singleton_params(dir * Rtype(j) / Rtype(N)) + wk = one(T) @inbounds for k in 1:N-1 - ck, sk = singleton_step(ck, sk, αk, βk) - t += Complex(ck, sk) * in[start_in + k*stride_in] + wk = singleton_step(wk, zj) + tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = t + out[start_out + j*stride_out] = tmp end end -function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::Complex{T}) where {T<:Real} +function fft_dft!( + out::AbstractVector{Complex{T}}, in::AbstractVector{T}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Real} halfN = N÷2 tmp = Complex{T}(in[start_in]) @@ -160,16 +174,17 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int end out[start_out] = tmp - dir = twiddle_direction(w) - @inbounds for d in 1:halfN - t = Complex{T}(in[start_in]) - αk, βk = singleton_params(dir * T(2 * d) / T(N)) - ck, sk = one(T), zero(T) + dir = direction_sign(d) + @inbounds for j in 1:halfN + tmp = Complex{T}(in[start_in]) + zj = singleton_params(dir * T(j) / T(N)) + wk = one(Complex{T}) @inbounds for k in 1:N-1 - ck, sk = singleton_step(ck, sk, αk, βk) - t += Complex{T}(ck, sk) * in[start_in + k*stride_in] + wk = singleton_step(wk, zj) + tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = t + out[start_out + j*stride_out] = tmp + out[start_out + (N-j)*stride_out] = conj(tmp) end end @@ -186,10 +201,16 @@ Radix-4 FFT for powers of 2, in place - `stride_out`: Stride of the output vector - `start_in`: Index of the first element of the input vector - `stride_in`: Stride of the input vector -- `w`: The value `cispi(direction_sign(d) * 2 / N)` +- `d`: Direction of the transform """ -function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} +function fft_pow2_radix4!( + out::AbstractVector{T}, in::AbstractVector{U}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Complex, U} # If N is 2, compute the size two DFT @inbounds if N == 2 out[start_out] = in[start_in] + in[start_in + stride_in] @@ -197,8 +218,10 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, return end + dir = direction_sign(d) + # If N is 4, compute an unrolled radix-2 FFT and return - minusi = -sign(imag(w)) * im + minusi = -dir * im @inbounds if N == 4 xee = in[start_in] xoe = in[start_in + stride_in] @@ -218,24 +241,20 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, # ...othersize split the problem in four and recur m = N ÷ 4 - Rtype = real(T) - dir = twiddle_direction(w) - # Recursive sub-problem step `cispi(dir · 2 / m) = w^4`; use `cispi` - # directly so the sub-tree gets a < 1 ULP starting phase. - w_sub = cispi(dir * Rtype(2) / Rtype(m)) - - fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w_sub) + fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d) + Rtype = real(T) # Singleton recurrence for the three running twiddles `w^k`, `w^2k`, `w^3k`. - α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N)) - α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N)) - α3, β3 = singleton_params(dir * Rtype(6) / Rtype(N)) - c1, s1 = one(Rtype), zero(Rtype) - c2, s2 = one(Rtype), zero(Rtype) - c3, s3 = one(Rtype), zero(Rtype) + z1 = singleton_params(dir * Rtype(1) / Rtype(N)) + z2 = singleton_params(dir * Rtype(2) / Rtype(N)) + z3 = singleton_params(dir * Rtype(3) / Rtype(N)) + + wkoe = one(T) + wkeo = one(T) + wkoo = one(T) @inbounds for k in 0:m-1 kee = start_out + k * stride_out @@ -243,9 +262,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, keo = start_out + (k + 2 * m) * stride_out koo = start_out + (k + 3 * m) * stride_out y_kee, y_koe, y_keo, y_koo = out[kee], out[koe], out[keo], out[koo] - t_keo = y_keo * Complex(c2, s2) - t_koe = y_koe * Complex(c1, s1) - t_koo = y_koo * Complex(c3, s3) + t_koe = y_koe * wkoe + t_keo = y_keo * wkeo + t_koo = y_koo * wkoo y_kee_p_y_keo = y_kee + t_keo y_kee_m_y_keo = y_kee - t_keo t_koe_p_t_koo = t_koe + t_koo @@ -254,9 +273,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, out[koe] = y_kee_m_y_keo + t_koe_m_t_koo out[keo] = y_kee_p_y_keo - t_koe_p_t_koo out[koo] = y_kee_m_y_keo - t_koe_m_t_koo - c1, s1 = singleton_step(c1, s1, α1, β1) - c2, s2 = singleton_step(c2, s2, α2, β2) - c3, s3 = singleton_step(c3, s3, α3, β3) + wkoe = singleton_step(wkoe, z1) + wkeo = singleton_step(wkeo, z2) + wkoo = singleton_step(wkoo, z3) end end @@ -273,12 +292,18 @@ Power of 3 FFT, in place - `stride_out`: Stride of the output vector - `start_in`: Index of the first element of the input vector - `stride_in`: Stride of the input vector -- `w`: The value `cispi(direction_sign(d) * 2 / N)` -- `plus120`: Depending on direction, perform either ±120° rotation - `minus120`: Depending on direction, perform either ∓120° rotation +- `d`: Direction of the transform """ -function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, minus120::T) where {T, U} +function fft_pow3!( + out::AbstractVector{T}, in::AbstractVector{U}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + minus120::T, + d::Direction +) where {T, U} plus120 = conj(minus120) if N == 3 @muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in] @@ -290,32 +315,28 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_ # Size of subproblem Nprime = N ÷ 3 + # Dividing into subproblems + fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, minus120, d) + fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, minus120, d) + fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, minus120, d) + Rtype = real(T) - dir = twiddle_direction(w) - # Recursive sub-problem step cispi(dir · 2 / Nprime) = w^3. - w_sub = cispi(dir * Rtype(2) / Rtype(Nprime)) + dir = direction_sign(d) - # Dividing into subproblems - fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w_sub, minus120) - fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w_sub, minus120) - fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w_sub, minus120) - - α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N)) - α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N)) - c1, s1 = one(Rtype), zero(Rtype) - c2, s2 = one(Rtype), zero(Rtype) + z1 = singleton_params(dir * Rtype(1) / Rtype(N)) + z2 = singleton_params(dir * Rtype(2) / Rtype(N)) + wk1 = one(T) + wk2 = one(T) for k in 0:Nprime-1 k0 = start_out + stride_out * k k1 = start_out + stride_out * (k + Nprime) k2 = start_out + stride_out * (k + 2 * Nprime) y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2] - wk1 = Complex(c1, s1) - wk2 = Complex(c2, s2) - @muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2 - @muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120 - @muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120 - c1, s1 = singleton_step(c1, s1, α1, β1) - c2, s2 = singleton_step(c2, s2, α2, β2) + @muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2 + @muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120 + @muladd out[k2] = y_k0 + y_k1 * wk1 * minus120 + y_k2 * wk2 * plus120 + wk1 = singleton_step(wk1, z1) + wk2 = singleton_step(wk2, z2) end end @@ -362,17 +383,17 @@ with a power 2 FFT. - `stride_out`: Stride of the output vector - `start_in`: Index of the first element of the input vector - `stride_in`: Stride of the input vector -- `w`: The value `cispi(direction_sign(d) * 2 / N)` +- `scratch` (optional): preallocated scratch space for bluestein """ function fft_bluestein!( - out::AbstractVector{T}, in::AbstractVector{T}, + out::AbstractVector{T}, in::AbstractVector{<:Number}, d::Direction, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, scratch::Tuple{Vector{T},Vector{T},Vector{T},Int}=prealloc_blue(N, d, T) -) where T<:Number +) where T<:Complex (tmp, a_series, b_series, pad_len) = scratch @@ -383,14 +404,13 @@ function fft_bluestein!( a_series[i] = in[start_in+(i-1)*stride_in] * conj(b_series[i]) end - w_pad = cispi(T(2) / pad_len) # leave b_n vector alone for last step - fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, w_pad) # Fa - fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, w_pad) # Fb + fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fa + fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fb tmp .*= a_series # convolution theorem ifft - fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, conj(w_pad)) + fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, FFT_FORWARD) conv_a_b = a_series Xk = tmp diff --git a/src/callgraph.jl b/src/callgraph.jl index f64d340..f672157 100644 --- a/src/callgraph.jl +++ b/src/callgraph.jl @@ -11,16 +11,17 @@ Node of a call graph - `right`: Offset to the right child node - `type`: Object representing the type of FFT - `sz`: Size of this FFT +- `s_in`: The stride of the input +- `s_out`: The stride of the output """ -struct CallGraphNode{T} +struct CallGraphNode left::Int right::Int type::FFTEnum sz::Int s_in::Int s_out::Int - w::T end """ @@ -35,7 +36,7 @@ Object representing a graph of FFT Calls """ struct CallGraph{T<:Complex} - nodes::Vector{CallGraphNode{T}} + nodes::Vector{CallGraphNode} workspace::Vector{Vector{T}} BLUESTEIN_CUTOFF::Int end @@ -66,34 +67,36 @@ Recursively instantiate a set of `CallGraphNode`s - `nodes`: A vector (which gets expanded) of `CallGraphNode`s - `N`: The size of the FFT - `workspace`: A vector (which gets expanded) of preallocated workspaces +- `BLUESTEIN_CUTOFF`: Minimum prime that will be FFTed with the + Bluestein algorithm, below which the O(N^2) DFT is used. - `s_in`: The stride of the input - `s_out`: The stride of the output """ function CallGraphNode!( - nodes::Vector{CallGraphNode{T}}, + nodes::Vector{CallGraphNode}, N::Int, workspace::Vector{Vector{T}}, BLUESTEIN_CUTOFF::Int, - s_in::Int, s_out::Int)::Int where {T} + s_in::Int, s_out::Int +)::Int where {T} if N <= 0 throw(DimensionMismatch("Array length must be strictly positive")) end - w = cispi(T(2) / N) if iseven(N) && ispow2(N) # _ispow24(N) push!(workspace, T[]) - push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, POW2RADIX4_FFT, N, s_in, s_out)) return 1 elseif N % 3 == 0 && nextpow(3, N) == N push!(workspace, T[]) - push!(nodes, CallGraphNode(0, 0, POW3_FFT, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, POW3_FFT, N, s_in, s_out)) return 1 elseif N == 1 || Primes.isprime(N) push!(workspace, T[]) # use Bluestein's algorithm for big primes LEAF_ALG = N < BLUESTEIN_CUTOFF ? DFT : BLUESTEIN - push!(nodes, CallGraphNode(0, 0, LEAF_ALG, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, LEAF_ALG, N, s_in, s_out)) return 1 end fzn = Primes.factor(N) @@ -113,12 +116,12 @@ function CallGraphNode!( end end N2 = N ÷ N1 - push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out, w)) + push!(nodes, CallGraphNode(0, 0, DFT, N, s_in, s_out)) sz = length(nodes) push!(workspace, Vector{T}(undef, N)) left_len = CallGraphNode!(nodes, N1, workspace, BLUESTEIN_CUTOFF, N2 , N2 * s_out) right_len = CallGraphNode!(nodes, N2, workspace, BLUESTEIN_CUTOFF, N1 * s_in, 1) - nodes[sz] = CallGraphNode(1, 1 + left_len, COMPOSITE_FFT, N, s_in, s_out, w) + nodes[sz] = CallGraphNode(1, 1 + left_len, COMPOSITE_FFT, N, s_in, s_out) return 1 + left_len + right_len end @@ -128,7 +131,7 @@ Instantiate a CallGraph from a number `N` """ function CallGraph{T}(N::Int, BLUESTEIN_CUTOFF::Int) where {T} - nodes = CallGraphNode{T}[] + nodes = CallGraphNode[] workspace = Vector{Vector{T}}() CallGraphNode!(nodes, N, workspace, BLUESTEIN_CUTOFF, 1, 1) CallGraph(nodes, workspace, BLUESTEIN_CUTOFF) diff --git a/src/plan.jl b/src/plan.jl index ef8e979..a71cbc8 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -2,7 +2,7 @@ abstract type FFTAPlan{T,N} <: AbstractFFTs.Plan{T} end -struct FFTAInvPlan{T,N} <: FFTAPlan{T,N} end +struct FFTAInvPlan{_T,_N} <: FFTAPlan{_T,_N} end const RegionTypes{N} = Union{Int,AbstractVector{Int},NTuple{N,Int}} @@ -14,23 +14,23 @@ struct FFTAPlan_cx{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N} end function FFTAPlan_cx{T,N}( cg::NTuple{N,CallGraph{T}}, r::R, - dir::Direction, pinv::FFTAInvPlan{T,N} + dir::Direction ) where {T,N,R<:RegionTypes{N}} - FFTAPlan_cx{T,N,R}(cg, r, dir, pinv) + FFTAPlan_cx{T,N,R}(cg, r, dir, FFTAInvPlan{T,N}()) end struct FFTAPlan_re{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N} callgraph::NTuple{N,CallGraph{T}} region::R dir::Direction - pinv::FFTAInvPlan{T,N} flen::Int + pinv::FFTAInvPlan{T,N} end function FFTAPlan_re{T,N}( cg::NTuple{N,CallGraph{T}}, r::R, - dir::Direction, pinv::FFTAInvPlan{T,N}, flen::Int + dir::Direction, flen::Int ) where {T,N,R<:RegionTypes{N}} - FFTAPlan_re{T,N,R}(cg, r, dir, pinv, flen) + FFTAPlan_re{T,N,R}(cg, r, dir, flen, FFTAInvPlan{T,N}()) end function Base.size(p::FFTAPlan{<:Any,N}, i::Int) where N @@ -84,19 +84,17 @@ function _plan_fft( if M == 1 R1 = Int(region[1]) g = CallGraph{T}(size(x, R1), BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{T,1}() - return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv) + return FFTAPlan_cx{T,1}((g,), R1, dir) elseif M == 2 R2 = _sort(region) g1 = CallGraph{T}(size(x, R2[1]), BLUESTEIN_CUTOFF) g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{T,2}() - return FFTAPlan_cx{T,2}((g1, g2), R2, dir, pinv) + return FFTAPlan_cx{T,2}((g1, g2), R2, dir) else RM = _sort(region) return FFTAPlan_cx{T,M}( ntuple(i -> CallGraph{T}(size(x, RM[i]), BLUESTEIN_CUTOFF), Val(M)), - RM, dir, FFTAInvPlan{T,M}() + RM, dir ) end end @@ -115,14 +113,12 @@ function AbstractFFTs.plan_rfft( # problems, we just solve the problem as a single complex nn = iseven(n) ? n >> 1 : n g = CallGraph{Complex{T}}(nn, BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{Complex{T},1}() - return FFTAPlan_re{Complex{T},1,Int}((g,), R1, FFT_FORWARD, pinv, n) + return FFTAPlan_re{Complex{T},1}((g,), R1, FFT_FORWARD, n) elseif M == 2 R2 = _sort(region) g1 = CallGraph{Complex{T}}(size(x, R2[1]), BLUESTEIN_CUTOFF) g2 = CallGraph{Complex{T}}(size(x, R2[2]), BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{Complex{T},2}() - return FFTAPlan_re{Complex{T},2}((g1, g2), R2, FFT_FORWARD, pinv, size(x, R2[1])) + return FFTAPlan_re{Complex{T},2}((g1, g2), R2, FFT_FORWARD, size(x, R2[1])) else throw(ArgumentError("only supports 1D and 2D FFTs")) end @@ -142,14 +138,12 @@ function AbstractFFTs.plan_brfft( R1 = Int(region[1]) nn = iseven(len) ? len >> 1 : len g = CallGraph{T}(nn, BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{T,1}() - return FFTAPlan_re{T,1,Int}((g,), R1, FFT_BACKWARD, pinv, len) + return FFTAPlan_re{T,1}((g,), R1, FFT_BACKWARD, len) elseif M == 2 R2 = _sort(region) g1 = CallGraph{T}(len, BLUESTEIN_CUTOFF) g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) - pinv = FFTAInvPlan{T,2}() - return FFTAPlan_re{T,2}((g1, g2), R2, FFT_BACKWARD, pinv, len) + return FFTAPlan_re{T,2}((g1, g2), R2, FFT_BACKWARD, len) else throw(ArgumentError("only supports 1D and 2D FFTs")) end @@ -400,16 +394,17 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R Base.require_one_based_indexing(x) n = p.flen + p_c = complex(p) if iseven(n) # For problems of even size, we solve the rfft problem by splitting the - # problem into the even and odd part and solving the simultanously as + # problem into the even and odd part and solving them simultaneously as # a single (complex) fft of half the size, see equations (6)-(8) of # Sorensen, H. V., D. Jones, Michael Heideman, and C. Burrus. # "Real-valued fast Fourier transform algorithms." # IEEE Transactions on acoustics, speech, and signal processing 35, no. 6 (2003): 849-863. if x isa Vector && isbitstype(T) - # For a vector of bits, we can just reintepret the bits to get the - # approciate representation of even (zero based) elements as the real + # For a vector of bits, we can just reinterpret the bits to get the + # appropriate representation of even (zero based) elements as the real # part and the odd as the complex part x_c = reinterpret(Complex{T}, x) else @@ -421,11 +416,12 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R # Allocate complex result vector of half the input size plus one y = similar(x_c, m + 1) # Solve the complex fft of half the size - LinearAlgebra.mul!(view(y, 1:m), complex(p), x_c) + LinearAlgebra.mul!(view(y, 1:m), p_c, x_c) # The w stored in the plan is for m, not n, so probably cheapest to # just recompute it instead of taking a square root - wj = w = cispi(-T(2) / n) + z1 = singleton_params(-one(T) / n) + wj = cispi(-T(2) / n) # Construct the result by first constructing the elements of the # real and imaginary part, followed by the usual radix-2 assembly, @@ -441,17 +437,18 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R XY = T(0.5) * (-yj + conj(ymj)) * im y[j] = XX + wj * XY y[m-j+2] = conj(XX - wj * XY) - wj *= w + wj = singleton_step(wj, z1) end return y else # when the problem cannot be split in two equal size chunks we # convert the problem to a complex fft and truncate the redundant # part of the result vector - x_c = similar(x, Complex{T}) - y = similar(x_c) - copyto!(x_c, x) - LinearAlgebra.mul!(y, complex(p), x_c) + if size(p_c) != size(x) + throw(DimensionMismatch("plan and input array axes do not match")) + end + y = similar(x, Complex{T}) + fft!(y, x, 1, 1, p_c.dir, p_c.callgraph[1][1].type, p_c.callgraph[1], 1) return y[1:end÷2+1] end end @@ -464,10 +461,15 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex} Base.require_one_based_indexing(x) n = p.flen + p_c = complex(p) # See explanation of this approach in the method for the FORWARD transform if iseven(n) m = n >> 1 - wj = w = cispi(T(2) / n) + + R = real(T) + z1 = singleton_params(one(R) / n) + wj = cispi(R(2) / n) + x_tmp = similar(x, length(x) - 1) x_tmp[1] = complex( (real(x[1]) + real(x[end])), @@ -478,20 +480,25 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex} XY = wj * (x[j] - conj(x[m-j+2])) x_tmp[j] = XX + im * XY x_tmp[m-j+2] = conj(XX - im * XY) - wj *= w + wj = singleton_step(wj, z1) end - y_c = complex(p) * x_tmp + + y_c = p_c * x_tmp if isbitstype(T) - return copy(reinterpret(real(T), y_c)) + return copy(reinterpret(R, y_c)) else - return mapreduce(t -> [real(t); imag(t)], vcat, y_c) + y_re = similar(y_c, R, 2 * length(y_c)) + for i in eachindex(y_c) + y_re[2i-1], y_re[2i] = reim(y_c[i]) + end + return y_re end else x_tmp = similar(x, n) x_tmp[1:end÷2+1] .= x - x_tmp[end÷2+2:end] .= iseven(n) ? conj.(x[end-1:-1:2]) : conj.(x[end:-1:2]) + x_tmp[end÷2+2:end] .= @views conj.(x[end-iseven(n):-1:2]) y = similar(x_tmp) - LinearAlgebra.mul!(y, complex(p), x_tmp) + LinearAlgebra.mul!(y, p_c, x_tmp) return real(y) end end diff --git a/src/singleton_twiddle.jl b/src/singleton_twiddle.jl index 416ac63..f0f5783 100644 --- a/src/singleton_twiddle.jl +++ b/src/singleton_twiddle.jl @@ -18,31 +18,22 @@ # order of magnitude faster than a fresh `cispi`), and the extra trig # (`sincospi(θ/2)`) happens once per kernel call. -# Direction lives in `sign(imag(w))`; when `w` is real (N = 2 or any -# degenerate case where `imag` rounds to zero) both directions collapse -# to the same twiddle set so we pick +1. -@inline function twiddle_direction(w::Complex{T}) where {T<:Real} - s = imag(w) - s > 0 ? one(T) : (s < 0 ? -one(T) : one(T)) -end - # Recurrence coefficients for stepping by `cispi(freq) = e^(iπ·freq)`. -# Uses `sincospi(freq/2)` so that `α` and `β` are exact-to-ULP even +# Uses `sincospi(hfreq)` so that `α` and `β` are exact-to-ULP even # for very small frequencies — writing `1 - cos(θ)` directly suffers # catastrophic cancellation there. -@inline function singleton_params(freq::T) where {T<:Real} - s_h, c_h = sincospi(freq / 2) +@inline function singleton_params(hfreq::Real) + s_h, c_h = sincospi(-hfreq) α = 2 * s_h * s_h β = 2 * s_h * c_h - (α, β) + Complex(α, β) end # Advance `(c, s) = (cos(kθ), sin(kθ))` to `(cos((k+1)θ), sin((k+1)θ))`. # Computed as `c - (αc + βs)` rather than `(1-α)c - βs` on purpose: # the correction is small so subtracting it from `c` preserves the # high-order bits and the recurrence self-heals. -@inline function singleton_step(c::T, s::T, α::T, β::T) where {T<:Real} - c_new = c - muladd(α, c, β * s) - s_new = s - muladd(α, s, -(β * c)) - (c_new, s_new) +@inline function singleton_step(w::T, z::T) where {T<:Complex} + # muladd only reduces instructions, doesn't help precision much + w - @fastmath(z * w) end diff --git a/test/argument_checking.jl b/test/argument_checking.jl index d8e582e..545dbd7 100644 --- a/test/argument_checking.jl +++ b/test/argument_checking.jl @@ -72,8 +72,8 @@ end yc3 = randn(ComplexF64, 5, 5, 5) pxc3 = plan_fft(xc3) @test_throws DimensionMismatch pxc3 * yc3 - invalid_p = plan_fft(randn(ComplexF64, ntuple(i -> 3, 5)), 3:5) - xc4 = randn(ComplexF64, (1, ntuple(i -> 5, 3)...)) + invalid_p = plan_fft(randn(ComplexF64, ntuple(_ -> 3, 5)), 3:5) + xc4 = randn(ComplexF64, (1, ntuple(_ -> 5, 3)...)) ### plan region out of bounds @@ -93,7 +93,7 @@ end end @testset "$(N)D array" for N in 2:4 - xN = randn(ComplexF64, ntuple(i -> 3, N)) + xN = randn(ComplexF64, ntuple(_ -> 3, N)) yN = similar(xN, size(xN) .+ 1) @testset "1D plan, region=$(region)" for region in 1:N @@ -130,7 +130,7 @@ end @testset "Invalid / mutated dims" verbose=true begin @testset "Extra elements" begin for n in 3:5 - x = rand(ComplexF64, ntuple(i -> 2, n)) + x = rand(ComplexF64, ntuple(_ -> 2, n)) p1 = plan_fft(x, [1:n-1;]) push!(p1.region, n) @test_throws DimensionMismatch("Region is invalid.") p1 * x @@ -138,7 +138,7 @@ end end @testset "Unsorted dims" begin for n in 3:5 - x = rand(ComplexF64, ntuple(i -> 2, n)) + x = rand(ComplexF64, ntuple(_ -> 2, n)) p2 = plan_fft(x, [1:n-1;]) p2.region[1:2] = [2, 1] @test_throws DimensionMismatch("Region is invalid.") p2 * x diff --git a/test/custom_element_types.jl b/test/custom_element_types.jl index 735b4ba..0a519e6 100644 --- a/test/custom_element_types.jl +++ b/test/custom_element_types.jl @@ -1,8 +1,7 @@ using Test, FFTA -x = randn(2*3*4*5) - -@testset "element type: $T" for T in (Float16, BigFloat) +@testset "element type: $T" for T in (Float16, BigFloat) + x = randn(2*3*4*5) Tx = T.(x) @testset "AbstractFFTs believes that single and double precision is everything." begin diff --git a/test/ndim/minimal_complex.jl b/test/ndim/minimal_complex.jl index 21e4c74..7735fe7 100644 --- a/test/ndim/minimal_complex.jl +++ b/test/ndim/minimal_complex.jl @@ -1,7 +1,7 @@ using FFTA, Test @testset "Basic ND checks" begin - for sz in ((3, 5, 7), (4, 14, 9), (103, 5, 13), (26, 33, 35, 4), ntuple(i -> 3, 5)) + for sz in ((3, 5, 7), (4, 14, 9), (103, 5, 13), (26, 33, 35, 4), ntuple(_ -> 3, 5)) x = ones(ComplexF64, sz) @test fft(x, Tuple(1:ndims(x))) ≈ setindex!(zeros(sz), prod(sz), 1) end diff --git a/test/onedim/accuracy.jl b/test/onedim/accuracy.jl index 7a15165..b4f357f 100644 --- a/test/onedim/accuracy.jl +++ b/test/onedim/accuracy.jl @@ -8,8 +8,6 @@ using FFTA, Test, Random, LinearAlgebra # still fail comfortably against the pre-fix naive `w *= step` # recurrence, which ballooned past ~4000 ULP at N = 16384. -Random.seed!(42) - # (N, max eps ratio) across the power-of-2 ladder. Covers both even # powers (= powers of 4, recursion bottoms at N = 4) and odd powers # (recursion bottoms at N = 2), which hit different base cases in @@ -47,12 +45,12 @@ const POWERS_OF_3 = ( function _worst_relerr(N::Int) worst = 0.0 for seed in 1:5 - Random.seed!(seed) - x64 = randn(ComplexF64, N) - x32 = ComplexF32.(x64) + rng = @isdefined(Xoshiro) ? Xoshiro(seed) : MersenneTwister(seed) + x32 = randn(rng, ComplexF32, N) + x64 = ComplexF64.(x32) y32 = fft(x32) y_ref = ComplexF32.(fft(x64)) - relerr = norm(y32 .- y_ref) / norm(y_ref) + relerr = norm(y32 - y_ref) / norm(y_ref) worst = max(worst, relerr / eps(Float32)) end return worst diff --git a/test/onedim/complex_backward.jl b/test/onedim/complex_backward.jl index a9380c1..f4c5785 100644 --- a/test/onedim/complex_backward.jl +++ b/test/onedim/complex_backward.jl @@ -33,5 +33,5 @@ end end @testset "error messages" begin - @test_throws DimensionMismatch bfft(complex.(zeros(0))) + @test_throws DimensionMismatch bfft(ComplexF64[]) end diff --git a/test/onedim/real_forward.jl b/test/onedim/real_forward.jl index 3c0d8c1..00b251b 100644 --- a/test/onedim/real_forward.jl +++ b/test/onedim/real_forward.jl @@ -19,9 +19,9 @@ end end @testset "temporarily test real dft separately until used by rfft" begin - y_dft = similar(y) - FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, cispi(-2/n)) - @test y ≈ y_dft + y_dft = similar(y, n) + FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, FFTA.FFT_FORWARD) + @test y ≈ y_dft[1:length(y)] end @testset "allocation regression" begin diff --git a/test/twodim/complex_backward.jl b/test/twodim/complex_backward.jl index 268207c..85f2f6d 100644 --- a/test/twodim/complex_backward.jl +++ b/test/twodim/complex_backward.jl @@ -3,7 +3,7 @@ using FFTA, Test @testset "backward. N=$N" for N in [8, 11, 15, 16, 27, 100] x = ones(ComplexF64, N, N) y = bfft(x) - y_ref = 0*y + y_ref = zero(y) y_ref[1] = length(x) @test y ≈ y_ref end diff --git a/test/twodim/complex_forward.jl b/test/twodim/complex_forward.jl index 6798d13..8e850d8 100644 --- a/test/twodim/complex_forward.jl +++ b/test/twodim/complex_forward.jl @@ -3,7 +3,7 @@ using FFTA, Test @testset " forward. N=$N" for N in [8, 11, 15, 16, 27, 100] x = ones(ComplexF64, N, N) y = fft(x) - y_ref = 0*y + y_ref = zero(y) y_ref[1] = length(x) @test y ≈ y_ref x = randn(N,N) @@ -14,7 +14,7 @@ end @testset "2D plan, 2D array. Size: $n" for n in 1:64 @testset "size: ($m, $n)" for m in n:(n + 1) - X = complex.(randn(m, n), randn(m, n)) + X = randn(ComplexF64, (m, n)) @testset "against naive implementation" begin @test naive_2d_fourier_transform(X, FFTA.FFT_FORWARD) ≈ fft(X) diff --git a/test/twodim/real_backward.jl b/test/twodim/real_backward.jl index 8913edf..7f87faa 100644 --- a/test/twodim/real_backward.jl +++ b/test/twodim/real_backward.jl @@ -3,7 +3,7 @@ using FFTA, Test @testset "backward. N=$N" for N in [8, 11, 15, 16, 27, 100] x = ones(Complex{Float64}, N, N) y = brfft(x, 2(N-1)) - y_ref = 0*y + y_ref = zero(y) y_ref[1] = N*(2(N-1)) @test y_ref ≈ y atol=1e-10 end diff --git a/test/twodim/real_forward.jl b/test/twodim/real_forward.jl index aded588..35d71c9 100644 --- a/test/twodim/real_forward.jl +++ b/test/twodim/real_forward.jl @@ -3,7 +3,7 @@ using FFTA, Test @testset " forward. N=$N" for N in [8, 11, 15, 16, 27, 100] x = ones(Float64, N, N) y = rfft(x) - y_ref = 0*y + y_ref = zero(y) y_ref[1] = length(x) @test y ≈ y_ref x = randn(N,N)