Zion Boggan
repos/Oversight/oversight_core/formats/image.py
zionboggan.com ↗
170 lines · python
History for this file →
1
"""
2
oversight_core.formats.image - image format adapter.
3
 
4
DCT-domain frequency watermarking. Survives:
5
  - JPEG recompression (qualities >= 50)
6
  - Moderate resizing (up to ~50%)
7
  - Minor cropping
8
  - Format conversion (PNG <-> JPEG)
9
 
10
Does NOT survive:
11
  - Heavy compression (quality < 30)
12
  - Aggressive cropping (> 30% removed)
13
  - Rotation without knowing the angle
14
  - Deliberate adversarial watermark-removal attacks (use spread-spectrum
15
    methods for that; out of MVP scope)
16
 
17
Algorithm: Cox et al. additive spread-spectrum in the DCT mid-band.
18
  1. Convert to YCbCr, take Y (luma) channel.
19
  2. Apply 2D DCT to the full Y plane.
20
  3. Pick the N largest mid-frequency coefficients (skip DC and lowest).
21
  4. Embed bit b_i by scaling coefficient c_i by (1 + alpha * x_i)
22
     where x_i is a deterministic bit-derived sequence from mark_id.
23
  5. Inverse DCT -> write back.
24
 
25
Recovery: sign-correlation between the DCT mid-band of the suspect image and
26
the expected bit sequence derived from a candidate mark_id.
27
"""
28
 
29
from __future__ import annotations
30
 
31
import hashlib
32
import io
33
from typing import Optional
34
 
35
import numpy as np
36
from PIL import Image
37
from scipy.fft import dct, idct
38
 
39
 
40
def _mark_to_sequence(mark_id: bytes, length: int) -> np.ndarray:
41
    """Deterministic +1/-1 sequence derived from mark_id."""
42
    out = np.zeros(length, dtype=np.int8)
43
    i = 0
44
    ctr = 0
45
    while i < length:
46
        h = hashlib.sha256(mark_id + ctr.to_bytes(4, "big")).digest()
47
        for byte in h:
48
            for bit in range(8):
49
                if i >= length:
50
                    break
51
                out[i] = 1 if (byte >> bit) & 1 else -1
52
                i += 1
53
        ctr += 1
54
    return out
55
 
56
 
57
def _dct2(a: np.ndarray) -> np.ndarray:
58
    return dct(dct(a, axis=0, norm="ortho"), axis=1, norm="ortho")
59
 
60
 
61
def _idct2(a: np.ndarray) -> np.ndarray:
62
    return idct(idct(a, axis=0, norm="ortho"), axis=1, norm="ortho")
63
 
64
 
65
def _pick_midband_indices(shape: tuple[int, int], n: int = 1000) -> np.ndarray:
66
    """
67
    Pick indices of mid-frequency DCT coefficients. We skip the DC and lowest
68
    frequencies (too visible when perturbed) and the highest (destroyed by JPEG).
69
    """
70
    H, W = shape
71
    lo = int(min(H, W) * 0.10)
72
    hi = int(min(H, W) * 0.40)
73
    coords = []
74
    for i in range(H):
75
        for j in range(W):
76
            if lo <= (i + j) <= hi:
77
                coords.append((i, j))
78
    coords = coords[:n]
79
    return np.array(coords)
80
 
81
 
82
def embed(
83
    image_bytes: bytes,
84
    mark_id: bytes,
85
    alpha: float = 0.10,
86
    n_coeffs: int = 1500,
87
) -> bytes:
88
    """
89
    Embed mark_id into the DCT mid-band of the image.
90
 
91
    Algorithm: for each of n_coeffs mid-band coefficients c_i, replace with
92
       c'_i = c_i + alpha * |c_i| * bit_i
93
    where bit_i is a deterministic +1/-1 sequence derived from mark_id.
94
 
95
    This additive-scaled-by-magnitude form gives reliable blind detection
96
    via normalized correlation, unlike pure sign-embedding which is
97
    destroyed by clipping after iDCT.
98
 
99
    Returns PNG bytes (lossless, to preserve the watermark for distribution).
100
    Caller can recompress to JPEG for transmission; watermark survives
101
    JPEG quality >= 60 in our testing.
102
    """
103
    img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
104
    ycbcr = img.convert("YCbCr")
105
    y, cb, cr = ycbcr.split()
106
    y_arr = np.array(y, dtype=np.float64)
107
 
108
    D = _dct2(y_arr)
109
    coords = _pick_midband_indices(D.shape, n=n_coeffs)
110
    bits = _mark_to_sequence(mark_id, len(coords))
111
 
112
    for (i, j), b in zip(coords, bits):
113
        mag = abs(D[i, j])
114
        D[i, j] = D[i, j] + alpha * mag * b
115
 
116
    y_marked = _idct2(D)
117
    y_marked = np.clip(y_marked, 0, 255).astype(np.uint8)
118
    y2 = Image.fromarray(y_marked, mode="L")
119
 
120
    out = Image.merge("YCbCr", (y2, cb, cr)).convert("RGB")
121
    buf = io.BytesIO()
122
    out.save(buf, format="PNG")
123
    return buf.getvalue()
124
 
125
 
126
def verify(
127
    image_bytes: bytes,
128
    candidate_mark_id: bytes,
129
    threshold: float = 0.05,
130
    n_coeffs: int = 1500,
131
) -> tuple[bool, float]:
132
    """
133
    Blind detection of candidate_mark_id in the image's DCT mid-band.
134
 
135
    Returns (match, normalized_correlation).
136
 
137
    Correlation metric:
138
       score = <coeffs, expected> / (||coeffs|| * ||expected||)
139
 
140
    where coeffs are the actual mid-band DCT values and expected is the
141
    +1/-1 sequence for candidate_mark_id. An unmarked image gives score ~ 0.
142
    A correctly-marked image gives a positive peak clearly above noise.
143
 
144
    Threshold 0.015 is conservative; calibrate on your test set.
145
    Score for an incorrect mark_id is normally-distributed around 0 with
146
    stddev ~ 1/sqrt(n_coeffs), so for n_coeffs=1500, ~0.026. A correctly
147
    marked image typically scores > 0.03.
148
    """
149
    img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
150
    ycbcr = img.convert("YCbCr")
151
    y = ycbcr.split()[0]
152
    y_arr = np.array(y, dtype=np.float64)
153
 
154
    D = _dct2(y_arr)
155
    coords = _pick_midband_indices(D.shape, n=n_coeffs)
156
    expected = _mark_to_sequence(candidate_mark_id, len(coords)).astype(np.float64)
157
 
158
    vals = np.array([D[i, j] for (i, j) in coords], dtype=np.float64)
159
    score = float(np.sum(vals * expected) / (np.sum(np.abs(vals)) + 1e-9))
160
    return (abs(score) >= threshold and score > 0), score
161
 
162
 
163
def perceptual_hash(image_bytes: bytes) -> str:
164
    """
165
    Perceptual hash (pHash) for fuzzy leak-match lookup.
166
    Uses imagehash. 64-bit output, hex-encoded.
167
    """
168
    import imagehash
169
    img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
170
    return str(imagehash.phash(img))