Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ jobs:
- ubuntu-latest
arch:
- x64
include:
- os: macos-latest
arch: arm64
version: '1'
- os: windows-latest
arch: x64
version: '1'
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v3
Expand Down
202 changes: 111 additions & 91 deletions src/algos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@
Int(d)
end

@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w))

function fft!(
out::AbstractVector{T}, in::AbstractVector{T},
out::AbstractVector{T}, in::AbstractVector{<:Number},
start_out::Int, start_in::Int,
d::Direction,
t::FFTEnum,
g::CallGraph{T},
idx::Int
) where T
) where T
if t === COMPOSITE_FFT
fft_composite!(out, in, start_out, start_in, d, g, idx)
else
root = g[idx]
s_in = root.s_in
s_out = root.s_out
N = root.sz
w = _conj(root.w, d)
if t === DFT
fft_dft!(out, in, N, start_out, s_out, start_in, s_in, w)
fft_dft!(out, in, N, start_out, s_out, start_in, s_in, d)
elseif t === POW2RADIX4_FFT
fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, w)
fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, d)
elseif t === POW3_FFT
_m_120 = cispi(T(2) / 3)
m_120 = d === FFT_FORWARD ? _m_120 : conj(_m_120)
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, w, m_120)
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, m_120, d)
elseif t === BLUESTEIN
fft_bluestein!(out, in, d, N, start_out, s_out, start_in, s_in)
else
Expand All @@ -51,7 +48,13 @@ Cooley-Tukey composite FFT, with a pre-computed call graph
- `idx`: Index of the current transform in the call graph

"""
function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out::Int, start_in::Int, d::Direction, g::CallGraph{T}, idx::Int) where {T,U}
function fft_composite!(
out::AbstractVector{T}, in::AbstractVector{U},
start_out::Int, start_in::Int,
d::Direction,
g::CallGraph{T},
idx::Int
) where {T,U}
root = g[idx]
left_idx = idx + root.left
right_idx = idx + root.right
Expand All @@ -66,12 +69,8 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
Rt = right.type
Lt = left.type

w1 = _conj(root.w, d)
Rtype = real(T)
# The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`.
# Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2
# for fixed j1; (α, β) depend on j1 so we reset them at each outer step.
dir = twiddle_direction(w1)
dir = direction_sign(d)
tmp = g.workspace[idx]

if Rt === BLUESTEIN
Expand All @@ -90,11 +89,14 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
end

if j1 > 0
αi, βi = singleton_params(dir * Rtype(2 * j1) / Rtype(N))
ci, si = one(Rtype), zero(Rtype)
# The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`.
# Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2
# for fixed j1; (α, β) depend on j1 so we reset them at each outer step.
zj1 = singleton_params(dir * Rtype(j1) / Rtype(N))
wk2 = one(T)
@inbounds for k2 in 1:N2-1
ci, si = singleton_step(ci, si, αi, βi)
tmp[R_start_out + k2] *= Complex(ci, si)
wk2 = singleton_step(wk2, zj1)
tmp[R_start_out + k2] *= wk2
end
end
end
Expand Down Expand Up @@ -127,31 +129,43 @@ Discrete Fourier Transform, O(N^2) algorithm, in place.
- `stride_out`: Stride of the output vector
- `start_in`: Index of the first element of the input vector
- `stride_in`: Stride of the input vector
- `w`: The value `cispi(direction_sign(d) * 2 / N)`
- `d`: Direction of the transform

"""
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T}
function fft_dft!(
out::AbstractVector{T}, in::AbstractVector{T},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Complex}
tmp = in[start_in]
@inbounds for j in 1:N-1
tmp += in[start_in + j*stride_in]
end
out[start_out] = tmp

Rtype = real(T)
dir = twiddle_direction(w)
@inbounds for d in 1:N-1
t = in[start_in]
αk, βk = singleton_params(dir * Rtype(2 * d) / Rtype(N))
ck, sk = one(Rtype), zero(Rtype)
dir = direction_sign(d)
@inbounds for j in 1:N-1
tmp = in[start_in]
zj = singleton_params(dir * Rtype(j) / Rtype(N))
wk = one(T)
@inbounds for k in 1:N-1
ck, sk = singleton_step(ck, sk, αk, βk)
t += Complex(ck, sk) * in[start_in + k*stride_in]
wk = singleton_step(wk, zj)
tmp += wk * in[start_in + k*stride_in]
end
out[start_out + d*stride_out] = t
out[start_out + j*stride_out] = tmp
end
end

function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::Complex{T}) where {T<:Real}
function fft_dft!(
out::AbstractVector{Complex{T}}, in::AbstractVector{T},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Real}
halfN = N÷2

tmp = Complex{T}(in[start_in])
Expand All @@ -160,16 +174,17 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
end
out[start_out] = tmp

dir = twiddle_direction(w)
@inbounds for d in 1:halfN
t = Complex{T}(in[start_in])
αk, βk = singleton_params(dir * T(2 * d) / T(N))
ck, sk = one(T), zero(T)
dir = direction_sign(d)
@inbounds for j in 1:halfN
tmp = Complex{T}(in[start_in])
zj = singleton_params(dir * T(j) / T(N))
wk = one(Complex{T})
@inbounds for k in 1:N-1
ck, sk = singleton_step(ck, sk, αk, βk)
t += Complex{T}(ck, sk) * in[start_in + k*stride_in]
wk = singleton_step(wk, zj)
tmp += wk * in[start_in + k*stride_in]
end
out[start_out + d*stride_out] = t
out[start_out + j*stride_out] = tmp
out[start_out + (N-j)*stride_out] = conj(tmp)
end
end

Expand All @@ -186,19 +201,27 @@ Radix-4 FFT for powers of 2, in place
- `stride_out`: Stride of the output vector
- `start_in`: Index of the first element of the input vector
- `stride_in`: Stride of the input vector
- `w`: The value `cispi(direction_sign(d) * 2 / N)`
- `d`: Direction of the transform

"""
function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
function fft_pow2_radix4!(
out::AbstractVector{T}, in::AbstractVector{U},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Complex, U}
# If N is 2, compute the size two DFT
@inbounds if N == 2
out[start_out] = in[start_in] + in[start_in + stride_in]
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
return
end

dir = direction_sign(d)

# If N is 4, compute an unrolled radix-2 FFT and return
minusi = -sign(imag(w)) * im
minusi = -dir * im
@inbounds if N == 4
xee = in[start_in]
xoe = in[start_in + stride_in]
Expand All @@ -218,34 +241,30 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
# ...othersize split the problem in four and recur
m = N ÷ 4

Rtype = real(T)
dir = twiddle_direction(w)
# Recursive sub-problem step `cispi(dir · 2 / m) = w^4`; use `cispi`
# directly so the sub-tree gets a < 1 ULP starting phase.
w_sub = cispi(dir * Rtype(2) / Rtype(m))

fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d)

Rtype = real(T)
# Singleton recurrence for the three running twiddles `w^k`, `w^2k`, `w^3k`.
α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N))
α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N))
α3, β3 = singleton_params(dir * Rtype(6) / Rtype(N))
c1, s1 = one(Rtype), zero(Rtype)
c2, s2 = one(Rtype), zero(Rtype)
c3, s3 = one(Rtype), zero(Rtype)
z1 = singleton_params(dir * Rtype(1) / Rtype(N))
z2 = singleton_params(dir * Rtype(2) / Rtype(N))
z3 = singleton_params(dir * Rtype(3) / Rtype(N))

wkoe = one(T)
wkeo = one(T)
wkoo = one(T)

@inbounds for k in 0:m-1
kee = start_out + k * stride_out
koe = start_out + (k + m) * stride_out
keo = start_out + (k + 2 * m) * stride_out
koo = start_out + (k + 3 * m) * stride_out
y_kee, y_koe, y_keo, y_koo = out[kee], out[koe], out[keo], out[koo]
t_keo = y_keo * Complex(c2, s2)
t_koe = y_koe * Complex(c1, s1)
t_koo = y_koo * Complex(c3, s3)
t_koe = y_koe * wkoe
t_keo = y_keo * wkeo
t_koo = y_koo * wkoo
y_kee_p_y_keo = y_kee + t_keo
y_kee_m_y_keo = y_kee - t_keo
t_koe_p_t_koo = t_koe + t_koo
Expand All @@ -254,9 +273,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
out[koe] = y_kee_m_y_keo + t_koe_m_t_koo
out[keo] = y_kee_p_y_keo - t_koe_p_t_koo
out[koo] = y_kee_m_y_keo - t_koe_m_t_koo
c1, s1 = singleton_step(c1, s1, α1, β1)
c2, s2 = singleton_step(c2, s2, α2, β2)
c3, s3 = singleton_step(c3, s3, α3, β3)
wkoe = singleton_step(wkoe, z1)
wkeo = singleton_step(wkeo, z2)
wkoo = singleton_step(wkoo, z3)
end
end

Expand All @@ -273,12 +292,18 @@ Power of 3 FFT, in place
- `stride_out`: Stride of the output vector
- `start_in`: Index of the first element of the input vector
- `stride_in`: Stride of the input vector
- `w`: The value `cispi(direction_sign(d) * 2 / N)`
- `plus120`: Depending on direction, perform either ±120° rotation
- `minus120`: Depending on direction, perform either ∓120° rotation
- `d`: Direction of the transform

"""
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, minus120::T) where {T, U}
function fft_pow3!(
out::AbstractVector{T}, in::AbstractVector{U},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
minus120::T,
d::Direction
) where {T, U}
plus120 = conj(minus120)
if N == 3
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
Expand All @@ -290,32 +315,28 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
# Size of subproblem
Nprime = N ÷ 3

# Dividing into subproblems
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, minus120, d)
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, minus120, d)
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, minus120, d)

Rtype = real(T)
dir = twiddle_direction(w)
# Recursive sub-problem step cispi(dir · 2 / Nprime) = w^3.
w_sub = cispi(dir * Rtype(2) / Rtype(Nprime))
dir = direction_sign(d)

# Dividing into subproblems
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w_sub, minus120)
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w_sub, minus120)
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w_sub, minus120)

α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N))
α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N))
c1, s1 = one(Rtype), zero(Rtype)
c2, s2 = one(Rtype), zero(Rtype)
z1 = singleton_params(dir * Rtype(1) / Rtype(N))
z2 = singleton_params(dir * Rtype(2) / Rtype(N))
wk1 = one(T)
wk2 = one(T)
for k in 0:Nprime-1
k0 = start_out + stride_out * k
k1 = start_out + stride_out * (k + Nprime)
k2 = start_out + stride_out * (k + 2 * Nprime)
y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
wk1 = Complex(c1, s1)
wk2 = Complex(c2, s2)
@muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2
@muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120
@muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120
c1, s1 = singleton_step(c1, s1, α1, β1)
c2, s2 = singleton_step(c2, s2, α2, β2)
@muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2
@muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120
@muladd out[k2] = y_k0 + y_k1 * wk1 * minus120 + y_k2 * wk2 * plus120
wk1 = singleton_step(wk1, z1)
wk2 = singleton_step(wk2, z2)
end
end

Expand Down Expand Up @@ -362,17 +383,17 @@ with a power 2 FFT.
- `stride_out`: Stride of the output vector
- `start_in`: Index of the first element of the input vector
- `stride_in`: Stride of the input vector
- `w`: The value `cispi(direction_sign(d) * 2 / N)`
- `scratch` (optional): preallocated scratch space for bluestein

"""
function fft_bluestein!(
out::AbstractVector{T}, in::AbstractVector{T},
out::AbstractVector{T}, in::AbstractVector{<:Number},
d::Direction,
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
scratch::Tuple{Vector{T},Vector{T},Vector{T},Int}=prealloc_blue(N, d, T)
) where T<:Number
) where T<:Complex

(tmp, a_series, b_series, pad_len) = scratch

Expand All @@ -383,14 +404,13 @@ function fft_bluestein!(
a_series[i] = in[start_in+(i-1)*stride_in] * conj(b_series[i])
end

w_pad = cispi(T(2) / pad_len)
# leave b_n vector alone for last step
fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, w_pad) # Fa
fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, w_pad) # Fb
fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fa
fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fb

tmp .*= a_series
# convolution theorem ifft
fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, conj(w_pad))
fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, FFT_FORWARD)
conv_a_b = a_series

Xk = tmp
Expand Down
Loading
Loading