summaryrefslogtreecommitdiffstats
path: root/gpu_fft.cc
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-27 19:24:44 -0700
committeryum <yum.food.vr@gmail.com>2025-07-27 19:24:44 -0700
commitbe1184abe8fee3afe89286dba0e4924270ee29cc (patch)
treed507638e9f49190a745ccdd3f149ca9badb05cec /gpu_fft.cc
parent1b918bf7f834735669a309d537e0f8d0794c05eb (diff)
Add inverse fft code
Diffstat (limited to 'gpu_fft.cc')
-rw-r--r--gpu_fft.cc130
1 files changed, 117 insertions, 13 deletions
diff --git a/gpu_fft.cc b/gpu_fft.cc
index 4a7723e..f5dd97e 100644
--- a/gpu_fft.cc
+++ b/gpu_fft.cc
@@ -87,6 +87,7 @@ struct ShaderUniforms {
int num_stages_per_dim;
int span;
int butterfly_size;
+ bool inverse;
// This will be baked into a texture.
std::vector<std::vector<gpu_complex>> twiddle_factors;
// Precomputed stage twiddle factors
@@ -106,9 +107,10 @@ unsigned int reverse_digits(unsigned int n, unsigned int num_digits, unsigned in
return reversed;
}
-// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N)
-gpu_complex twiddle_factor(int k, int N) {
- float angle = -2.0f * std::numbers::pi * k / N;
+// Compute twiddle factor W_N^k = exp(-2*pi*i*k/N) for forward FFT
+// or exp(+2*pi*i*k/N) for inverse FFT
+gpu_complex twiddle_factor(int k, int N, bool inverse = false) {
+ float angle = (inverse ? 2.0f : -2.0f) * std::numbers::pi * k / N;
return {std::cos(angle), std::sin(angle)};
}
@@ -195,12 +197,12 @@ void apply_2d_bit_reversal(const int n, const int radix, const stage_texture& in
}
// Precompute twiddle factors for a given radix
-std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix) {
+std::vector<std::vector<gpu_complex>> compute_twiddle_factors(int radix, bool inverse = false) {
std::vector<std::vector<gpu_complex>> twiddle_factors(radix,
std::vector<gpu_complex>(radix));
for (int k = 0; k < radix; ++k) {
for (int n = 0; n < radix; ++n) {
- twiddle_factors[k][n] = twiddle_factor(k * n, radix);
+ twiddle_factors[k][n] = twiddle_factor(k * n, radix, inverse);
}
}
return twiddle_factors;
@@ -216,10 +218,10 @@ int int_pow(int base, int exp) {
}
// Precompute stage twiddle factors
-std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) {
+std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size, bool inverse = false) {
std::vector<gpu_complex> stage_twiddles(butterfly_size);
for (int i = 0; i < butterfly_size; ++i) {
- stage_twiddles[i] = twiddle_factor(i, butterfly_size);
+ stage_twiddles[i] = twiddle_factor(i, butterfly_size, inverse);
}
return stage_twiddles;
}
@@ -228,9 +230,10 @@ std::vector<gpu_complex> compute_stage_twiddles(int butterfly_size) {
void evaluate_stages(
const int n,
const int radix,
+ const bool inverse,
std::vector<stage_texture>& textures) {
const std::vector<std::vector<gpu_complex>> twiddle_factors =
- compute_twiddle_factors(radix);
+ compute_twiddle_factors(radix, inverse);
const int num_stages_per_dim = std::log2(n) / std::log2(radix);
const int num_stages = num_stages_per_dim * 2;
@@ -242,16 +245,27 @@ void evaluate_stages(
int butterfly_size = span * radix;
// Precompute stage twiddle factors
- std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size);
+ std::vector<gpu_complex> stage_twiddles = compute_stage_twiddles(butterfly_size, inverse);
ShaderUniforms uniforms = {n, radix, stage, num_stages_per_dim, span, butterfly_size,
- twiddle_factors, stage_twiddles};
+ inverse, twiddle_factors, stage_twiddles};
evaluate_stage(uniforms, textures[stage], textures[stage+1]);
}
// Apply bit reversal once at the end
stage_texture temp = textures[num_stages];
apply_2d_bit_reversal(n, radix, temp, textures[num_stages]);
+
+ // For inverse FFT, normalize by 1/(n*n)
+ if (inverse) {
+ float norm_factor = 1.0f / (n * n);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ textures[num_stages][y][x].first *= norm_factor;
+ textures[num_stages][y][x].second *= norm_factor;
+ }
+ }
+ }
}
bool check_result(
@@ -330,7 +344,7 @@ void print_diagnostics(
});
}
-bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
+bool evaluateAlgorithm(const int n, const int radix, const bool inverse, std::mt19937& rng) {
const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2;
const std::vector<std::vector<gpu_complex>> black_texture(n,
@@ -347,7 +361,7 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
// Evaluate the GPU algorithm.
auto start = std::chrono::high_resolution_clock::now();
- evaluate_stages(n, radix, textures);
+ evaluate_stages(n, radix, inverse, textures);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
@@ -383,9 +397,76 @@ bool evaluateAlgorithm(const int n, const int radix, std::mt19937& rng) {
return false;
}
+// Test FFT followed by inverse FFT
+bool testFFTInverse(const int n, const int radix, std::mt19937& rng) {
+ const int NUM_STAGES = (std::log2(n) / std::log2(radix)) * 2;
+
+ const std::vector<std::vector<gpu_complex>> black_texture(n,
+ std::vector<gpu_complex>(n, {0, 0}));
+ std::vector<stage_texture> textures_fft(NUM_STAGES + 1, black_texture);
+ std::vector<stage_texture> textures_ifft(NUM_STAGES + 1, black_texture);
+
+ // Initialize with random data
+ std::uniform_real_distribution<float> dist(0.0f, 1.0f);
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ textures_fft[0][y][x] = {dist(rng), dist(rng)};
+ }
+ }
+
+ // Save original input
+ stage_texture original_input = textures_fft[0];
+
+ // Perform forward FFT
+ evaluate_stages(n, radix, false, textures_fft);
+
+ // Copy FFT result as input to inverse FFT
+ textures_ifft[0] = textures_fft[NUM_STAGES];
+
+ // Perform inverse FFT
+ evaluate_stages(n, radix, true, textures_ifft);
+
+ // Check if inverse FFT gives back the original input
+ float max_error = 0.0f;
+ for (int y = 0; y < n; ++y) {
+ for (int x = 0; x < n; ++x) {
+ float err_real = std::abs(textures_ifft[NUM_STAGES][y][x].first - original_input[y][x].first);
+ float err_imag = std::abs(textures_ifft[NUM_STAGES][y][x].second - original_input[y][x].second);
+ max_error = std::max(max_error, std::max(err_real, err_imag));
+ }
+ }
+
+ const float epsilon = 1e-3; // Tolerance for round-trip error
+ bool success = (max_error < epsilon);
+
+ if (!success) {
+ std::cout << "FFT->IFFT round-trip test FAILED. Max error: " << max_error << std::endl;
+
+ // Print some diagnostics
+ std::cout << "\nFirst 4x4 block comparison:" << std::endl;
+ std::cout << "Original vs Reconstructed (real parts):" << std::endl;
+ for (int y = 0; y < std::min(4, n); ++y) {
+ for (int x = 0; x < std::min(4, n); ++x) {
+ std::cout << std::fixed << std::setprecision(3)
+ << original_input[y][x].first << " ";
+ }
+ std::cout << " | ";
+ for (int x = 0; x < std::min(4, n); ++x) {
+ std::cout << std::fixed << std::setprecision(3)
+ << textures_ifft[NUM_STAGES][y][x].first << " ";
+ }
+ std::cout << std::endl;
+ }
+ }
+
+ return success;
+}
+
int main() {
std::mt19937 rng(std::random_device{}());
+ // First run the original forward FFT tests
+ std::cout << "Testing forward FFT correctness against reference implementation..." << std::endl;
for (int log_radix = 1; log_radix < 5; ++log_radix) {
int radix = std::pow(2, log_radix);
for (int log_n = 1; log_n < 12; ++log_n) {
@@ -394,11 +475,34 @@ int main() {
break;
}
std::cout << "Testing radix=" << radix << " n=" << n << std::endl;
- if (!evaluateAlgorithm(n, radix, rng)) {
+ if (!evaluateAlgorithm(n, radix, false, rng)) {
+ return 1;
+ }
+ }
+ }
+
+ std::cout << "\nAll forward FFT tests passed!" << std::endl;
+
+ // Now run the FFT->IFFT round-trip tests
+ std::cout << "\nTesting FFT->IFFT round-trip correctness..." << std::endl;
+
+ for (int log_radix = 1; log_radix < 5; ++log_radix) {
+ int radix = std::pow(2, log_radix);
+ for (int log_n = 1; log_n < 10; ++log_n) {
+ int n = std::pow(radix, log_n);
+ if (n > 512) {
+ break;
+ }
+ std::cout << "Testing radix=" << radix << " n=" << n << " ... ";
+ if (testFFTInverse(n, radix, rng)) {
+ std::cout << "PASSED" << std::endl;
+ } else {
+ std::cout << "FAILED" << std::endl;
return 1;
}
}
}
+ std::cout << "\nAll FFT->IFFT tests passed!" << std::endl;
return 0;
}