Zion Boggan
repos/Pitch Tracker CV/tools/yolo_train_ball.py
zionboggan.com ↗
68 lines · python
History for this file →
1
"""Train and export a one-class MLB ball detector with Ultralytics YOLO."""
2
from __future__ import annotations
3
 
4
import argparse
5
import shutil
6
from pathlib import Path
7
 
8
from rich.console import Console
9
 
10
ROOT = Path(__file__).resolve().parents[1]
11
DATA_YAML = ROOT / "configs" / "ball_yolo.yaml"
12
MODEL_DIR = ROOT / "configs" / "models"
13
console = Console()
14
 
15
def main() -> int:
16
    ap = argparse.ArgumentParser(description="Train/export MLB ball YOLO model.")
17
    ap.add_argument("--model", default="yolo11n.pt", help="Base model, e.g. yolo11n.pt or yolov8n.pt.")
18
    ap.add_argument("--epochs", type=int, default=80)
19
    ap.add_argument("--imgsz", type=int, default=640)
20
    ap.add_argument("--batch", type=float, default=None, help="Batch size. Omit for Ultralytics default.")
21
    ap.add_argument("--device", default=None, help="cuda device id, cpu, or omit for auto.")
22
    ap.add_argument("--name", default="mlb26_ball")
23
    ap.add_argument("--workers", type=int, default=0, help="Dataloader workers. Use 0 on Windows/network paths.")
24
    args = ap.parse_args()
25
 
26
    from ultralytics import YOLO
27
 
28
    model = YOLO(args.model)
29
    train_kwargs = {
30
        "data": str(DATA_YAML),
31
        "epochs": args.epochs,
32
        "imgsz": args.imgsz,
33
        "name": args.name,
34
        "project": str(ROOT / "runs" / "detect"),
35
        "single_cls": True,
36
        "patience": 25,
37
        "amp": False,
38
        "workers": args.workers,
39
        "mosaic": 0.0,
40
        "fliplr": 0.0,
41
        "erasing": 0.0,
42
        "hsv_h": 0.0,
43
        "hsv_s": 0.15,
44
        "hsv_v": 0.15,
45
    }
46
    if args.batch is not None:
47
        train_kwargs["batch"] = int(args.batch) if args.batch >= 1 else args.batch
48
    if args.device:
49
        train_kwargs["device"] = args.device
50
 
51
    results = model.train(**train_kwargs)
52
    best = Path(results.save_dir) / "weights" / "best.pt"
53
    if not best.exists():
54
        console.print(f"[red]Training finished but best.pt was not found at {best}[/red]")
55
        return 2
56
 
57
    trained = YOLO(str(best))
58
    onnx_path = Path(trained.export(format="onnx", imgsz=args.imgsz, simplify=True, opset=12))
59
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
60
    final_onnx = MODEL_DIR / "mlb26_ball.onnx"
61
    shutil.copy2(onnx_path, final_onnx)
62
    shutil.copy2(best, MODEL_DIR / "mlb26_ball.pt")
63
    console.print(f"[bold green]Exported ONNX -> {final_onnx}[/bold green]")
64
    console.print(f"[green]Saved PyTorch weights -> {MODEL_DIR / 'mlb26_ball.pt'}[/green]")
65
    return 0
66
 
67
if __name__ == "__main__":
68
    raise SystemExit(main())