

#!/usr/bin/env python3
"""chunk_note.py

Chunk locally exported OneNote/SharePoint text into token-sized pieces suitable for RAG.

Features
- Heading-aware chunking for Markdown (#, ##, ### ...)
- Fallback token-window chunking for plain text
- Uses tiktoken (you have 0.12.0) for accurate token budgeting
- Writes JSONL: one chunk per line with useful metadata

Typical usage
  python chunk_note.py --input ./export --output ./chunks.jsonl

Then next step: embed + index these chunks.
"""

from __future__ import annotations

import argparse
import dataclasses
import hashlib
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Optional, Tuple

import tiktoken


# -----------------------------
# Config defaults
# -----------------------------
DEFAULT_ENCODING = "cl100k_base"
DEFAULT_TARGET_TOKENS = 520
DEFAULT_OVERLAP_TOKENS = 110
DEFAULT_MIN_CHUNK_TOKENS = 90


# -----------------------------
# Data structures
# -----------------------------
@dataclass
class Document:
    doc_id: str
    source_path: str
    title: str
    text: str
    source_url: Optional[str] = None
    source_client_url: Optional[str] = None


@dataclass
class Chunk:
    chunk_id: str
    doc_id: str
    source_path: str
    title: str
    source_url: Optional[str]
    source_client_url: Optional[str]
    heading_path: List[str]
    chunk_index: int
    text: str
    token_count: int
    char_start: int
    char_end: int


# -----------------------------
# Utilities
# -----------------------------
def stable_hash(s: str) -> str:
    return hashlib.sha1(s.encode("utf-8", errors="ignore")).hexdigest()


def normalize_whitespace(text: str) -> str:
    # Keep newlines (structure matters), but avoid huge whitespace noise.
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    # Remove trailing spaces
    text = "\n".join(line.rstrip() for line in text.split("\n"))
    # Collapse 3+ blank lines to max 2
    text = re.sub(r"\n{3,}", "\n\n", text)
    return text.strip()


def guess_title_from_path(path: Path) -> str:
    # Title is filename without extension, with underscores turned to spaces
    return path.stem.replace("_", " ").strip() or path.name


def is_text_file(path: Path) -> bool:
    if path.suffix.lower() in {".md", ".markdown", ".txt"}:
        return True
    if path.suffix.lower() == ".json":
        return True
    return False


def extract_href(value: object) -> Optional[str]:
    if isinstance(value, str):
        v = value.strip()
        return v or None
    if isinstance(value, dict):
        href = value.get("href")
        if isinstance(href, str):
            href = href.strip()
            return href or None
    return None


# -----------------------------
# Document loading
# -----------------------------
def load_document(path: Path) -> Optional[Document]:
    """Load a document from .md/.txt/.json exports.

    For .json we try common patterns:
      - {"text": "..."}
      - {"content": "..."}
      - {"title": "...", "text": "..."}
      - list of pages/blocks -> concatenated
    """
    try:
        raw = path.read_text(encoding="utf-8", errors="ignore")
    except Exception:
        return None

    title = guess_title_from_path(path)
    text = ""

    if path.suffix.lower() in {".md", ".markdown", ".txt"}:
        text = raw

    elif path.suffix.lower() == ".json":
        try:
            obj = json.loads(raw)
        except Exception:
            # treat as plain text if JSON is malformed
            text = raw
        else:
            # Common shapes
            if isinstance(obj, dict):
                if isinstance(obj.get("title"), str) and obj["title"].strip():
                    title = obj["title"].strip()

                # Prefer explicit keys
                for key in ("text", "content", "body", "markdown"):
                    if isinstance(obj.get(key), str):
                        text = obj[key]
                        break

                # If still empty, try concatenating known arrays
                if not text:
                    for key in ("pages", "items", "blocks", "notes"):
                        if isinstance(obj.get(key), list):
                            parts: List[str] = []
                            for it in obj[key]:
                                if isinstance(it, str):
                                    parts.append(it)
                                elif isinstance(it, dict):
                                    for k2 in ("text", "content", "body", "markdown"):
                                        if isinstance(it.get(k2), str):
                                            parts.append(it[k2])
                                            break
                            text = "\n\n".join(parts)
                            break

            elif isinstance(obj, list):
                parts2: List[str] = []
                for it in obj:
                    if isinstance(it, str):
                        parts2.append(it)
                    elif isinstance(it, dict):
                        # best-effort flatten
                        for k2 in ("title", "heading"):
                            if isinstance(it.get(k2), str) and it[k2].strip():
                                parts2.append(f"# {it[k2].strip()}")
                                break
                        for k2 in ("text", "content", "body", "markdown"):
                            if isinstance(it.get(k2), str):
                                parts2.append(it[k2])
                                break
                text = "\n\n".join(parts2)

            if not text:
                # fallback to raw JSON string
                text = raw

    text = normalize_whitespace(text)
    if not text:
        return None

    doc_id = stable_hash(str(path.resolve()))
    return Document(doc_id=doc_id, source_path=str(path), title=title, text=text)


def iter_documents(input_dir: Path) -> Iterator[Document]:
    if is_export_root_dir(input_dir):
        pages_dir = input_dir / "pages"
        for page_folder in sorted(pages_dir.iterdir()):
            if not page_folder.is_dir():
                continue

            txt_path = page_folder / "page.txt"
            if not txt_path.exists():
                continue

            try:
                text = txt_path.read_text(encoding="utf-8", errors="ignore")
            except Exception:
                continue

            title = page_folder.name
            source_url: Optional[str] = None
            source_client_url: Optional[str] = None
            meta_path = page_folder / "meta.json"
            if meta_path.exists():
                try:
                    meta = json.loads(meta_path.read_text(encoding="utf-8", errors="ignore"))
                    if isinstance(meta, dict):
                        if isinstance(meta.get("title"), str) and meta["title"].strip():
                            title = meta["title"].strip()

                        links = meta.get("links") if isinstance(meta.get("links"), dict) else {}
                        if isinstance(links, dict):
                            source_url = extract_href(links.get("oneNoteWebUrl")) or source_url
                            source_client_url = extract_href(links.get("oneNoteClientUrl")) or source_client_url

                        source_url = source_url or extract_href(meta.get("oneNoteWebUrl"))
                        source_client_url = source_client_url or extract_href(meta.get("oneNoteClientUrl"))
                except Exception:
                    pass

            text = normalize_whitespace(text)
            if not text:
                continue

            doc_id = stable_hash(str(txt_path.resolve()))
            yield Document(
                doc_id=doc_id,
                source_path=str(txt_path),
                title=title,
                text=text,
                source_url=source_url,
                source_client_url=source_client_url,
            )
        return

    for path in input_dir.rglob("*"):
        if not path.is_file():
            continue
        if not is_text_file(path):
            continue
        doc = load_document(path)
        if doc is not None:
            yield doc


# -----------------------------
# Markdown heading-aware splitting
# -----------------------------
HEADING_RE = re.compile(r"^(#{1,6})\s+(.*)\s*$")


def split_markdown_by_headings(text: str) -> List[Tuple[List[str], str, int]]:
    """Split markdown into sections.

    Returns a list of (heading_path, section_text, section_char_start).
    heading_path is the current stack of headings.
    """
    lines = text.split("\n")
    sections: List[Tuple[List[str], List[str], int]] = []

    heading_stack: List[Tuple[int, str]] = []  # (level, title)
    current_lines: List[str] = []
    current_start_char = 0
    char_cursor = 0

    def flush():
        nonlocal current_lines, current_start_char
        if current_lines:
            hp = [t for (_, t) in heading_stack]
            sections.append((hp, current_lines[:], current_start_char))
            current_lines = []

    for line in lines:
        m = HEADING_RE.match(line)
        if m:
            # New heading begins: flush previous section
            flush()

            level = len(m.group(1))
            title = m.group(2).strip()

            # Pop stack until parent level
            while heading_stack and heading_stack[-1][0] >= level:
                heading_stack.pop()
            heading_stack.append((level, title))

            # Start new section including the heading line
            current_start_char = char_cursor
            current_lines.append(line)
        else:
            if not current_lines:
                current_start_char = char_cursor
            current_lines.append(line)

        # advance char cursor (+1 for newline)
        char_cursor += len(line) + 1

    flush()

    # Convert to (heading_path, text, char_start)
    out: List[Tuple[List[str], str, int]] = []
    for hp, ls, st in sections:
        sec_text = "\n".join(ls).strip()
        if sec_text:
            out.append((hp, sec_text, st))
    return out


def looks_like_markdown(text: str) -> bool:
    # Heuristic: headings, bullet lists, or markdown links
    return (
        bool(HEADING_RE.search(text))
        or "- " in text
        or "* " in text
        or "[" in text and "](" in text
    )


# -----------------------------
# Token chunking
# -----------------------------
def chunk_by_token_window(
    text: str,
    enc: tiktoken.Encoding,
    target_tokens: int,
    overlap_tokens: int,
    min_chunk_tokens: int,
    base_char_start: int = 0,
) -> List[Tuple[str, int, int, int]]:
    """Return list of (chunk_text, token_count, char_start, char_end).

    Uses a token sliding window with overlap. We compute char spans approximately
    by decoding token slices and searching forward; good enough for traceability.
    """
    text = text.strip()
    if not text:
        return []

    tokens = enc.encode(text)
    n = len(tokens)
    if n <= target_tokens:
        return [(text, n, base_char_start, base_char_start + len(text))]

    chunks: List[Tuple[str, int, int, int]] = []
    start = 0
    # For char span tracking
    # We'll keep a running cursor to avoid O(n^2) searches.
    running_char_cursor = 0

    while start < n:
        end = min(start + target_tokens, n)
        tok_slice = tokens[start:end]
        chunk_text = enc.decode(tok_slice).strip()
        tok_count = len(tok_slice)

        # Enforce minimum size unless it's the tail.
        if tok_count < min_chunk_tokens and end < n:
            end = min(start + min_chunk_tokens, n)
            tok_slice = tokens[start:end]
            chunk_text = enc.decode(tok_slice).strip()
            tok_count = len(tok_slice)

        if not chunk_text:
            break

        # Approximate char start/end by searching from running cursor.
        # This is robust enough even with repeated substrings.
        idx = text.find(chunk_text, running_char_cursor)
        if idx == -1:
            idx = text.find(chunk_text)
        if idx == -1:
            # last resort: fall back to running cursor
            idx = running_char_cursor

        char_start = base_char_start + idx
        char_end = base_char_start + idx + len(chunk_text)

        chunks.append((chunk_text, tok_count, char_start, char_end))

        running_char_cursor = max(0, idx + max(1, len(chunk_text) - 10))

        if end >= n:
            break

        start = max(0, end - overlap_tokens)

    return chunks


# -----------------------------
# Main chunking logic
# -----------------------------
def chunk_document(
    doc: Document,
    enc: tiktoken.Encoding,
    target_tokens: int,
    overlap_tokens: int,
    min_chunk_tokens: int,
) -> List[Chunk]:
    text = doc.text

    # Prefer heading-aware sections if markdown-ish.
    sections: List[Tuple[List[str], str, int]]
    if looks_like_markdown(text):
        sections = split_markdown_by_headings(text)
    else:
        sections = [([], text, 0)]

    chunks_out: List[Chunk] = []
    chunk_index = 0

    for heading_path, sec_text, sec_char_start in sections:
        # If section is tiny, we still include it; token-window function handles it.
        tok_chunks = chunk_by_token_window(
            sec_text,
            enc=enc,
            target_tokens=target_tokens,
            overlap_tokens=overlap_tokens,
            min_chunk_tokens=min_chunk_tokens,
            base_char_start=sec_char_start,
        )

        for chunk_text, tok_count, cstart, cend in tok_chunks:
            # Make chunk_id stable-ish but unique per chunk.
            # Include heading path + index + content hash.
            content_sig = stable_hash("\n".join(heading_path) + "\n" + chunk_text)
            chunk_id = stable_hash(f"{doc.doc_id}:{chunk_index}:{content_sig}")

            chunks_out.append(
                Chunk(
                    chunk_id=chunk_id,
                    doc_id=doc.doc_id,
                    source_path=doc.source_path,
                    title=doc.title,
                    source_url=doc.source_url,
                    source_client_url=doc.source_client_url,
                    heading_path=heading_path,
                    chunk_index=chunk_index,
                    text=chunk_text,
                    token_count=tok_count,
                    char_start=cstart,
                    char_end=cend,
                )
            )
            chunk_index += 1

    return chunks_out


# -----------------------------
# IO
# -----------------------------
def chunk_to_json_dict(ch: Chunk) -> Dict:
    return {
        "chunk_id": ch.chunk_id,
        "doc_id": ch.doc_id,
        "source_path": ch.source_path,
        "title": ch.title,
        "source_url": ch.source_url,
        "source_client_url": ch.source_client_url,
        "heading_path": ch.heading_path,
        "chunk_index": ch.chunk_index,
        "text": ch.text,
        "token_count": ch.token_count,
        "char_start": ch.char_start,
        "char_end": ch.char_end,
    }


def write_jsonl(chunks: Iterable[Chunk], output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as f:
        for ch in chunks:
            f.write(json.dumps(chunk_to_json_dict(ch), ensure_ascii=False) + "\n")

def is_export_root_dir(path: Path) -> bool:
    return path.is_dir() and (path / "manifest.json").exists() and (path / "pages").exists()


def find_latest_export_root(search_dir: Path) -> Optional[Path]:
    candidates: List[Path] = []
    if not search_dir.exists() or not search_dir.is_dir():
        return None

    for p in search_dir.iterdir():
        if p.is_dir() and p.name.startswith("export_") and is_export_root_dir(p):
            candidates.append(p)

    if not candidates:
        return None

    candidates.sort(key=lambda d: (d / "manifest.json").stat().st_mtime, reverse=True)
    return candidates[0]


def resolve_input_and_output(args) -> tuple[Path, Path]:
    cwd = Path.cwd()

    if args.auto_latest or not args.input:
        input_dir = find_latest_export_root(cwd)
        if input_dir is None:
            raise SystemExit(
                "Kunne ikke finde en export_* mappe med manifest.json. Angiv --input <mappe>."
            )
    else:
        input_dir = Path(args.input).expanduser().resolve()

    if not input_dir.exists() or not input_dir.is_dir():
        raise SystemExit(f"Input path is not a directory: {input_dir}")

    if args.output:
        output_path = Path(args.output).expanduser().resolve()
    else:
        output_path = (input_dir / "index" / "chunks.jsonl").resolve() if is_export_root_dir(input_dir) else (cwd / "chunks.jsonl").resolve()

    return input_dir, output_path
# -----------------------------
# CLI
# -----------------------------
def build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description=(
            "Chunk exported OneNote text into JSONL for RAG. "
            "If --input is omitted, the script will try to find the latest export_*/ folder "
            "(with manifest.json) in the current working directory."
        )
    )

    p.add_argument(
        "--input",
        "-i",
        default=None,
        help=(
            "Input directory. Can be either: (A) export root folder containing manifest.json, "
            "(B) any folder containing .md/.txt/.json exports. "
            "If omitted, we auto-detect the newest export_* folder in the current directory."
        ),
    )
    p.add_argument(
        "--output",
        "-o",
        default=None,
        help=(
            "Output .jsonl file path. If omitted and input is an export root, defaults to <export_root>/chunks.jsonl. "
            "Otherwise defaults to ./chunks.jsonl."
        ),
    )

    p.add_argument(
        "--auto_latest",
        action="store_true",
        help="Force auto-detection of the newest export_* folder (ignores --input).",
    )

    p.add_argument("--encoding", default=DEFAULT_ENCODING, help=f"tiktoken encoding (default: {DEFAULT_ENCODING})")
    p.add_argument("--target_tokens", type=int, default=DEFAULT_TARGET_TOKENS, help="Target tokens per chunk")
    p.add_argument("--overlap_tokens", type=int, default=DEFAULT_OVERLAP_TOKENS, help="Overlap tokens between chunks")
    p.add_argument("--min_chunk_tokens", type=int, default=DEFAULT_MIN_CHUNK_TOKENS, help="Minimum tokens per chunk")
    p.add_argument("--max_docs", type=int, default=0, help="Optional limit for docs (0 = no limit)")

    return p


def main() -> int:
    args = build_argparser().parse_args()

    input_dir, output_path = resolve_input_and_output(args)

    if not input_dir.exists() or not input_dir.is_dir():
        raise SystemExit(f"Input path is not a directory: {input_dir}")

    enc = tiktoken.get_encoding(args.encoding)

    all_chunks: List[Chunk] = []
    doc_count = 0

    for doc in iter_documents(input_dir):
        doc_count += 1
        chunks = chunk_document(
            doc,
            enc=enc,
            target_tokens=args.target_tokens,
            overlap_tokens=args.overlap_tokens,
            min_chunk_tokens=args.min_chunk_tokens,
        )
        all_chunks.extend(chunks)

        if args.max_docs and doc_count >= args.max_docs:
            break

    write_jsonl(all_chunks, output_path)

    # Small human-friendly summary
    total_tokens = sum(ch.token_count for ch in all_chunks)
    print(f"Docs: {doc_count}")
    print(f"Chunks: {len(all_chunks)}")
    print(f"Total tokens (approx): {total_tokens}")
    print(f"Wrote: {output_path}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
