summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--fft.shader55
-rw-r--r--generate_twiddle_tables.py58
-rw-r--r--gpu_fft.cc130
4 files changed, 197 insertions, 48 deletions
diff --git a/.gitignore b/.gitignore
index 2545bfa..1f02462 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
build
fft_twiddle_tables.cginc
+.*.sw[po]
+*.meta
diff --git a/fft.shader b/fft.shader
index 3ce9b61..97d5bec 100644
--- a/fft.shader
+++ b/fft.shader
@@ -7,6 +7,7 @@ Shader "yum_food/fft"
_Radix ("Radix", Int) = 16
_Stage ("Stage", Int) = 0
[Toggle] _PassThrough ("Pass Through", Float) = 0
+ [Toggle] _Inverse ("Inverse FFT", Float) = 0
}
SubShader
{
@@ -37,15 +38,16 @@ Shader "yum_food/fft"
int num_stages_per_dim : TEXCOORD1;
int span : TEXCOORD2;
int butterfly_size : TEXCOORD3;
+ int num_stages : TEXCOORD4;
};
texture2D _MainTex;
- float4 _MainTex_ST;
- SamplerState point_repeat_s;
+ SamplerState point_clamp_s;
int _N;
int _Radix;
int _Stage;
float _PassThrough;
+ float _Inverse;
// Helper function to compute integer power
int int_pow(int base, int exp)
@@ -65,7 +67,7 @@ Shader "yum_food/fft"
o.uv = v.uv;
// Calculate num_stages_per_dim = log_radix(N)
- o.num_stages_per_dim = (int)(log(_N) / log(_Radix) + 0.5);
+ o.num_stages_per_dim = (int)(log(_N) / log(_Radix));
// Determine current stage (0-based index within row or column passes)
int current_stage = (_Stage < o.num_stages_per_dim) ? _Stage : (_Stage - o.num_stages_per_dim);
@@ -83,15 +85,21 @@ Shader "yum_food/fft"
fixed4 frag (v2f i) : SV_Target
{
+ int2 pixel_index = (int2)(i.uv * _N);
+ float2 uv = (pixel_index + 0.5f) / _N;
+
// If pass through is enabled, just return the input
if (_PassThrough > 0.5)
{
- return _MainTex.SampleLevel(point_repeat_s, i.uv, 0);
+#if 0
+ float lum = luminance(_MainTex.SampleLevel(point_clamp_s, uv, 0).rgb);
+#else
+ float lum = luminance(_MainTex.SampleLevel(point_clamp_s, i.uv, 0).rgb);
+#endif
+ return float4(lum, lum, lum, 1);
}
- const float n2 = _N * _N;
// Calculate pixel index from UV coordinates
- int2 pixel_index = int2(floor(i.uv * _N));
int x = pixel_index.x;
int y = pixel_index.y;
@@ -122,24 +130,23 @@ Shader "yum_food/fft"
}
else
{
- input_uv = float2(i.uv.x, (input_pos + 0.5) / (float)_N);
+ float xuv = (x + 0.5) / _N;
+ input_uv = float2(xuv, (input_pos + 0.5) / (float)_N);
}
// Read input value
- float4 input_tex = _MainTex.SampleLevel(point_repeat_s, input_uv, 0);
+ float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
float2 input_val;
if (_Stage == 0) {
input_val.x = luminance(input_tex.xyz);
input_val.y = 0;
} else {
- // Remap onto [-1, 1]
- input_tex = input_tex * 2.0f - 1.0f;
- input_val.x = input_tex.x * n2 + input_tex.y;
- input_val.y = input_tex.z * n2 + input_tex.w;
+ input_val.x = input_tex.x + input_tex.y;
+ input_val.y = input_tex.z + input_tex.w;
}
- // Read DFT coefficient from the table
- float2 coeff = DFT_MATRIX[wing][j];
+ // Read DFT coefficient from the table (use inverse matrix if _Inverse is set)
+ float2 coeff = _Inverse > 0.5 ? IDFT_MATRIX[wing][j] : DFT_MATRIX[wing][j];
// Complex multiply-accumulate
sum.x += coeff.x * input_val.x - coeff.y * input_val.y;
@@ -150,7 +157,7 @@ Shader "yum_food/fft"
if (wing > 0 && idx_in_wing > 0)
{
int twiddle_idx = wing * idx_in_wing;
- float2 tw = STAGE_TWIDDLES[twiddle_idx];
+ float2 tw = _Inverse > 0.5 ? STAGE_TWIDDLES_INV[twiddle_idx] : STAGE_TWIDDLES[twiddle_idx];
// Output = twiddle * sum
float2 output;
@@ -163,24 +170,18 @@ Shader "yum_food/fft"
float real_part = sum.x;
float imag_part = sum.y;
+ if (_Inverse > 0.5 && _Stage == i.num_stages_per_dim * 2 - 1) {
+ // Last stage of IFFT is just back to the original real-valued signal.
+ real_part /= _N * i.num_stages_per_dim;
+ return float4(real_part, real_part, real_part, 1);
+ }
+
// Split into 2 parts.
float real_big = floor(real_part);
float real_small = real_part - real_big;
float imag_big = floor(imag_part);
float imag_small = imag_part - imag_big;
- // Compress onto [-1,1].
- // For an N*N FFT, the maximum value is N^2 and the min value
- // is -N^2 / 2.
- real_big /= n2;
- imag_big /= n2;
-
- // Map onto [0, 1].
- real_big = (real_big + 1.0f) * 0.5f;
- real_small = (real_small + 1.0f) * 0.5f;
- imag_big = (imag_big + 1.0f) * 0.5f;
- imag_small = (imag_small + 1.0f) * 0.5f;
-
return float4(real_big, real_small, imag_big, imag_small);
}
ENDCG
diff --git a/generate_twiddle_tables.py b/generate_twiddle_tables.py
index 69f9da3..152208c 100644
--- a/generate_twiddle_tables.py
+++ b/generate_twiddle_tables.py
@@ -8,18 +8,19 @@ Uses static branching with preprocessor directives.
import math
import os
-def twiddle(k, N):
- """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)"""
- angle = -2.0 * math.pi * k / N
+def twiddle(k, N, inverse=False):
+ """Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
+ or exp(+2*pi*i*k/N) for inverse FFT"""
+ angle = (2.0 if inverse else -2.0) * math.pi * k / N
return complex(math.cos(angle), math.sin(angle))
-def generate_dft_matrix(radix):
+def generate_dft_matrix(radix, inverse=False):
"""Generate DFT matrix for given radix"""
matrix = []
for k in range(radix):
row = []
for n in range(radix):
- row.append(twiddle(k * n, radix))
+ row.append(twiddle(k * n, radix, inverse))
matrix.append(row)
return matrix
@@ -35,7 +36,7 @@ def generate_cginc(radices, max_size, output_file):
f.write("#ifndef FFT_TWIDDLE_TABLES_CGINC\n")
f.write("#define FFT_TWIDDLE_TABLES_CGINC\n\n")
- # Generate DFT matrices for each radix
+ # Generate DFT matrices for each radix (forward)
for radix in radices:
f.write(f"#if defined(GPU_FFT_RADIX{radix})\n")
f.write(f"static const float2 DFT_MATRIX[{radix}][{radix}] = {{\n")
@@ -52,10 +53,25 @@ def generate_cginc(radices, max_size, output_file):
f.write(",")
f.write("\n")
f.write("};\n")
+
+ # Generate inverse DFT matrix
+ f.write(f"static const float2 IDFT_MATRIX[{radix}][{radix}] = {{\n")
+ matrix_inv = generate_dft_matrix(radix, inverse=True)
+ for k in range(radix):
+ f.write(" { ")
+ for n in range(radix):
+ f.write(format_complex(matrix_inv[k][n]))
+ if n < radix - 1:
+ f.write(", ")
+ f.write(" }")
+ if k < radix - 1:
+ f.write(",")
+ f.write("\n")
+ f.write("};\n")
f.write("#endif\n\n")
- # Generate stage twiddle factor tables
- f.write("// Stage twiddle factors\n")
+ # Generate stage twiddle factor tables (forward)
+ f.write("// Stage twiddle factors (forward FFT)\n")
for radix in radices:
butterfly_size = radix
stage_idx = 0
@@ -80,6 +96,32 @@ def generate_cginc(radices, max_size, output_file):
butterfly_size *= radix
stage_idx += 1
+ # Generate stage twiddle factor tables (inverse)
+ f.write("// Stage twiddle factors (inverse FFT)\n")
+ for radix in radices:
+ butterfly_size = radix
+ stage_idx = 0
+
+ while butterfly_size <= max_size:
+ f.write(f"#if defined(GPU_FFT_RADIX{radix}_SIZE{butterfly_size})\n")
+ f.write(f"static const float2 STAGE_TWIDDLES_INV[{butterfly_size}] = {{\n")
+
+ # Write twiddles in rows of 4 for readability
+ for i in range(0, butterfly_size, 4):
+ f.write(" ")
+ for j in range(4):
+ if i + j < butterfly_size:
+ c = twiddle(i + j, butterfly_size, inverse=True)
+ f.write(format_complex(c))
+ if i + j < butterfly_size - 1:
+ f.write(", ")
+ f.write("\n")
+
+ f.write("};\n")
+ f.write("#endif\n\n")
+ butterfly_size *= radix
+ stage_idx += 1
+
f.write("#endif // FFT_TWIDDLE_TABLES_CGINC\n\n")
def main():
diff --git a/gpu_fft.cc b/gpu_fft.cc
index 4a7723e..f5dd97e 100644
--- a/gpu_fft.cc
+++ b/gpu_fft.cc
@@ -87,6 +87,7 @@ struct ShaderUniforms {
int num_stages_per_dim;
int span;
int butterfly_size;
+ bool inverse;
// This will be baked into a texture.
std::vector<std::vector<gpu_complex>> twiddle_factors;
// Precomputed stage twiddle factors
@@ -106,9 +107,10 @@ unsigned int reverse_digits(unsigned int n, unsigned int num_digits, unsigned in
return reversed;
}
-// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)
-gpu_complex twiddle_factor(int k, int N) {
- float angle = -2.0f * std::numbers::pi * k / N;
+// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
+// or exp(+2*pi*i*k/N) for inverse FFT
+gpu_complex twiddle_factor(int k, int N, bool inverse = false) {
+ float angle = (inverse ? 2.0f : -2.0f) * std::numbers::pi * k / N;
return {std::cos(angle), std::sin(angle)};
}
@@ -195,12 +197,12 @@ void apply_2d_bit_reversal(const int n, const int radix, const stage_texture& in
}
// Precompute twiddle factors for a given radix
-std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix) {
+std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix, bool inverse = false) {
std::vector<std::vector<gpu_complex>> twiddle_factors(radix,
std::vector<gpu_complex>(radix));
for (int k = 0; k < radix; ++k) {
for (int n = 0; n < radix; ++n) {
- twiddle_factors[k][n] = twiddle_factor(k * n, radix);
+ twiddle_factors[k][n] = twiddle_factor(k * n, radix, inverse);
}
}
return twiddle_factors;
@@ -216,10 +218,10 @@ int int_pow(int base, int exp) {
}
// Precompute stage twiddle factors
-std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) {
+std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size, bool inverse = false) {
std::vector<gpu_complex> stage_twiddles(butterfly_size);
for (int i = 0; i < butterfly_size; ++i) {
- stage_twiddles[i] = twiddle_factor(i, butterfly_size);
+ stage_twiddles[i] = twiddle_factor(i, butterfly_size, inverse);
}
return stage_twiddles;
}
@@ -228,9 +230,10 @@ std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) {
void evaluate_stages(
const int n,
const int radix,
+ const bool inverse,
std::vector<stage_texture>& textures) {
const std::vector<std::vector<gpu_complex>> twiddle_factors =
- compute_twiddle_factors(radix);
+ compute_twiddle_factors(radix, inverse);
const int num_stages_per_dim = std::log2(n) / std::log2(radix);
const int num_stages = num_stages_per_dim * 2;
@@ -242,16 +245,27 @@ void evaluate_stages(
int butterfly_size = span * radix;
// Precompute stage twiddle factors
- std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size);
+ std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size, inverse);
ShaderUniforms uniforms = {n, radix, stage, num_stages_per_dim, span, butterfly_size,
- twiddle_factors, stage_twiddles};
+ inverse, twiddle_factors, stage_twiddles};
evaluate_stage(uniforms, textures[stage], textures[stage+1]);
}
// Apply bit reversal once at the end
stage_texture temp = textures[num_stages];
apply_2d_bit_reversal(n, radix, temp, textures[num_stages]);
+
+ // For inverse FFT, normalize by 1/(n*n)
+ if (inverse) {
+ float norm_factor = 1.0f / (n * n);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ textures[num_stages][y][x].first *= norm_factor;
+ textures[num_stages][y][x].second *= norm_factor;
+ }
+ }
+ }
}
bool check_result(
@@ -330,7 +344,7 @@ void print_diagnostics(
});
}
-bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
+bool evaluateAlgorithm(const int n, const int radix, const bool inverse, std::mt19937& rng) {
const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2;
const std::vector<std::vector<gpu_complex>> black_texture(n,
@@ -347,7 +361,7 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
// Evaluate the GPU algorithm.
auto start = std::chrono::high_resolution_clock::now();
- evaluate_stages(n, radix, textures);
+ evaluate_stages(n, radix, inverse, textures);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
@@ -383,9 +397,76 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
return false;
}
+// Test FFT followed by inverse FFT
+bool testFFTInverse(const int n, const int radix, std::mt19937& rng) {
+ const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2;
+
+ const std::vector<std::vector<gpu_complex>> black_texture(n,
+ std::vector<gpu_complex>(n, {0, 0}));
+ std::vector<stage_texture> textures_fft(NUM_STAGES + 1, black_texture);
+ std::vector<stage_texture> textures_ifft(NUM_STAGES + 1, black_texture);
+
+ // Initialize with random data
+ std::uniform_real_distribution<float> dist(0.0f, 1.0f);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ textures_fft[0][y][x] = {dist(rng), dist(rng)};
+ }
+ }
+
+ // Save original input
+ stage_texture original_input = textures_fft[0];
+
+ // Perform forward FFT
+ evaluate_stages(n, radix, false, textures_fft);
+
+ // Copy FFT result as input to inverse FFT
+ textures_ifft[0] = textures_fft[NUM_STAGES];
+
+ // Perform inverse FFT
+ evaluate_stages(n, radix, true, textures_ifft);
+
+ // Check if inverse FFT gives back the original input
+ float max_error = 0.0f;
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ float err_real = std::abs(textures_ifft[NUM_STAGES][y][x].first - original_input[y][x].first);
+ float err_imag = std::abs(textures_ifft[NUM_STAGES][y][x].second - original_input[y][x].second);
+ max_error = std::max(max_error, std::max(err_real, err_imag));
+ }
+ }
+
+ const float epsilon = 1e-3; // Tolerance for round-trip error
+ bool success = (max_error < epsilon);
+
+ if (!success) {
+ std::cout << "FFT->IFFT round-trip test FAILED. Max error: " << max_error << std::endl;
+
+ // Print some diagnostics
+ std::cout << "\nFirst 4x4 block comparison:" << std::endl;
+ std::cout << "Original vs Reconstructed (real parts):" << std::endl;
+ for (int y = 0; y < std::min(4, n); ++y) {
+ for (int x = 0; x < std::min(4, n); ++x) {
+ std::cout << std::fixed << std::setprecision(3)
+ << original_input[y][x].first << " ";
+ }
+ std::cout << " | ";
+ for (int x = 0; x < std::min(4, n); ++x) {
+ std::cout << std::fixed << std::setprecision(3)
+ << textures_ifft[NUM_STAGES][y][x].first << " ";
+ }
+ std::cout << std::endl;
+ }
+ }
+
+ return success;
+}
+
int main() {
std::mt19937 rng(std::random_device{}());
+ // First run the original forward FFT tests
+ std::cout << "Testing forward FFT correctness against reference implementation..." << std::endl;
for (int log_radix = 1; log_radix < 5; ++log_radix) {
int radix = std::pow(2, log_radix);
for (int log_n = 1; log_n < 12; ++log_n) {
@@ -394,11 +475,34 @@ int main() {
break;
}
std::cout << "Testing radix=" << radix << " n=" << n << std::endl;
- if (!evaluateAlgorithm(n, radix, rng)) {
+ if (!evaluateAlgorithm(n, radix, false, rng)) {
+ return 1;
+ }
+ }
+ }
+
+ std::cout << "\nAll forward FFT tests passed!" << std::endl;
+
+ // Now run the FFT->IFFT round-trip tests
+ std::cout << "\nTesting FFT->IFFT round-trip correctness..." << std::endl;
+
+ for (int log_radix = 1; log_radix < 5; ++log_radix) {
+ int radix = std::pow(2, log_radix);
+ for (int log_n = 1; log_n < 10; ++log_n) {
+ int n = std::pow(radix, log_n);
+ if (n > 512) {
+ break;
+ }
+ std::cout << "Testing radix=" << radix << " n=" << n << " ... ";
+ if (testFFTInverse(n, radix, rng)) {
+ std::cout << "PASSED" << std::endl;
+ } else {
+ std::cout << "FAILED" << std::endl;
return 1;
}
}
}
+ std::cout << "\nAll FFT->IFFT tests passed!" << std::endl;
return 0;
}