diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-27 23:16:51 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-27 23:16:51 -0700 |
| commit | f8a7d3fda6abef2923b0fa8f7021b9d0646ed8d5 (patch) | |
| tree | 017b0f7504064327928d62998cf5e68194be7e70 /generate_twiddle_tables.py | |
| parent | d90ab65bfc740ea7657ac8fca201bdfa8ecc7a22 (diff) | |
Fix twiddle factors in shader
Diffstat (limited to 'generate_twiddle_tables.py')
| -rw-r--r-- | generate_twiddle_tables.py | 121 |
1 files changed, 70 insertions, 51 deletions
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py index 152208c..de6c599 100644 --- a/generate_twiddle_tables.py +++ b/generate_twiddle_tables.py @@ -26,7 +26,19 @@ def generate_dft_matrix(radix, inverse=False): def format_complex(c): """Format complex number as float2""" - return f"float2({c.real:.9f}f, {c.imag:.9f}f)" + return f"float2({c.real}f, {c.imag}f)" + +def get_butterfly_sizes_for_config(n, radix): + """Get the exact butterfly sizes needed for a specific N and radix""" + butterfly_sizes = [] + num_stages = int(math.log(n) / math.log(radix)) + + for stage in range(num_stages): + span = n // (radix ** (stage + 1)) + butterfly_size = span * radix + butterfly_sizes.append(butterfly_size) + + return butterfly_sizes def generate_cginc(radices, max_size, output_file): """Generate .cginc file with twiddle factor tables""" @@ -70,57 +82,54 @@ def generate_cginc(radices, max_size, output_file): f.write("};\n") f.write("#endif\n\n") - # Generate stage twiddle factor tables (forward) - f.write("// Stage twiddle factors (forward FFT)\n") - for radix in radices: - butterfly_size = radix - stage_idx = 0 - - while butterfly_size <= max_size: - f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n") - f.write(f"static const float2 STAGE_TWIDDLES[{butterfly_size}] = {{\n") - - # Write twiddles in rows of 4 for readability - for i in range(0, butterfly_size, 4): - f.write(" ") - for j in range(4): - if i + j < butterfly_size: - c = twiddle(i + j, butterfly_size) - f.write(format_complex(c)) - if i + j < butterfly_size - 1: - f.write(", ") - f.write("\n") - - f.write("};\n") - f.write("#endif\n\n") - butterfly_size *= radix - stage_idx += 1 - - # Generate stage twiddle factor tables (inverse) - f.write("// Stage twiddle factors (inverse FFT)\n") + # Generate stage twiddle tables for each radix/size combination + f.write("// Stage twiddle factors for specific radix/size combinations\n") + for radix in radices: - butterfly_size = radix - stage_idx = 0 - - while butterfly_size <= max_size: - f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n") - f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n") - - # Write twiddles in rows of 4 for readability - for i in range(0, butterfly_size, 4): - f.write(" ") - for j in range(4): - if i + j < butterfly_size: - c = twiddle(i + j, butterfly_size, inverse=True) - f.write(format_complex(c)) - if i + j < butterfly_size - 1: - f.write(", ") - f.write("\n") - - f.write("};\n") - f.write("#endif\n\n") - butterfly_size *= radix - stage_idx += 1 + n = radix + while n <= max_size: + # Get the exact butterfly sizes for this configuration + butterfly_sizes = get_butterfly_sizes_for_config(n, radix) + + # Generate tables for this specific configuration + f.write(f"\n#if defined(GPU_FFT_RADIX{radix}_N{n})\n") + + # Generate a table for each stage + for stage_idx, butterfly_size in enumerate(butterfly_sizes): + f.write(f"// Stage {stage_idx}: butterfly_size = {butterfly_size}\n") + f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES[{butterfly_size}] = {{\n") + + for i in range(0, butterfly_size, 4): + f.write(" ") + line_items = [] + for j in range(4): + if i + j < butterfly_size: + c = twiddle(i + j, butterfly_size) + line_items.append(format_complex(c)) + f.write(", ".join(line_items)) + if i + 4 < butterfly_size: + f.write(",") + f.write("\n") + f.write("};\n") + + # Inverse version + f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES_INV[{butterfly_size}] = {{\n") + for i in range(0, butterfly_size, 4): + f.write(" ") + line_items = [] + for j in range(4): + if i + j < butterfly_size: + c = twiddle(i + j, butterfly_size, inverse=True) + line_items.append(format_complex(c)) + f.write(", ".join(line_items)) + if i + 4 < butterfly_size: + f.write(",") + f.write("\n") + f.write("};\n\n") + + f.write("#endif\n") + n *= radix + f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n") @@ -135,6 +144,16 @@ def main(): generate_cginc(radices, max_size, output_file) + # Print summary of what was generated + print("\nGenerated configurations:") + for radix in radices: + print(f"\n Radix {radix}:") + n = radix + while n <= max_size: + butterfly_sizes = get_butterfly_sizes_for_config(n, radix) + print(f" N={n}: stages use butterfly sizes {butterfly_sizes}") + n *= radix + if __name__ == "__main__": main() |
