summaryrefslogtreecommitdiffstats
path: root/generate_twiddle_tables.py
diff options
context:
space:
mode:
Diffstat (limited to 'generate_twiddle_tables.py')
-rw-r--r--generate_twiddle_tables.py58
1 files changed, 50 insertions, 8 deletions
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py
index 69f9da3..152208c 100644
--- a/generate_twiddle_tables.py
+++ b/generate_twiddle_tables.py
@@ -8,18 +8,19 @@ Uses static branching with preprocessor directives.
import math
import os
-def twiddle(k, N):
- """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)"""
- angle = -2.0 * math.pi * k / N
+def twiddle(k, N, inverse=False):
+ """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
+ or exp(+2*pi*i*k/N) for inverse FFT"""
+ angle = (2.0 if inverse else -2.0) * math.pi * k / N
return complex(math.cos(angle), math.sin(angle))
-def generate_dft_matrix(radix):
+def generate_dft_matrix(radix, inverse=False):
"""Generate DFT matrix for given radix"""
matrix = []
for k in range(radix):
row = []
for n in range(radix):
- row.append(twiddle(k * n, radix))
+ row.append(twiddle(k * n, radix, inverse))
matrix.append(row)
return matrix
@@ -35,7 +36,7 @@ def generate_cginc(radices, max_size, output_file):
f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n")
f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n")
- # Generate DFT matrices for each radix
+ # Generate DFT matrices for each radix (forward)
for radix in radices:
f.write(f"#if defined(GPU_FFT_RADIX{radix})\n")
f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n")
@@ -52,10 +53,25 @@ def generate_cginc(radices, max_size, output_file):
f.write(",")
f.write("\n")
f.write("};\n")
+
+ # Generate inverse DFT matrix
+ f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n")
+ matrix_inv = generate_dft_matrix(radix, inverse=True)
+ for k in range(radix):
+ f.write(" { ")
+ for n in range(radix):
+ f.write(format_complex(matrix_inv[k][n]))
+ if n < radix - 1:
+ f.write(", ")
+ f.write(" }")
+ if k < radix - 1:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n")
f.write("#endif\n\n")
- # Generate stage twiddle factor tables
- f.write("// Stage twiddle factors\n")
+ # Generate stage twiddle factor tables (forward)
+ f.write("// Stage twiddle factors (forward FFT)\n")
for radix in radices:
butterfly_size = radix
stage_idx = 0
@@ -80,6 +96,32 @@ def generate_cginc(radices, max_size, output_file):
butterfly_size *= radix
stage_idx += 1
+ # Generate stage twiddle factor tables (inverse)
+ f.write("// Stage twiddle factors (inverse FFT)\n")
+ for radix in radices:
+ butterfly_size = radix
+ stage_idx = 0
+
+ while butterfly_size <= max_size:
+ f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
+ f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n")
+
+ # Write twiddles in rows of 4 for readability
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size, inverse=True)
+ f.write(format_complex(c))
+ if i + j < butterfly_size - 1:
+ f.write(", ")
+ f.write("\n")
+
+ f.write("};\n")
+ f.write("#endif\n\n")
+ butterfly_size *= radix
+ stage_idx += 1
+
f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")
def main():