summaryrefslogtreecommitdiffstats
path: root/fft.shader
diff options
context:
space:
mode:
Diffstat (limited to 'fft.shader')
-rw-r--r--fft.shader148
1 files changed, 94 insertions, 54 deletions
diff --git a/fft.shader b/fft.shader
index 7684853..a6f7d28 100644
--- a/fft.shader
+++ b/fft.shader
@@ -6,9 +6,11 @@ Shader "yum_food/fft"
_N ("N", Int) = 256
_Radix ("Radix", Int) = 16
_Stage ("Stage", Int) = 0
- [Toggle] _PassThrough ("Pass Through", Float) = 0
+ [Toggle] _Passthrough ("Pass Through", Float) = 0
[Toggle] _LDS ("Temporal LDS", Float) = 0
+ [Toggle] _Luminance ("Luminance", Float) = 0
[Toggle] _Inverse ("Inverse FFT", Float) = 0
+ [Toggle] _BitReversal ("Bit Reversal Only", Float) = 0
}
SubShader
{
@@ -23,7 +25,7 @@ Shader "yum_food/fft"
#include "UnityCG.cginc"
#define GPU_FFT_RADIX16
- #define GPU_FFT_RADIX16_SIZE256
+ #define GPU_FFT_RADIX16_N256
#include "fft_twiddle_tables.cginc"
struct appdata
@@ -36,7 +38,7 @@ Shader "yum_food/fft"
{
float2 uv : TEXCOORD0;
float4 vertex : SV_POSITION;
- int num_stages_per_dim : TEXCOORD1;
+ uint num_stages_per_dim : TEXCOORD1;
int span : TEXCOORD2;
int butterfly_size : TEXCOORD3;
int num_stages : TEXCOORD4;
@@ -47,9 +49,11 @@ Shader "yum_food/fft"
int _N;
int _Radix;
int _Stage;
- float _PassThrough;
+ float _Passthrough;
float _LDS;
+ float _Luminance;
float _Inverse;
+ float _BitReversal;
#define PHI 1.618033988749894
@@ -64,6 +68,21 @@ Shader "yum_food/fft"
return result;
}
+ // Generalized digit reversal for any radix
+ uint reverse_digits(uint n, uint num_digits, uint radix)
+ {
+ uint bits_per_digit = (uint)(log2(radix));
+ uint digit_mask = radix - 1;
+ uint reversed = 0;
+
+ for (uint i = 0; i < num_digits; i++)
+ {
+ uint digit = (n >> (bits_per_digit * i)) & digit_mask;
+ reversed |= digit << (bits_per_digit * (num_digits - 1 - i));
+ }
+ return reversed;
+ }
+
v2f vert (appdata v)
{
v2f o;
@@ -71,7 +90,7 @@ Shader "yum_food/fft"
o.uv = v.uv;
// Calculate num_stages_per_dim = log_radix(N)
- o.num_stages_per_dim = (int)(log(_N) / log(_Radix));
+ o.num_stages_per_dim = (uint)(log(_N) / log(_Radix));
// Determine current stage (0-based index within row or column passes)
int current_stage = (_Stage < o.num_stages_per_dim) ? _Stage : (_Stage - o.num_stages_per_dim);
@@ -89,37 +108,51 @@ Shader "yum_food/fft"
fixed4 frag (v2f i) : SV_Target
{
+ // Extract coordinates
int2 pixel_index = (int2)(i.uv * _N);
- float2 uv = (pixel_index + 0.5f) / _N;
+ int x = pixel_index.x;
+ int y = pixel_index.y;
- // If pass through is enabled, just return the input
- if (_PassThrough > 0.5)
+ // Bit reversal mode
+ if (_BitReversal > 0.5)
{
+ uint num_digits = i.num_stages_per_dim;
+ uint rev_x = reverse_digits((uint)x, num_digits, (uint)_Radix);
+ uint rev_y = reverse_digits((uint)y, num_digits, (uint)_Radix);
+
+ float2 rev_uv = float2((rev_x + 0.5) / (float) _N, (rev_y + 0.5) / (float) _N);
+ float4 col = _MainTex.SampleLevel(point_clamp_s, rev_uv, 0);
+ return col;
+ }
+
+ // Pass through mode
+ if (_Passthrough > 0.5)
+ {
+ float2 uv = (pixel_index + 0.5f) / _N;
float3 col = _MainTex.SampleLevel(point_clamp_s, uv, 0).rgb;
if (_LDS > 0.5) {
col += PHI * _Time[0];
col = frac(col);
}
- float lum = luminance(col);
- return float4(lum, lum, lum, 1);
+ if (_Luminance > 0.5) {
+ col = luminance(col);
+ }
+ return float4(col, 1);
}
- // Calculate pixel index from UV coordinates
- int x = pixel_index.x;
- int y = pixel_index.y;
-
- // Determine processing direction (row stage or column stage)
+ // Determine processing direction
bool is_row_stage = (_Stage < i.num_stages_per_dim);
int coord = is_row_stage ? x : y;
- // Calculate butterfly indices
+ // Calculate butterfly indices (simple integer math)
int group = coord / i.butterfly_size;
int idx_in_group = coord % i.butterfly_size;
int wing = idx_in_group / i.span;
int idx_in_wing = idx_in_group % i.span;
// Accumulate DFT sum
- float2 sum = float2(0.0, 0.0);
+ float sum_real = 0.0;
+ float sum_imag = 0.0;
// Main DFT loop
for (int j = 0; j < _Radix; j++)
@@ -127,67 +160,74 @@ Shader "yum_food/fft"
// Calculate input position
int input_pos = group * i.butterfly_size + j * i.span + idx_in_wing;
- // Calculate UV for input texture read
- float2 input_uv;
+ // Read input value
+ float in_real, in_imag;
if (is_row_stage)
{
- input_uv = float2((input_pos + 0.5) / (float)_N, i.uv.y);
+ float2 input_uv = float2((input_pos + 0.5) / (float)_N, i.uv.y);
+ float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
+ if (_Stage == 0) {
+ // Assume that input is grayscale and real-valued.
+ in_real = input_tex.x;
+ in_imag = 0;
+ } else {
+ in_real = input_tex.x;
+ in_imag = input_tex.y;
+ }
}
else
{
float xuv = (x + 0.5) / _N;
- input_uv = float2(xuv, (input_pos + 0.5) / (float)_N);
+ float2 input_uv = float2(xuv, (input_pos + 0.5) / (float)_N);
+ float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
+ in_real = input_tex.x;
+ in_imag = input_tex.y;
}
- // Read input value
- float4 input_tex = _MainTex.SampleLevel(point_clamp_s, input_uv, 0);
- float2 input_val;
- if (_Stage == 0) {
- input_val.x = luminance(input_tex.xyz);
- input_val.y = 0;
- } else {
- input_val.x = input_tex.x + input_tex.y;
- input_val.y = input_tex.z + input_tex.w;
- }
-
- // Read DFT coefficient from the table (use inverse matrix if _Inverse is set)
+ // Read DFT coefficient
float2 coeff = _Inverse > 0.5 ? IDFT_MATRIX[wing][j] : DFT_MATRIX[wing][j];
+ float coeff_real = coeff.x;
+ float coeff_imag = coeff.y;
// Complex multiply-accumulate
- sum.x += coeff.x * input_val.x - coeff.y * input_val.y;
- sum.y += coeff.x * input_val.y + coeff.y * input_val.x;
+ sum_real += coeff_real * in_real - coeff_imag * in_imag;
+ sum_imag += coeff_real * in_imag + coeff_imag * in_real;
}
// Apply stage twiddle if needed
+ float out_real, out_imag;
if (wing > 0 && idx_in_wing > 0)
{
int twiddle_idx = wing * idx_in_wing;
- float2 tw = _Inverse > 0.5 ? STAGE_TWIDDLES_INV[twiddle_idx] : STAGE_TWIDDLES[twiddle_idx];
+ float2 tw;
+
+ if (_Stage % 2 == 0) {
+ tw = _Inverse > 0.5 ? STAGE0_TWIDDLES_INV[twiddle_idx] : STAGE0_TWIDDLES[twiddle_idx];
+ } else {
+ tw = _Inverse > 0.5 ? STAGE1_TWIDDLES_INV[twiddle_idx] : STAGE1_TWIDDLES[twiddle_idx];
+ }
+
+ float tw_real = tw.x;
+ float tw_imag = tw.y;
// Output = twiddle * sum
- float2 output;
- output.x = tw.x * sum.x - tw.y * sum.y;
- output.y = tw.x * sum.y + tw.y * sum.x;
- sum = output;
+ out_real = tw_real * sum_real - tw_imag * sum_imag;
+ out_imag = tw_real * sum_imag + tw_imag * sum_real;
+ }
+ else
+ {
+ out_real = sum_real;
+ out_imag = sum_imag;
}
- // Pack complex result into RGBA.
- float real_part = sum.x;
- float imag_part = sum.y;
-
+ // Handle final stage of inverse FFT
if (_Inverse > 0.5 && _Stage == i.num_stages_per_dim * 2 - 1) {
- // Last stage of IFFT is just back to the original real-valued signal.
- real_part /= _N * i.num_stages_per_dim;
- return float4(real_part, real_part, real_part, 1);
+ float normalized = out_real / (_N * _N);
+ return float4(normalized, normalized, normalized, 1);
}
- // Split into 2 parts.
- float real_big = floor(real_part);
- float real_small = real_part - real_big;
- float imag_big = floor(imag_part);
- float imag_small = imag_part - imag_big;
-
- return float4(real_big, real_small, imag_big, imag_small);
+ // Pack complex result into RGBA
+ return float4(out_real, out_imag, 0, 1);
}
ENDCG
}