summaryrefslogtreecommitdiffstats
path: root/generate_fft_reference.py
blob: a7d58f4fb1fb57e693311d49f1fb34398c4b7c1b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
"""
Generate reference FFT image in OpenEXR format, matching the shader's output format.
"""

import numpy as np
import argparse
from PIL import Image
import OpenEXR
import Imath

def main():
    parser = argparse.ArgumentParser(description='Generate reference FFT image')
    parser.add_argument('input', type=str, help='Input image file')
    parser.add_argument('output', type=str, help='Output EXR file')
    parser.add_argument('--size', type=int, default=256, help='Size of the FFT (NxN)')
    args = parser.parse_args()

    # Load input image and convert to luminance
    img = Image.open(args.input).convert('RGB')
    img = img.resize((args.size, args.size), Image.Resampling.LANCZOS)
    img_array = np.array(img, dtype=np.float32) / 255.0
    luminance = 0.2126 * img_array[:,:,0] + 0.7152 * img_array[:,:,1] + 0.0722 * img_array[:,:,2]

    # Perform 2D FFT (no fftshift, matching GPU implementation)
    fft_result = np.fft.fft2(luminance)

    # Pack complex numbers into RGBA matching shader format:
    # R: real part, G: imaginary part, B: 0, A: 1
    real_part = fft_result.real.astype(np.float32)
    imag_part = fft_result.imag.astype(np.float32)

    # Create EXR
    height, width = args.size, args.size
    header = OpenEXR.Header(width, height)
    half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
    header['channels'] = {'R': half_chan, 'G': half_chan, 'B': half_chan, 'A': half_chan}

    # Write EXR
    out = OpenEXR.OutputFile(args.output, header)

    # Create zero and one arrays for B and A channels
    zeros = np.zeros((height, width), dtype=np.float32)
    ones = np.ones((height, width), dtype=np.float32)

    out.writePixels({
        'R': real_part.astype(np.float32).tobytes(),
        'G': imag_part.astype(np.float32).tobytes(),
        'B': zeros.tobytes(),
        'A': ones.tobytes()
    })
    out.close()

    print(f"FFT complete. Output saved to {args.output}")
    print(f"Real range: [{real_part.min():.1f}, {real_part.max():.1f}]")
    print(f"Imag range: [{imag_part.min():.1f}, {imag_part.max():.1f}]")

if __name__ == "__main__":
    main()