diff options
| author | yum <yum.food.vr@gmail.com> | 2025-10-18 12:14:09 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-10-18 12:14:09 -0700 |
| commit | c6581fe78d5abae4289432d9d8fd01729fe17b0e (patch) | |
| tree | 63a8d98f69f3eca19fca1569626e281c09322a5e /Scripts/approximate.py | |
| parent | ddf847e758d9766d33ccc9b1a560a84142e395e0 (diff) | |
add vibe coded fourier approximation tool
Diffstat (limited to 'Scripts/approximate.py')
| -rw-r--r-- | Scripts/approximate.py | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/Scripts/approximate.py b/Scripts/approximate.py new file mode 100644 index 0000000..1bfe13f --- /dev/null +++ b/Scripts/approximate.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +""" +Fourier approximation utility. + +Given an analytic expression f(x) and a finite interval [a, b], this script +samples the function uniformly, computes Fourier series coefficients via the +FFT, and reports the L2 error of partial sums. Sine and cosine components are +treated as individual terms and applied in descending order of amplitude to +highlight the strongest contributions first. +""" + +from __future__ import annotations + +import argparse +import numpy as np +import sys +from typing import Callable, Dict, NamedTuple + + +class FourierTerm(NamedTuple): + index: int + kind: str # "sin" or "cos" + coefficient: float + amplitude: float + phase: float + angular_frequency: float + frequency: float + +def frac(x): + """Return the fractional part of x in [0, 1).""" + arr = np.asarray(x, dtype=float) + frac_part = arr - np.floor(arr) + if np.isscalar(x): + return float(frac_part) + return frac_part + +# Functions that are allowed inside the user supplied expression. The subset is +# intentionally small to keep evaluation safe while still being practical. +_ALLOWED_FUNCS: Dict[str, object] = { + "np": np, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, + "exp": np.exp, + "log": np.log, + "log10": np.log10, + "log2": np.log2, + "sqrt": np.sqrt, + "sinh": np.sinh, + "cosh": np.cosh, + "tanh": np.tanh, + "arcsin": np.arcsin, + "arccos": np.arccos, + "arctan": np.arctan, + "abs": np.abs, + "pi": np.pi, + "e": np.e, + "frac": frac, +} + + +def build_function(expression: str) -> Callable[[np.ndarray], np.ndarray]: + """Create a vectorised callable from the provided expression string.""" + try: + code = compile(expression, "<expression>", "eval") + except SyntaxError as exc: + raise ValueError(f"Invalid function expression: {exc}") from exc + + def func(x: np.ndarray) -> np.ndarray: + local_dict = dict(_ALLOWED_FUNCS) + local_dict["x"] = x + try: + value = eval(code, {"__builtins__": {}}, local_dict) + except Exception as exc: # pragma: no cover - user provided expression + raise ValueError(f"Error while evaluating expression: {exc}") from exc + return np.asarray(value, dtype=float) + + return func + + +def l2_norm(values: np.ndarray, length: float) -> float: + """Compute the L2 norm using a midpoint rule over the sampling grid.""" + squared = np.square(values) + step = length / values.size + integral = np.sum(squared) * step + return float(np.sqrt(integral / length)) + + +def fft_terms( + func: Callable[[np.ndarray], np.ndarray], + interval: tuple[float, float], + term_count: int, + samples: int, +) -> tuple[np.ndarray, np.ndarray, float, list[FourierTerm], float, int]: + """Sample the function and return Fourier terms up to term_count.""" + start, end = interval + if end <= start: + raise ValueError("Interval end must be greater than start.") + if term_count < 1: + raise ValueError("term_count must be at least 1.") + if samples < 2: + raise ValueError("samples must be at least 2.") + + length = end - start + xs = start + (np.arange(samples) * length / samples) + + fx = func(xs) + if fx.shape == (): + fx = np.full_like(xs, float(fx)) + if fx.shape != xs.shape: + raise ValueError( + "Function evaluation did not return values of the expected shape." + ) + + spectrum = np.fft.rfft(fx) / samples + + constant_term = float(spectrum[0].real) + + max_terms = min(term_count, max(len(spectrum) - 1, 0)) + terms: list[FourierTerm] = [] + + if max_terms == 0: + return xs, fx, constant_term, terms, length, max_terms + + tol = 1e-12 + + for n in range(1, max_terms + 1): + coeff = spectrum[n] + an = float(2.0 * coeff.real) + bn = float(-2.0 * coeff.imag) + angular_frequency = float(2.0 * np.pi * n / length) + frequency = float(n / length) + + if abs(an) > tol: + amplitude = abs(an) + phase = 0.0 if an >= 0 else float(np.pi) + terms.append( + FourierTerm( + index=n, + kind="cos", + coefficient=an, + amplitude=amplitude, + phase=phase, + angular_frequency=angular_frequency, + frequency=frequency, + ) + ) + + if abs(bn) > tol: + amplitude = abs(bn) + phase = 0.0 if bn >= 0 else float(np.pi) + terms.append( + FourierTerm( + index=n, + kind="sin", + coefficient=bn, + amplitude=amplitude, + phase=phase, + angular_frequency=angular_frequency, + frequency=frequency, + ) + ) + + return xs, fx, constant_term, terms, length, max_terms + + +def format_partial_expression( + constant_term: float, + terms: list[FourierTerm], + interval_start: float, + length: float, +) -> str: + """Build a copy-pastable expression for the partial trigonometric sum.""" + + tol = 1e-12 + expr_parts: list[str] = [] + + if abs(constant_term) > tol: + expr_parts.append(f"{constant_term:.6e}") + + def append_component(components: list[str], coeff: float, func: str, argument: str) -> None: + if abs(coeff) <= tol: + return + base = f"{abs(coeff):.6e} * {func}({argument})" + if components: + sign = "+" if coeff >= 0 else "-" + components.append(f"{sign} {base}") + else: + components.append(base if coeff >= 0 else f"-{base}") + + for term in terms: + omega = 2.0 * np.pi * term.index / length + if abs(interval_start) <= tol: + argument = f"{omega:.6e}*x" + elif interval_start < 0: + argument = f"{omega:.6e}*(x + {abs(interval_start):.6e})" + else: + argument = f"{omega:.6e}*(x - {interval_start:.6e})" + func_name = "sin" if term.kind == "sin" else "cos" + append_component(expr_parts, term.coefficient, func_name, argument) + + if not expr_parts: + return "0" + + return " ".join(expr_parts) + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Approximate a real-valued function on [start, end] using a Fourier " + "series derived from FFT samples and report the L2 error of the " + "first N partial sums (sorted by amplitude)." + ) + ) + parser.add_argument( + "expression", + help=( + "Function expression in terms of x. Use numpy-style syntax, e.g. " + "'sin(x) + 0.5*cos(3*x)'." + ), + ) + parser.add_argument( + "start", + type=float, + help="Beginning of the interval for the approximation.", + ) + parser.add_argument( + "end", + type=float, + help="End of the interval for the approximation.", + ) + parser.add_argument( + "--terms", + type=int, + default=10, + help="Number of Fourier harmonics to include (default: 10).", + ) + parser.add_argument( + "--samples", + type=int, + default=2048, + help="Number of uniform samples for the FFT (default: 2048).", + ) + parser.add_argument( + "--relative", + action="store_true", + help="Report the relative L2 error in addition to the absolute error.", + ) + if argv is None: + argv = sys.argv[1:] + if not argv: + parser.print_help(sys.stderr) + parser.exit(1) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + try: + func = build_function(args.expression) + xs, fx, constant_term, terms, length, available_harmonics = fft_terms( + func=func, + interval=(args.start, args.end), + term_count=args.terms, + samples=args.samples, + ) + except ValueError as exc: + print(f"Error: {exc}", file=sys.stderr) + return 1 + + base_norm = l2_norm(fx, length) + + if args.terms > available_harmonics: + print( + f"Warning: Requested {args.terms} harmonics but only {available_harmonics} available with {args.samples} samples.", + file=sys.stderr, + ) + + theta = 2.0 * np.pi * (xs - args.start) / length + sorted_terms = sorted(terms, key=lambda term: term.amplitude, reverse=True) + available_components = len(sorted_terms) + + if args.terms > available_components: + print( + f"Warning: Requested {args.terms} components but only {available_components} available from the sampled harmonics.", + file=sys.stderr, + ) + + sorted_terms = sorted_terms[: args.terms] + + partial = np.full_like(fx, constant_term) + cumulative_terms: list[FourierTerm] = [] + trig_cache: dict[int, tuple[np.ndarray, np.ndarray]] = {} + + header = "Terms".rjust(5) + " " + "L2 error".rjust(14) + if args.relative: + header += " " + "Rel. L2 error".rjust(14) + print(header) + print("-" * len(header)) + + print(f"Constant term: {constant_term:.6e}") + print( + "Terms are sorted by descending amplitude. Each line is a sine or cosine component with params (amplitude, phase, frequency); phase in radians, frequency in cycles per unit." + ) + + for idx, term in enumerate(sorted_terms, start=1): + cos_sin = trig_cache.get(term.index) + if cos_sin is None: + angles = term.index * theta + cos_sin = (np.cos(angles), np.sin(angles)) + trig_cache[term.index] = cos_sin + cos_n, sin_n = cos_sin + if term.kind == "cos": + partial += term.coefficient * cos_n + else: + partial += term.coefficient * sin_n + error = l2_norm(fx - partial, length) + cumulative_terms.append(term) + line = f"{idx:5d} {error:14.6e}" + if args.relative: + if base_norm > 0.0: + rel_error = error / base_norm + else: + rel_error = float("nan") if error > 0 else 0.0 + line += f" {rel_error:14.6e}" + term_info = ( + f"n={term.index} {term.kind} (coeff {term.coefficient:.6e}, amp {term.amplitude:.6e}, phase {term.phase:.6e}, freq {term.frequency:.6e})" + ) + line += f" {term_info}" + print(line) + expression = format_partial_expression( + constant_term, + cumulative_terms, + args.start, + length, + ) + print(f" expr: {expression}") + + print( + f"Interval length: {length:.6g}, samples: {args.samples}, base L2 norm: {base_norm:.6e}" + ) + print("Note: errors use a midpoint-rule approximation on the sampling grid.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) |
