#!/usr/bin/env python3 """ Fourier approximation utility. Given an analytic expression f(x) and a finite interval [a, b], this script samples the function uniformly, computes Fourier series coefficients via the FFT, and reports the L2 error of partial sums. Sine and cosine components are treated as individual terms and applied in descending order of amplitude to highlight the strongest contributions first. """ from __future__ import annotations import argparse import numpy as np import sys from typing import Callable, Dict, NamedTuple class FourierTerm(NamedTuple): index: int kind: str # "sin" or "cos" coefficient: float amplitude: float phase: float angular_frequency: float frequency: float def frac(x): """Return the fractional part of x in [0, 1).""" arr = np.asarray(x, dtype=float) frac_part = arr - np.floor(arr) if np.isscalar(x): return float(frac_part) return frac_part # Functions that are allowed inside the user supplied expression. The subset is # intentionally small to keep evaluation safe while still being practical. _ALLOWED_FUNCS: Dict[str, object] = { "np": np, "sin": np.sin, "cos": np.cos, "tan": np.tan, "exp": np.exp, "log": np.log, "log10": np.log10, "log2": np.log2, "sqrt": np.sqrt, "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh, "arcsin": np.arcsin, "arccos": np.arccos, "arctan": np.arctan, "abs": np.abs, "pi": np.pi, "e": np.e, "frac": frac, } def build_function(expression: str) -> Callable[[np.ndarray], np.ndarray]: """Create a vectorised callable from the provided expression string.""" try: code = compile(expression, "", "eval") except SyntaxError as exc: raise ValueError(f"Invalid function expression: {exc}") from exc def func(x: np.ndarray) -> np.ndarray: local_dict = dict(_ALLOWED_FUNCS) local_dict["x"] = x try: value = eval(code, {"__builtins__": {}}, local_dict) except Exception as exc: # pragma: no cover - user provided expression raise ValueError(f"Error while evaluating expression: {exc}") from exc return np.asarray(value, dtype=float) return func def l2_norm(values: np.ndarray, length: float) -> float: """Compute the L2 norm using a midpoint rule over the sampling grid.""" squared = np.square(values) step = length / values.size integral = np.sum(squared) * step return float(np.sqrt(integral / length)) def fft_terms( func: Callable[[np.ndarray], np.ndarray], interval: tuple[float, float], term_count: int, samples: int, ) -> tuple[np.ndarray, np.ndarray, float, list[FourierTerm], float, int]: """Sample the function and return Fourier terms up to term_count.""" start, end = interval if end <= start: raise ValueError("Interval end must be greater than start.") if term_count < 1: raise ValueError("term_count must be at least 1.") if samples < 2: raise ValueError("samples must be at least 2.") length = end - start xs = start + (np.arange(samples) * length / samples) fx = func(xs) if fx.shape == (): fx = np.full_like(xs, float(fx)) if fx.shape != xs.shape: raise ValueError( "Function evaluation did not return values of the expected shape." ) spectrum = np.fft.rfft(fx) / samples constant_term = float(spectrum[0].real) max_terms = min(term_count, max(len(spectrum) - 1, 0)) terms: list[FourierTerm] = [] if max_terms == 0: return xs, fx, constant_term, terms, length, max_terms tol = 1e-12 for n in range(1, max_terms + 1): coeff = spectrum[n] an = float(2.0 * coeff.real) bn = float(-2.0 * coeff.imag) angular_frequency = float(2.0 * np.pi * n / length) frequency = float(n / length) if abs(an) > tol: amplitude = abs(an) phase = 0.0 if an >= 0 else float(np.pi) terms.append( FourierTerm( index=n, kind="cos", coefficient=an, amplitude=amplitude, phase=phase, angular_frequency=angular_frequency, frequency=frequency, ) ) if abs(bn) > tol: amplitude = abs(bn) phase = 0.0 if bn >= 0 else float(np.pi) terms.append( FourierTerm( index=n, kind="sin", coefficient=bn, amplitude=amplitude, phase=phase, angular_frequency=angular_frequency, frequency=frequency, ) ) return xs, fx, constant_term, terms, length, max_terms def format_partial_expression( constant_term: float, terms: list[FourierTerm], interval_start: float, length: float, ) -> str: """Build a copy-pastable expression for the partial trigonometric sum.""" tol = 1e-12 expr_parts: list[str] = [] if abs(constant_term) > tol: expr_parts.append(f"{constant_term:.6e}") def append_component(components: list[str], coeff: float, func: str, argument: str) -> None: if abs(coeff) <= tol: return base = f"{abs(coeff):.6e} * {func}({argument})" if components: sign = "+" if coeff >= 0 else "-" components.append(f"{sign} {base}") else: components.append(base if coeff >= 0 else f"-{base}") for term in terms: omega = 2.0 * np.pi * term.index / length if abs(interval_start) <= tol: argument = f"{omega:.6e}*x" elif interval_start < 0: argument = f"{omega:.6e}*(x + {abs(interval_start):.6e})" else: argument = f"{omega:.6e}*(x - {interval_start:.6e})" func_name = "sin" if term.kind == "sin" else "cos" append_component(expr_parts, term.coefficient, func_name, argument) if not expr_parts: return "0" return " ".join(expr_parts) def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Approximate a real-valued function on [start, end] using a Fourier " "series derived from FFT samples and report the L2 error of the " "first N partial sums (sorted by amplitude)." ) ) parser.add_argument( "expression", help=( "Function expression in terms of x. Use numpy-style syntax, e.g. " "'sin(x) + 0.5*cos(3*x)'." ), ) parser.add_argument( "start", type=float, help="Beginning of the interval for the approximation.", ) parser.add_argument( "end", type=float, help="End of the interval for the approximation.", ) parser.add_argument( "--terms", type=int, default=10, help="Number of Fourier harmonics to include (default: 10).", ) parser.add_argument( "--samples", type=int, default=2048, help="Number of uniform samples for the FFT (default: 2048).", ) parser.add_argument( "--relative", action="store_true", help="Report the relative L2 error in addition to the absolute error.", ) if argv is None: argv = sys.argv[1:] if not argv: parser.print_help(sys.stderr) parser.exit(1) return parser.parse_args(argv) def main(argv: list[str] | None = None) -> int: args = parse_args(argv) try: func = build_function(args.expression) xs, fx, constant_term, terms, length, available_harmonics = fft_terms( func=func, interval=(args.start, args.end), term_count=args.terms, samples=args.samples, ) except ValueError as exc: print(f"Error: {exc}", file=sys.stderr) return 1 base_norm = l2_norm(fx, length) if args.terms > available_harmonics: print( f"Warning: Requested {args.terms} harmonics but only {available_harmonics} available with {args.samples} samples.", file=sys.stderr, ) theta = 2.0 * np.pi * (xs - args.start) / length sorted_terms = sorted(terms, key=lambda term: abs(term.amplitude), reverse=True) available_components = len(sorted_terms) if args.terms > available_components: print( f"Warning: Requested {args.terms} components but only {available_components} available from the sampled harmonics.", file=sys.stderr, ) sorted_terms = sorted_terms[: args.terms] partial = np.full_like(fx, constant_term) cumulative_terms: list[FourierTerm] = [] trig_cache: dict[int, tuple[np.ndarray, np.ndarray]] = {} header = "Terms".rjust(5) + " " + "L2 error".rjust(14) if args.relative: header += " " + "Rel. L2 error".rjust(14) print(header) print("-" * len(header)) print(f"Constant term: {constant_term:.6e}") print( "Terms are sorted by descending amplitude. Each line is a sine or cosine component with params (amplitude, phase, frequency); phase in radians, frequency in cycles per unit." ) for idx, term in enumerate(sorted_terms, start=1): cos_sin = trig_cache.get(term.index) if cos_sin is None: angles = term.index * theta cos_sin = (np.cos(angles), np.sin(angles)) trig_cache[term.index] = cos_sin cos_n, sin_n = cos_sin if term.kind == "cos": partial += term.coefficient * cos_n else: partial += term.coefficient * sin_n error = l2_norm(fx - partial, length) cumulative_terms.append(term) line = f"{idx:5d} {error:14.6e}" if args.relative: if base_norm > 0.0: rel_error = error / base_norm else: rel_error = float("nan") if error > 0 else 0.0 line += f" {rel_error:14.6e}" term_info = ( f"n={term.index} {term.kind} (coeff {term.coefficient:.6e}, amp {term.amplitude:.6e}, phase {term.phase:.6e}, freq {term.frequency:.6e})" ) line += f" {term_info}" print(line) expression = format_partial_expression( constant_term, cumulative_terms, args.start, length, ) print(f" expr: {expression}") print( f"Interval length: {length:.6g}, samples: {args.samples}, base L2 norm: {base_norm:.6e}" ) print("Note: errors use a midpoint-rule approximation on the sampling grid.") return 0 if __name__ == "__main__": sys.exit(main())