| 1 | """Train and export a one-class MLB ball detector with Ultralytics YOLO.""" |
| 2 | from __future__ import annotations |
| 3 | |
| 4 | import argparse |
| 5 | import shutil |
| 6 | from pathlib import Path |
| 7 | |
| 8 | from rich.console import Console |
| 9 | |
| 10 | ROOT = Path(__file__).resolve().parents[1] |
| 11 | DATA_YAML = ROOT / "configs" / "ball_yolo.yaml" |
| 12 | MODEL_DIR = ROOT / "configs" / "models" |
| 13 | console = Console() |
| 14 | |
| 15 | def main() -> int: |
| 16 | ap = argparse.ArgumentParser(description="Train/export MLB ball YOLO model.") |
| 17 | ap.add_argument("--model", default="yolo11n.pt", help="Base model, e.g. yolo11n.pt or yolov8n.pt.") |
| 18 | ap.add_argument("--epochs", type=int, default=80) |
| 19 | ap.add_argument("--imgsz", type=int, default=640) |
| 20 | ap.add_argument("--batch", type=float, default=None, help="Batch size. Omit for Ultralytics default.") |
| 21 | ap.add_argument("--device", default=None, help="cuda device id, cpu, or omit for auto.") |
| 22 | ap.add_argument("--name", default="mlb26_ball") |
| 23 | ap.add_argument("--workers", type=int, default=0, help="Dataloader workers. Use 0 on Windows/network paths.") |
| 24 | args = ap.parse_args() |
| 25 | |
| 26 | from ultralytics import YOLO |
| 27 | |
| 28 | model = YOLO(args.model) |
| 29 | train_kwargs = { |
| 30 | "data": str(DATA_YAML), |
| 31 | "epochs": args.epochs, |
| 32 | "imgsz": args.imgsz, |
| 33 | "name": args.name, |
| 34 | "project": str(ROOT / "runs" / "detect"), |
| 35 | "single_cls": True, |
| 36 | "patience": 25, |
| 37 | "amp": False, |
| 38 | "workers": args.workers, |
| 39 | "mosaic": 0.0, |
| 40 | "fliplr": 0.0, |
| 41 | "erasing": 0.0, |
| 42 | "hsv_h": 0.0, |
| 43 | "hsv_s": 0.15, |
| 44 | "hsv_v": 0.15, |
| 45 | } |
| 46 | if args.batch is not None: |
| 47 | train_kwargs["batch"] = int(args.batch) if args.batch >= 1 else args.batch |
| 48 | if args.device: |
| 49 | train_kwargs["device"] = args.device |
| 50 | |
| 51 | results = model.train(**train_kwargs) |
| 52 | best = Path(results.save_dir) / "weights" / "best.pt" |
| 53 | if not best.exists(): |
| 54 | console.print(f"[red]Training finished but best.pt was not found at {best}[/red]") |
| 55 | return 2 |
| 56 | |
| 57 | trained = YOLO(str(best)) |
| 58 | onnx_path = Path(trained.export(format="onnx", imgsz=args.imgsz, simplify=True, opset=12)) |
| 59 | MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| 60 | final_onnx = MODEL_DIR / "mlb26_ball.onnx" |
| 61 | shutil.copy2(onnx_path, final_onnx) |
| 62 | shutil.copy2(best, MODEL_DIR / "mlb26_ball.pt") |
| 63 | console.print(f"[bold green]Exported ONNX -> {final_onnx}[/bold green]") |
| 64 | console.print(f"[green]Saved PyTorch weights -> {MODEL_DIR / 'mlb26_ball.pt'}[/green]") |
| 65 | return 0 |
| 66 | |
| 67 | if __name__ == "__main__": |
| 68 | raise SystemExit(main()) |