diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | fft.shader | 55 | ||||
| -rw-r--r-- | generate_twiddle_tables.py | 58 | ||||
| -rw-r--r-- | gpu_fft.cc | 130 |
4 files changed, 197 insertions, 48 deletions
@@ -1,3 +1,5 @@ build fft_twiddle_tables.cginc +.*.sw[po] +*.meta @@ -7,6 +7,7 @@ Shader "yum_food/fft" _Radix ("Radix", Int) = 16 _Stage ("Stage", Int) = 0 [Toggle] _PassThrough ("Pass Through", Float) = 0 + [Toggle] _Inverse ("Inverse FFT", Float) = 0 } SubShader { @@ -37,15 +38,16 @@ Shader "yum_food/fft" int num_stages_per_dim : TEXCOORD1; int span : TEXCOORD2; int butterfly_size : TEXCOORD3; + int num_stages : TEXCOORD4; }; texture2D _MainTex; - float4 _MainTex_ST; - SamplerState point_repeat_s; + SamplerState point_clamp_s; int _N; int _Radix; int _Stage; float _PassThrough; + float _Inverse; // Helper function to compute integer power int int_pow(int base, int exp) @@ -65,7 +67,7 @@ Shader "yum_food/fft" o.uv = v.uv; // Calculate num_stages_per_dim = log_radix(N) - o.num_stages_per_dim = (int)(log(_N) / log(_Radix) + 0.5); + o.num_stages_per_dim = (int)(log(_N) / log(_Radix)); // Determine current stage (0-based index within row or column passes) int current_stage = (_Stage < o.num_stages_per_dim) ? _Stage : (_Stage - o.num_stages_per_dim); @@ -83,15 +85,21 @@ Shader "yum_food/fft" fixed4 frag (v2f i) : SV_Target { + int2 pixel_index = (int2)(i.uv * _N); + float2 uv = (pixel_index + 0.5f) / _N; + // If pass through is enabled, just return the input if (_PassThrough > 0.5) { - return _MainTex.SampleLevel(point_repeat_s, i.uv, 0); +#if 0 + float lum = luminance(_MainTex.SampleLevel(point_clamp_s, uv, 0).rgb); +#else + float lum = luminance(_MainTex.SampleLevel(point_clamp_s, i.uv, 0).rgb); +#endif + return float4(lum, lum, lum, 1); } - const float n2 = _N * _N; // Calculate pixel index from UV coordinates - int2 pixel_index = int2(floor(i.uv * _N)); int x = pixel_index.x; int y = pixel_index.y; @@ -122,24 +130,23 @@ Shader "yum_food/fft" } else { - input_uv = float2(i.uv.x, (input_pos + 0.5) / (float)_N); + float xuv = (x + 0.5) / _N; + input_uv = float2(xuv, (input_pos + 0.5) / (float)_N); } // Read input value - float4 input_tex = _MainTex.SampleLevel(point_repeat_s, input_uv, 0); + float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0); float2 input_val; if (_Stage == 0) { input_val.x = luminance(input_tex.xyz); input_val.y = 0; } else { - // Remap onto [-1, 1] - input_tex = input_tex * 2.0f - 1.0f; - input_val.x = input_tex.x * n2 + input_tex.y; - input_val.y = input_tex.z * n2 + input_tex.w; + input_val.x = input_tex.x + input_tex.y; + input_val.y = input_tex.z + input_tex.w; } - // Read DFT coefficient from the table - float2 coeff = DFT_MATRIX[wing][j]; + // Read DFT coefficient from the table (use inverse matrix if _Inverse is set) + float2 coeff = _Inverse > 0.5 ? IDFT_MATRIX[wing][j] : DFT_MATRIX[wing][j]; // Complex multiply-accumulate sum.x += coeff.x * input_val.x - coeff.y * input_val.y; @@ -150,7 +157,7 @@ Shader "yum_food/fft" if (wing > 0 && idx_in_wing > 0) { int twiddle_idx = wing * idx_in_wing; - float2 tw = STAGE_TWIDDLES[twiddle_idx]; + float2 tw = _Inverse > 0.5 ? STAGE_TWIDDLES_INV[twiddle_idx] : STAGE_TWIDDLES[twiddle_idx]; // Output = twiddle * sum float2 output; @@ -163,24 +170,18 @@ Shader "yum_food/fft" float real_part = sum.x; float imag_part = sum.y; + if (_Inverse > 0.5 && _Stage == i.num_stages_per_dim * 2 - 1) { + // Last stage of IFFT is just back to the original real-valued signal. + real_part /= _N * i.num_stages_per_dim; + return float4(real_part, real_part, real_part, 1); + } + // Split into 2 parts. float real_big = floor(real_part); float real_small = real_part - real_big; float imag_big = floor(imag_part); float imag_small = imag_part - imag_big; - // Compress onto [-1,1]. - // For an N*N FFT, the maximum value is N^2 and the min value - // is -N^2 / 2. - real_big /= n2; - imag_big /= n2; - - // Map onto [0, 1]. - real_big = (real_big + 1.0f) * 0.5f; - real_small = (real_small + 1.0f) * 0.5f; - imag_big = (imag_big + 1.0f) * 0.5f; - imag_small = (imag_small + 1.0f) * 0.5f; - return float4(real_big, real_small, imag_big, imag_small); } ENDCG diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py index 69f9da3..152208c 100644 --- a/generate_twiddle_tables.py +++ b/generate_twiddle_tables.py @@ -8,18 +8,19 @@ Uses static branching with preprocessor directives. import math import os -def twiddle(k, N): - """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)""" - angle = -2.0 * math.pi * k / N +def twiddle(k, N, inverse=False): + """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT + or exp(+2*pi*i*k/N) for inverse FFT""" + angle = (2.0 if inverse else -2.0) * math.pi * k / N return complex(math.cos(angle), math.sin(angle)) -def generate_dft_matrix(radix): +def generate_dft_matrix(radix, inverse=False): """Generate DFT matrix for given radix""" matrix = [] for k in range(radix): row = [] for n in range(radix): - row.append(twiddle(k * n, radix)) + row.append(twiddle(k * n, radix, inverse)) matrix.append(row) return matrix @@ -35,7 +36,7 @@ def generate_cginc(radices, max_size, output_file): f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n") f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n") - # Generate DFT matrices for each radix + # Generate DFT matrices for each radix (forward) for radix in radices: f.write(f"#if defined(GPU_FFT_RADIX{radix})\n") f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n") @@ -52,10 +53,25 @@ def generate_cginc(radices, max_size, output_file): f.write(",") f.write("\n") f.write("};\n") + + # Generate inverse DFT matrix + f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n") + matrix_inv = generate_dft_matrix(radix, inverse=True) + for k in range(radix): + f.write(" { ") + for n in range(radix): + f.write(format_complex(matrix_inv[k][n])) + if n < radix - 1: + f.write(", ") + f.write(" }") + if k < radix - 1: + f.write(",") + f.write("\n") + f.write("};\n") f.write("#endif\n\n") - # Generate stage twiddle factor tables - f.write("// Stage twiddle factors\n") + # Generate stage twiddle factor tables (forward) + f.write("// Stage twiddle factors (forward FFT)\n") for radix in radices: butterfly_size = radix stage_idx = 0 @@ -80,6 +96,32 @@ def generate_cginc(radices, max_size, output_file): butterfly_size *= radix stage_idx += 1 + # Generate stage twiddle factor tables (inverse) + f.write("// Stage twiddle factors (inverse FFT)\n") + for radix in radices: + butterfly_size = radix + stage_idx = 0 + + while butterfly_size <= max_size: + f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n") + f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n") + + # Write twiddles in rows of 4 for readability + for i in range(0, butterfly_size, 4): + f.write(" ") + for j in range(4): + if i + j < butterfly_size: + c = twiddle(i + j, butterfly_size, inverse=True) + f.write(format_complex(c)) + if i + j < butterfly_size - 1: + f.write(", ") + f.write("\n") + + f.write("};\n") + f.write("#endif\n\n") + butterfly_size *= radix + stage_idx += 1 + f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n") def main(): @@ -87,6 +87,7 @@ struct ShaderUniforms { int num_stages_per_dim; int span; int butterfly_size; + bool inverse; // This will be baked into a texture. std::vector<std::vector<gpu_complex>> twiddle_factors; // Precomputed stage twiddle factors @@ -106,9 +107,10 @@ unsigned int reverse_digits(unsigned int n, unsigned int num_digits, unsigned in return reversed; } -// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) -gpu_complex twiddle_factor(int k, int N) { - float angle = -2.0f * std::numbers::pi * k / N; +// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT +// or exp(+2*pi*i*k/N) for inverse FFT +gpu_complex twiddle_factor(int k, int N, bool inverse = false) { + float angle = (inverse ? 2.0f : -2.0f) * std::numbers::pi * k / N; return {std::cos(angle), std::sin(angle)}; } @@ -195,12 +197,12 @@ void apply_2d_bit_reversal(const int n, const int radix, const stage_texture& in } // Precompute twiddle factors for a given radix -std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix) { +std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix, bool inverse = false) { std::vector<std::vector<gpu_complex>> twiddle_factors(radix, std::vector<gpu_complex>(radix)); for (int k = 0; k < radix; ++k) { for (int n = 0; n < radix; ++n) { - twiddle_factors[k][n] = twiddle_factor(k * n, radix); + twiddle_factors[k][n] = twiddle_factor(k * n, radix, inverse); } } return twiddle_factors; @@ -216,10 +218,10 @@ int int_pow(int base, int exp) { } // Precompute stage twiddle factors -std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) { +std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size, bool inverse = false) { std::vector<gpu_complex> stage_twiddles(butterfly_size); for (int i = 0; i < butterfly_size; ++i) { - stage_twiddles[i] = twiddle_factor(i, butterfly_size); + stage_twiddles[i] = twiddle_factor(i, butterfly_size, inverse); } return stage_twiddles; } @@ -228,9 +230,10 @@ std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) { void evaluate_stages( const int n, const int radix, + const bool inverse, std::vector<stage_texture>& textures) { const std::vector<std::vector<gpu_complex>> twiddle_factors = - compute_twiddle_factors(radix); + compute_twiddle_factors(radix, inverse); const int num_stages_per_dim = std::log2(n) / std::log2(radix); const int num_stages = num_stages_per_dim * 2; @@ -242,16 +245,27 @@ void evaluate_stages( int butterfly_size = span * radix; // Precompute stage twiddle factors - std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size); + std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size, inverse); ShaderUniforms uniforms = {n, radix, stage, num_stages_per_dim, span, butterfly_size, - twiddle_factors, stage_twiddles}; + inverse, twiddle_factors, stage_twiddles}; evaluate_stage(uniforms, textures[stage], textures[stage+1]); } // Apply bit reversal once at the end stage_texture temp = textures[num_stages]; apply_2d_bit_reversal(n, radix, temp, textures[num_stages]); + + // For inverse FFT, normalize by 1/(n*n) + if (inverse) { + float norm_factor = 1.0f / (n * n); + for (int y = 0; y < n; ++y) { + for (int x = 0; x < n; ++x) { + textures[num_stages][y][x].first *= norm_factor; + textures[num_stages][y][x].second *= norm_factor; + } + } + } } bool check_result( @@ -330,7 +344,7 @@ void print_diagnostics( }); } -bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) { +bool evaluateAlgorithm(const int n, const int radix, const bool inverse, std::mt19937& rng) { const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2; const std::vector<std::vector<gpu_complex>> black_texture(n, @@ -347,7 +361,7 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) { // Evaluate the GPU algorithm. auto start = std::chrono::high_resolution_clock::now(); - evaluate_stages(n, radix, textures); + evaluate_stages(n, radix, inverse, textures); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start); @@ -383,9 +397,76 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) { return false; } +// Test FFT followed by inverse FFT +bool testFFTInverse(const int n, const int radix, std::mt19937& rng) { + const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2; + + const std::vector<std::vector<gpu_complex>> black_texture(n, + std::vector<gpu_complex>(n, {0, 0})); + std::vector<stage_texture> textures_fft(NUM_STAGES + 1, black_texture); + std::vector<stage_texture> textures_ifft(NUM_STAGES + 1, black_texture); + + // Initialize with random data + std::uniform_real_distribution<float> dist(0.0f, 1.0f); + for (int y = 0; y < n; ++y) { + for (int x = 0; x < n; ++x) { + textures_fft[0][y][x] = {dist(rng), dist(rng)}; + } + } + + // Save original input + stage_texture original_input = textures_fft[0]; + + // Perform forward FFT + evaluate_stages(n, radix, false, textures_fft); + + // Copy FFT result as input to inverse FFT + textures_ifft[0] = textures_fft[NUM_STAGES]; + + // Perform inverse FFT + evaluate_stages(n, radix, true, textures_ifft); + + // Check if inverse FFT gives back the original input + float max_error = 0.0f; + for (int y = 0; y < n; ++y) { + for (int x = 0; x < n; ++x) { + float err_real = std::abs(textures_ifft[NUM_STAGES][y][x].first - original_input[y][x].first); + float err_imag = std::abs(textures_ifft[NUM_STAGES][y][x].second - original_input[y][x].second); + max_error = std::max(max_error, std::max(err_real, err_imag)); + } + } + + const float epsilon = 1e-3; // Tolerance for round-trip error + bool success = (max_error < epsilon); + + if (!success) { + std::cout << "FFT->IFFT round-trip test FAILED. Max error: " << max_error << std::endl; + + // Print some diagnostics + std::cout << "\nFirst 4x4 block comparison:" << std::endl; + std::cout << "Original vs Reconstructed (real parts):" << std::endl; + for (int y = 0; y < std::min(4, n); ++y) { + for (int x = 0; x < std::min(4, n); ++x) { + std::cout << std::fixed << std::setprecision(3) + << original_input[y][x].first << " "; + } + std::cout << " | "; + for (int x = 0; x < std::min(4, n); ++x) { + std::cout << std::fixed << std::setprecision(3) + << textures_ifft[NUM_STAGES][y][x].first << " "; + } + std::cout << std::endl; + } + } + + return success; +} + int main() { std::mt19937 rng(std::random_device{}()); + // First run the original forward FFT tests + std::cout << "Testing forward FFT correctness against reference implementation..." << std::endl; for (int log_radix = 1; log_radix < 5; ++log_radix) { int radix = std::pow(2, log_radix); for (int log_n = 1; log_n < 12; ++log_n) { @@ -394,11 +475,34 @@ int main() { break; } std::cout << "Testing radix=" << radix << " n=" << n << std::endl; - if (!evaluateAlgorithm(n, radix, rng)) { + if (!evaluateAlgorithm(n, radix, false, rng)) { + return 1; + } + } + } + + std::cout << "\nAll forward FFT tests passed!" << std::endl; + + // Now run the FFT->IFFT round-trip tests + std::cout << "\nTesting FFT->IFFT round-trip correctness..." << std::endl; + + for (int log_radix = 1; log_radix < 5; ++log_radix) { + int radix = std::pow(2, log_radix); + for (int log_n = 1; log_n < 10; ++log_n) { + int n = std::pow(radix, log_n); + if (n > 512) { + break; + } + std::cout << "Testing radix=" << radix << " n=" << n << " ... "; + if (testFFTInverse(n, radix, rng)) { + std::cout << "PASSED" << std::endl; + } else { + std::cout << "FAILED" << std::endl; return 1; } } } + std::cout << "\nAll FFT->IFFT tests passed!" << std::endl; return 0; } |
