diff options
Diffstat (limited to 'generate_twiddle_tables.py')
| -rw-r--r-- | generate_twiddle_tables.py | 58 |
1 files changed, 50 insertions, 8 deletions
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py index 69f9da3..152208c 100644 --- a/generate_twiddle_tables.py +++ b/generate_twiddle_tables.py @@ -8,18 +8,19 @@ Uses static branching with preprocessor directives. import math import os -def twiddle(k, N): - """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)""" - angle = -2.0 * math.pi * k / N +def twiddle(k, N, inverse=False): + """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT + or exp(+2*pi*i*k/N) for inverse FFT""" + angle = (2.0 if inverse else -2.0) * math.pi * k / N return complex(math.cos(angle), math.sin(angle)) -def generate_dft_matrix(radix): +def generate_dft_matrix(radix, inverse=False): """Generate DFT matrix for given radix""" matrix = [] for k in range(radix): row = [] for n in range(radix): - row.append(twiddle(k * n, radix)) + row.append(twiddle(k * n, radix, inverse)) matrix.append(row) return matrix @@ -35,7 +36,7 @@ def generate_cginc(radices, max_size, output_file): f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n") f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n") - # Generate DFT matrices for each radix + # Generate DFT matrices for each radix (forward) for radix in radices: f.write(f"#if defined(GPU_FFT_RADIX{radix})\n") f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n") @@ -52,10 +53,25 @@ def generate_cginc(radices, max_size, output_file): f.write(",") f.write("\n") f.write("};\n") + + # Generate inverse DFT matrix + f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n") + matrix_inv = generate_dft_matrix(radix, inverse=True) + for k in range(radix): + f.write(" { ") + for n in range(radix): + f.write(format_complex(matrix_inv[k][n])) + if n < radix - 1: + f.write(", ") + f.write(" }") + if k < radix - 1: + f.write(",") + f.write("\n") + f.write("};\n") f.write("#endif\n\n") - # Generate stage twiddle factor tables - f.write("// Stage twiddle factors\n") + # Generate stage twiddle factor tables (forward) + f.write("// Stage twiddle factors (forward FFT)\n") for radix in radices: butterfly_size = radix stage_idx = 0 @@ -80,6 +96,32 @@ def generate_cginc(radices, max_size, output_file): butterfly_size *= radix stage_idx += 1 + # Generate stage twiddle factor tables (inverse) + f.write("// Stage twiddle factors (inverse FFT)\n") + for radix in radices: + butterfly_size = radix + stage_idx = 0 + + while butterfly_size <= max_size: + f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n") + f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n") + + # Write twiddles in rows of 4 for readability + for i in range(0, butterfly_size, 4): + f.write(" ") + for j in range(4): + if i + j < butterfly_size: + c = twiddle(i + j, butterfly_size, inverse=True) + f.write(format_complex(c)) + if i + j < butterfly_size - 1: + f.write(", ") + f.write("\n") + + f.write("};\n") + f.write("#endif\n\n") + butterfly_size *= radix + stage_idx += 1 + f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n") def main(): |
