#!/usr/bin/env python3
from __future__ import annotations

import argparse
import heapq
import json
import math
import os
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import hnswlib

import chunk_note
import build_index
import llm_provider


def _resolve_api_key(*, api_key: Optional[str], api_key_env: Optional[str]) -> Optional[str]:
    direct = (api_key or "").strip()
    if direct:
        return direct
    env_name = (api_key_env or "").strip()
    if env_name:
        return os.getenv(env_name)
    return os.getenv("OPENAI_API_KEY")


def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    out = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                out.append(json.loads(line))
    return out


def _extract_href(value: Any) -> str:
    if isinstance(value, str):
        return value.strip()
    if isinstance(value, dict):
        href = value.get("href")
        if isinstance(href, str):
            return href.strip()
    return ""


@lru_cache(maxsize=2048)
def _source_links_from_meta_for_source_path(source_path: str) -> Tuple[str, str]:
    """Fallback for older indexes that did not store source_url in chunk metadata."""
    if not source_path:
        return "", ""

    p = Path(source_path)
    meta_path = p.with_name("meta.json")
    if p.name != "page.txt" or not meta_path.exists():
        return "", ""

    try:
        meta = json.loads(meta_path.read_text(encoding="utf-8", errors="ignore"))
    except Exception:
        return "", ""

    if not isinstance(meta, dict):
        return "", ""

    links = meta.get("links") if isinstance(meta.get("links"), dict) else {}
    web = _extract_href((links or {}).get("oneNoteWebUrl")) or _extract_href(meta.get("oneNoteWebUrl"))
    client = _extract_href((links or {}).get("oneNoteClientUrl")) or _extract_href(meta.get("oneNoteClientUrl"))
    return web, client


def source_links_for_chunk(chunk: Dict[str, Any]) -> Tuple[str, str]:
    web = _extract_href(chunk.get("source_url")) or _extract_href(chunk.get("oneNoteWebUrl"))
    client = _extract_href(chunk.get("source_client_url")) or _extract_href(chunk.get("oneNoteClientUrl"))
    if web or client:
        return web, client

    src = chunk.get("source_path")
    if isinstance(src, str) and src.strip():
        m_web, m_client = _source_links_from_meta_for_source_path(src.strip())
        return web or m_web, client or m_client

    return web, client


def best_source_url_for_chunk(chunk: Dict[str, Any]) -> str:
    web, client = source_links_for_chunk(chunk)
    return web or client or ""


def _clean_link_label(label: str) -> str:
    label = re.sub(r"\s+", " ", (label or "").strip())
    label = label.replace("[", "").replace("]", "")
    return label or "Kilde"


_CITATION_BRACKET_RE = re.compile(r"\[([^\]]*?(?:KILDE|kilde)[^\]]*?)\]")
_CITATION_INLINE_RE = re.compile(r"(?<!\[)\b(?:KILDE|kilde)\s*(\d+)\b")


def render_answer_with_citation_links(answer: str, hits: List[Dict[str, Any]]) -> str:
    """Replace [KILDE n] tags with markdown links to actual OneNote page URLs."""
    if not answer:
        return answer

    refs: Dict[int, Dict[str, str]] = {}
    for idx, h in enumerate(hits, start=1):
        c = h.get("chunk", {}) if isinstance(h, dict) else {}
        if not isinstance(c, dict):
            c = {}
        refs[idx] = {
            "title": _clean_link_label(str(c.get("title") or f"Kilde {idx}")),
            "url": best_source_url_for_chunk(c),
        }

    def render_ref(n: int, fallback: str) -> str:
        ref = refs.get(n)
        if not ref:
            return fallback
        title = ref["title"]
        url = ref["url"]
        if not url:
            return f"[Kilde {n}: {title}]"
        return f"[Kilde {n}: {title}](<{url}>)"

    def ordered_unique_numbers(text: str) -> List[int]:
        out: List[int] = []
        seen = set()
        for raw in re.findall(r"\d+", text or ""):
            try:
                n = int(raw)
            except Exception:
                continue
            if n in seen:
                continue
            seen.add(n)
            out.append(n)
        return out

    def repl_bracket(m: re.Match) -> str:
        raw = m.group(1) or ""
        nums = ordered_unique_numbers(raw)
        if not nums:
            return m.group(0)
        return " og ".join(render_ref(n, fallback=f"Kilde {n}") for n in nums)

    linked = _CITATION_BRACKET_RE.sub(repl_bracket, answer)

    def repl_inline(m: re.Match) -> str:
        try:
            n = int(m.group(1))
        except Exception:
            return m.group(0)
        return render_ref(n, fallback=m.group(0))

    linked = _CITATION_INLINE_RE.sub(repl_inline, linked)
    return linked


_STOPWORDS_DA = {
    "alle",
    "altså",
    "at",
    "af",
    "blive",
    "bliver",
    "de",
    "dem",
    "den",
    "der",
    "deres",
    "det",
    "dig",
    "du",
    "eller",
    "en",
    "er",
    "et",
    "for",
    "fra",
    "får",
    "har",
    "have",
    "hvad",
    "hvem",
    "hvilke",
    "hvilken",
    "hvilket",
    "hvor",
    "hvordan",
    "hvorfor",
    "hvis",
    "i",
    "ikke",
    "ind",
    "kan",
    "man",
    "med",
    "mig",
    "min",
    "mine",
    "mit",
    "må",
    "når",
    "og",
    "om",
    "os",
    "på",
    "sig",
    "siger",
    "så",
    "til",
    "ud",
    "vi",
    "vores",
}


def _tokenize_query(q: str) -> List[str]:
    q = q.lower()
    # keep Danish letters; split on non-letters/digits
    parts = [p for p in re.split(r"[^a-z0-9æøå]+", q) if p]
    # remove short tokens + common stopwords + duplicates
    out: List[str] = []
    seen = set()
    for p in parts:
        if len(p) < 3:
            continue
        if p in _STOPWORDS_DA:
            continue
        if p in seen:
            continue
        seen.add(p)
        out.append(p)
    return out


def _keyword_boost(question: str, chunk: Dict[str, Any]) -> float:
    """Small positive boost if question tokens appear in title/headings."""
    toks = _tokenize_query(question)
    if not toks:
        return 0.0

    title = (chunk.get("title") or "").lower()
    headings = " ".join(chunk.get("heading_path") or []).lower()

    boost = 0.0
    for t in toks:
        if t in title:
            boost += 0.14
        if t in headings:
            boost += 0.08
    return boost


def _make_chunk_haystack(chunk: Dict[str, Any]) -> str:
    title = (chunk.get("title") or "").lower()
    headings = " ".join(chunk.get("heading_path") or []).lower()
    text = (chunk.get("text") or "").lower()
    return f"{title}\n{headings}\n{text}"


def _token_weights_for_query(question_tokens: List[str], meta: List[Dict[str, Any]]) -> Dict[str, float]:
    """Compute simple IDF-style weights so rare query tokens are prioritized."""
    if not question_tokens:
        return {}

    n = max(1, len(meta))
    df: Dict[str, int] = {t: 0 for t in question_tokens}

    for c in meta:
        hay = _make_chunk_haystack(c)
        for t in question_tokens:
            if t in hay:
                df[t] += 1

    weights: Dict[str, float] = {}
    for t in question_tokens:
        # Smooth IDF in a compact range; common words trend toward ~1.0.
        idf = 1.0 + math.log((n + 1.0) / (1.0 + float(df[t])))
        weights[t] = max(0.8, min(idf, 3.5))
    return weights


def _query_phrases(question_tokens: List[str]) -> List[str]:
    phrases: List[str] = []
    for i in range(len(question_tokens) - 1):
        a = question_tokens[i]
        b = question_tokens[i + 1]
        if a and b:
            phrases.append(f"{a} {b}")
    return phrases


def _lexical_score(
    question_tokens: List[str],
    chunk: Dict[str, Any],
    token_weights: Dict[str, float],
    phrases: List[str],
) -> float:
    """Keyword score across title/headings/text to support Danish and exact-term queries."""
    if not question_tokens:
        return 0.0

    title = (chunk.get("title") or "").lower()
    headings = " ".join(chunk.get("heading_path") or []).lower()
    text = (chunk.get("text") or "").lower()
    hay = f"{title}\n{headings}\n{text}"

    score = 0.0
    for tok in question_tokens:
        w = token_weights.get(tok, 1.0)
        if tok in title:
            score += 3.4 * w
        if tok in headings:
            score += 1.7 * w
        if tok in hay:
            score += 1.1 * w

    # Exact phrase matches (e.g., "civil disobedience") should jump up.
    for phr in phrases:
        if not phr:
            continue
        if phr in title:
            score += 4.2
        elif phr in headings:
            score += 2.8
        elif phr in hay:
            score += 2.0

    return score


def _top_lexical_candidates(question: str, meta: List[Dict[str, Any]], limit: int) -> Dict[int, float]:
    toks = _tokenize_query(question)
    if not toks:
        return {}

    token_weights = _token_weights_for_query(toks, meta)
    phrases = _query_phrases(toks)

    scored: List[Tuple[float, int]] = []
    for idx, chunk in enumerate(meta):
        s = _lexical_score(toks, chunk, token_weights=token_weights, phrases=phrases)
        if s <= 0.0:
            continue
        scored.append((s, idx))

    if not scored:
        return {}

    top = heapq.nlargest(limit, scored, key=lambda x: x[0])
    return {idx: score for score, idx in top}


def _normalize_dense_distance(distance: float) -> float:
    # Ollama/HNSW cosine distance: lower is better. Convert to [0..1] similarity-ish score.
    return 1.0 / (1.0 + max(0.0, distance))


def _combine_scores(dense_score: float, lexical_score: float) -> float:
    # Dense remains primary, lexical acts as precision rescue.
    return (0.78 * dense_score) + (0.32 * lexical_score)


def _is_prompt_like_chunk_text(text: str) -> bool:
    t = (text or "").lower().strip()
    if not t:
        return False
    qmarks = t.count("?")
    if len(t) < 280 and qmarks >= 1:
        return True
    for marker in (
        "gruppearbejde",
        "hvad handler jeres tekststykke om",
        "konstruer et utopisk samfund",
    ):
        if marker in t and len(t) < 600:
            return True
    return False


def _chunk_quality_score(chunk: Dict[str, Any]) -> float:
    """Heuristic quality prior so content-heavy chunks beat empty prompt pages."""
    text = (chunk.get("text") or "").strip()
    if not text:
        return -0.22

    t = text.lower()
    score = 0.0
    n = len(text)

    # Reward richer chunks slightly.
    if n >= 1200:
        score += 0.17
    elif n >= 700:
        score += 0.12
    elif n >= 350:
        score += 0.07
    elif n < 180:
        score -= 0.08

    # Attachment-derived text often contains the actual source content.
    if "[attachments]" in t and "extracted_text:" in t:
        score += 0.26
    elif "[attachments]" in t and "attachments with extracted text: 0" in t:
        score -= 0.07

    if _is_prompt_like_chunk_text(text):
        score -= 0.25

    # Empty image metadata blocks are rarely useful as primary evidence.
    if "[images]" in t and "found images: 0 | images with ocr text: 0" in t and n < 260:
        score -= 0.06

    return score


def _source_key_for_chunk(chunk: Dict[str, Any]) -> str:
    src = (chunk.get("source_path") or "").strip()
    if src:
        return src
    title = (chunk.get("title") or "").strip()
    if title:
        return f"title:{title.lower()}"
    return "unknown"


def _apply_source_diversity(hits: List[Dict[str, Any]], top_k: int, per_source_cap: int = 2) -> List[Dict[str, Any]]:
    """Keep top results diverse so one source doesn't crowd out context."""
    if per_source_cap <= 0:
        return hits[:max(1, top_k)]

    selected: List[Dict[str, Any]] = []
    counts: Dict[str, int] = {}
    overflow: List[Dict[str, Any]] = []

    for h in hits:
        c = h.get("chunk", {}) if isinstance(h, dict) else {}
        if not isinstance(c, dict):
            overflow.append(h)
            continue
        key = _source_key_for_chunk(c)
        cnt = counts.get(key, 0)
        if cnt < per_source_cap:
            selected.append(h)
            counts[key] = cnt + 1
        else:
            overflow.append(h)

        if len(selected) >= top_k:
            return selected[:top_k]

    # Backfill if diversity cap removed too many.
    if len(selected) < top_k:
        for h in overflow:
            selected.append(h)
            if len(selected) >= top_k:
                break

    return selected[:max(1, top_k)]


def _rerank_hybrid(
    *,
    question: str,
    meta: List[Dict[str, Any]],
    dense_distances: Dict[int, float],
    lexical_scores: Dict[int, float],
    top_k: int,
) -> List[Dict[str, Any]]:
    candidate_ids = set(dense_distances.keys()) | set(lexical_scores.keys())
    if not candidate_ids:
        return []

    max_dense = 0.0
    dense_norm: Dict[int, float] = {}
    for idx, dist in dense_distances.items():
        s = _normalize_dense_distance(dist)
        dense_norm[idx] = s
        if s > max_dense:
            max_dense = s
    if max_dense > 0.0:
        for idx in list(dense_norm.keys()):
            dense_norm[idx] = dense_norm[idx] / max_dense

    max_lex = max(lexical_scores.values()) if lexical_scores else 0.0

    hits: List[Dict[str, Any]] = []
    for idx in candidate_ids:
        dense_s = dense_norm.get(idx, 0.0)
        lex_s = (lexical_scores.get(idx, 0.0) / max_lex) if max_lex > 0.0 else 0.0
        lex_s += _keyword_boost(question, meta[idx])

        # Lexical rescue: if dense missed a strong exact-term chunk, let lexical signal carry it.
        if idx not in dense_distances and lex_s > 0.0:
            dense_s = max(dense_s, 0.88 * lex_s)

        quality_s = _chunk_quality_score(meta[idx])
        final_score = _combine_scores(dense_s, lex_s) + quality_s

        if idx in dense_distances:
            dist = float(dense_distances[idx])
            mode = "dense+lexical" if idx in lexical_scores else "dense"
        else:
            # Keep a valid float for UI display; these are lexical rescue candidates.
            dist = 1.0
            mode = "lexical"

        hits.append(
            {
                "distance": dist,
                "score": float(final_score),
                "quality_score": float(quality_s),
                "retrieval_mode": mode,
                "chunk": meta[idx],
            }
        )

    hits.sort(key=lambda h: h.get("score", 0.0), reverse=True)
    return _apply_source_diversity(hits, top_k=top_k, per_source_cap=2)


def chat_completion(
    prompt: str,
    *,
    model: str,
    provider: str,
    host: str,
    api_base: str = "",
    api_key: Optional[str] = None,
) -> str:
    return llm_provider.chat_completion(
        prompt,
        provider=provider,
        model=model,
        host=host,
        api_base=api_base,
        api_key=api_key,
        timeout=180,
    )


def ollama_chat(prompt: str, model: str, host: str) -> str:
    # Backward-compatible wrapper
    return chat_completion(prompt, model=model, provider="ollama", host=host)


def is_export_root_dir(path: Path) -> bool:
    return path.is_dir() and (path / "manifest.json").exists() and (path / "pages").exists()


def find_latest_export_root(search_dir: Path) -> Optional[Path]:
    candidates: List[Path] = []
    if not search_dir.exists() or not search_dir.is_dir():
        return None

    for p in search_dir.iterdir():
        if p.is_dir() and p.name.startswith("export_") and is_export_root_dir(p):
            candidates.append(p)

    if not candidates:
        return None

    candidates.sort(key=lambda d: (d / "manifest.json").stat().st_mtime, reverse=True)
    return candidates[0]


def _latest_export_inputs_mtime(export_root: Path) -> float:
    latest = 0.0
    try:
        manifest = export_root / "manifest.json"
        if manifest.exists():
            latest = max(latest, manifest.stat().st_mtime)

        pages_dir = export_root / "pages"
        if pages_dir.exists():
            for pat in ("page.txt", "meta.json"):
                for p in pages_dir.rglob(pat):
                    try:
                        latest = max(latest, p.stat().st_mtime)
                    except Exception:
                        continue
    except Exception:
        return latest
    return latest


def ensure_chunks(export_root: Path) -> Path:
    """Ensure <export_root>/chunks.jsonl exists; if not, generate it using chunk_note."""
    chunks_path = export_root / "chunks.jsonl"
    if chunks_path.exists() and chunks_path.stat().st_size > 0:
        try:
            chunks_mtime = chunks_path.stat().st_mtime
            src_mtime = _latest_export_inputs_mtime(export_root)
            if chunks_mtime >= src_mtime:
                return chunks_path
            print("Eksisterende chunks.jsonl er ældre end pages/manifest — regenererer.")
        except Exception:
            return chunks_path

    print(f"Mangler chunks.jsonl — genererer: {chunks_path}")

    # Build chunks from export root structure (pages/*/page.txt)
    enc = chunk_note.tiktoken.get_encoding(chunk_note.DEFAULT_ENCODING)

    all_chunks: List[chunk_note.Chunk] = []
    doc_count = 0
    for doc in chunk_note.iter_documents(export_root):
        doc_count += 1
        all_chunks.extend(
            chunk_note.chunk_document(
                doc,
                enc=enc,
                target_tokens=chunk_note.DEFAULT_TARGET_TOKENS,
                overlap_tokens=chunk_note.DEFAULT_OVERLAP_TOKENS,
                min_chunk_tokens=chunk_note.DEFAULT_MIN_CHUNK_TOKENS,
            )
        )

    chunk_note.write_jsonl(all_chunks, chunks_path)
    print(f"Chunks klar. Docs={doc_count} Chunks={len(all_chunks)}")
    return chunks_path


def ensure_index(
    export_root: Path,
    chunks_path: Path,
    *,
    embed_model: str,
    provider: str = "ollama",
    host: str = "http://localhost:11434",
    api_base: str = "",
    api_key: Optional[str] = None,
) -> Path:
    """Ensure index directory exists and contains required files."""
    index_dir = export_root / "index"
    index_path = index_dir / "chunks_hnsw.bin"
    meta_path = index_dir / "chunks_meta.jsonl"
    cfg_path = index_dir / "index_config.json"

    if index_path.exists() and meta_path.exists() and cfg_path.exists():
        try:
            idx_oldest = min(index_path.stat().st_mtime, meta_path.stat().st_mtime, cfg_path.stat().st_mtime)
            chunks_mtime = chunks_path.stat().st_mtime
            cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
            cfg_model = str(cfg.get("embed_model") or "")
            cfg_provider = str(cfg.get("provider") or "ollama").strip().lower()
            cfg_host = str(cfg.get("host") or "").rstrip("/")
            cfg_base = str(cfg.get("api_base") or "").rstrip("/")
            req_provider = str(provider or "ollama").strip().lower()
            req_host = str(host or "").rstrip("/")
            req_base = str(api_base or "").rstrip("/")
            endpoint_ok = (cfg_host == req_host) if req_provider == "ollama" else (cfg_base == req_base)
            if idx_oldest >= chunks_mtime and cfg_model == embed_model and cfg_provider == req_provider and endpoint_ok:
                return index_dir
            print("Eksisterende index matcher ikke data/model/provider-endpoint — genbygger index.")
        except Exception:
            print("Kunne ikke validere eksisterende index — genbygger index.")

    print(f"Mangler index — bygger i: {index_dir}")
    try:
        build_index.build_index_from_chunks(
            chunks_path=chunks_path,
            out_dir=index_dir,
            embed_model=embed_model,
            provider=provider,
            host=host,
            api_base=api_base,
            api_key=api_key,
        )
    except Exception as e:
        raise RuntimeError("Kunne ikke bygge embeddings/index.\n\nOriginal fejl: " + str(e))
    return index_dir


def load_index(index_dir: Path) -> tuple[hnswlib.Index, List[Dict[str, Any]], Dict[str, Any]]:
    cfg = json.loads((index_dir / "index_config.json").read_text(encoding="utf-8"))
    meta = read_jsonl(index_dir / "chunks_meta.jsonl")

    dim = int(cfg["dim"])
    index = hnswlib.Index(space="cosine", dim=dim)
    index.load_index(str(index_dir / "chunks_hnsw.bin"))
    return index, meta, cfg


def retrieve(
    index: hnswlib.Index,
    meta: List[Dict[str, Any]],
    question: str,
    embed_model: str,
    host: str,
    top_k: int,
    *,
    provider: str = "ollama",
    api_base: str = "",
    api_key: Optional[str] = None,
) -> List[Dict[str, Any]]:
    qvec = llm_provider.embed_one(
        question,
        provider=provider,
        model=embed_model,
        host=host,
        api_base=api_base,
        api_key=api_key,
        timeout=120,
    )
    total = len(meta)
    # Keep dense candidate pool broad even when top_k is small.
    dense_k = min(total, max(top_k * 20, 200))
    labels, distances = index.knn_query(qvec, k=dense_k)

    dense_distances: Dict[int, float] = {}
    for lab, dist in zip(labels[0].tolist(), distances[0].tolist()):
        dense_distances[int(lab)] = float(dist)

    # Score lexical relevance across all chunks.
    # Limiting by top_k can hide the right chunk (e.g. attachment pages) before rerank.
    lexical_scores = _top_lexical_candidates(question, meta, limit=total)

    return _rerank_hybrid(
        question=question,
        meta=meta,
        dense_distances=dense_distances,
        lexical_scores=lexical_scores,
        top_k=top_k,
    )


def build_context(hits: List[Dict[str, Any]]) -> str:
    parts: List[str] = []
    for rank, h in enumerate(hits, start=1):
        c = h["chunk"]
        src = c.get("title") or c.get("source_path")
        hp = " > ".join(c.get("heading_path") or [])
        header = f"[KILDE {rank}] {src}" + (f" — {hp}" if hp else "")
        parts.append(header + "\n" + (c.get("text") or ""))
    return "\n\n---\n\n".join(parts)


def chat_loop(
    index_dir: Path,
    index: hnswlib.Index,
    meta: List[Dict[str, Any]],
    cfg: Dict[str, Any],
    *,
    chat_model: str,
    top_k: int,
    provider: str,
    host: str,
    api_base: str,
    api_key: Optional[str],
) -> None:
    embed_model = cfg.get("embed_model") or ""

    print("\nKlar. Skriv dit spørgsmål. Skriv 'exit' for at stoppe.")
    print(f"Index: {index_dir}")

    while True:
        try:
            q = input("\n> ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\nStop.")
            return

        if not q:
            continue
        if q.lower() in {"exit", "quit", ":q"}:
            print("Stop.")
            return

        if q.lower() in {":sources", ":kilder"}:
            print("Skriv et normalt spørgsmål først, så henter jeg kilder og viser dem her.")
            continue

        if q.lower() in {":context", ":kontekst"}:
            print("Skriv et normalt spørgsmål først, så kan du få vist den hentede kontekst.")
            continue

        hits = retrieve(
            index,
            meta,
            q,
            embed_model=embed_model,
            host=host,
            top_k=top_k,
            provider=provider,
            api_base=api_base,
            api_key=api_key,
        )

        # Small transparency: show what we actually retrieved
        print("\nHentede kilder (top-k):")
        for i, h in enumerate(hits, start=1):
            c = h["chunk"]
            src = c.get("title") or c.get("source_path")
            hp = " > ".join(c.get("heading_path") or [])
            dist = h.get("distance")
            print(f"  {i}. dist={dist:.4f} | {src}" + (f" — {hp}" if hp else ""))
            link = best_source_url_for_chunk(c)
            if link:
                print(f"     link: {link}")

        context = build_context(hits)

        full_prompt = (
            "Spørgsmål:\n"
            f"{q}\n\n"
            "KONTEKST (uddrag fra dine noter — du må kun bruge dette):\n"
            f"{context}\n\n"
            "OPGAVE:\n"
            "- Svar på dansk og brug KUN konteksten.\n"
            "- Når du nævner en teoretiker/begreb/påstand, sæt kildehenvisning som [KILDE 1] osv.\n"
            "- Hvis noget (fx en bestemt teoretiker) ikke står i konteksten, så sig tydeligt at du ikke kan finde det i noterne.\n"
        )

        answer = chat_completion(
            full_prompt,
            model=chat_model,
            provider=provider,
            host=host,
            api_base=api_base,
            api_key=api_key,
        )

        # Guardrail + retry:
        # The model sometimes forgets to add [KILDE n] even when it used the context.
        # Also accept different casing like "kilde".
        has_citation = bool(re.search(r"\bKILDE\b|\bkilde\b", answer))

        if not has_citation:
            # Retry once with a short correction message.
            retry_prompt = full_prompt + "\n\nDU GLEMTE KILDEHENVISNINGER: Skriv svaret igen og tilføj [KILDE 1], [KILDE 2] osv. ved alle konkrete påstande."
            answer2 = chat_completion(
                retry_prompt,
                model=chat_model,
                provider=provider,
                host=host,
                api_base=api_base,
                api_key=api_key,
            )
            has_citation2 = bool(re.search(r"\bKILDE\b|\bkilde\b", answer2))

            if has_citation2 and answer2.strip():
                answer = answer2
            else:
                # If still no citations, show the raw answer (so you can debug), but clearly label it as unverified.
                if answer.strip():
                    answer = (
                        "(Uden kildehenvisninger — kan være usikkert)\n" + answer.strip() +
                        "\n\nJeg kan ikke garantere at dette står i de hentede kilder. Prøv at omformulere spørgsmålet eller øg top_k."
                    )
                else:
                    answer = "Jeg kan ikke finde det i dine noter (i de hentede kilder). Prøv at omformulere spørgsmålet eller øg top_k."

        answer = render_answer_with_citation_links(answer.strip(), hits)
        print("\n" + answer.strip())


def main() -> int:
    ap = argparse.ArgumentParser(description="One-command RAG chat over your latest OneNote export.")
    ap.add_argument("--export_dir", default=None, help="Export root folder (export_*). If omitted, auto-detect latest.")
    ap.add_argument("--provider", default="ollama", choices=["ollama", "openai"], help="Embedding/chat provider")
    ap.add_argument("--host", default="http://localhost:11434", help="Ollama host")
    ap.add_argument("--api_base", default="", help="OpenAI-compatible base URL (optional)")
    ap.add_argument("--api_key", default="", help="API key for provider (optional; otherwise env)")
    ap.add_argument("--api_key_env", default="OPENAI_API_KEY", help="Env var name to read API key from")
    ap.add_argument("--embed_model", default="nomic-embed-text", help="Embedding model")
    ap.add_argument("--chat_model", default="mistral", help="Chat model")
    ap.add_argument("--top_k", type=int, default=6)
    args = ap.parse_args()

    cwd = Path.cwd()

    if args.export_dir:
        export_root = Path(args.export_dir).expanduser().resolve()
    else:
        export_root = find_latest_export_root(cwd)

    if export_root is None or not export_root.exists():
        raise SystemExit(
            "Kunne ikke finde en export_* mappe. Kør fra en mappe der indeholder export_*/, "
            "eller angiv --export_dir /sti/til/export_..."
        )

    if not is_export_root_dir(export_root):
        raise SystemExit(f"Mappen ligner ikke en export root (mangler manifest.json/pages): {export_root}")

    chunks_path = ensure_chunks(export_root)
    api_key = _resolve_api_key(api_key=args.api_key, api_key_env=args.api_key_env)
    index_dir = ensure_index(
        export_root,
        chunks_path,
        embed_model=args.embed_model,
        provider=args.provider,
        host=args.host,
        api_base=args.api_base,
        api_key=api_key,
    )

    index, meta, cfg = load_index(index_dir)

    provider = str(cfg.get("provider") or args.provider or "ollama").strip().lower()
    host = str(cfg.get("host") or args.host or "http://localhost:11434")
    api_base = str(cfg.get("api_base") or args.api_base or "")

    chat_loop(
        index_dir,
        index,
        meta,
        cfg,
        chat_model=args.chat_model,
        top_k=args.top_k,
        provider=provider,
        host=host,
        api_base=api_base,
        api_key=api_key,
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
