summaryrefslogtreecommitdiffstats
path: root/generate_fft_reference.py
blob: db2289d87ba6528a319f9a436a96161d48665ee3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
"""
Generate reference FFT image in OpenEXR format, matching the shader's output format.
"""

import numpy as np
import argparse
from PIL import Image
import OpenEXR
import Imath

def main():
    parser = argparse.ArgumentParser(description='Generate reference FFT image')
    parser.add_argument('input', type=str, help='Input image file')
    parser.add_argument('output', type=str, help='Output EXR file')
    parser.add_argument('--size', type=int, default=256, help='Size of the FFT (NxN)')
    args = parser.parse_args()

    # Load input image and convert to luminance
    img = Image.open(args.input).convert('RGB')
    img = img.resize((args.size, args.size), Image.Resampling.NEAREST)
    img_array = np.array(img, dtype=np.float32) / 255.0
    luminance = 0.2126 * img_array[:,:,0] + 0.7152 * img_array[:,:,1] + 0.0722 * img_array[:,:,2]

    # Unity's UV origin is bottom-left, NumPy/PIL's is top-left; flip vertically to align
    luminance = np.flipud(luminance)

    # Perform 2D FFT (no fftshift, matching GPU implementation)
    fft_result = np.fft.fft2(luminance)

    # Pack complex numbers into RGBA matching shader format:
    # R: real part, G: imaginary part, B: 0, A: 1
    real_part = fft_result.real.astype(np.float32)
    imag_part = fft_result.imag.astype(np.float32)

    # Unity's output texture also uses bottom-left origin, so flip output too
    real_part = np.flipud(real_part)
    imag_part = np.flipud(imag_part)

    # Create EXR
    height, width = args.size, args.size
    header = OpenEXR.Header(width, height)
    half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
    header['channels'] = {'R': half_chan, 'G': half_chan, 'B': half_chan, 'A': half_chan}

    # Write EXR
    out = OpenEXR.OutputFile(args.output, header)

    # Create zero and one arrays for B and A channels
    zeros = np.zeros((height, width), dtype=np.float32)
    ones = np.ones((height, width), dtype=np.float32)

    out.writePixels({
        'R': real_part.astype(np.float32).tobytes(),
        'G': imag_part.astype(np.float32).tobytes(),
        'B': zeros.tobytes(),
        'A': ones.tobytes()
    })
    out.close()

    print(f"FFT complete. Output saved to {args.output}")
    print(f"Real range: [{real_part.min():.1f}, {real_part.max():.1f}]")
    print(f"Imag range: [{imag_part.min():.1f}, {imag_part.max():.1f}]")

if __name__ == "__main__":
    main()