Zion Boggan
repos/Pitch Tracker CV/tools/yolo_label_ball.py
zionboggan.com ↗
178 lines · python
History for this file →
1
"""Tiny OpenCV YOLO labeler for class 0 = ball.
2
 
3
Controls:
4
  drag left mouse: draw ball box, save, and go next
5
  s, Space, or Enter: save label and go next
6
  n: next without saving/changing
7
  b: previous
8
  c: clear current label
9
  e: save empty/no-ball label and go next
10
  q or Esc: quit
11
"""
12
from __future__ import annotations
13
 
14
import argparse
15
from pathlib import Path
16
 
17
import cv2
18
from rich.console import Console
19
 
20
ROOT = Path(__file__).resolve().parents[1]
21
DATA = ROOT / "datasets" / "mlb26_ball_yolo"
22
IMG_DIR = DATA / "images" / "all"
23
LBL_DIR = DATA / "labels" / "all"
24
console = Console()
25
 
26
class LabelState:
27
    def __init__(self, image, scale: float):
28
        self.image = image
29
        self.scale = scale
30
        self.box = None
31
        self.drag_start = None
32
        self.drag_now = None
33
        self.auto_advance = False
34
 
35
    def set_box_from_display(self, x1, y1, x2, y2):
36
        ox1 = int(min(x1, x2) / self.scale)
37
        oy1 = int(min(y1, y2) / self.scale)
38
        ox2 = int(max(x1, x2) / self.scale)
39
        oy2 = int(max(y1, y2) / self.scale)
40
        h, w = self.image.shape[:2]
41
        ox1, ox2 = max(0, ox1), min(w - 1, ox2)
42
        oy1, oy2 = max(0, oy1), min(h - 1, oy2)
43
        if ox2 - ox1 >= 2 and oy2 - oy1 >= 2:
44
            self.box = (ox1, oy1, ox2, oy2)
45
 
46
def yolo_line(box, w, h) -> str:
47
    x1, y1, x2, y2 = box
48
    xc = ((x1 + x2) / 2.0) / w
49
    yc = ((y1 + y2) / 2.0) / h
50
    bw = (x2 - x1) / w
51
    bh = (y2 - y1) / h
52
    return f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n"
53
 
54
def load_label(path: Path, w: int, h: int):
55
    if not path.exists() or not path.read_text(encoding="utf-8").strip():
56
        return None
57
    parts = path.read_text(encoding="utf-8").split()
58
    if len(parts) < 5:
59
        return None
60
    _, xc, yc, bw, bh = parts[:5]
61
    xc, yc, bw, bh = map(float, (xc, yc, bw, bh))
62
    x1 = int((xc - bw / 2.0) * w)
63
    y1 = int((yc - bh / 2.0) * h)
64
    x2 = int((xc + bw / 2.0) * w)
65
    y2 = int((yc + bh / 2.0) * h)
66
    return x1, y1, x2, y2
67
 
68
def draw(state: LabelState, title: str):
69
    img = state.image.copy()
70
    if state.box:
71
        x1, y1, x2, y2 = state.box
72
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), 2)
73
    disp = cv2.resize(img, None, fx=state.scale, fy=state.scale, interpolation=cv2.INTER_AREA)
74
    if state.drag_start and state.drag_now:
75
        cv2.rectangle(disp, state.drag_start, state.drag_now, (0, 255, 0), 2)
76
    cv2.putText(disp, title, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 4, cv2.LINE_AA)
77
    cv2.putText(disp, title, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
78
    help_text = "Draw tight box around baseball. Mouse-up saves + next. b=back c=clear e=empty q=quit"
79
    cv2.rectangle(disp, (8, disp.shape[0] - 42), (min(disp.shape[1] - 8, 980), disp.shape[0] - 8), (0, 0, 0), -1)
80
    cv2.putText(disp, help_text, (16, disp.shape[0] - 18), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2, cv2.LINE_AA)
81
    return disp
82
 
83
def main() -> int:
84
    ap = argparse.ArgumentParser(description="Label YOLO ball boxes.")
85
    ap.add_argument("--images", type=Path, default=IMG_DIR)
86
    ap.add_argument("--labels", type=Path, default=LBL_DIR)
87
    ap.add_argument("--scale", type=float, default=0.65)
88
    ap.add_argument("--from-start", action="store_true", help="start at first image instead of first unlabeled image")
89
    ap.add_argument("--start-index", type=int, default=0, help="1-based index to start from")
90
    ap.add_argument("--start-name", default="", help="exact image filename to start from")
91
    ap.add_argument("--manual", action="store_true", help="do not auto-save/advance on mouse release")
92
    args = ap.parse_args()
93
 
94
    args.labels.mkdir(parents=True, exist_ok=True)
95
    images = sorted([p for p in args.images.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png"}])
96
    if not images:
97
        console.print(f"[red]No images found in {args.images}[/red]")
98
        return 2
99
 
100
    idx = 0
101
    if args.start_name:
102
        found = next((i for i, p in enumerate(images) if p.name == args.start_name), None)
103
        if found is None:
104
            console.print(f"[red]--start-name not found: {args.start_name}[/red]")
105
            return 2
106
        idx = found
107
    elif args.start_index > 0:
108
        idx = min(len(images) - 1, max(0, args.start_index - 1))
109
    elif not args.from_start:
110
        for i, image_path in enumerate(images):
111
            if not (args.labels / f"{image_path.stem}.txt").exists():
112
                idx = i
113
                break
114
    win = "label ball"
115
 
116
    while 0 <= idx < len(images):
117
        path = images[idx]
118
        img = cv2.imread(str(path))
119
        if img is None:
120
            idx += 1
121
            continue
122
        h, w = img.shape[:2]
123
        label_path = args.labels / f"{path.stem}.txt"
124
        state = LabelState(img, args.scale)
125
        state.box = load_label(label_path, w, h)
126
 
127
        def on_mouse(event, x, y, flags, param):
128
            if event == cv2.EVENT_LBUTTONDOWN:
129
                state.drag_start = (x, y)
130
                state.drag_now = (x, y)
131
            elif event == cv2.EVENT_MOUSEMOVE and state.drag_start:
132
                state.drag_now = (x, y)
133
            elif event == cv2.EVENT_LBUTTONUP and state.drag_start:
134
                state.set_box_from_display(state.drag_start[0], state.drag_start[1], x, y)
135
                state.drag_start = None
136
                state.drag_now = None
137
                if state.box and not args.manual:
138
                    state.auto_advance = True
139
 
140
        cv2.namedWindow(win, cv2.WINDOW_NORMAL)
141
        cv2.setMouseCallback(win, on_mouse)
142
        while True:
143
            title = f"{idx + 1}/{len(images)} {path.name}  auto-save on mouse-up | s save | b back | c clear | q quit"
144
            cv2.imshow(win, draw(state, title))
145
            key = cv2.waitKey(16) & 0xFF
146
            if state.auto_advance:
147
                label_path.write_text(yolo_line(state.box, w, h), encoding="utf-8")
148
                idx += 1
149
                break
150
            if key in (ord("q"), 27):
151
                cv2.destroyAllWindows()
152
                return 0
153
            if key == ord("c"):
154
                state.box = None
155
                if label_path.exists():
156
                    label_path.unlink()
157
            if key in (ord("s"), 13, 32):
158
                if state.box:
159
                    label_path.write_text(yolo_line(state.box, w, h), encoding="utf-8")
160
                idx += 1
161
                break
162
            if key == ord("e"):
163
                label_path.write_text("", encoding="utf-8")
164
                idx += 1
165
                break
166
            if key == ord("n"):
167
                idx += 1
168
                break
169
            if key == ord("b"):
170
                idx = max(0, idx - 1)
171
                break
172
 
173
    cv2.destroyAllWindows()
174
    console.print("[green]Labeling complete.[/green]")
175
    return 0
176
 
177
if __name__ == "__main__":
178
    raise SystemExit(main())