Commit cc327b1f by nzy

first commit

parents
# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
# Virtual environment
.venv/
# Testing
.pytest_cache/
.coverage
htmlcov/
# OS
.DS_Store
Thumbs.db
# IDE
.vscode/
.idea/
This diff is collapsed. Click to expand it.
__version__ = "0.1.0"
"""Allow `python -m data_tool` invocation."""
from .main import main
import sys
sys.exit(main())
"""Document class wrapping raw text with precomputed n-grams for signal computation."""
import re
from typing import Optional, Tuple, Callable
from .utils import normalize, form_ngrams, TextSlice
def _compute_ngrams(seq, n):
return tuple(form_ngrams(iter(seq), n))
def _split_paragraphs(
text: str, normalizer: Callable[[str], str], remove_empty: bool = True
) -> Tuple[TextSlice, ...]:
"""Split text into paragraphs. Adapted from dolma."""
slices = tuple(
TextSlice(normalizer(text[m.start():m.end()]), m.start(), m.end())
for m in re.finditer(r"([^\n]*\n|[^\n]+$)", text)
)
if remove_empty:
slices = tuple(s for s in slices if s.text.strip())
return slices
class Document:
"""Wraps a text string and precomputes normalized forms, lines, words, and n-grams."""
__slots__ = (
"_raw_content",
"_normalized_content",
"_raw_lines",
"_normalized_lines",
"_raw_words",
"_normalized_words",
"_norm_2grams",
"_norm_3grams",
"_norm_4grams",
"_norm_5grams",
"_norm_6grams",
"_norm_7grams",
"_norm_8grams",
"_norm_9grams",
"_norm_10grams",
"source_file",
)
def __init__(self, text: str, source_file: Optional[str] = None):
self._raw_content = text
self._normalized_content = normalize(text)
self.source_file = source_file # filename from JSONL record, used for per-language signals
self._raw_lines = _split_paragraphs(text, normalizer=lambda x: x, remove_empty=False)
self._normalized_lines = _split_paragraphs(text, normalizer=normalize, remove_empty=False)
self._raw_words = tuple(self._raw_content.split())
self._normalized_words = tuple(self._normalized_content.split())
self._norm_2grams = _compute_ngrams(self._normalized_words, 2)
self._norm_3grams = _compute_ngrams(self._normalized_words, 3)
self._norm_4grams = _compute_ngrams(self._normalized_words, 4)
self._norm_5grams = _compute_ngrams(self._normalized_words, 5)
self._norm_6grams = _compute_ngrams(self._normalized_words, 6)
self._norm_7grams = _compute_ngrams(self._normalized_words, 7)
self._norm_8grams = _compute_ngrams(self._normalized_words, 8)
self._norm_9grams = _compute_ngrams(self._normalized_words, 9)
self._norm_10grams = _compute_ngrams(self._normalized_words, 10)
def __len__(self):
return len(self._raw_content)
@property
def raw_content(self):
return self._raw_content
@property
def normalized_content(self):
return self._normalized_content
@property
def raw_lines(self) -> Tuple[TextSlice, ...]:
return self._raw_lines
@property
def normalized_lines(self) -> Tuple[TextSlice, ...]:
return self._normalized_lines
@property
def raw_words(self) -> Tuple[str, ...]:
return self._raw_words
@property
def normalized_words(self) -> Tuple[str, ...]:
return self._normalized_words
@property
def num_raw_words(self):
return len(self._raw_words)
@property
def num_normalized_words(self):
return len(self._normalized_words)
@property
def norm_2grams(self):
return self._norm_2grams
@property
def norm_3grams(self):
return self._norm_3grams
@property
def norm_4grams(self):
return self._norm_4grams
@property
def norm_5grams(self):
return self._norm_5grams
@property
def norm_6grams(self):
return self._norm_6grams
@property
def norm_7grams(self):
return self._norm_7grams
@property
def norm_8grams(self):
return self._norm_8grams
@property
def norm_9grams(self):
return self._norm_9grams
@property
def norm_10grams(self):
return self._norm_10grams
def get_ngrams(self, n: int) -> Tuple[Tuple[str, ...], ...]:
"""Get cached n-grams for the given size (2-10)."""
return getattr(self, f"norm_{n}grams")
"""Threshold-based document filtering using quality signal values."""
import json
import statistics
from typing import Dict, List, Optional, Tuple, Union
DEFAULT_THRESHOLDS: Dict[str, Dict[str, float]] = {
# Content
"rps_doc_ldnoobw_words": {"max": 0},
"rps_doc_lorem_ipsum": {"max": 0.0},
"rps_doc_curly_bracket": {"max": 0.1},
"rps_doc_stop_word_fraction": {"min": 0.01},
# Natural language
"rps_doc_word_count": {"min": 50},
"rps_doc_num_sentences": {"min": 3},
"rps_doc_mean_word_length": {"min": 2.0, "max": 12.0},
"rps_doc_symbol_to_word_ratio": {"max": 0.1},
"rps_doc_frac_lines_end_with_ellipsis": {"max": 0.3},
"rps_doc_frac_no_alph_words": {"max": 0.5},
"rps_doc_frac_unique_words": {"min": 0.1},
"rps_doc_unigram_entropy": {"min": 1.0},
"rps_doc_frac_all_caps_words": {"max": 0.4},
# Repetition
"rps_doc_frac_chars_top_2gram": {"max": 0.3},
"rps_doc_frac_chars_top_3gram": {"max": 0.2},
"rps_doc_frac_chars_top_4gram": {"max": 0.15},
"rps_doc_frac_chars_dupe_5grams": {"max": 0.3},
"rps_doc_frac_chars_dupe_6grams": {"max": 0.25},
"rps_doc_frac_chars_dupe_7grams": {"max": 0.2},
"rps_doc_frac_chars_dupe_8grams": {"max": 0.15},
"rps_doc_frac_chars_dupe_9grams": {"max": 0.1},
"rps_doc_frac_chars_dupe_10grams": {"max": 0.1},
# Line-level (aggregated via mean)
"rps_lines_javascript_counts": {"max": 0},
"rps_lines_ending_with_terminal_punctution_mark": {"min": 0.5},
"rps_lines_num_words": {"min": 5},
"rps_lines_uppercase_letter_fraction": {"max": 0.3},
"rps_lines_numerical_chars_fraction": {"max": 0.3},
"rps_lines_start_with_bulletpoint": {"max": 0.9},
# Code quality
"rps_doc_code_keyword_density": {"max": 0.3},
"rps_doc_html_tag_density": {"max": 0.1},
"rps_doc_code_line_fraction": {"max": 0.3},
# Garbled characters
"rps_doc_garbled_char_frac": {"max": 0.15},
}
# Signals that return per-line lists and need aggregation
_LINE_SIGNALS = frozenset({
"rps_lines_javascript_counts",
"rps_lines_ending_with_terminal_punctution_mark",
"rps_lines_num_words",
"rps_lines_uppercase_letter_fraction",
"rps_lines_numerical_chars_fraction",
"rps_lines_start_with_bulletpoint",
})
def _aggregate_line_signal(values: List[float], agg: str = "mean") -> float:
"""Aggregate a per-line signal into a single document-level value."""
if not values:
return 0
if agg == "mean":
return statistics.mean(values)
if agg == "max":
return max(values)
if agg == "min":
return min(values)
if agg == "sum":
return sum(values)
raise ValueError(f"Unknown aggregation: {agg}")
def check_document(
signal_values: Dict[str, Optional[Union[float, List[float]]]],
thresholds: Optional[Dict[str, Dict[str, float]]] = None,
line_agg: str = "mean",
) -> Tuple[bool, Dict[str, str]]:
"""Check whether a document passes all threshold filters.
Returns (keep, reasons) where reasons maps signal_name -> failure description.
A signal value of None is treated as passing (the signal was not computable).
All failing signals are collected — not just the first one.
"""
if thresholds is None:
thresholds = DEFAULT_THRESHOLDS
reasons: Dict[str, str] = {}
for name, rules in thresholds.items():
value = signal_values.get(name)
if value is None:
continue
if name in _LINE_SIGNALS and isinstance(value, list):
value = _aggregate_line_signal(value, agg=line_agg)
if not isinstance(value, (int, float)):
continue
if "min" in rules and value < rules["min"]:
reasons[name] = f"value={value} < min={rules['min']}"
if "max" in rules and value > rules["max"]:
reasons[name] = f"value={value} > max={rules['max']}"
if reasons:
return False, reasons
return True, {}
def load_thresholds(path: str) -> Dict[str, Dict[str, float]]:
"""Load custom thresholds from a JSON file and merge with defaults."""
with open(path, "r") as f:
custom = json.load(f)
merged = dict(DEFAULT_THRESHOLDS)
merged.update(custom)
return merged
"""Strip common web boilerplate: navigation, cookie notices, footers, social bars, etc."""
import re
from typing import List, Optional, Tuple
# ---------------------------------------------------------------------------
# Boilerplate patterns — each is a (regex, category, severity) tuple.
# Severity: 3 = strong signal (always strip), 2 = medium, 1 = weak (strip
# only at document boundaries).
# ---------------------------------------------------------------------------
_BOILERPLATE_RULES: List[Tuple[re.Pattern, str, int]] = [
# -- Strong signals (severity 3): almost certainly boilerplate --
(re.compile(r"skip\s+to\s+(main\s+)?content", re.I), "accessibility", 3),
(re.compile(r"cookie\s*(?:consent|notice|policy|settings|preferences)", re.I), "cookies", 3),
(re.compile(r"accept\s+(?:all\s+)?cookies", re.I), "cookies", 3),
(re.compile(r"we\s+use\s+cookies", re.I), "cookies", 3),
(re.compile(r"this\s+site\s+uses?\s+cookies", re.I), "cookies", 3),
(re.compile(r"gdpr", re.I), "cookies", 3),
(re.compile(r"ccpa", re.I), "privacy", 3),
(re.compile(r"share\s+(?:on|this|to)\s*(?:facebook|twitter|x\b|linkedin|pinterest|reddit|whatsapp|telegram|email)", re.I), "social", 3),
(re.compile(r"(?:facebook|twitter|x\b|linkedin|pinterest|reddit)\s*share", re.I), "social", 3),
(re.compile(r"follow\s+(?:us|me)\s+on\s*(?:facebook|twitter|x\b|instagram|linkedin|youtube)", re.I), "social", 3),
(re.compile(r"copyright\s*©?\s*\d{4}", re.I), "copyright", 3),
(re.compile(r"all\s+rights?\s+reserved", re.I), "copyright", 3),
(re.compile(r"terms\s+(?:of\s+(?:service|use)|and\s+conditions)", re.I), "legal", 3),
(re.compile(r"privacy\s*(?:policy|notice|statement)", re.I), "legal", 3),
(re.compile(r"advertisement|sponsored\s+(?:content|by|post)|ad\s+choices", re.I), "ads", 3),
(re.compile(r"subscribe\s+(?:to|now|for|today)", re.I), "subscribe", 3),
(re.compile(r"(?:sign|log)\s*(?:in|out|up)\s*(?:/|or|with|to)", re.I), "auth", 3),
(re.compile(r"create\s+(?:an?\s+)?account", re.I), "auth", 3),
(re.compile(r"(?:leave|post)\s+a\s+comment", re.I), "comments", 3),
(re.compile(r"\d+\s+comments?", re.I), "comments", 3),
(re.compile(r"posted\s+(?:by|on)\b", re.I), "comments", 3),
(re.compile(r"print\s+(?:this\s+(?:page|article)|friendly)", re.I), "print", 3),
(re.compile(r"email\s+this\s+(?:page|article|story)", re.I), "share", 3),
# -- Medium signals (severity 2): likely boilerplate --
(re.compile(r"main\s+navigation", re.I), "navigation", 2),
(re.compile(r"breadcrumb", re.I), "navigation", 2),
(re.compile(r"site\s+map", re.I), "navigation", 2),
(re.compile(r"(?:click|tap)\s+here\s+to\s+(?:read|view|learn|see|continue)", re.I), "navigation", 2),
(re.compile(r"read\s+(?:more|full|on)\b", re.I), "navigation", 2),
(re.compile(r"related\s+(?:articles?|posts?|stories?|content|links?)", re.I), "related", 2),
(re.compile(r"you\s+(?:may|might)\s+(?:also\s+)?(?:like|enjoy|be\s+interested)", re.I), "related", 2),
(re.compile(r"recommended\s+(?:for\s+you|articles?|posts?)", re.I), "related", 2),
(re.compile(r"previous\s+(?:article|post|story|page)", re.I), "navigation", 2),
(re.compile(r"next\s+(?:article|post|story|page)", re.I), "navigation", 2),
(re.compile(r"back\s+to\s+top", re.I), "navigation", 2),
(re.compile(r"scroll\s+(?:down|up)\s+(?:for|to)", re.I), "navigation", 2),
(re.compile(r"newsletter(?:\s+sign\s*up)?", re.I), "subscribe", 2),
(re.compile(r"(?:free\s+)?newsletter", re.I), "subscribe", 2),
(re.compile(r"register\s+(?:now|today|free|here)", re.I), "auth", 2),
(re.compile(r"©\s*\d{4}", re.I), "copyright", 2),
(re.compile(r"contact\s+us", re.I), "footer", 2),
(re.compile(r"about\s+(?:us|this\s+(?:site|blog))", re.I), "footer", 2),
(re.compile(r"disclaimer", re.I), "legal", 2),
(re.compile(r"table\s+of\s+contents?", re.I), "navigation", 2),
# -- Weak signals (severity 1): may be boilerplate at document boundaries --
(re.compile(r"^(?:home|about|contact|search|menu|help|faq)\s*$", re.I), "nav-item", 1),
(re.compile(r"^(?:previous|next|back|forward)\s*$", re.I), "nav-item", 1),
(re.compile(r"^(?:page\s*\d+)\s*$", re.I), "pagination", 1),
(re.compile(r"^(?:1|2|3|4|5|6|7|8|9|10|11|12|13|14|15)\s*$", re.I), "pagination", 1),
(re.compile(r"^\s*\[.*\]\s*$", re.I), "nav-item", 1),
]
# Lines shorter than this that match any boilerplate rule are always stripped
_MIN_BOILERPLATE_LINE_LEN = 3
# Maximum fraction of lines that can be stripped before we stop (safety valve)
_MAX_STRIP_FRACTION = 0.30
def _boilerplate_score(line: str) -> Tuple[int, str]:
"""Return (severity, category) for a line, or (0, '') if not boilerplate."""
stripped = line.strip()
if len(stripped) <= _MIN_BOILERPLATE_LINE_LEN:
return 0, ""
for pattern, category, severity in _BOILERPLATE_RULES:
if pattern.search(stripped):
return severity, category
return 0, ""
def clean_headers(text: str) -> str:
"""Remove common web boilerplate headers, footers, and chrome from text.
Strips boilerplate lines aggressively from the top and bottom of the
document, and removes strong boilerplate lines from the middle.
Returns the cleaned text.
"""
if not text:
return text
lines = text.split("\n")
if len(lines) <= 3:
return text
scored = [_boilerplate_score(line) for line in lines]
keep = [True] * len(lines)
# -- Phase 1: strip header boilerplate (top-down until we hit content) --
header_strip_count = 0
for i in range(len(lines)):
severity, _ = scored[i]
if severity >= 2: # medium+ signal at the top is boilerplate
keep[i] = False
header_strip_count += 1
elif severity >= 1 and header_strip_count > 0:
# weak signal after other boilerplate is also stripped
keep[i] = False
header_strip_count += 1
elif len(lines[i].strip()) > 30 and severity == 0:
# found a substantial non-boilerplate line — stop stripping header
break
else:
# short/empty non-boilerplate line — keep it but continue scanning
pass
# -- Phase 2: strip footer boilerplate (bottom-up until we hit content) --
footer_strip_count = 0
for i in range(len(lines) - 1, -1, -1):
if not keep[i]:
continue
severity, _ = scored[i]
if severity >= 2:
keep[i] = False
footer_strip_count += 1
elif severity >= 1 and footer_strip_count > 0:
keep[i] = False
footer_strip_count += 1
elif len(lines[i].strip()) > 30 and severity == 0:
break
else:
pass
# -- Phase 3: strip strong boilerplate lines from the middle --
for i in range(len(lines)):
if not keep[i]:
continue
severity, _ = scored[i]
if severity >= 3:
# Only strip strong signals from the middle if they're short lines
# or clearly standalone boilerplate
if len(lines[i].strip()) < 120:
keep[i] = False
# Safety valve: don't strip too much
stripped_frac = (len(lines) - sum(keep)) / len(lines)
if stripped_frac > _MAX_STRIP_FRACTION:
return text
cleaned = "\n".join(line for line, k in zip(lines, keep) if k)
return cleaned
def is_mostly_boilerplate(text: str) -> bool:
"""Check if a document is mostly boilerplate (should be filtered entirely)."""
if not text:
return True
lines = text.split("\n")
if len(lines) < 3:
return False
scored = [_boilerplate_score(line) for line in lines]
boilerplate_lines = sum(1 for s, _ in scored if s >= 2)
return boilerplate_lines / len(lines) > 0.5
"""LSH-based near-deduplication using MinHash banded signatures and Union-Find."""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from .minhash import MinHash
logger = logging.getLogger(__name__)
class UnionFind:
"""Disjoint-set data structure with path compression and union by rank."""
__slots__ = ("_parent", "_rank")
def __init__(self):
self._parent: Dict[int, int] = {}
self._rank: Dict[int, int] = {}
def add(self, x: int) -> None:
if x not in self._parent:
self._parent[x] = x
self._rank[x] = 0
def find(self, x: int) -> int:
root = x
while self._parent[root] != root:
root = self._parent[root]
while self._parent[x] != root:
nxt = self._parent[x]
self._parent[x] = root
x = nxt
return root
def union(self, x: int, y: int) -> None:
self.add(x)
self.add(y)
rx = self.find(x)
ry = self.find(y)
if rx == ry:
return
if self._rank[rx] < self._rank[ry]:
self._parent[rx] = ry
elif self._rank[rx] > self._rank[ry]:
self._parent[ry] = rx
else:
self._parent[ry] = rx
self._rank[rx] += 1
def components(self) -> Dict[int, List[int]]:
groups: Dict[int, List[int]] = defaultdict(list)
for node in self._parent:
groups[self.find(node)].append(node)
return dict(groups)
def _build_edges(
records: List[Dict[str, Any]],
minhash: MinHash,
text_key: str = "text",
) -> List[Tuple[int, int]]:
"""Build candidate edges from banded MinHash signatures."""
# bucket: (band_idx, band_hash) -> [doc indices]
buckets: Dict[Tuple[int, bytes], List[int]] = defaultdict(list)
for idx, record in enumerate(records):
tokens = record[text_key].split()
sig = minhash.compute_banded_signatures(tuple(tokens))
if sig is None:
continue
for band_idx, band_hash in enumerate(sig):
buckets[(band_idx, band_hash)].append(idx)
edges = []
for doc_indices in buckets.values():
if len(doc_indices) <= 1:
continue
# add edges between all pairs in the bucket
n = len(doc_indices)
for i in range(n):
for j in range(i + 1, n):
edges.append((doc_indices[i], doc_indices[j]))
return edges
def run_lsh(
records: List[Dict[str, Any]],
minhash: MinHash,
text_key: str = "text",
) -> List[Dict[str, Any]]:
"""Run LSH deduplication and return one representative per cluster.
For each connected component (cluster of near-duplicates), the first
document by index is kept.
"""
if len(records) <= 1:
return list(records)
edges = _build_edges(records, minhash, text_key=text_key)
logger.info(f"LSH: built {len(edges)} candidate edges from {len(records)} docs")
if not edges:
logger.info("LSH: no near-duplicate edges found, all documents kept")
return list(records)
uf = UnionFind()
for a, b in edges:
uf.union(a, b)
comps = uf.components()
kept_indices = set()
removed_count = 0
for root, members in comps.items():
# keep the first document (by original index) in each cluster
keeper = min(members)
kept_indices.add(keeper)
removed_count += len(members) - 1
# also keep documents with no edges (not in any component)
for idx in range(len(records)):
if idx not in uf._parent:
kept_indices.add(idx)
logger.info(f"LSH: kept {len(kept_indices)} docs, removed {removed_count} near-duplicates")
return [records[i] for i in sorted(kept_indices)]
"""MinHash banded signature computation for LSH near-deduplication.
Uses datasketch for MinHash computation.
"""
import hashlib
from typing import List, Optional, Tuple
from datasketch import MinHash as DatasketchMinHash
from .utils import form_ngrams
def optimal_param(
threshold: float,
num_perm: int,
false_positive_weight: float = 0.5,
false_negative_weight: float = 0.5,
) -> Tuple[int, int]:
"""Compute optimal (bands, rows) for a given similarity threshold.
Minimizes the weighted sum of false-positive and false-negative probabilities
via numerical integration over the similarity distribution.
"""
from scipy.integrate import quad as integrate
def _false_positive_probability(b, r):
def proba(s):
return 1 - (1 - s ** float(r)) ** float(b)
a, *_ = integrate(proba, 0.0, threshold)
return a
def _false_negative_probability(b, r):
def proba(s):
return 1 - (1 - (1 - s ** float(r)) ** float(b))
a, *_ = integrate(proba, threshold, 1.0)
return a
min_error = float("inf")
opt = (0, 0)
for b in range(1, num_perm + 1):
max_r = int(num_perm / b)
for r in range(1, max_r + 1):
fp = _false_positive_probability(b, r)
fn = _false_negative_probability(b, r)
err = fp * false_positive_weight + fn * false_negative_weight
if err < min_error:
min_error = err
opt = (b, r)
return opt
class MinHash:
"""Computes banded MinHash signatures for LSH near-deduplication."""
def __init__(
self,
ngram_size: int = 5,
num_permutations: int = 128,
similarity_threshold: float = 0.8,
seed: int = 42,
):
self._ngram_size = ngram_size
self._num_permutations = num_permutations
self._similarity_threshold = similarity_threshold
self._seed = seed
self._bands, self._rows = optimal_param(similarity_threshold, num_permutations)
self._hashrange = [(i * self._rows, (i + 1) * self._rows) for i in range(self._bands)]
self._checksum = hashlib.sha256(
f"datasketch-{ngram_size}-{num_permutations}-{seed}".encode()
).hexdigest()
@property
def ngram_size(self):
return self._ngram_size
@property
def num_permutations(self):
return self._num_permutations
@property
def similarity_threshold(self):
return self._similarity_threshold
@property
def checksum(self):
return self._checksum
@property
def num_bands(self):
return self._bands
def compute_banded_signatures(self, tokens: Tuple[str, ...]) -> Optional[List[bytes]]:
"""Compute banded minhash signatures. Returns None if too few tokens."""
if len(tokens) < self._ngram_size:
return None
m = DatasketchMinHash(num_perm=self._num_permutations, seed=self._seed)
for ngram in form_ngrams(iter(tokens), self._ngram_size):
m.update(" ".join(ngram).encode("utf-8"))
hashvalues = m.hashvalues
return [
bytes(hashvalues[start:end].data)
for start, end in self._hashrange
]
"""Utility functions: text normalization, n-grams, JSONL I/O, hashing."""
import gzip
import hashlib
import json
import re
import string
import struct
import unicodedata
from dataclasses import dataclass
from typing import Iterator, Tuple, Generator, Any, Dict
__all__ = [
"normalize",
"form_ngrams",
"TextSlice",
"read_jsonl",
"write_jsonl",
"sha1_hash32",
]
_TRANSLATION_TABLE_PUNCTUATION = str.maketrans("", "", string.punctuation)
def normalize(
text: str,
remove_punct: bool = True,
lowercase: bool = True,
nfd_unicode: bool = True,
white_space: bool = True,
) -> str:
"""Normalize text by lowercasing, removing punctuation, normalizing whitespace."""
if remove_punct:
text = text.translate(_TRANSLATION_TABLE_PUNCTUATION)
if lowercase:
text = text.lower()
if white_space:
text = text.strip()
text = re.sub(r"\s+", " ", text)
if nfd_unicode:
text = unicodedata.normalize("NFD", text)
return text
def form_ngrams(sequence, n):
"""Yield n-grams from an iterator as tuples."""
history = []
while n > 1:
try:
next_item = next(sequence)
except StopIteration:
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
@dataclass
class TextSlice:
text: str
start: int
end: int
def __len__(self):
return len(self.text)
def read_jsonl(filepath: str) -> Generator[Dict[str, Any], None, None]:
"""Read a .jsonl or .jsonl.gz file, yielding parsed dicts line by line."""
open_fn = gzip.open if filepath.endswith(".gz") else open
with open_fn(filepath, "rt", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
yield json.loads(line)
def write_jsonl(filepath: str, records) -> None:
"""Write an iterable of dicts to a .jsonl or .jsonl.gz file."""
open_fn = gzip.open if filepath.endswith(".gz") else open
with open_fn(filepath, "wt", encoding="utf-8") as f:
for record in records:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def sha1_hash32(data: bytes) -> int:
"""Return a 32-bit integer hash from the first 4 bytes of SHA1."""
return struct.unpack("<I", hashlib.sha1(data).digest()[:4])[0]
[project]
name = "data_tool"
version = "0.1.0"
description = "Quality-signal filtering and MinHash LSH near-deduplication for text datasets."
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"numpy>=1.24.0",
"scipy>=1.10.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0",
]
[project.scripts]
data_tool = "data_tool.main:main"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
[tool.pytest.ini_options]
testpaths = ["tests"]
"""Shared test fixtures."""
import json
import tempfile
from pathlib import Path
from typing import Any, Dict, List
import pytest
from data_tool.document import Document
@pytest.fixture
def good_text() -> str:
return (
"The quick brown fox jumps over the lazy dog. "
"This is a proper sentence with enough words to pass all quality filters. "
"It contains meaningful English text that should be kept in the dataset for training language models. "
"Machine learning is a field of inquiry devoted to understanding and building methods that learn from data. "
"These methods have been proven to be extremely valuable in a wide variety of applications including "
"natural language processing, computer vision, and speech recognition."
)
@pytest.fixture
def lorem_text() -> str:
return "lorem ipsum dolor sit amet lorem ipsum consectetur adipiscing elit"
@pytest.fixture
def code_text() -> str:
return "function foo() { var x = 1; if (x > 0) { return x; } } " * 5
@pytest.fixture
def empty_text() -> str:
return ""
@pytest.fixture
def good_doc(good_text: str) -> Document:
return Document(good_text)
@pytest.fixture
def lorem_doc(lorem_text: str) -> Document:
return Document(lorem_text)
@pytest.fixture
def code_doc(code_text: str) -> Document:
return Document(code_text)
@pytest.fixture
def empty_doc() -> Document:
return Document("")
@pytest.fixture
def tmp_jsonl() -> str:
"""Create a temporary JSONL file path that persists for the test."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
pass
path = f.name
yield path
Path(path).unlink(missing_ok=True)
@pytest.fixture
def signal_funcs():
from data_tool.quality_signals import build_signal_functions
return build_signal_functions(language="en")
@pytest.fixture
def near_duplicate_docs() -> List[Dict[str, Any]]:
base = (
"The cat sat on the mat and watched the birds fly by the window. "
"It was a sunny day and the garden was full of flowers blooming in the spring sunshine. "
"The gentle breeze carried the scent of jasmine through the open window as the afternoon wore on. "
)
return [
{"file": "doc_a.txt", "text": base + "The cat purred softly while watching the birds."},
{"file": "doc_b.txt", "text": base + "The cat purred softly while gazing at the sparrows."},
{"file": "doc_c.txt", "text": base + "A gentle hum came from the sleeping feline."},
{"file": "unique.txt", "text": "Completely different text about quantum mechanics and particle physics in the early universe. " * 5},
]
"""End-to-end CLI tests."""
import json
import os
import subprocess
import sys
import tempfile
from pathlib import Path
import pytest
from data_tool.main import main
GOOD_TEXT = (
"The quick brown fox jumps over the lazy dog. "
"This is a proper sentence with enough words to pass all quality filters. "
"It contains meaningful English text that should be kept in the dataset. "
"Machine learning is a field of inquiry devoted to understanding data. "
"Natural language processing has made tremendous progress in recent years. "
"Researchers continue to push the boundaries of what is possible with AI. "
)
def _write_jsonl(path, records):
with open(path, "w") as f:
for r in records:
f.write(json.dumps(r) + "\n")
def _read_jsonl(path):
with open(path) as f:
return [json.loads(line) for line in f if line.strip()]
class TestCLISignals:
def test_basic_signals(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "doc1.txt", "text": GOOD_TEXT},
])
ret = main(["signals", "--input", inp, "--output", out])
assert ret == 0
assert os.path.exists(out)
records = _read_jsonl(out)
assert len(records) == 1
assert "_signals" in records[0]
assert "_filter_pass" in records[0]
def test_signals_filtering(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "good.txt", "text": GOOD_TEXT},
{"file": "bad.txt", "text": "lorem ipsum lorem ipsum lorem ipsum"},
])
ret = main(["signals", "--input", inp, "--output", out])
assert ret == 0
records = _read_jsonl(out)
assert len(records) == 2
# First (good) should pass, second (lorem) should fail
assert records[0]["_filter_pass"] is True
assert records[1]["_filter_pass"] is False
def test_max_docs(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": f"doc{i}.txt", "text": GOOD_TEXT}
for i in range(10)
])
ret = main(["signals", "--input", inp, "--output", out, "--max-docs", "3"])
assert ret == 0
records = _read_jsonl(out)
assert len(records) == 3
def test_missing_input(self):
with tempfile.TemporaryDirectory() as tmp:
out = os.path.join(tmp, "out.jsonl")
ret = main(["signals", "--input", "/nonexistent/file.jsonl", "--output", out])
assert ret != 0
def test_code_rules_file(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
rules_file = os.path.join(tmp, "rules.json")
# Custom rule: .cs maps to C# keywords so this file passes code keyword check
with open(rules_file, "w") as f:
json.dump({".cs": ["public", "class", "void"]}, f)
_write_jsonl(inp, [
{"file": "Foo.cs", "text": "public class Foo { public void Bar() { } } " + GOOD_TEXT},
])
ret = main(["signals", "--input", inp, "--output", out, "--code-rules-file", rules_file])
assert ret == 0
records = _read_jsonl(out)
assert len(records) == 1
def test_code_rules_file_suffix_skip(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
rules_file = os.path.join(tmp, "rules.json")
# Map .txt to null = skip code keyword density
with open(rules_file, "w") as f:
json.dump({".txt": None}, f)
_write_jsonl(inp, [
{"file": "doc.txt", "text": GOOD_TEXT},
])
ret = main(["signals", "--input", inp, "--output", out, "--code-rules-file", rules_file])
assert ret == 0
class TestCLIDedup:
def test_basic_dedup(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "a.txt", "text": GOOD_TEXT},
{"file": "b.txt", "text": GOOD_TEXT + " slightly different ending here yes"},
{"file": "c.txt", "text": "Completely unrelated text about quantum physics and the nature of the universe in the early moments after the big bang."},
])
ret = main(["dedup", "--input", inp, "--output", out])
assert ret == 0
records = _read_jsonl(out)
assert len(records) <= 3
assert len(records) >= 1
def test_dedup_preserves_fields(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "a.txt", "text": GOOD_TEXT, "custom": "value"},
])
ret = main(["dedup", "--input", inp, "--output", out])
assert ret == 0
records = _read_jsonl(out)
assert records[0]["custom"] == "value"
class TestCLIPipeline:
def test_full_pipeline(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "good1.txt", "text": GOOD_TEXT},
{"file": "good2.txt", "text": GOOD_TEXT + " more text here"},
{"file": "good3.txt", "text": "Another completely different document about astronomy and space exploration. The James Webb Space Telescope has revolutionized our understanding of distant galaxies."},
{"file": "bad_lorem.txt", "text": "lorem ipsum dolor sit amet lorem ipsum"},
{"file": "bad_code.txt", "text": "function foo() { var x = 1; if (x > 0) { return x; } }"},
])
ret = main([
"pipeline", "--input", inp, "--output", out,
"--similarity", "0.8", "--num-perm", "64",
])
assert ret == 0
records = _read_jsonl(out)
assert len(records) >= 1 # at least one good doc kept
# No lorem or code docs should survive (they get filtered)
for r in records:
assert r["_filter_pass"] is True
def test_pipeline_save_filtered(self):
with tempfile.TemporaryDirectory() as tmp:
inp = os.path.join(tmp, "in.jsonl")
out = os.path.join(tmp, "out.jsonl")
_write_jsonl(inp, [
{"file": "doc1.txt", "text": GOOD_TEXT},
])
ret = main(["pipeline", "--input", inp, "--output", out, "--save-filtered"])
assert ret == 0
assert os.path.exists(out + ".filtered.jsonl")
class TestCLIHelp:
def test_no_args_shows_help(self):
ret = main([])
assert ret == 1
def test_signals_help(self):
with pytest.raises(SystemExit):
main(["signals", "--help"])
"""Tests for document.py."""
from data_tool.document import Document
class TestDocument:
def test_raw_content(self, good_doc, good_text):
assert good_doc.raw_content == good_text
def test_len(self, good_doc, good_text):
assert len(good_doc) == len(good_text)
def test_normalized_content(self, good_doc):
normalized = good_doc.normalized_content
assert normalized == normalized.lower()
assert "," not in normalized
def test_raw_words(self, good_doc):
words = good_doc.raw_words
assert isinstance(words, tuple)
assert len(words) > 0
def test_normalized_words(self, good_doc):
words = good_doc.normalized_words
assert isinstance(words, tuple)
assert len(words) > 0
for w in words:
assert w == w.lower()
def test_raw_lines(self, good_doc):
lines = good_doc.raw_lines
assert isinstance(lines, tuple)
assert len(lines) > 0
assert hasattr(lines[0], "text")
assert hasattr(lines[0], "start")
assert hasattr(lines[0], "end")
def test_num_raw_words(self, good_doc):
assert good_doc.num_raw_words == len(good_doc.raw_words)
def test_num_normalized_words(self, good_doc):
assert good_doc.num_normalized_words == len(good_doc.normalized_words)
def test_ngrams_precomputed(self, good_doc):
for n in range(2, 11):
ngrams = good_doc.get_ngrams(n)
assert isinstance(ngrams, tuple), f"ngram {n} should be tuple"
if len(good_doc.normalized_words) >= n:
assert len(ngrams) > 0, f"ngram {n} should not be empty"
def test_empty_document(self, empty_doc):
assert len(empty_doc) == 0
assert empty_doc.num_raw_words == 0
assert empty_doc.num_normalized_words == 0
def test_get_ngrams_consistency(self, good_doc):
"""N-grams should be the same whether accessed directly or via get_ngrams."""
for n in range(2, 11):
direct = getattr(good_doc, f"norm_{n}grams")
via_method = good_doc.get_ngrams(n)
assert direct is via_method
def test_source_file_none_by_default(self, good_doc):
assert good_doc.source_file is None
def test_source_file_stored(self):
from data_tool.document import Document
doc = Document("hello", source_file="foo.py")
assert doc.source_file == "foo.py"
def test_source_file_none_explicit(self):
from data_tool.document import Document
doc = Document("hello", source_file=None)
assert doc.source_file is None
"""Tests for filter.py."""
import json
import statistics
from data_tool.filter import check_document, load_thresholds, DEFAULT_THRESHOLDS
from data_tool.quality_signals import build_signal_functions, compute_all_signals
class TestCheckDocument:
def test_good_doc_passes(self, good_doc, signal_funcs):
signals = compute_all_signals(good_doc, signal_funcs)
keep, reasons = check_document(signals)
assert keep, f"Good doc should pass, got: {reasons}"
def test_lorem_fails(self, lorem_doc, signal_funcs):
signals = compute_all_signals(lorem_doc, signal_funcs)
keep, reasons = check_document(signals)
assert not keep
def test_code_fails_curly_bracket(self, code_doc, signal_funcs):
signals = compute_all_signals(code_doc, signal_funcs)
keep, reasons = check_document(signals)
# Code text has high curly bracket ratio AND high repetition
assert not keep
def test_empty_doc_fails_word_count(self, empty_doc, signal_funcs):
signals = compute_all_signals(empty_doc, signal_funcs)
keep, reasons = check_document(signals)
assert not keep
def test_none_signal_skipped(self, signal_funcs):
"""Signals that return None should be skipped, not cause failure."""
signals = {"rps_doc_ldnoobw_words": None, "rps_doc_word_count": 100.0}
keep, _ = check_document(signals)
assert keep
def test_min_threshold(self):
signals = {"rps_doc_word_count": 30.0}
thresholds = {"rps_doc_word_count": {"min": 50}}
keep, reasons = check_document(signals, thresholds)
assert not keep
assert "rps_doc_word_count" in reasons
def test_max_threshold(self):
signals = {"rps_doc_curly_bracket": 0.5}
thresholds = {"rps_doc_curly_bracket": {"max": 0.1}}
keep, reasons = check_document(signals, thresholds)
assert not keep
def test_both_min_and_max(self):
signals = {"rps_doc_mean_word_length": 15.0}
thresholds = {"rps_doc_mean_word_length": {"min": 2.0, "max": 12.0}}
keep, reasons = check_document(signals, thresholds)
assert not keep
def test_custom_thresholds(self):
"""Custom threshold overrides should be respected."""
signals = {"rps_doc_word_count": 30.0}
thresholds = {"rps_doc_word_count": {"min": 30}} # exactly at boundary
keep, _ = check_document(signals, thresholds)
assert keep # >= means it passes at 30
def test_line_signal_aggregation_mean(self):
"""Line signals are aggregated via mean before threshold check."""
signals = {"rps_lines_num_words": [2.0, 4.0, 6.0]} # mean = 4
thresholds = {"rps_lines_num_words": {"min": 5}}
keep, _ = check_document(signals, thresholds, line_agg="mean")
assert not keep # mean is 4, below min 5
def test_line_signal_aggregation_max(self):
signals = {"rps_lines_num_words": [2.0, 10.0]}
thresholds = {"rps_lines_num_words": {"min": 5}}
keep, _ = check_document(signals, thresholds, line_agg="max")
assert keep # max is 10, above min 5
def test_no_threshold_for_signal(self):
"""Signals without thresholds are ignored."""
signals = {"some_unknown_signal": 999.0}
keep, _ = check_document(signals)
assert keep
class TestDefaultThresholds:
def test_all_defaults_have_min_or_max(self):
for name, rules in DEFAULT_THRESHOLDS.items():
assert "min" in rules or "max" in rules, f"{name} missing min/max"
assert len(rules) >= 1, f"{name} has no rules"
class TestLoadThresholds:
def test_load_and_merge(self, tmp_jsonl):
custom = {"rps_doc_word_count": {"min": 100, "max": 10000}}
with open(tmp_jsonl, "w") as f:
json.dump(custom, f)
merged = load_thresholds(tmp_jsonl)
assert merged["rps_doc_word_count"] == {"min": 100, "max": 10000}
# Defaults still there
assert "rps_doc_curly_bracket" in merged
"""Tests for header_cleaner.py."""
from data_tool.header_cleaner import clean_headers, is_mostly_boilerplate, _boilerplate_score
class TestBoilerplateScore:
def test_no_match(self):
severity, category = _boilerplate_score("This is normal text content.")
assert severity == 0
assert category == ""
def test_cookie_match(self):
severity, category = _boilerplate_score("This site uses cookies to improve your experience.")
assert severity == 3
assert category == "cookies"
def test_gdpr_match(self):
severity, category = _boilerplate_score("We comply with GDPR regulations.")
assert severity == 3
def test_copyright_match(self):
severity, category = _boilerplate_score("Copyright 2024 All Rights Reserved.")
assert severity == 3
def test_navigation_match(self):
severity, category = _boilerplate_score("Main navigation menu here.")
assert severity == 2
def test_short_line_skipped(self):
severity, category = _boilerplate_score("ok")
assert severity == 0
def test_privacy_policy_match(self):
severity, category = _boilerplate_score("View our privacy policy here.")
assert severity == 3
def test_share_match(self):
severity, category = _boilerplate_score("Share on Facebook and Twitter.")
assert severity == 3
class TestCleanHeaders:
def test_empty_string(self):
assert clean_headers("") == ""
def test_short_text_preserved(self):
assert clean_headers("Hello world") == "Hello world"
def test_normal_text_preserved(self):
text = (
"The quick brown fox jumps over the lazy dog.\n"
"This is a proper sentence with enough words.\n"
"Machine learning is a field of inquiry.\n"
)
cleaned = clean_headers(text)
assert "quick brown fox" in cleaned
assert "Machine learning" in cleaned
def test_cookie_header_removed(self):
text = (
"This site uses cookies to improve your experience.\n"
"Accept all cookies\n"
"\n"
"The quick brown fox jumps over the lazy dog.\n"
"This is the actual article content that should be preserved.\n"
"Machine learning is a field of inquiry devoted to understanding data.\n"
"Natural language processing has made tremendous progress in recent years.\n"
"Researchers continue to push the boundaries of what is possible with AI.\n"
"Deep neural networks have revolutionized computer vision and speech recognition.\n"
"The transformer architecture has become the foundation of modern language models.\n"
"Attention mechanisms allow models to focus on relevant parts of the input.\n"
"Transfer learning has enabled rapid progress across many domains and tasks.\n"
)
cleaned = clean_headers(text)
assert "cookies" not in cleaned.lower()
assert "quick brown fox" in cleaned
def test_footer_removed(self):
text = (
"The quick brown fox jumps over the lazy dog.\n"
"This is the actual article content that should be preserved.\n"
"Machine learning is a field of inquiry devoted to understanding data.\n"
"Natural language processing has made tremendous progress in recent years.\n"
"Researchers continue to push the boundaries of what is possible with AI.\n"
"Deep neural networks have revolutionized computer vision and speech recognition.\n"
"The transformer architecture has become the foundation of modern language models.\n"
"Copyright 2024 All Rights Reserved.\n"
"Contact us | Privacy policy | Terms of service\n"
)
cleaned = clean_headers(text)
assert "Copyright" not in cleaned
assert "Contact us" not in cleaned
assert "quick brown fox" in cleaned
def test_social_share_removed(self):
text = (
"Share on Facebook Twitter LinkedIn\n"
"Follow us on Instagram\n"
"The quick brown fox jumps over the lazy dog.\n"
"This is meaningful content here yes indeed.\n"
"Machine learning is a field of inquiry devoted to understanding data.\n"
"Natural language processing has made tremendous progress in recent years.\n"
"Researchers continue to push the boundaries of what is possible with AI.\n"
"Deep neural networks have revolutionized computer vision and speech recognition.\n"
)
cleaned = clean_headers(text)
assert "Share on Facebook" not in cleaned
assert "quick brown fox" in cleaned
def test_subscribe_newsletter_removed(self):
text = (
"Subscribe to our newsletter for updates!\n"
"The quick brown fox jumps over the lazy dog.\n"
"This is the real content that we want to keep in the dataset.\n"
)
cleaned = clean_headers(text)
assert "newsletter" not in cleaned.lower()
assert "quick brown fox" in cleaned
def test_navigation_breadcrumb_removed(self):
text = (
"Main navigation breadcrumb\n"
"Home > Articles > Science > Today\n"
"The quick brown fox jumps over the lazy dog.\n"
"Here is some more substantial content that is long enough to trigger stop.\n"
)
cleaned = clean_headers(text)
assert "breadcrumb" not in cleaned.lower()
assert "quick brown fox" in cleaned
assert "substantial content" in cleaned
def test_advertisement_removed(self):
text = (
"Advertisement\n"
"Sponsored content brought to you by brand\n"
"\n"
"The quick brown fox jumps over the lazy dog.\n"
"Actual article text with enough length to be substantial and valuable.\n"
"Machine learning is a field of inquiry devoted to understanding data.\n"
"Natural language processing has made tremendous progress in recent years.\n"
"Researchers continue to push the boundaries of what is possible with AI.\n"
"Deep neural networks have revolutionized computer vision and speech recognition.\n"
"The transformer architecture has become the foundation of modern language models.\n"
)
cleaned = clean_headers(text)
assert "Advertisement" not in cleaned
assert "Sponsored" not in cleaned
assert "quick brown fox" in cleaned
def test_safety_valve_too_much_stripped(self):
"""When >30% would be stripped, return original text unchanged."""
bad_lines = [
"Share on Facebook",
"Share on Twitter",
"Subscribe to our newsletter",
"Copyright 2024 All Rights Reserved",
"Contact us",
]
text = "\n".join(bad_lines)
result = clean_headers(text)
# Safety valve: all lines are boilerplate so >30% -> return original
assert result == text
def test_mixed_content_partial_cleaning(self):
text = (
"Skip to main content\n"
"This site uses cookies\n"
"Accept all cookies\n"
"\n"
"The quick brown fox jumps over the lazy dog.\n"
"This is meaningful content with substantial length.\n"
"Another paragraph of real text here that is quite long.\n"
"Machine learning is a field of inquiry devoted to understanding data.\n"
"Natural language processing has made tremendous progress in recent years.\n"
"Researchers continue to push the boundaries of what is possible with AI.\n"
"Deep neural networks have revolutionized computer vision and speech recognition.\n"
"The transformer architecture has become the foundation of modern language models.\n"
"Attention mechanisms allow models to focus on relevant parts of the input.\n"
"Transfer learning has enabled rapid progress across many domains and tasks.\n"
"\n"
"Copyright 2024\n"
"Contact us | About us\n"
)
cleaned = clean_headers(text)
assert "cookies" not in cleaned.lower()
assert "Copyright" not in cleaned
assert "Contact us" not in cleaned
assert "quick brown fox" in cleaned
assert "meaningful content" in cleaned
def test_code_like_content_not_stripped(self):
text = (
"def main():\n"
" return 0\n"
"class Foo:\n"
" pass\n"
)
cleaned = clean_headers(text)
assert "def main" in cleaned
assert "class Foo" in cleaned
class TestIsMostlyBoilerplate:
def test_empty_text(self):
assert is_mostly_boilerplate("") is True
def test_short_text(self):
assert is_mostly_boilerplate("Hello\nWorld") is False
def test_mostly_content(self):
text = (
"The quick brown fox jumps over the lazy dog.\n"
"This is proper content with enough text.\n"
)
assert is_mostly_boilerplate(text) is False
def test_mostly_boilerplate(self):
text = (
"Skip to main content\n"
"This site uses cookies\n"
"Accept all cookies\n"
"Share on Facebook\n"
"Subscribe to newsletter\n"
"Copyright 2024\n"
"Contact us\n"
)
assert is_mostly_boilerplate(text) is True
"""Tests for lsh.py."""
import pytest
from data_tool.lsh import UnionFind, run_lsh
from data_tool.minhash import MinHash
class TestUnionFind:
def test_add_and_find(self):
uf = UnionFind()
uf.add(1)
assert uf.find(1) == 1
def test_union(self):
uf = UnionFind()
uf.union(1, 2)
assert uf.find(1) == uf.find(2)
def test_transitive_union(self):
uf = UnionFind()
uf.union(1, 2)
uf.union(2, 3)
assert uf.find(1) == uf.find(3)
def test_components(self):
uf = UnionFind()
uf.union(1, 2)
uf.union(3, 4)
uf.add(5)
comps = uf.components()
assert len(comps) == 3
sizes = sorted(len(v) for v in comps.values())
assert sizes == [1, 2, 2]
def test_single_element_component(self):
uf = UnionFind()
uf.add(42)
comps = uf.components()
assert len(comps) == 1
assert 42 in list(comps.values())[0]
def test_self_union_noop(self):
uf = UnionFind()
uf.add(1)
root_before = uf.find(1)
uf.union(1, 1)
assert uf.find(1) == root_before
def test_large_set(self):
uf = UnionFind()
for i in range(1000):
uf.union(i, i + 1)
assert uf.find(0) == uf.find(1000)
comps = uf.components()
assert len(comps) == 1
class TestRunLSH:
@pytest.fixture
def minhash(self):
return MinHash(ngram_size=3, num_permutations=64, similarity_threshold=0.8, seed=42)
def test_single_document(self, minhash):
docs = [{"file": "a.txt", "text": "hello world"}]
result = run_lsh(docs, minhash)
assert len(result) == 1
assert result[0]["file"] == "a.txt"
def test_empty_list(self, minhash):
result = run_lsh([], minhash)
assert result == []
def test_near_duplicates_removed(self, minhash, near_duplicate_docs):
result = run_lsh(near_duplicate_docs, minhash)
# 4 near-duplicate docs, 3 very similar + 1 unique
# At 0.8 similarity, the 3 similar docs should be clustered
assert len(result) < len(near_duplicate_docs)
def test_unique_doc_kept(self, minhash, near_duplicate_docs):
result = run_lsh(near_duplicate_docs, minhash)
kept_files = {r["file"] for r in result}
assert "unique.txt" in kept_files, "unique doc must always be kept"
def test_output_preserves_fields(self, minhash):
docs = [{"file": "a.txt", "text": "hello world", "extra_field": 42}]
result = run_lsh(docs, minhash)
assert result[0]["extra_field"] == 42
def test_too_short_docs_not_crashing(self, minhash):
"""Documents shorter than ngram_size should be handled gracefully."""
docs = [
{"file": "short1.txt", "text": "hi"},
{"file": "short2.txt", "text": "hi"}, # same content but short
{"file": "long.txt", "text": "this is a longer document that will actually get a signature because it has enough words"},
]
result = run_lsh(docs, minhash)
# Short docs with identical but too-few-token content: signatures are None,
# so they don't generate edges and are both kept
assert len(result) <= 3
"""Tests for minhash.py."""
import pytest
from data_tool.minhash import MinHash
from data_tool.utils import sha1_hash32
class TestMinHash:
@pytest.fixture
def minhash(self):
return MinHash(ngram_size=3, num_permutations=64, similarity_threshold=0.8, seed=42)
def test_initialization(self, minhash):
assert minhash.ngram_size == 3
assert minhash.num_permutations == 64
assert minhash.similarity_threshold == 0.8
def test_checksum_stable(self, minhash):
cs1 = minhash.checksum
cs2 = minhash.checksum
assert cs1 == cs2
assert len(cs1) == 64 # SHA256 hex is 64 chars
def test_different_seed_different_checksum(self):
mh1 = MinHash(seed=42)
mh2 = MinHash(seed=43)
assert mh1.checksum != mh2.checksum
def test_banded_signatures_structure(self, minhash):
tokens = tuple("the quick brown fox jumps over the lazy dog".split())
sig = minhash.compute_banded_signatures(tokens)
assert isinstance(sig, list)
assert all(isinstance(b, bytes) for b in sig), "all bands should be bytes"
def test_banded_signatures_length(self, minhash):
tokens = tuple("the quick brown fox jumps over the lazy dog".split())
sig = minhash.compute_banded_signatures(tokens)
assert len(sig) == minhash.num_bands
def test_too_few_tokens(self, minhash):
tokens = tuple("hi there".split())
sig = minhash.compute_banded_signatures(tokens)
assert sig is None
def test_deterministic(self, minhash):
tokens = tuple("the quick brown fox jumps over the lazy dog".split())
sig1 = minhash.compute_banded_signatures(tokens)
sig2 = minhash.compute_banded_signatures(tokens)
assert sig1 == sig2
def test_same_text_same_signature(self, minhash):
tokens1 = tuple("the quick brown fox jumps".split())
tokens2 = tuple("the quick brown fox jumps".split())
assert minhash.compute_banded_signatures(tokens1) == minhash.compute_banded_signatures(tokens2)
def test_similar_text_shares_bands(self, minhash):
"""Near-identical texts should share at least some band hashes."""
base = tuple("the cat sat on the mat in the garden by the tree".split())
similar = tuple("the cat sat on the mat in the garden by the fence".split())
sig1 = minhash.compute_banded_signatures(base)
sig2 = minhash.compute_banded_signatures(similar)
shared = sum(1 for b1, b2 in zip(sig1, sig2) if b1 == b2)
assert shared > 0, "near-identical texts should share some bands"
def test_dissimilar_text_few_shared_bands(self, minhash):
"""Very different texts should share few or no band hashes."""
tokens1 = tuple("the quick brown fox jumps over the lazy dog".split())
tokens2 = tuple("quantum mechanics describes behavior of subatomic particles".split())
sig1 = minhash.compute_banded_signatures(tokens1)
sig2 = minhash.compute_banded_signatures(tokens2)
shared = sum(1 for b1, b2 in zip(sig1, sig2) if b1 == b2)
# At high similarity thresholds, very different texts share few bands
assert shared < len(sig1) * 0.5
def test_different_ngram_size(self):
mh3 = MinHash(ngram_size=3)
mh5 = MinHash(ngram_size=5)
tokens = tuple("a b c d e f g h i j k".split())
sig3 = mh3.compute_banded_signatures(tokens)
sig5 = mh5.compute_banded_signatures(tokens)
assert sig3 != sig5
def test_different_similarity_threshold(self):
mh_low = MinHash(similarity_threshold=0.5, num_permutations=64)
mh_high = MinHash(similarity_threshold=0.9, num_permutations=64)
# Different thresholds should produce different band counts
assert mh_low.num_bands != mh_high.num_bands
class TestSHA1Hash32:
"""Re-test from minhash's perspective."""
def test_known_value(self):
h = sha1_hash32(b"test")
assert isinstance(h, int)
assert 0 <= h < 2**32
"""Tests for utils.py."""
from data_tool.utils import normalize, form_ngrams, sha1_hash32, read_jsonl, write_jsonl
from data_tool.utils import TextSlice
class TestNormalize:
def test_lowercase(self):
assert normalize("Hello WORLD") == "hello world"
def test_remove_punctuation(self):
result = normalize("hello, world!")
assert "!" not in result
assert "," not in result
def test_whitespace_collapse(self):
assert normalize("hello world") == "hello world"
def test_strip(self):
assert normalize(" hello world ") == "hello world"
def test_empty(self):
assert normalize("") == ""
class TestFormNgrams:
def test_bigrams(self):
words = iter(["the", "quick", "brown", "fox"])
result = list(form_ngrams(words, 2))
assert result == [("the", "quick"), ("quick", "brown"), ("brown", "fox")]
def test_trigrams(self):
words = iter(["a", "b", "c", "d"])
result = list(form_ngrams(words, 3))
assert result == [("a", "b", "c"), ("b", "c", "d")]
def test_too_few_items(self):
words = iter(["a", "b"])
result = list(form_ngrams(words, 3))
assert result == []
def test_empty(self):
words = iter([])
result = list(form_ngrams(words, 2))
assert result == []
class TestTextSlice:
def test_len(self):
ts = TextSlice("hello", 0, 5)
assert len(ts) == 5
def test_fields(self):
ts = TextSlice("hello", 10, 15)
assert ts.text == "hello"
assert ts.start == 10
assert ts.end == 15
class TestSHA1Hash32:
def test_deterministic(self):
assert sha1_hash32(b"hello") == sha1_hash32(b"hello")
def test_different_inputs(self):
assert sha1_hash32(b"hello") != sha1_hash32(b"world")
def test_returns_int(self):
assert isinstance(sha1_hash32(b"test"), int)
class TestJSONLIO:
def test_read_write_roundtrip(self, tmp_jsonl):
data = [
{"file": "a.txt", "text": "hello world"},
{"file": "b.txt", "text": "goodbye"},
]
write_jsonl(tmp_jsonl, data)
result = list(read_jsonl(tmp_jsonl))
assert result == data
def test_empty_write(self, tmp_jsonl):
write_jsonl(tmp_jsonl, [])
result = list(read_jsonl(tmp_jsonl))
assert result == []
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment