summaryrefslogtreecommitdiffstats
path: root/generate_twiddle_tables.py
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-27 23:16:51 -0700
committeryum <yum.food.vr@gmail.com>2025-07-27 23:16:51 -0700
commitf8a7d3fda6abef2923b0fa8f7021b9d0646ed8d5 (patch)
tree017b0f7504064327928d62998cf5e68194be7e70 /generate_twiddle_tables.py
parentd90ab65bfc740ea7657ac8fca201bdfa8ecc7a22 (diff)
Fix twiddle factors in shader
Diffstat (limited to 'generate_twiddle_tables.py')
-rw-r--r--generate_twiddle_tables.py121
1 files changed, 70 insertions, 51 deletions
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py
index 152208c..de6c599 100644
--- a/generate_twiddle_tables.py
+++ b/generate_twiddle_tables.py
@@ -26,7 +26,19 @@ def generate_dft_matrix(radix, inverse=False):
def format_complex(c):
"""Format complex number as float2"""
- return f"float2({c.real:.9f}f, {c.imag:.9f}f)"
+ return f"float2({c.real}f, {c.imag}f)"
+
+def get_butterfly_sizes_for_config(n, radix):
+ """Get the exact butterfly sizes needed for a specific N and radix"""
+ butterfly_sizes = []
+ num_stages = int(math.log(n) / math.log(radix))
+
+ for stage in range(num_stages):
+ span = n // (radix ** (stage + 1))
+ butterfly_size = span * radix
+ butterfly_sizes.append(butterfly_size)
+
+ return butterfly_sizes
def generate_cginc(radices, max_size, output_file):
"""Generate .cginc file with twiddle factor tables"""
@@ -70,57 +82,54 @@ def generate_cginc(radices, max_size, output_file):
f.write("};\n")
f.write("#endif\n\n")
- # Generate stage twiddle factor tables (forward)
- f.write("// Stage twiddle factors (forward FFT)\n")
- for radix in radices:
- butterfly_size = radix
- stage_idx = 0
-
- while butterfly_size <= max_size:
- f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
- f.write(f"static const float2 STAGE_TWIDDLES[{butterfly_size}] = {{\n")
-
- # Write twiddles in rows of 4 for readability
- for i in range(0, butterfly_size, 4):
- f.write(" ")
- for j in range(4):
- if i + j < butterfly_size:
- c = twiddle(i + j, butterfly_size)
- f.write(format_complex(c))
- if i + j < butterfly_size - 1:
- f.write(", ")
- f.write("\n")
-
- f.write("};\n")
- f.write("#endif\n\n")
- butterfly_size *= radix
- stage_idx += 1
-
- # Generate stage twiddle factor tables (inverse)
- f.write("// Stage twiddle factors (inverse FFT)\n")
+ # Generate stage twiddle tables for each radix/size combination
+ f.write("// Stage twiddle factors for specific radix/size combinations\n")
+
for radix in radices:
- butterfly_size = radix
- stage_idx = 0
-
- while butterfly_size <= max_size:
- f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
- f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n")
-
- # Write twiddles in rows of 4 for readability
- for i in range(0, butterfly_size, 4):
- f.write(" ")
- for j in range(4):
- if i + j < butterfly_size:
- c = twiddle(i + j, butterfly_size, inverse=True)
- f.write(format_complex(c))
- if i + j < butterfly_size - 1:
- f.write(", ")
- f.write("\n")
-
- f.write("};\n")
- f.write("#endif\n\n")
- butterfly_size *= radix
- stage_idx += 1
+ n = radix
+ while n <= max_size:
+ # Get the exact butterfly sizes for this configuration
+ butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
+
+ # Generate tables for this specific configuration
+ f.write(f"\n#if defined(GPU_FFT_RADIX{radix}_N{n})\n")
+
+ # Generate a table for each stage
+ for stage_idx, butterfly_size in enumerate(butterfly_sizes):
+ f.write(f"// Stage {stage_idx}: butterfly_size = {butterfly_size}\n")
+ f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES[{butterfly_size}] = {{\n")
+
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ line_items = []
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size)
+ line_items.append(format_complex(c))
+ f.write(", ".join(line_items))
+ if i + 4 < butterfly_size:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n")
+
+ # Inverse version
+ f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES_INV[{butterfly_size}] = {{\n")
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ line_items = []
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size, inverse=True)
+ line_items.append(format_complex(c))
+ f.write(", ".join(line_items))
+ if i + 4 < butterfly_size:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n\n")
+
+ f.write("#endif\n")
+ n *= radix
+
f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")
@@ -135,6 +144,16 @@ def main():
generate_cginc(radices, max_size, output_file)
+ # Print summary of what was generated
+ print("\nGenerated configurations:")
+ for radix in radices:
+ print(f"\n Radix {radix}:")
+ n = radix
+ while n <= max_size:
+ butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
+ print(f" N={n}: stages use butterfly sizes {butterfly_sizes}")
+ n *= radix
+
if __name__ == "__main__":
main()