summaryrefslogtreecommitdiffstats
path: root/generate_twiddle_tables.py
blob: de6c599daf5175e3f9a75b4391860f00a4323e24 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
"""
Generate .cginc file with hard-coded twiddle factor tables for GPU FFT.
Uses full 32-bit float precision instead of textures.
Uses static branching with preprocessor directives.
"""

import math
import os

def twiddle(k, N, inverse=False):
    """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
    or exp(+2*pi*i*k/N) for inverse FFT"""
    angle = (2.0 if inverse else -2.0) * math.pi * k / N
    return complex(math.cos(angle), math.sin(angle))

def generate_dft_matrix(radix, inverse=False):
    """Generate DFT matrix for given radix"""
    matrix = []
    for k in range(radix):
        row = []
        for n in range(radix):
            row.append(twiddle(k * n, radix, inverse))
        matrix.append(row)
    return matrix

def format_complex(c):
    """Format complex number as float2"""
    return f"float2({c.real}f, {c.imag}f)"

def get_butterfly_sizes_for_config(n, radix):
    """Get the exact butterfly sizes needed for a specific N and radix"""
    butterfly_sizes = []
    num_stages = int(math.log(n) / math.log(radix))

    for stage in range(num_stages):
        span = n // (radix ** (stage + 1))
        butterfly_size = span * radix
        butterfly_sizes.append(butterfly_size)

    return butterfly_sizes

def generate_cginc(radices, max_size, output_file):
    """Generate .cginc file with twiddle factor tables"""

    with open(output_file, 'w') as f:
        f.write("// Auto-generated FFT twiddle factor tables\n")
        f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n")
        f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n")

        # Generate DFT matrices for each radix (forward)
        for radix in radices:
            f.write(f"#if defined(GPU_FFT_RADIX{radix})\n")
            f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n")

            matrix = generate_dft_matrix(radix)
            for k in range(radix):
                f.write("    { ")
                for n in range(radix):
                    f.write(format_complex(matrix[k][n]))
                    if n < radix - 1:
                        f.write(", ")
                f.write(" }")
                if k < radix - 1:
                    f.write(",")
                f.write("\n")
            f.write("};\n")

            # Generate inverse DFT matrix
            f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n")
            matrix_inv = generate_dft_matrix(radix, inverse=True)
            for k in range(radix):
                f.write("    { ")
                for n in range(radix):
                    f.write(format_complex(matrix_inv[k][n]))
                    if n < radix - 1:
                        f.write(", ")
                f.write(" }")
                if k < radix - 1:
                    f.write(",")
                f.write("\n")
            f.write("};\n")
            f.write("#endif\n\n")

        # Generate stage twiddle tables for each radix/size combination
        f.write("// Stage twiddle factors for specific radix/size combinations\n")

        for radix in radices:
            n = radix
            while n <= max_size:
                # Get the exact butterfly sizes for this configuration
                butterfly_sizes = get_butterfly_sizes_for_config(n, radix)

                # Generate tables for this specific configuration
                f.write(f"\n#if defined(GPU_FFT_RADIX{radix}_N{n})\n")

                # Generate a table for each stage
                for stage_idx, butterfly_size in enumerate(butterfly_sizes):
                    f.write(f"// Stage {stage_idx}: butterfly_size = {butterfly_size}\n")
                    f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES[{butterfly_size}] = {{\n")

                    for i in range(0, butterfly_size, 4):
                        f.write("    ")
                        line_items = []
                        for j in range(4):
                            if i + j < butterfly_size:
                                c = twiddle(i + j, butterfly_size)
                                line_items.append(format_complex(c))
                        f.write(", ".join(line_items))
                        if i + 4 < butterfly_size:
                            f.write(",")
                        f.write("\n")
                    f.write("};\n")

                    # Inverse version
                    f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES_INV[{butterfly_size}] = {{\n")
                    for i in range(0, butterfly_size, 4):
                        f.write("    ")
                        line_items = []
                        for j in range(4):
                            if i + j < butterfly_size:
                                c = twiddle(i + j, butterfly_size, inverse=True)
                                line_items.append(format_complex(c))
                        f.write(", ".join(line_items))
                        if i + 4 < butterfly_size:
                            f.write(",")
                        f.write("\n")
                    f.write("};\n\n")

                f.write("#endif\n")
                n *= radix


        f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")

def main():
    radices = [2, 4, 8, 16]
    max_size = 1024
    output_file = 'fft_twiddle_tables.cginc'

    print(f"Generating twiddle factor tables for radices: {radices}")
    print(f"Maximum FFT size: {max_size}")
    print(f"Output file: {output_file}")

    generate_cginc(radices, max_size, output_file)

    # Print summary of what was generated
    print("\nGenerated configurations:")
    for radix in radices:
        print(f"\n  Radix {radix}:")
        n = radix
        while n <= max_size:
            butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
            print(f"    N={n}: stages use butterfly sizes {butterfly_sizes}")
            n *= radix

if __name__ == "__main__":
    main()