Zion Boggan
repos/Oversight/oversight_core/watermark.py
zionboggan.com ↗
398 lines · python
History for this file →
1
"""
2
oversight_core.watermark
3
=======================
4
 
5
Per-recipient watermarking. The point is attribution after plaintext escape:
6
if a sealed file is decrypted and leaked, the recovered plaintext still contains
7
marks that identify WHICH recipient's copy it was.
8
 
9
This MVP ships three mark layers. Each is independently keyed, so an attacker
10
stripping one doesn't defeat the others. The `mark_id` is a random per-recipient
11
tag registered in the manifest - matching it in leaked content proves the source.
12
 
13
Layers:
14
  L1 (zero-width unicode stego):
15
      Embeds mark_id bits as ZWSP / ZWNJ / ZWJ in text content. Survives copy-paste
16
      and most format conversions. Defeated by "normalize/strip invisibles" passes.
17
 
18
  L2 (whitespace pattern):
19
      Encodes bits as trailing space vs tab at line endings. Survives more aggressive
20
      cleaning than L1 because linters often don't touch trailing whitespace in
21
      content-bearing fields.
22
 
23
  L3 (synonym rotation + punctuation):
24
      Semantic watermarking via synonym-class rotation (151 classes in v2) and
25
      punctuation-style fingerprinting (Oxford comma, em dash, curly quotes).
26
      Survives format conversion, invisible-char stripping, and whitespace
27
      normalization because the marks are in the words and punctuation chosen.
28
      Implementation in oversight_core.semantic; wired in here via apply_all.
29
 
30
Future (not in MVP):
31
  - Visual DCT-domain watermarks for images (robust to recompression + screenshot)
32
  - Layout perturbation for PDFs (micro-kerning, line-spacing)
33
  - Structural marks for code files (whitespace + comment ordering)
34
 
35
All mark IDs are random per-recipient. Decoder returns the first matching ID
36
from the registry - that's your attribution.
37
"""
38
 
39
from __future__ import annotations
40
 
41
import secrets
42
from typing import Iterable, Optional
43
 
44
 
45
ZW_SPACE = "\u200b"
46
ZW_NONJOIN = "\u200c"
47
ZW_JOIN = "\u200d"
48
ZW_ALL = (ZW_SPACE, ZW_NONJOIN, ZW_JOIN)
49
 
50
 
51
def _bits_of(data: bytes) -> list[int]:
52
    out = []
53
    for byte in data:
54
        for i in range(8):
55
            out.append((byte >> (7 - i)) & 1)
56
    return out
57
 
58
 
59
def _bytes_from_bits(bits: Iterable[int]) -> bytes:
60
    bits = list(bits)
61
    n = (len(bits) // 8) * 8
62
    bits = bits[:n]
63
    out = bytearray()
64
    for i in range(0, n, 8):
65
        b = 0
66
        for j in range(8):
67
            b = (b << 1) | (bits[i + j] & 1)
68
        out.append(b)
69
    return bytes(out)
70
 
71
 
72
def new_mark_id(n_bytes: int = 8) -> bytes:
73
    """A per-recipient mark ID. 8 bytes = 64 bits = plenty for attribution."""
74
    return secrets.token_bytes(n_bytes)
75
 
76
 
77
 
78
def embed_zw(text: str, mark_id: bytes, density: int = 40) -> str:
79
    """
80
    Embed mark_id into text as zero-width unicode characters.
81
    density = approx chars between mark insertions (so 1000-char doc gets 25 mark copies).
82
 
83
    Encoding: a frame of [ZW_JOIN] [bits of mark_id as ZWSP/ZWNJ] [ZW_JOIN].
84
    Multiple redundant frames are scattered through the text.
85
    """
86
    bits = _bits_of(mark_id)
87
    frame = ZW_JOIN + "".join(ZW_SPACE if b == 0 else ZW_NONJOIN for b in bits) + ZW_JOIN
88
 
89
    if len(text) < density:
90
        return text + frame
91
 
92
    out = []
93
    for i, ch in enumerate(text):
94
        out.append(ch)
95
        if i > 0 and i % density == 0:
96
            out.append(frame)
97
    return "".join(out)
98
 
99
 
100
def extract_zw(text: str, mark_len_bytes: int = 8) -> list[bytes]:
101
    """
102
    Recover all candidate mark_ids from zero-width marks in text.
103
    Returns a list (may have repeats if multiple frames survived).
104
    """
105
    marks = []
106
    expected_bits = mark_len_bytes * 8
107
    i = 0
108
    while i < len(text):
109
        if text[i] == ZW_JOIN:
110
            bits = []
111
            j = i + 1
112
            while j < len(text) and text[j] in (ZW_SPACE, ZW_NONJOIN):
113
                bits.append(0 if text[j] == ZW_SPACE else 1)
114
                j += 1
115
            if j < len(text) and text[j] == ZW_JOIN and len(bits) == expected_bits:
116
                marks.append(_bytes_from_bits(bits))
117
            i = j + 1
118
        else:
119
            i += 1
120
    return marks
121
 
122
 
123
 
124
def embed_ws(text: str, mark_id: bytes) -> str:
125
    """
126
    Encode bits as trailing space (bit 0) vs trailing tab (bit 1) on the first N lines.
127
    Non-destructive: only affects lines that end in the natural way.
128
    """
129
    bits = _bits_of(mark_id)
130
    lines = text.split("\n")
131
    out_lines = []
132
    bi = 0
133
    for line in lines:
134
        if bi < len(bits) and line.rstrip() == line:
135
            suffix = " " if bits[bi] == 0 else "\t"
136
            out_lines.append(line + suffix)
137
            bi += 1
138
        else:
139
            out_lines.append(line)
140
    return "\n".join(out_lines)
141
 
142
 
143
def extract_ws(text: str, mark_len_bytes: int = 8) -> Optional[bytes]:
144
    """Read the whitespace mark back out. Returns None if incomplete."""
145
    needed = mark_len_bytes * 8
146
    bits: list[int] = []
147
    for line in text.split("\n"):
148
        if line.endswith(" "):
149
            bits.append(0)
150
        elif line.endswith("\t"):
151
            bits.append(1)
152
        if len(bits) >= needed:
153
            break
154
    if len(bits) < needed:
155
        return None
156
    return _bytes_from_bits(bits[:needed])
157
 
158
 
159
 
160
try:
161
    from . import semantic as _semantic
162
    _L3_AVAILABLE = True
163
except ImportError:
164
    _L3_AVAILABLE = False
165
 
166
 
167
 
168
def extract_ws_partial(
169
    text: str, mark_len_bytes: int = 8
170
) -> tuple[Optional[bytes], float, int, int]:
171
    """
172
    Like extract_ws but returns partial results with confidence.
173
 
174
    Returns:
175
      (best_candidate, confidence, bits_recovered, bits_needed)
176
 
177
    If all bits are recovered, confidence = 1.0 and best_candidate is exact.
178
    If partial, best_candidate has recovered bits filled in and unknown bits
179
    set to 0, confidence = bits_recovered / bits_needed.
180
    """
181
    needed = mark_len_bytes * 8
182
    bits: list[int] = []
183
    for line in text.split("\n"):
184
        if line.endswith(" "):
185
            bits.append(0)
186
        elif line.endswith("\t"):
187
            bits.append(1)
188
        if len(bits) >= needed:
189
            break
190
 
191
    recovered = len(bits)
192
    if recovered == 0:
193
        return None, 0.0, 0, needed
194
 
195
    padded = bits[:needed] + [0] * max(0, needed - recovered)
196
    candidate = _bytes_from_bits(padded[:needed])
197
    confidence = min(recovered, needed) / needed
198
    return candidate, confidence, min(recovered, needed), needed
199
 
200
 
201
 
202
def apply_all(
203
    text: str,
204
    mark_id: bytes,
205
    *,
206
    include_l3: bool = False,
207
    l3_mode: str = "full",
208
) -> str:
209
    """
210
    Apply all available watermark layers to text.
211
 
212
    Layer order matters: L3 (synonym rotation) runs FIRST because it rewrites
213
    words. L2 (trailing whitespace) runs second. L1 (zero-width unicode) runs
214
    last because it inserts invisible characters that could fragment synonym
215
    words if applied earlier.
216
    """
217
    if include_l3 and _L3_AVAILABLE:
218
        from . import l3_policy
219
        t = l3_policy.apply_l3_safe(text, mark_id, mode=l3_mode)
220
    else:
221
        t = text
222
    t = embed_ws(t, mark_id)
223
    t = embed_zw(t, mark_id)
224
    return t
225
 
226
 
227
def recover_marks(text: str, mark_len_bytes: int = 8) -> dict:
228
    """
229
    Try every layer; return a dict of {layer: [candidate_mark_bytes]} for the registry
230
    to match against known recipient IDs.
231
    """
232
    return {
233
        "L1_zero_width": extract_zw(text, mark_len_bytes),
234
        "L2_whitespace": [m for m in [extract_ws(text, mark_len_bytes)] if m],
235
        "L3_synonyms": [],
236
    }
237
 
238
 
239
def verify_l3(
240
    text: str,
241
    candidate_mark_ids: list[bytes],
242
    threshold: float = 0.70,
243
) -> list[tuple[bytes, float, dict]]:
244
    """
245
    Test candidate mark_ids against the semantic marks in text.
246
 
247
    Returns a list of (mark_id, score, detail_dict) for candidates that
248
    score above the threshold. Results are sorted by score descending.
249
    """
250
    if not _L3_AVAILABLE:
251
        return []
252
 
253
    hits = []
254
    for mid in candidate_mark_ids:
255
        detail = _semantic.verify_semantic(text, mid)
256
        if detail["overall_match"]:
257
            hits.append((mid, detail["synonyms_score"], detail))
258
    hits.sort(key=lambda x: x[1], reverse=True)
259
    return hits
260
 
261
 
262
def recover_marks_v2(
263
    text: str,
264
    candidate_mark_ids: list[bytes] | None = None,
265
    mark_len_bytes: int = 8,
266
) -> dict:
267
    """
268
    Enhanced recovery with partial L2, L3 verification, and per-layer diagnostics.
269
 
270
    Returns a dict with:
271
      - layers: per-layer results with confidence
272
      - candidates: fused candidate list
273
      - diagnostics: human-readable per-layer status strings
274
    """
275
    l1_marks = extract_zw(text, mark_len_bytes)
276
    l1_unique = list(set(l1_marks))
277
 
278
    l2_candidate, l2_confidence, l2_bits, l2_needed = extract_ws_partial(
279
        text, mark_len_bytes
280
    )
281
    l2_marks = [l2_candidate] if l2_candidate and l2_confidence >= 0.5 else []
282
 
283
    l3_hits: list[tuple[bytes, float, dict]] = []
284
    if candidate_mark_ids and _L3_AVAILABLE:
285
        l3_hits = verify_l3(text, candidate_mark_ids)
286
 
287
    diagnostics = []
288
    if l1_unique:
289
        diagnostics.append(
290
            f"L1: {len(l1_marks)} frames found, "
291
            f"{len(l1_unique)} unique mark(s): "
292
            + ", ".join(m.hex() for m in l1_unique)
293
        )
294
    else:
295
        diagnostics.append(
296
            "L1: 0 zero-width frames found (invisible chars stripped?)"
297
        )
298
 
299
    if l2_confidence >= 1.0:
300
        diagnostics.append(
301
            f"L2: {l2_bits}/{l2_needed} bits recovered (100%), "
302
            f"mark: {l2_candidate.hex()}"
303
        )
304
    elif l2_confidence > 0:
305
        diagnostics.append(
306
            f"L2: {l2_bits}/{l2_needed} bits recovered "
307
            f"({l2_confidence:.0%} confidence), "
308
            f"partial candidate: {l2_candidate.hex()}"
309
        )
310
    else:
311
        diagnostics.append(
312
            "L2: 0 trailing whitespace marks found (whitespace stripped?)"
313
        )
314
 
315
    if not _L3_AVAILABLE:
316
        diagnostics.append("L3: semantic module not available")
317
    elif not candidate_mark_ids:
318
        diagnostics.append(
319
            "L3: no candidate mark_ids provided (query registry first)"
320
        )
321
    elif l3_hits:
322
        for mid, score, detail in l3_hits:
323
            diagnostics.append(
324
                f"L3: mark {mid.hex()} matched with score "
325
                f"{score:.2f} (synonyms) / "
326
                f"{detail['punctuation_hits']} (punctuation), "
327
                f"dict={detail['dict_version']}"
328
            )
329
    else:
330
        diagnostics.append(
331
            f"L3: {len(candidate_mark_ids)} candidate(s) tested, "
332
            "none matched above threshold"
333
        )
334
 
335
    all_candidates = _fuse_candidates(
336
        l1_unique, l2_candidate, l2_confidence, l3_hits
337
    )
338
 
339
    return {
340
        "layers": {
341
            "L1_zero_width": l1_unique,
342
            "L2_whitespace": l2_marks,
343
            "L2_confidence": l2_confidence,
344
            "L3_semantic": [(m, s) for m, s, _ in l3_hits],
345
        },
346
        "candidates": all_candidates,
347
        "diagnostics": diagnostics,
348
    }
349
 
350
 
351
def _fuse_candidates(
352
    l1_marks: list[bytes],
353
    l2_candidate: bytes | None,
354
    l2_confidence: float,
355
    l3_hits: list[tuple[bytes, float, dict]],
356
) -> list[tuple[bytes, float, str]]:
357
    """
358
    Multi-layer Bayesian fusion: combine evidence from all layers into a
359
    single ranked candidate list.
360
 
361
    Returns list of (mark_id, combined_score, evidence_summary).
362
 
363
    Scoring:
364
      - L1 exact match: 0.95 (high, but not 1.0 because of frame corruption)
365
      - L2 exact match: 0.90 (slightly lower, whitespace is fragile)
366
      - L2 partial: l2_confidence * 0.60 (scaled down for uncertainty)
367
      - L3 match: l3_score * 0.85 (probabilistic, weighted by synonym score)
368
 
369
    When multiple layers agree on the same mark_id, scores combine:
370
      combined = 1 - (1-s1)(1-s2)...(1-sN)
371
    This is a standard independence-assumption combination.
372
    """
373
    evidence: dict[bytes, list[tuple[float, str]]] = {}
374
 
375
    for m in l1_marks:
376
        evidence.setdefault(m, []).append((0.95, "L1"))
377
 
378
    if l2_candidate and l2_confidence >= 0.5:
379
        l2_score = min(l2_confidence, 1.0) * 0.90
380
        evidence.setdefault(l2_candidate, []).append((l2_score, "L2"))
381
 
382
    for m, s, _ in l3_hits:
383
        evidence.setdefault(m, []).append((s * 0.85, "L3"))
384
 
385
    results = []
386
    for mark_id, scores in evidence.items():
387
        if len(scores) == 1:
388
            combined = scores[0][0]
389
        else:
390
            combined = 1.0
391
            for s, _ in scores:
392
                combined *= (1.0 - s)
393
            combined = 1.0 - combined
394
        layers_hit = "+".join(lbl for _, lbl in scores)
395
        results.append((mark_id, combined, layers_hit))
396
 
397
    results.sort(key=lambda x: x[1], reverse=True)
398
    return results