summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-27 23:16:51 -0700
committeryum <yum.food.vr@gmail.com>2025-07-27 23:16:51 -0700
commitf8a7d3fda6abef2923b0fa8f7021b9d0646ed8d5 (patch)
tree017b0f7504064327928d62998cf5e68194be7e70
parentd90ab65bfc740ea7657ac8fca201bdfa8ecc7a22 (diff)
Fix twiddle factors in shader
-rw-r--r--fft.shader148
-rw-r--r--generate_fft_reference.py59
-rw-r--r--generate_twiddle_tables.py121
-rw-r--r--gpu_fft.cc245
-rw-r--r--mandrill_256x256.exrbin0 -> 532830 bytes
-rw-r--r--mandrill_256x256.exr.meta127
-rw-r--r--mandrill_256x256.pngbin0 -> 626099 bytes
-rw-r--r--mandrill_256x256.png.meta127
8 files changed, 676 insertions, 151 deletions
diff --git a/fft.shader b/fft.shader
index 7684853..a6f7d28 100644
--- a/fft.shader
+++ b/fft.shader
@@ -6,9 +6,11 @@ Shader "yum_food/fft"
_N ("N", Int) = 256
_Radix ("Radix", Int) = 16
_Stage ("Stage", Int) = 0
- [Toggle] _PassThrough ("Pass Through", Float) = 0
+ [Toggle] _Passthrough ("Pass Through", Float) = 0
[Toggle] _LDS ("Temporal LDS", Float) = 0
+ [Toggle] _Luminance ("Luminance", Float) = 0
[Toggle] _Inverse ("Inverse FFT", Float) = 0
+ [Toggle] _BitReversal ("Bit Reversal Only", Float) = 0
}
SubShader
{
@@ -23,7 +25,7 @@ Shader "yum_food/fft"
#include "UnityCG.cginc"
#define GPU_FFT_RADIX16
- #define GPU_FFT_RADIX16_SIZE256
+ #define GPU_FFT_RADIX16_N256
#include "fft_twiddle_tables.cginc"
struct appdata
@@ -36,7 +38,7 @@ Shader "yum_food/fft"
{
float2 uv : TEXCOORD0;
float4 vertex : SV_POSITION;
- int num_stages_per_dim : TEXCOORD1;
+ uint num_stages_per_dim : TEXCOORD1;
int span : TEXCOORD2;
int butterfly_size : TEXCOORD3;
int num_stages : TEXCOORD4;
@@ -47,9 +49,11 @@ Shader "yum_food/fft"
int _N;
int _Radix;
int _Stage;
- float _PassThrough;
+ float _Passthrough;
float _LDS;
+ float _Luminance;
float _Inverse;
+ float _BitReversal;
#define PHI 1.618033988749894
@@ -64,6 +68,21 @@ Shader "yum_food/fft"
return result;
}
+ // Generalized digit reversal for any radix
+ uint reverse_digits(uint n, uint num_digits, uint radix)
+ {
+ uint bits_per_digit = (uint)(log2(radix));
+ uint digit_mask = radix - 1;
+ uint reversed = 0;
+
+ for (uint i = 0; i < num_digits; i++)
+ {
+ uint digit = (n >> (bits_per_digit * i)) & digit_mask;
+ reversed |= digit << (bits_per_digit * (num_digits - 1 - i));
+ }
+ return reversed;
+ }
+
v2f vert (appdata v)
{
v2f o;
@@ -71,7 +90,7 @@ Shader "yum_food/fft"
o.uv = v.uv;
// Calculate num_stages_per_dim = log_radix(N)
- o.num_stages_per_dim = (int)(log(_N) / log(_Radix));
+ o.num_stages_per_dim = (uint)(log(_N) / log(_Radix));
// Determine current stage (0-based index within row or column passes)
int current_stage = (_Stage < o.num_stages_per_dim) ? _Stage : (_Stage - o.num_stages_per_dim);
@@ -89,37 +108,51 @@ Shader "yum_food/fft"
fixed4 frag (v2f i) : SV_Target
{
+ // Extract coordinates
int2 pixel_index = (int2)(i.uv * _N);
- float2 uv = (pixel_index + 0.5f) / _N;
+ int x = pixel_index.x;
+ int y = pixel_index.y;
- // If pass through is enabled, just return the input
- if (_PassThrough > 0.5)
+ // Bit reversal mode
+ if (_BitReversal > 0.5)
{
+ uint num_digits = i.num_stages_per_dim;
+ uint rev_x = reverse_digits((uint)x, num_digits, (uint)_Radix);
+ uint rev_y = reverse_digits((uint)y, num_digits, (uint)_Radix);
+
+ float2 rev_uv = float2((rev_x + 0.5) / (float) _N, (rev_y + 0.5) / (float) _N);
+ float4 col = _MainTex.SampleLevel(point_clamp_s, rev_uv, 0);
+ return col;
+ }
+
+ // Pass through mode
+ if (_Passthrough > 0.5)
+ {
+ float2 uv = (pixel_index + 0.5f) / _N;
float3 col = _MainTex.SampleLevel(point_clamp_s, uv, 0).rgb;
if (_LDS > 0.5) {
col += PHI * _Time[0];
col = frac(col);
}
- float lum = luminance(col);
- return float4(lum, lum, lum, 1);
+ if (_Luminance > 0.5) {
+ col = luminance(col);
+ }
+ return float4(col, 1);
}
- // Calculate pixel index from UV coordinates
- int x = pixel_index.x;
- int y = pixel_index.y;
-
- // Determine processing direction (row stage or column stage)
+ // Determine processing direction
bool is_row_stage = (_Stage < i.num_stages_per_dim);
int coord = is_row_stage ? x : y;
- // Calculate butterfly indices
+ // Calculate butterfly indices (simple integer math)
int group = coord / i.butterfly_size;
int idx_in_group = coord % i.butterfly_size;
int wing = idx_in_group / i.span;
int idx_in_wing = idx_in_group % i.span;
// Accumulate DFT sum
- float2 sum = float2(0.0, 0.0);
+ float sum_real = 0.0;
+ float sum_imag = 0.0;
// Main DFT loop
for (int j = 0; j < _Radix; j++)
@@ -127,67 +160,74 @@ Shader "yum_food/fft"
// Calculate input position
int input_pos = group * i.butterfly_size + j * i.span + idx_in_wing;
- // Calculate UV for input texture read
- float2 input_uv;
+ // Read input value
+ float in_real, in_imag;
if (is_row_stage)
{
- input_uv = float2((input_pos + 0.5) / (float)_N, i.uv.y);
+ float2 input_uv = float2((input_pos + 0.5) / (float)_N, i.uv.y);
+ float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
+ if (_Stage == 0) {
+ // Assume that input is grayscale and real-valued.
+ in_real = input_tex.x;
+ in_imag = 0;
+ } else {
+ in_real = input_tex.x;
+ in_imag = input_tex.y;
+ }
}
else
{
float xuv = (x + 0.5) / _N;
- input_uv = float2(xuv, (input_pos + 0.5) / (float)_N);
+ float2 input_uv = float2(xuv, (input_pos + 0.5) / (float)_N);
+ float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
+ in_real = input_tex.x;
+ in_imag = input_tex.y;
}
- // Read input value
- float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
- float2 input_val;
- if (_Stage == 0) {
- input_val.x = luminance(input_tex.xyz);
- input_val.y = 0;
- } else {
- input_val.x = input_tex.x + input_tex.y;
- input_val.y = input_tex.z + input_tex.w;
- }
-
- // Read DFT coefficient from the table (use inverse matrix if _Inverse is set)
+ // Read DFT coefficient
float2 coeff = _Inverse > 0.5 ? IDFT_MATRIX[wing][j] : DFT_MATRIX[wing][j];
+ float coeff_real = coeff.x;
+ float coeff_imag = coeff.y;
// Complex multiply-accumulate
- sum.x += coeff.x * input_val.x - coeff.y * input_val.y;
- sum.y += coeff.x * input_val.y + coeff.y * input_val.x;
+ sum_real += coeff_real * in_real - coeff_imag * in_imag;
+ sum_imag += coeff_real * in_imag + coeff_imag * in_real;
}
// Apply stage twiddle if needed
+ float out_real, out_imag;
if (wing > 0 && idx_in_wing > 0)
{
int twiddle_idx = wing * idx_in_wing;
- float2 tw = _Inverse > 0.5 ? STAGE_TWIDDLES_INV[twiddle_idx] : STAGE_TWIDDLES[twiddle_idx];
+ float2 tw;
+
+ if (_Stage % 2 == 0) {
+ tw = _Inverse > 0.5 ? STAGE0_TWIDDLES_INV[twiddle_idx] : STAGE0_TWIDDLES[twiddle_idx];
+ } else {
+ tw = _Inverse > 0.5 ? STAGE1_TWIDDLES_INV[twiddle_idx] : STAGE1_TWIDDLES[twiddle_idx];
+ }
+
+ float tw_real = tw.x;
+ float tw_imag = tw.y;
// Output = twiddle * sum
- float2 output;
- output.x = tw.x * sum.x - tw.y * sum.y;
- output.y = tw.x * sum.y + tw.y * sum.x;
- sum = output;
+ out_real = tw_real * sum_real - tw_imag * sum_imag;
+ out_imag = tw_real * sum_imag + tw_imag * sum_real;
+ }
+ else
+ {
+ out_real = sum_real;
+ out_imag = sum_imag;
}
- // Pack complex result into RGBA.
- float real_part = sum.x;
- float imag_part = sum.y;
-
+ // Handle final stage of inverse FFT
if (_Inverse > 0.5 && _Stage == i.num_stages_per_dim * 2 - 1) {
- // Last stage of IFFT is just back to the original real-valued signal.
- real_part /= _N * i.num_stages_per_dim;
- return float4(real_part, real_part, real_part, 1);
+ float normalized = out_real / (_N * _N);
+ return float4(normalized, normalized, normalized, 1);
}
- // Split into 2 parts.
- float real_big = floor(real_part);
- float real_small = real_part - real_big;
- float imag_big = floor(imag_part);
- float imag_small = imag_part - imag_big;
-
- return float4(real_big, real_small, imag_big, imag_small);
+ // Pack complex result into RGBA
+ return float4(out_real, out_imag, 0, 1);
}
ENDCG
}
diff --git a/generate_fft_reference.py b/generate_fft_reference.py
new file mode 100644
index 0000000..a7d58f4
--- /dev/null
+++ b/generate_fft_reference.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+"""
+Generate reference FFT image in OpenEXR format, matching the shader's output format.
+"""
+
+import numpy as np
+import argparse
+from PIL import Image
+import OpenEXR
+import Imath
+
+def main():
+ parser = argparse.ArgumentParser(description='Generate reference FFT image')
+ parser.add_argument('input', type=str, help='Input image file')
+ parser.add_argument('output', type=str, help='Output EXR file')
+ parser.add_argument('--size', type=int, default=256, help='Size of the FFT (NxN)')
+ args = parser.parse_args()
+
+ # Load input image and convert to luminance
+ img = Image.open(args.input).convert('RGB')
+ img = img.resize((args.size, args.size), Image.Resampling.LANCZOS)
+ img_array = np.array(img, dtype=np.float32) / 255.0
+ luminance = 0.2126 * img_array[:,:,0] + 0.7152 * img_array[:,:,1] + 0.0722 * img_array[:,:,2]
+
+ # Perform 2D FFT (no fftshift, matching GPU implementation)
+ fft_result = np.fft.fft2(luminance)
+
+ # Pack complex numbers into RGBA matching shader format:
+ # R: real part, G: imaginary part, B: 0, A: 1
+ real_part = fft_result.real.astype(np.float32)
+ imag_part = fft_result.imag.astype(np.float32)
+
+ # Create EXR
+ height, width = args.size, args.size
+ header = OpenEXR.Header(width, height)
+ half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
+ header['channels'] = {'R': half_chan, 'G': half_chan, 'B': half_chan, 'A': half_chan}
+
+ # Write EXR
+ out = OpenEXR.OutputFile(args.output, header)
+
+ # Create zero and one arrays for B and A channels
+ zeros = np.zeros((height, width), dtype=np.float32)
+ ones = np.ones((height, width), dtype=np.float32)
+
+ out.writePixels({
+ 'R': real_part.astype(np.float32).tobytes(),
+ 'G': imag_part.astype(np.float32).tobytes(),
+ 'B': zeros.tobytes(),
+ 'A': ones.tobytes()
+ })
+ out.close()
+
+ print(f"FFT complete. Output saved to {args.output}")
+ print(f"Real range: [{real_part.min():.1f}, {real_part.max():.1f}]")
+ print(f"Imag range: [{imag_part.min():.1f}, {imag_part.max():.1f}]")
+
+if __name__ == "__main__":
+ main()
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py
index 152208c..de6c599 100644
--- a/generate_twiddle_tables.py
+++ b/generate_twiddle_tables.py
@@ -26,7 +26,19 @@ def generate_dft_matrix(radix, inverse=False):
def format_complex(c):
"""Format complex number as float2"""
- return f"float2({c.real:.9f}f, {c.imag:.9f}f)"
+ return f"float2({c.real}f, {c.imag}f)"
+
+def get_butterfly_sizes_for_config(n, radix):
+ """Get the exact butterfly sizes needed for a specific N and radix"""
+ butterfly_sizes = []
+ num_stages = int(math.log(n) / math.log(radix))
+
+ for stage in range(num_stages):
+ span = n // (radix ** (stage + 1))
+ butterfly_size = span * radix
+ butterfly_sizes.append(butterfly_size)
+
+ return butterfly_sizes
def generate_cginc(radices, max_size, output_file):
"""Generate .cginc file with twiddle factor tables"""
@@ -70,57 +82,54 @@ def generate_cginc(radices, max_size, output_file):
f.write("};\n")
f.write("#endif\n\n")
- # Generate stage twiddle factor tables (forward)
- f.write("// Stage twiddle factors (forward FFT)\n")
- for radix in radices:
- butterfly_size = radix
- stage_idx = 0
-
- while butterfly_size <= max_size:
- f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
- f.write(f"static const float2 STAGE_TWIDDLES[{butterfly_size}] = {{\n")
-
- # Write twiddles in rows of 4 for readability
- for i in range(0, butterfly_size, 4):
- f.write(" ")
- for j in range(4):
- if i + j < butterfly_size:
- c = twiddle(i + j, butterfly_size)
- f.write(format_complex(c))
- if i + j < butterfly_size - 1:
- f.write(", ")
- f.write("\n")
-
- f.write("};\n")
- f.write("#endif\n\n")
- butterfly_size *= radix
- stage_idx += 1
-
- # Generate stage twiddle factor tables (inverse)
- f.write("// Stage twiddle factors (inverse FFT)\n")
+ # Generate stage twiddle tables for each radix/size combination
+ f.write("// Stage twiddle factors for specific radix/size combinations\n")
+
for radix in radices:
- butterfly_size = radix
- stage_idx = 0
-
- while butterfly_size <= max_size:
- f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
- f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n")
-
- # Write twiddles in rows of 4 for readability
- for i in range(0, butterfly_size, 4):
- f.write(" ")
- for j in range(4):
- if i + j < butterfly_size:
- c = twiddle(i + j, butterfly_size, inverse=True)
- f.write(format_complex(c))
- if i + j < butterfly_size - 1:
- f.write(", ")
- f.write("\n")
-
- f.write("};\n")
- f.write("#endif\n\n")
- butterfly_size *= radix
- stage_idx += 1
+ n = radix
+ while n <= max_size:
+ # Get the exact butterfly sizes for this configuration
+ butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
+
+ # Generate tables for this specific configuration
+ f.write(f"\n#if defined(GPU_FFT_RADIX{radix}_N{n})\n")
+
+ # Generate a table for each stage
+ for stage_idx, butterfly_size in enumerate(butterfly_sizes):
+ f.write(f"// Stage {stage_idx}: butterfly_size = {butterfly_size}\n")
+ f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES[{butterfly_size}] = {{\n")
+
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ line_items = []
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size)
+ line_items.append(format_complex(c))
+ f.write(", ".join(line_items))
+ if i + 4 < butterfly_size:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n")
+
+ # Inverse version
+ f.write(f"static const float2 STAGE{stage_idx}_TWIDDLES_INV[{butterfly_size}] = {{\n")
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ line_items = []
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size, inverse=True)
+ line_items.append(format_complex(c))
+ f.write(", ".join(line_items))
+ if i + 4 < butterfly_size:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n\n")
+
+ f.write("#endif\n")
+ n *= radix
+
f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")
@@ -135,6 +144,16 @@ def main():
generate_cginc(radices, max_size, output_file)
+ # Print summary of what was generated
+ print("\nGenerated configurations:")
+ for radix in radices:
+ print(f"\n Radix {radix}:")
+ n = radix
+ while n <= max_size:
+ butterfly_sizes = get_butterfly_sizes_for_config(n, radix)
+ print(f" N={n}: stages use butterfly sizes {butterfly_sizes}")
+ n *= radix
+
if __name__ == "__main__":
main()
diff --git a/gpu_fft.cc b/gpu_fft.cc
index f5dd97e..5b3f4ab 100644
--- a/gpu_fft.cc
+++ b/gpu_fft.cc
@@ -15,6 +15,13 @@
#include <random>
#include <vector>
+struct float2 {
+ float x, y;
+};
+#define GPU_FFT_RADIX16
+#define GPU_FFT_RADIX16_N256
+#include "fft_twiddle_tables.cginc"
+
// This is a reference Cooley-Tukey FFT implementation. It's just here to check
// for correctness.
std::vector<std::complex<float>> fft1d_naive(
@@ -110,8 +117,32 @@ unsigned int reverse_digits(unsigned int n, unsigned int num_digits, unsigned in
// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
// or exp(+2*pi*i*k/N) for inverse FFT
gpu_complex twiddle_factor(int k, int N, bool inverse = false) {
- float angle = (inverse ? 2.0f : -2.0f) * std::numbers::pi * k / N;
- return {std::cos(angle), std::sin(angle)};
+ // Use double precision for angle computation, then cast to float
+ // This matches how the precomputed tables were generated
+ double angle = (inverse ? 2.0 : -2.0) * std::numbers::pi * k / N;
+ return {(float)std::cos(angle), (float)std::sin(angle)};
+}
+
+// Helper to compute radix^power using integer arithmetic
+int int_pow(int base, int exp) {
+ int result = 1;
+ for (int i = 0; i < exp; ++i) {
+ result *= base;
+ }
+ return result;
+}
+
+// Apply 2D bit reversal to the output
+void apply_2d_bit_reversal(const int n, const int radix, const stage_texture& in, stage_texture& out) {
+ const int num_digits = std::log2(n) / std::log2(radix);
+
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ const int rev_x = reverse_digits(x, num_digits, radix);
+ const int rev_y = reverse_digits(y, num_digits, radix);
+ out[y][x] = in[rev_y][rev_x];
+ }
+ }
}
// Main shader function - simplified for GPU/HLSL conversion
@@ -183,15 +214,42 @@ void evaluate_stage(
}
}
-// Apply 2D bit reversal to the output
-void apply_2d_bit_reversal(const int n, const int radix, const stage_texture& in, stage_texture& out) {
- const int num_digits = std::log2(n) / std::log2(radix);
+// Evaluate all stages - unified function
+void evaluate_stages(
+ const int n,
+ const int radix,
+ const bool inverse,
+ std::vector<stage_texture>& textures,
+ const std::vector<std::vector<gpu_complex>>& dft_matrix,
+ const std::vector<std::vector<gpu_complex>>& stage_twiddles_array) {
- for (int y = 0; y < n; ++y) {
- for (int x = 0; x < n; ++x) {
- const int rev_x = reverse_digits(x, num_digits, radix);
- const int rev_y = reverse_digits(y, num_digits, radix);
- out[y][x] = in[rev_y][rev_x];
+ const int num_stages_per_dim = std::log2(n) / std::log2(radix);
+ const int num_stages = num_stages_per_dim * 2;
+
+ for (int stage = 0; stage < num_stages; ++stage) {
+ int current_stage = (stage < num_stages_per_dim) ? stage : (stage - num_stages_per_dim);
+ int span = n / int_pow(radix, current_stage + 1);
+ int butterfly_size = span * radix;
+
+ const std::vector<gpu_complex>& stage_twiddles = stage_twiddles_array[current_stage];
+
+ ShaderUniforms uniforms = {n, radix, stage, num_stages_per_dim, span, butterfly_size,
+ inverse, dft_matrix, stage_twiddles};
+ evaluate_stage(uniforms, textures[stage], textures[stage+1]);
+ }
+
+ // Apply bit reversal once at the end
+ stage_texture temp = textures[num_stages];
+ apply_2d_bit_reversal(n, radix, temp, textures[num_stages]);
+
+ // For inverse FFT, normalize by 1/(n*n)
+ if (inverse) {
+ float norm_factor = 1.0f / (n * n);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ textures[num_stages][y][x].first *= norm_factor;
+ textures[num_stages][y][x].second *= norm_factor;
+ }
}
}
}
@@ -208,15 +266,6 @@ std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix, bool in
return twiddle_factors;
}
-// Helper to compute radix^power using integer arithmetic
-int int_pow(int base, int exp) {
- int result = 1;
- for (int i = 0; i < exp; ++i) {
- result *= base;
- }
- return result;
-}
-
// Precompute stage twiddle factors
std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size, bool inverse = false) {
std::vector<gpu_complex> stage_twiddles(butterfly_size);
@@ -226,46 +275,142 @@ std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size, bool inverse
return stage_twiddles;
}
-// Evaluate all stages.
-void evaluate_stages(
+// Convert float2 arrays to gpu_complex vectors
+std::vector<std::vector<gpu_complex>> convert_dft_matrix(bool inverse = false) {
+ std::vector<std::vector<gpu_complex>> result(16, std::vector<gpu_complex>(16));
+ for (int i = 0; i < 16; ++i) {
+ for (int j = 0; j < 16; ++j) {
+ if (inverse) {
+ result[i][j] = {IDFT_MATRIX[i][j].x, IDFT_MATRIX[i][j].y};
+ } else {
+ result[i][j] = {DFT_MATRIX[i][j].x, DFT_MATRIX[i][j].y};
+ }
+ }
+ }
+ return result;
+}
+
+std::vector<gpu_complex> convert_stage_twiddles(int stage, bool inverse = false) {
+ if (stage == 0) { // butterfly_size = 256
+ std::vector<gpu_complex> result(256);
+ for (int i = 0; i < 256; ++i) {
+ if (inverse) {
+ result[i] = {STAGE0_TWIDDLES_INV[i].x, STAGE0_TWIDDLES_INV[i].y};
+ } else {
+ result[i] = {STAGE0_TWIDDLES[i].x, STAGE0_TWIDDLES[i].y};
+ }
+ }
+ return result;
+ } else { // stage == 1, butterfly_size = 16
+ std::vector<gpu_complex> result(16);
+ for (int i = 0; i < 16; ++i) {
+ if (inverse) {
+ result[i] = {STAGE1_TWIDDLES_INV[i].x, STAGE1_TWIDDLES_INV[i].y};
+ } else {
+ result[i] = {STAGE1_TWIDDLES[i].x, STAGE1_TWIDDLES[i].y};
+ }
+ }
+ return result;
+ }
+}
+
+// Wrapper for computed twiddles
+void evaluate_stages_computed(
const int n,
const int radix,
const bool inverse,
std::vector<stage_texture>& textures) {
- const std::vector<std::vector<gpu_complex>> twiddle_factors =
+
+ const std::vector<std::vector<gpu_complex>> dft_matrix =
compute_twiddle_factors(radix, inverse);
+ // Precompute all stage twiddles
const int num_stages_per_dim = std::log2(n) / std::log2(radix);
- const int num_stages = num_stages_per_dim * 2;
+ std::vector<std::vector<gpu_complex>> stage_twiddles_array(num_stages_per_dim);
- for (int stage = 0; stage < num_stages; ++stage) {
- // Compute span and butterfly_size for this stage
- int current_stage = (stage < num_stages_per_dim) ? stage : (stage - num_stages_per_dim);
- int span = n / int_pow(radix, current_stage + 1);
+ for (int stage = 0; stage < num_stages_per_dim; ++stage) {
+ int span = n / int_pow(radix, stage + 1);
int butterfly_size = span * radix;
+ stage_twiddles_array[stage] = compute_stage_twiddles(butterfly_size, inverse);
+ }
- // Precompute stage twiddle factors
- std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size, inverse);
+ evaluate_stages(n, radix, inverse, textures, dft_matrix, stage_twiddles_array);
+}
- ShaderUniforms uniforms = {n, radix, stage, num_stages_per_dim, span, butterfly_size,
- inverse, twiddle_factors, stage_twiddles};
- evaluate_stage(uniforms, textures[stage], textures[stage+1]);
+// Wrapper for precomputed twiddles
+void evaluate_stages_precomputed(
+ const int n,
+ const int radix,
+ const bool inverse,
+ std::vector<stage_texture>& textures,
+ const std::vector<std::vector<gpu_complex>>& precomputed_dft_matrix) {
+
+ // Convert precomputed stage twiddles
+ const int num_stages_per_dim = std::log2(n) / std::log2(radix);
+ std::vector<std::vector<gpu_complex>> stage_twiddles_array(num_stages_per_dim);
+
+ for (int stage = 0; stage < num_stages_per_dim; ++stage) {
+ stage_twiddles_array[stage] = convert_stage_twiddles(stage, inverse);
}
- // Apply bit reversal once at the end
- stage_texture temp = textures[num_stages];
- apply_2d_bit_reversal(n, radix, temp, textures[num_stages]);
+ evaluate_stages(n, radix, inverse, textures, precomputed_dft_matrix, stage_twiddles_array);
+}
- // For inverse FFT, normalize by 1/(n*n)
- if (inverse) {
- float norm_factor = 1.0f / (n * n);
- for (int y = 0; y < n; ++y) {
- for (int x = 0; x < n; ++x) {
- textures[num_stages][y][x].first *= norm_factor;
- textures[num_stages][y][x].second *= norm_factor;
+// Verify FFT results match between computed and precomputed tables
+bool verify_fft_with_tables(std::mt19937& rng) {
+ const int n = 256;
+ const int radix = 16;
+ const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2;
+
+ // Initialize test data
+ const std::vector<std::vector<gpu_complex>> black_texture(n,
+ std::vector<gpu_complex>(n, {0, 0}));
+ std::vector<stage_texture> textures_computed(NUM_STAGES + 1, black_texture);
+ std::vector<stage_texture> textures_precomputed(NUM_STAGES + 1, black_texture);
+
+ // Fill with identical random data
+ std::uniform_real_distribution<float> dist(0.0f, 1.0f);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ float real = dist(rng);
+ float imag = dist(rng);
+ textures_computed[0][y][x] = {real, imag};
+ textures_precomputed[0][y][x] = {real, imag};
+ }
+ }
+
+ // Run FFT with computed tables
+ evaluate_stages_computed(n, radix, false, textures_computed);
+
+ // Run FFT with precomputed tables from cginc
+ auto dft_matrix = convert_dft_matrix(false);
+ evaluate_stages_precomputed(n, radix, false, textures_precomputed, dft_matrix);
+
+ // Compare results
+ float max_error = 0.0f;
+ int mismatch_count = 0;
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ float err_real = std::abs(textures_computed[NUM_STAGES][y][x].first -
+ textures_precomputed[NUM_STAGES][y][x].first);
+ float err_imag = std::abs(textures_computed[NUM_STAGES][y][x].second -
+ textures_precomputed[NUM_STAGES][y][x].second);
+ max_error = std::max(max_error, std::max(err_real, err_imag));
+ if (err_real > 1e-6f || err_imag > 1e-6f) {
+ mismatch_count++;
}
}
}
+
+ std::cout << "FFT max error between computed and precomputed tables: "
+ << std::scientific << max_error << std::fixed << std::endl;
+
+ if (mismatch_count > 0) {
+ std::cout << "ERROR: " << mismatch_count << " pixels have error > 1e-6" << std::endl;
+ return false;
+ }
+
+ return true;
}
bool check_result(
@@ -361,7 +506,7 @@ bool evaluateAlgorithm(const int n, const int radix, const bool inverse, std::mt
// Evaluate the GPU algorithm.
auto start = std::chrono::high_resolution_clock::now();
- evaluate_stages(n, radix, inverse, textures);
+ evaluate_stages_computed(n, radix, inverse, textures);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
@@ -418,13 +563,13 @@ bool testFFTInverse(const int n, const int radix, std::mt19937& rng) {
stage_texture original_input = textures_fft[0];
// Perform forward FFT
- evaluate_stages(n, radix, false, textures_fft);
+ evaluate_stages_computed(n, radix, false, textures_fft);
// Copy FFT result as input to inverse FFT
textures_ifft[0] = textures_fft[NUM_STAGES];
// Perform inverse FFT
- evaluate_stages(n, radix, true, textures_ifft);
+ evaluate_stages_computed(n, radix, true, textures_ifft);
// Check if inverse FFT gives back the original input
float max_error = 0.0f;
@@ -436,7 +581,7 @@ bool testFFTInverse(const int n, const int radix, std::mt19937& rng) {
}
}
- const float epsilon = 1e-3; // Tolerance for round-trip error
+ const float epsilon = 1e-5; // Tolerance for round-trip error
bool success = (max_error < epsilon);
if (!success) {
@@ -465,6 +610,14 @@ bool testFFTInverse(const int n, const int radix, std::mt19937& rng) {
int main() {
std::mt19937 rng(std::random_device{}());
+ // Verify FFT results match with precomputed tables
+ std::cout << "Verifying FFT with precomputed tables from fft_twiddle_tables.cginc..." << std::endl;
+ if (!verify_fft_with_tables(rng)) {
+ std::cout << "ERROR: FFT results do not match between computed and precomputed tables!" << std::endl;
+ return 1;
+ }
+ std::cout << "FFT verification passed!\n" << std::endl;
+
// First run the original forward FFT tests
std::cout << "Testing forward FFT correctness against reference implementation..." << std::endl;
for (int log_radix = 1; log_radix < 5; ++log_radix) {
diff --git a/mandrill_256x256.exr b/mandrill_256x256.exr
new file mode 100644
index 0000000..01c4a3a
--- /dev/null
+++ b/mandrill_256x256.exr
Binary files differ
diff --git a/mandrill_256x256.exr.meta b/mandrill_256x256.exr.meta
new file mode 100644
index 0000000..05bf3cf
--- /dev/null
+++ b/mandrill_256x256.exr.meta
@@ -0,0 +1,127 @@
+fileFormatVersion: 2
+guid: d3b73aebac11d354bb702321288d9099
+TextureImporter:
+ internalIDToNameTable: []
+ externalObjects: {}
+ serializedVersion: 12
+ mipmaps:
+ mipMapMode: 0
+ enableMipMap: 1
+ sRGBTexture: 1
+ linearTexture: 0
+ fadeOut: 0
+ borderMipMap: 0
+ mipMapsPreserveCoverage: 0
+ alphaTestReferenceValue: 0.5
+ mipMapFadeDistanceStart: 1
+ mipMapFadeDistanceEnd: 3
+ bumpmap:
+ convertToNormalMap: 0
+ externalNormalMap: 0
+ heightScale: 0.25
+ normalMapFilter: 0
+ flipGreenChannel: 0
+ isReadable: 0
+ streamingMipmaps: 0
+ streamingMipmapsPriority: 0
+ vTOnly: 0
+ ignoreMipmapLimit: 0
+ grayScaleToAlpha: 0
+ generateCubemap: 6
+ cubemapConvolution: 0
+ seamlessCubemap: 0
+ textureFormat: 1
+ maxTextureSize: 2048
+ textureSettings:
+ serializedVersion: 2
+ filterMode: 1
+ aniso: 1
+ mipBias: 0
+ wrapU: 0
+ wrapV: 0
+ wrapW: 0
+ nPOTScale: 1
+ lightmap: 0
+ compressionQuality: 50
+ spriteMode: 0
+ spriteExtrude: 1
+ spriteMeshType: 1
+ alignment: 0
+ spritePivot: {x: 0.5, y: 0.5}
+ spritePixelsToUnits: 100
+ spriteBorder: {x: 0, y: 0, z: 0, w: 0}
+ spriteGenerateFallbackPhysicsShape: 1
+ alphaUsage: 1
+ alphaIsTransparency: 0
+ spriteTessellationDetail: -1
+ textureType: 0
+ textureShape: 1
+ singleChannelComponent: 0
+ flipbookRows: 1
+ flipbookColumns: 1
+ maxTextureSizeSet: 0
+ compressionQualitySet: 0
+ textureFormatSet: 0
+ ignorePngGamma: 0
+ applyGammaDecoding: 0
+ swizzle: 50462976
+ cookieLightType: 0
+ platformSettings:
+ - serializedVersion: 3
+ buildTarget: DefaultTexturePlatform
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ - serializedVersion: 3
+ buildTarget: Standalone
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ - serializedVersion: 3
+ buildTarget: Android
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ spriteSheet:
+ serializedVersion: 2
+ sprites: []
+ outline: []
+ physicsShape: []
+ bones: []
+ spriteID:
+ internalID: 0
+ vertices: []
+ indices:
+ edges: []
+ weights: []
+ secondaryTextures: []
+ nameFileIdTable: {}
+ mipmapLimitGroupName:
+ pSDRemoveMatte: 0
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/mandrill_256x256.png b/mandrill_256x256.png
new file mode 100644
index 0000000..a9a8624
--- /dev/null
+++ b/mandrill_256x256.png
Binary files differ
diff --git a/mandrill_256x256.png.meta b/mandrill_256x256.png.meta
new file mode 100644
index 0000000..7219e1c
--- /dev/null
+++ b/mandrill_256x256.png.meta
@@ -0,0 +1,127 @@
+fileFormatVersion: 2
+guid: d0dba03849af9df4fae990e250f66246
+TextureImporter:
+ internalIDToNameTable: []
+ externalObjects: {}
+ serializedVersion: 12
+ mipmaps:
+ mipMapMode: 0
+ enableMipMap: 1
+ sRGBTexture: 1
+ linearTexture: 0
+ fadeOut: 0
+ borderMipMap: 0
+ mipMapsPreserveCoverage: 0
+ alphaTestReferenceValue: 0.5
+ mipMapFadeDistanceStart: 1
+ mipMapFadeDistanceEnd: 3
+ bumpmap:
+ convertToNormalMap: 0
+ externalNormalMap: 0
+ heightScale: 0.25
+ normalMapFilter: 0
+ flipGreenChannel: 0
+ isReadable: 0
+ streamingMipmaps: 0
+ streamingMipmapsPriority: 0
+ vTOnly: 0
+ ignoreMipmapLimit: 0
+ grayScaleToAlpha: 0
+ generateCubemap: 6
+ cubemapConvolution: 0
+ seamlessCubemap: 0
+ textureFormat: 1
+ maxTextureSize: 2048
+ textureSettings:
+ serializedVersion: 2
+ filterMode: 1
+ aniso: 1
+ mipBias: 0
+ wrapU: 0
+ wrapV: 0
+ wrapW: 0
+ nPOTScale: 1
+ lightmap: 0
+ compressionQuality: 50
+ spriteMode: 0
+ spriteExtrude: 1
+ spriteMeshType: 1
+ alignment: 0
+ spritePivot: {x: 0.5, y: 0.5}
+ spritePixelsToUnits: 100
+ spriteBorder: {x: 0, y: 0, z: 0, w: 0}
+ spriteGenerateFallbackPhysicsShape: 1
+ alphaUsage: 1
+ alphaIsTransparency: 0
+ spriteTessellationDetail: -1
+ textureType: 0
+ textureShape: 1
+ singleChannelComponent: 0
+ flipbookRows: 1
+ flipbookColumns: 1
+ maxTextureSizeSet: 0
+ compressionQualitySet: 0
+ textureFormatSet: 0
+ ignorePngGamma: 0
+ applyGammaDecoding: 0
+ swizzle: 50462976
+ cookieLightType: 0
+ platformSettings:
+ - serializedVersion: 3
+ buildTarget: DefaultTexturePlatform
+ maxTextureSize: 256
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 2
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ - serializedVersion: 3
+ buildTarget: Standalone
+ maxTextureSize: 256
+ resizeAlgorithm: 0
+ textureFormat: 3
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 1
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ - serializedVersion: 3
+ buildTarget: Android
+ maxTextureSize: 2048
+ resizeAlgorithm: 0
+ textureFormat: -1
+ textureCompression: 1
+ compressionQuality: 50
+ crunchedCompression: 0
+ allowsAlphaSplitting: 0
+ overridden: 0
+ ignorePlatformSupport: 0
+ androidETC2FallbackOverride: 0
+ forceMaximumCompressionQuality_BC6H_BC7: 0
+ spriteSheet:
+ serializedVersion: 2
+ sprites: []
+ outline: []
+ physicsShape: []
+ bones: []
+ spriteID:
+ internalID: 0
+ vertices: []
+ indices:
+ edges: []
+ weights: []
+ secondaryTextures: []
+ nameFileIdTable: {}
+ mipmapLimitGroupName:
+ pSDRemoveMatte: 0
+ userData:
+ assetBundleName:
+ assetBundleVariant: