| 1 | """Ball tracker. |
| 2 | |
| 3 | Subscribes to capture frames, finds the ball via HSV + circularity, keeps a |
| 4 | rolling window of detections, and emits: |
| 5 | - ball_track events each frame with the most-confident detection (or miss) |
| 6 | - pitch_pred events once a fittable trajectory accumulates (plate_x + eta_ms) |
| 7 | |
| 8 | Classical CV only. Tune HSV and radius bounds via configs/runtime.yaml. |
| 9 | """ |
| 10 | from __future__ import annotations |
| 11 | |
| 12 | import argparse |
| 13 | import sys |
| 14 | import time |
| 15 | from collections import deque |
| 16 | from pathlib import Path |
| 17 | |
| 18 | import cv2 |
| 19 | import numpy as np |
| 20 | from rich.console import Console |
| 21 | |
| 22 | sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
| 23 | from cv._common import ( |
| 24 | event_subscriber, |
| 25 | iter_latest_frames, |
| 26 | load_config, |
| 27 | make_frame_subscriber, |
| 28 | make_pub, |
| 29 | send_event, |
| 30 | ) |
| 31 | |
| 32 | console = Console() |
| 33 | |
| 34 | TRAJ_WINDOW_S = 0.8 |
| 35 | MIN_FIT_POINTS = 4 |
| 36 | FIT_USE_LAST_N = 8 |
| 37 | PITCH_GAP_MS = 300 |
| 38 | MIN_DOWN_PX = 30 |
| 39 | STATIC_WINDOW = 3 |
| 40 | STATIC_SPREAD_PX = 3.0 |
| 41 | BAN_ZONE_MS = 1500 |
| 42 | BAN_ZONE_PX = 6 |
| 43 | PITCH_START_MAX_Y = 500 |
| 44 | |
| 45 | MAX_STEP_PX = 200 |
| 46 | |
| 47 | class Detection: |
| 48 | __slots__ = ("ts_ns", "x", "y", "r", "score") |
| 49 | def __init__(self, ts_ns: int, x: float, y: float, r: float, score: float): |
| 50 | self.ts_ns = ts_ns |
| 51 | self.x = x |
| 52 | self.y = y |
| 53 | self.r = r |
| 54 | self.score = score |
| 55 | |
| 56 | def detect_ball(frame: np.ndarray, cfg: dict) -> Detection | None: |
| 57 | pd = cfg["cv"]["pitch_detect"] |
| 58 | hsv_low = np.array(pd["ball_hsv_low"], dtype=np.uint8) |
| 59 | hsv_high = np.array(pd["ball_hsv_high"], dtype=np.uint8) |
| 60 | r_min = float(pd["min_ball_radius_px"]) |
| 61 | r_max = float(pd["max_ball_radius_px"]) |
| 62 | |
| 63 | hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
| 64 | mask = cv2.inRange(hsv, hsv_low, hsv_high) |
| 65 | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1) |
| 66 | |
| 67 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| 68 | best: Detection | None = None |
| 69 | best_score = 0.0 |
| 70 | for c in contours: |
| 71 | area = cv2.contourArea(c) |
| 72 | if area < np.pi * r_min * r_min * 0.5: |
| 73 | continue |
| 74 | (cx, cy), r = cv2.minEnclosingCircle(c) |
| 75 | if r < r_min or r > r_max: |
| 76 | continue |
| 77 | |
| 78 | circ_area = np.pi * r * r |
| 79 | score = float(area / circ_area) if circ_area > 0 else 0.0 |
| 80 | if score < 0.6: |
| 81 | continue |
| 82 | if score > best_score: |
| 83 | best_score = score |
| 84 | best = Detection(ts_ns=0, x=float(cx), y=float(cy), r=float(r), score=score) |
| 85 | return best |
| 86 | |
| 87 | def try_fit(trail: deque[Detection], plate_y_px: float) -> tuple[float, float] | None: |
| 88 | """Return (plate_x_px, eta_ms_from_now) or None if not fittable.""" |
| 89 | if len(trail) < MIN_FIT_POINTS: |
| 90 | return None |
| 91 | |
| 92 | recent = list(trail)[-FIT_USE_LAST_N:] |
| 93 | ys = np.array([d.y for d in recent], dtype=np.float64) |
| 94 | if ys.max() - ys.min() < MIN_DOWN_PX: |
| 95 | return None |
| 96 | |
| 97 | if (ys[-1] - ys[0]) < -30: |
| 98 | return None |
| 99 | |
| 100 | t0 = recent[0].ts_ns |
| 101 | ts = np.array([(d.ts_ns - t0) / 1e9 for d in recent], dtype=np.float64) |
| 102 | xs = np.array([d.x for d in recent], dtype=np.float64) |
| 103 | |
| 104 | ay, by, cy = np.polyfit(ts, ys, 2) |
| 105 | bx, cx = np.polyfit(ts, xs, 1) |
| 106 | |
| 107 | disc = by * by - 4 * ay * (cy - plate_y_px) |
| 108 | if disc < 0 or abs(ay) < 1e-6: |
| 109 | return None |
| 110 | sqrt_d = float(np.sqrt(disc)) |
| 111 | t_candidates = [(-by + sqrt_d) / (2 * ay), (-by - sqrt_d) / (2 * ay)] |
| 112 | t_cross = None |
| 113 | for tc in t_candidates: |
| 114 | if tc > ts[-1]: |
| 115 | if t_cross is None or tc < t_cross: |
| 116 | t_cross = tc |
| 117 | if t_cross is None: |
| 118 | return None |
| 119 | now_ns = time.time_ns() |
| 120 | plate_ns = t0 + int(t_cross * 1e9) |
| 121 | eta_ms = (plate_ns - now_ns) / 1e6 |
| 122 | if eta_ms < 0 or eta_ms > 2000: |
| 123 | return None |
| 124 | plate_x = bx * t_cross + cx |
| 125 | return float(plate_x), float(eta_ms) |
| 126 | |
| 127 | def main() -> int: |
| 128 | ap = argparse.ArgumentParser(description="Ball tracker + parabolic pitch-prediction.") |
| 129 | ap.add_argument("--duration", type=float, default=0.0, help="Stop after N seconds (0 = run forever).") |
| 130 | ap.add_argument("--quiet", action="store_true", help="Suppress per-frame status lines.") |
| 131 | args = ap.parse_args() |
| 132 | |
| 133 | cfg = load_config() |
| 134 | capture_ep = cfg["capture"]["publish_endpoint"] |
| 135 | ball_ep = cfg["cv"]["ball_events_endpoint"] |
| 136 | plate_y_frac = float(cfg["cv"].get("plate_y_frac", 0.72)) |
| 137 | |
| 138 | sub = make_frame_subscriber(capture_ep) |
| 139 | pub = make_pub(ball_ep) |
| 140 | console.print(f"[green]ball_tracker[/green] sub={capture_ep} pub={ball_ep}") |
| 141 | |
| 142 | trail: deque[Detection] = deque(maxlen=64) |
| 143 | banned: deque = deque(maxlen=32) |
| 144 | pitch_id = 0 |
| 145 | last_det_ns: int | None = None |
| 146 | t_end = time.perf_counter() + args.duration if args.duration > 0 else None |
| 147 | frames = 0 |
| 148 | hits = 0 |
| 149 | preds = 0 |
| 150 | t_report = time.perf_counter() |
| 151 | |
| 152 | try: |
| 153 | for meta, frame in iter_latest_frames(sub, timeout_ms=3000): |
| 154 | if t_end is not None and time.perf_counter() >= t_end: |
| 155 | break |
| 156 | frames += 1 |
| 157 | ts_ns = int(meta["ts_ns"]) |
| 158 | h = int(meta.get("h", frame.shape[0])) |
| 159 | plate_y_px = h * plate_y_frac |
| 160 | |
| 161 | while banned and banned[0][2] <= ts_ns: |
| 162 | banned.popleft() |
| 163 | |
| 164 | det = detect_ball(frame, cfg) |
| 165 | if det is not None: |
| 166 | |
| 167 | in_banned = any( |
| 168 | abs(det.x - bx) < BAN_ZONE_PX and abs(det.y - by) < BAN_ZONE_PX |
| 169 | for (bx, by, _) in banned |
| 170 | ) |
| 171 | if in_banned: |
| 172 | send_event(pub, { |
| 173 | "type": "ball_miss", |
| 174 | "seq": int(meta["seq"]), |
| 175 | "ts_ns": ts_ns, |
| 176 | "reason": "banned_zone", |
| 177 | }) |
| 178 | continue |
| 179 | |
| 180 | det.ts_ns = ts_ns |
| 181 | |
| 182 | if last_det_ns is not None and (ts_ns - last_det_ns) > PITCH_GAP_MS * 1e6: |
| 183 | trail.clear() |
| 184 | pitch_id += 1 |
| 185 | |
| 186 | if trail: |
| 187 | last_d = trail[-1] |
| 188 | dx = det.x - last_d.x |
| 189 | dy_step = det.y - last_d.y |
| 190 | if (dx * dx + dy_step * dy_step) ** 0.5 > MAX_STEP_PX: |
| 191 | trail.clear() |
| 192 | pitch_id += 1 |
| 193 | |
| 194 | if len(trail) == 0 and det.y > PITCH_START_MAX_Y: |
| 195 | send_event(pub, { |
| 196 | "type": "ball_miss", |
| 197 | "seq": int(meta["seq"]), |
| 198 | "ts_ns": ts_ns, |
| 199 | "reason": "not_pitch_start", |
| 200 | "det_y": det.y, |
| 201 | }) |
| 202 | continue |
| 203 | |
| 204 | trail.append(det) |
| 205 | last_det_ns = ts_ns |
| 206 | |
| 207 | cutoff = ts_ns - int(TRAJ_WINDOW_S * 1e9) |
| 208 | while trail and trail[0].ts_ns < cutoff: |
| 209 | trail.popleft() |
| 210 | |
| 211 | if len(trail) >= STATIC_WINDOW: |
| 212 | recent = list(trail)[-STATIC_WINDOW:] |
| 213 | rxs = [d.x for d in recent] |
| 214 | rys = [d.y for d in recent] |
| 215 | spread = max(max(rxs) - min(rxs), max(rys) - min(rys)) |
| 216 | if spread < STATIC_SPREAD_PX: |
| 217 | cx = sum(rxs) / len(rxs) |
| 218 | cy = sum(rys) / len(rys) |
| 219 | banned.append((cx, cy, ts_ns + int(BAN_ZONE_MS * 1e6))) |
| 220 | trail.clear() |
| 221 | last_det_ns = None |
| 222 | send_event(pub, { |
| 223 | "type": "ball_miss", |
| 224 | "seq": int(meta["seq"]), |
| 225 | "ts_ns": ts_ns, |
| 226 | "reason": "static_ui_banned", |
| 227 | "banned_x": cx, "banned_y": cy, |
| 228 | }) |
| 229 | continue |
| 230 | |
| 231 | hits += 1 |
| 232 | send_event(pub, { |
| 233 | "type": "ball_track", |
| 234 | "seq": int(meta["seq"]), |
| 235 | "ts_ns": ts_ns, |
| 236 | "pitch_id": pitch_id, |
| 237 | "x": det.x, "y": det.y, "r": det.r, |
| 238 | "score": det.score, |
| 239 | }) |
| 240 | |
| 241 | fit = try_fit(trail, plate_y_px) |
| 242 | if fit is not None: |
| 243 | plate_x, eta_ms = fit |
| 244 | preds += 1 |
| 245 | send_event(pub, { |
| 246 | "type": "pitch_pred", |
| 247 | "seq": int(meta["seq"]), |
| 248 | "ts_ns": ts_ns, |
| 249 | "pitch_id": pitch_id, |
| 250 | "plate_x": plate_x, |
| 251 | "plate_y": plate_y_px, |
| 252 | "eta_ms": eta_ms, |
| 253 | "n_points": len(trail), |
| 254 | }) |
| 255 | else: |
| 256 | send_event(pub, { |
| 257 | "type": "ball_miss", |
| 258 | "seq": int(meta["seq"]), |
| 259 | "ts_ns": ts_ns, |
| 260 | }) |
| 261 | |
| 262 | now = time.perf_counter() |
| 263 | if not args.quiet and now - t_report >= 5.0: |
| 264 | console.print(f"[dim] ball: {frames} frames, {hits} hits, {preds} preds, trail={len(trail)}[/dim]") |
| 265 | t_report = now |
| 266 | except TimeoutError as e: |
| 267 | console.print(f"[red]ball_tracker: {e}. Is capture/ingest.py running?[/red]") |
| 268 | return 2 |
| 269 | except KeyboardInterrupt: |
| 270 | console.print("[yellow]ball_tracker interrupted.[/yellow]") |
| 271 | finally: |
| 272 | sub.close(0) |
| 273 | pub.close(0) |
| 274 | console.print(f"[bold]ball_tracker summary:[/bold] frames={frames} hits={hits} preds={preds}") |
| 275 | return 0 |
| 276 | |
| 277 | if __name__ == "__main__": |
| 278 | sys.exit(main()) |