1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
|
#!/usr/bin/env python3
"""
Generate .cginc file with hard-coded twiddle factor tables for GPU FFT.
Uses full 32-bit float precision instead of textures.
Uses static branching with preprocessor directives.
"""
import math
import os
def twiddle(k, N, inverse=False):
"""Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
or exp(+2*pi*i*k/N) for inverse FFT"""
angle = (2.0 if inverse else -2.0) * math.pi * k / N
return complex(math.cos(angle), math.sin(angle))
def generate_dft_matrix(radix, inverse=False):
"""Generate DFT matrix for given radix"""
matrix = []
for k in range(radix):
row = []
for n in range(radix):
row.append(twiddle(k * n, radix, inverse))
matrix.append(row)
return matrix
def format_complex(c):
"""Format complex number as float2"""
return f"float2({c.real}f, {c.imag}f)"
def get_butterfly_sizes_for_config(n, radix):
"""Get the exact butterfly sizes needed for a specific N and radix"""
butterfly_sizes = []
num_stages = int(math.log(n) / math.log(radix))
for stage in range(num_stages):
span = n // (radix ** (stage + 1))
butterfly_size = span * radix
butterfly_sizes.append(butterfly_size)
return butterfly_sizes
def generate_cginc(radices, max_size, output_file):
"""Generate .cginc file with twiddle factor tables"""
with open(output_file, 'w') as f:
f.write("// Auto-generated FFT twiddle factor tables\n")
f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n")
f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n")
# Generate DFT matrices for each radix (forward)
for radix in radices:
f.write(f"#if defined(GPU_FFT_RADIX{radix})\n")
f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n")
matrix = generate_dft_matrix(radix)
for k in range(radix):
f.write(" { ")
for n in range(radix):
f.write(format_complex(matrix[k][n]))
if n < radix - 1:
f.write(", ")
f.write(" }")
if k < radix - 1:
f.write(",")
f.write("\n")
f.write("};\n")
# Generate inverse DFT matrix
f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n")
matrix_inv = generate_dft_matrix(radix, inverse=True)
for k in range(radix):
f.write(" { ")
for n in range(radix):
f.write(format_complex(matrix_inv[k][n]))
if n < radix - 1:
f.write(", ")
f.write(" }")
if k < radix - 1:
f.write(",")
f.write("\n")
f.write("};\n")
f.write("#endif\n\n")
# Generate stage twiddle tables for each radix/size combination
f.write("// Stage twiddle factors for specific radix/size combinations\n")
for radix in radices:
n = radix
while n <= max_size:
# Get the exact butterfly sizes for this configuration
butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
# Generate tables for this specific configuration
f.write(f"\n#if defined(GPU_FFT_RADIX{radix}_N{n})\n")
# Generate a table for each stage
for stage_idx, butterfly_size in enumerate(butterfly_sizes):
f.write(f"// Stage {stage_idx}: butterfly_size = {butterfly_size}\n")
f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES[{butterfly_size}] = {{\n")
for i in range(0, butterfly_size, 4):
f.write(" ")
line_items = []
for j in range(4):
if i + j < butterfly_size:
c = twiddle(i + j, butterfly_size)
line_items.append(format_complex(c))
f.write(", ".join(line_items))
if i + 4 < butterfly_size:
f.write(",")
f.write("\n")
f.write("};\n")
# Inverse version
f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES_INV[{butterfly_size}] = {{\n")
for i in range(0, butterfly_size, 4):
f.write(" ")
line_items = []
for j in range(4):
if i + j < butterfly_size:
c = twiddle(i + j, butterfly_size, inverse=True)
line_items.append(format_complex(c))
f.write(", ".join(line_items))
if i + 4 < butterfly_size:
f.write(",")
f.write("\n")
f.write("};\n\n")
f.write("#endif\n")
n *= radix
f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")
def main():
radices = [2, 4, 8, 16]
max_size = 1024
output_file = 'fft_twiddle_tables.cginc'
print(f"Generating twiddle factor tables for radices: {radices}")
print(f"Maximum FFT size: {max_size}")
print(f"Output file: {output_file}")
generate_cginc(radices, max_size, output_file)
# Print summary of what was generated
print("\nGenerated configurations:")
for radix in radices:
print(f"\n Radix {radix}:")
n = radix
while n <= max_size:
butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
print(f" N={n}: stages use butterfly sizes {butterfly_sizes}")
n *= radix
if __name__ == "__main__":
main()
|