#!/usr/bin/env python3
"""Probe Tranco domains for llms.txt adoption.

The crawler is intentionally conservative:
- downloads the public Tranco daily top-1m CSV;
- probes Tranco Top N domains for /llms.txt and /llms-full.txt;
- enforces one global request rate across HEAD and GET requests;
- writes every successful raw response to disk;
- supports resume by skipping domains already present in the result CSV.
"""

from __future__ import annotations

import argparse
import asyncio
import csv
import json
import re
import time
import zipfile
from collections import Counter
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable
from urllib.parse import urlparse
from urllib.request import Request, urlopen

import aiohttp


TRANC0_ZIP_URL = "https://tranco-list.eu/top-1m.csv.zip"
USER_AGENT = "Thunderbit-Research/1.0 ([REDACTED_EMAIL])"
PATHS = ("llms.txt", "llms-full.txt")


@dataclass
class Probe:
    path: str
    final_url: str = ""
    method_hit: str = ""
    status: int | str = ""
    content_type: str = ""
    content_length_header: str = ""
    bytes_saved: int = 0
    sha256: str = ""
    is_valid: int = 0
    invalid_reason: str = ""
    elapsed_ms: int = 0
    error: str = ""


@dataclass
class DomainResult:
    rank: int
    domain: str
    llms_txt_status: int | str = ""
    llms_txt_url: str = ""
    llms_txt_method: str = ""
    llms_txt_bytes: int = 0
    llms_txt_valid: int = 0
    llms_txt_invalid_reason: str = ""
    llms_txt_content_type: str = ""
    llms_txt_error: str = ""
    llms_full_txt_status: int | str = ""
    llms_full_txt_url: str = ""
    llms_full_txt_method: str = ""
    llms_full_txt_bytes: int = 0
    llms_full_txt_valid: int = 0
    llms_full_txt_invalid_reason: str = ""
    llms_full_txt_content_type: str = ""
    llms_full_txt_error: str = ""
    elapsed_ms: int = 0


class RateLimiter:
    def __init__(self, rate_per_second: float) -> None:
        self.interval = 1.0 / rate_per_second
        self.lock = asyncio.Lock()
        self.next_at = 0.0

    async def wait(self) -> None:
        async with self.lock:
            now = time.monotonic()
            if self.next_at > now:
                await asyncio.sleep(self.next_at - now)
            self.next_at = max(now, self.next_at) + self.interval


def safe_name(domain: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", domain)


def ensure_dirs(base: Path) -> dict[str, Path]:
    dirs = {
        "data": base / "data",
        "raw_llms": base / "raw_llms_txt" / "llms_txt",
        "raw_full": base / "raw_llms_txt" / "llms_full_txt",
        "charts": base / "charts",
    }
    for path in dirs.values():
        path.mkdir(parents=True, exist_ok=True)
    return dirs


def download_tranco(data_dir: Path) -> Path:
    zip_path = data_dir / "tranco_top_1m_daily.csv.zip"
    req = Request(TRANC0_ZIP_URL, headers={"User-Agent": USER_AGENT})
    with urlopen(req, timeout=60) as response:
        metadata = {
            "url": response.geturl(),
            "downloaded_at_utc": datetime.now(timezone.utc).isoformat(),
            "last_modified": response.headers.get("Last-Modified", ""),
            "etag": response.headers.get("ETag", ""),
        }
        zip_path.write_bytes(response.read())
    (data_dir / "tranco_download_metadata.json").write_text(
        json.dumps(metadata, indent=2), encoding="utf-8"
    )
    return zip_path


def extract_top_domains(zip_path: Path, limit: int, data_dir: Path) -> list[tuple[int, str]]:
    with zipfile.ZipFile(zip_path) as zf:
        name = zf.namelist()[0]
        rows: list[tuple[int, str]] = []
        with zf.open(name) as fp:
            for raw_line in fp:
                rank_s, domain = raw_line.decode("utf-8").strip().split(",", 1)
                rows.append((int(rank_s), domain))
                if len(rows) >= limit:
                    break
    out = data_dir / f"tranco_top_{limit}.csv"
    with out.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["rank", "domain"])
        writer.writerows(rows)
    return rows


def load_done(results_path: Path) -> set[str]:
    if not results_path.exists():
        return set()
    with results_path.open(newline="", encoding="utf-8") as f:
        return {row["domain"] for row in csv.DictReader(f)}


def validate_body(final_url: str, path_name: str, body: bytes) -> tuple[int, str]:
    final_path = urlparse(final_url).path.rstrip("/")
    if final_path != f"/{path_name}":
        return 0, "redirected_off_target_path"

    sample = body[:2048].decode("utf-8", errors="ignore").lstrip().lower()
    if sample.startswith("<!doctype html") or sample.startswith("<html"):
        return 0, "html_document"
    if "<title>" in sample[:500] and "<body" in sample[:1000]:
        return 0, "html_document"
    if not body.strip():
        return 0, "empty_body"
    return 1, ""


async def request_once(
    session: aiohttp.ClientSession,
    limiter: RateLimiter,
    method: str,
    url: str,
    timeout_seconds: float,
) -> tuple[aiohttp.ClientResponse | None, bytes, str, int]:
    started = time.monotonic()
    await limiter.wait()
    try:
        async with session.request(
            method,
            url,
            allow_redirects=True,
            timeout=aiohttp.ClientTimeout(total=timeout_seconds),
        ) as response:
            body = b""
            if method == "GET":
                body = await response.read()
            elapsed_ms = int((time.monotonic() - started) * 1000)
            return response, body, "", elapsed_ms
    except Exception as exc:
        elapsed_ms = int((time.monotonic() - started) * 1000)
        return None, b"", f"{type(exc).__name__}: {exc}", elapsed_ms


async def probe_path(
    session: aiohttp.ClientSession,
    limiter: RateLimiter,
    domain: str,
    path_name: str,
    raw_dir: Path,
    timeout_seconds: float,
) -> Probe:
    total_started = time.monotonic()
    last_error = ""
    last_status: int | str = ""

    for scheme in ("https", "http"):
        url = f"{scheme}://{domain}/{path_name}"
        response, _body, error, _elapsed_ms = await request_once(
            session, limiter, "HEAD", url, timeout_seconds
        )
        if error:
            last_error = error
            continue

        assert response is not None
        last_status = response.status
        should_get = response.status == 200 or response.status in {405, 501}
        if not should_get:
            return Probe(
                path=path_name,
                final_url=str(response.url),
                status=response.status,
                content_type=response.headers.get("Content-Type", ""),
                content_length_header=response.headers.get("Content-Length", ""),
                elapsed_ms=int((time.monotonic() - total_started) * 1000),
                error=last_error,
            )

        get_response, body, get_error, _get_elapsed_ms = await request_once(
            session, limiter, "GET", url, timeout_seconds
        )
        if get_error:
            last_error = get_error
            continue
        assert get_response is not None
        if get_response.status == 200:
            import hashlib

            digest = hashlib.sha256(body).hexdigest()
            is_valid, invalid_reason = validate_body(str(get_response.url), path_name, body)
            raw_path = raw_dir / f"{safe_name(domain)}.txt"
            raw_path.write_bytes(body)
            return Probe(
                path=path_name,
                final_url=str(get_response.url),
                method_hit="GET" if response.status != 200 else "HEAD+GET",
                status=get_response.status,
                content_type=get_response.headers.get("Content-Type", ""),
                content_length_header=get_response.headers.get("Content-Length", ""),
                bytes_saved=len(body),
                sha256=digest,
                is_valid=is_valid,
                invalid_reason=invalid_reason,
                elapsed_ms=int((time.monotonic() - total_started) * 1000),
                error="",
            )

        last_status = get_response.status
        if scheme == "http":
            return Probe(
                path=path_name,
                final_url=str(get_response.url),
                method_hit="GET_AFTER_HEAD",
                status=get_response.status,
                content_type=get_response.headers.get("Content-Type", ""),
                content_length_header=get_response.headers.get("Content-Length", ""),
                elapsed_ms=int((time.monotonic() - total_started) * 1000),
                error=last_error,
            )

    return Probe(
        path=path_name,
        status=last_status,
        elapsed_ms=int((time.monotonic() - total_started) * 1000),
        error=last_error,
    )


async def probe_domain(
    session: aiohttp.ClientSession,
    limiter: RateLimiter,
    rank: int,
    domain: str,
    dirs: dict[str, Path],
    timeout_seconds: float,
) -> DomainResult:
    started = time.monotonic()
    llms, full = await asyncio.gather(
        probe_path(session, limiter, domain, "llms.txt", dirs["raw_llms"], timeout_seconds),
        probe_path(session, limiter, domain, "llms-full.txt", dirs["raw_full"], timeout_seconds),
    )
    return DomainResult(
        rank=rank,
        domain=domain,
        llms_txt_status=llms.status,
        llms_txt_url=llms.final_url,
        llms_txt_method=llms.method_hit,
        llms_txt_bytes=llms.bytes_saved,
        llms_txt_valid=llms.is_valid,
        llms_txt_invalid_reason=llms.invalid_reason,
        llms_txt_content_type=llms.content_type,
        llms_txt_error=llms.error,
        llms_full_txt_status=full.status,
        llms_full_txt_url=full.final_url,
        llms_full_txt_method=full.method_hit,
        llms_full_txt_bytes=full.bytes_saved,
        llms_full_txt_valid=full.is_valid,
        llms_full_txt_invalid_reason=full.invalid_reason,
        llms_full_txt_content_type=full.content_type,
        llms_full_txt_error=full.error,
        elapsed_ms=int((time.monotonic() - started) * 1000),
    )


def append_results(path: Path, rows: Iterable[DomainResult]) -> None:
    rows = list(rows)
    if not rows:
        return
    exists = path.exists()
    with path.open("a", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(asdict(rows[0]).keys()))
        if not exists:
            writer.writeheader()
        for row in rows:
            writer.writerow(asdict(row))


async def crawl(args: argparse.Namespace) -> None:
    base = Path(args.output_dir)
    dirs = ensure_dirs(base)
    zip_path = dirs["data"] / "tranco_top_1m_daily.csv.zip"
    if not zip_path.exists() or args.refresh_tranco:
        zip_path = download_tranco(dirs["data"])
    rows = extract_top_domains(zip_path, args.limit, dirs["data"])
    results_path = dirs["data"] / f"llms_probe_results_top_{args.limit}.csv"
    done = load_done(results_path)
    todo = [(rank, domain) for rank, domain in rows if domain not in done]

    headers = {"User-Agent": USER_AGENT, "Accept": "text/markdown,text/plain,*/*;q=0.8"}
    connector = aiohttp.TCPConnector(limit=args.concurrency, ttl_dns_cache=300, ssl=False)
    limiter = RateLimiter(args.rate)
    started = time.monotonic()
    completed = 0
    hits = 0

    async with aiohttp.ClientSession(headers=headers, connector=connector) as session:
        for i in range(0, len(todo), args.batch_size):
            batch = todo[i : i + args.batch_size]
            tasks = [
                probe_domain(session, limiter, rank, domain, dirs, args.timeout)
                for rank, domain in batch
            ]
            results = await asyncio.gather(*tasks)
            append_results(results_path, results)
            completed += len(results)
            hits += sum(1 for row in results if row.llms_txt_status == 200 and row.llms_txt_valid)
            total_done = len(done) + completed
            elapsed = time.monotonic() - started
            rate = completed / elapsed if elapsed else 0.0
            print(
                f"progress domains={total_done}/{args.limit} "
                f"batch_hits={sum(1 for row in results if row.llms_txt_status == 200 and row.llms_txt_valid)} "
                f"new_hits={hits} domains_per_sec={rate:.2f}",
                flush=True,
            )

    summarize(results_path, dirs["data"] / f"analysis_top_{args.limit}.json")


def truthy_status(value: str) -> bool:
    return str(value) == "200"


def valid_hit(row: dict[str, str], prefix: str) -> bool:
    return truthy_status(row[f"{prefix}_status"]) and str(row.get(f"{prefix}_valid", "0")) == "1"


def bucket_for_rank(rank: int) -> str:
    if rank <= 100:
        return "Top 100"
    if rank <= 1000:
        return "Top 101-1,000"
    return "Top 1,001-10,000"


def summarize(results_path: Path, out_json: Path) -> None:
    with results_path.open(newline="", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))

    total = len(rows)
    llms_http_200 = [r for r in rows if truthy_status(r["llms_txt_status"])]
    full_http_200 = [r for r in rows if truthy_status(r["llms_full_txt_status"])]
    llms_hits = [r for r in rows if valid_hit(r, "llms_txt")]
    full_hits = [r for r in rows if valid_hit(r, "llms_full_txt")]
    both_hits = [
        r for r in rows
        if valid_hit(r, "llms_txt") and valid_hit(r, "llms_full_txt")
    ]
    buckets: dict[str, dict[str, int | float]] = {}
    for bucket, bucket_rows in group_by_bucket(rows).items():
        n = len(bucket_rows)
        h = sum(1 for r in bucket_rows if valid_hit(r, "llms_txt"))
        f = sum(1 for r in bucket_rows if valid_hit(r, "llms_full_txt"))
        b = sum(
            1 for r in bucket_rows
            if valid_hit(r, "llms_txt") and valid_hit(r, "llms_full_txt")
        )
        buckets[bucket] = {
            "domains": n,
            "llms_txt_hits": h,
            "llms_txt_adoption_pct": round(h / n * 100, 4) if n else 0,
            "llms_full_txt_hits": f,
            "both_files_hits": b,
        }

    size_buckets = Counter()
    section_buckets = Counter()
    link_buckets = Counter()
    for row in llms_hits:
        size = int(row["llms_txt_bytes"] or 0)
        if size == 0:
            size_buckets["empty"] += 1
        elif size < 500:
            size_buckets["thin_<500B"] += 1
        elif size < 5000:
            size_buckets["index_500B-5KB"] += 1
        else:
            size_buckets["substantial_5KB+"] += 1

        raw_path = results_path.parent.parent / "raw_llms_txt" / "llms_txt" / f"{safe_name(row['domain'])}.txt"
        try:
            text = raw_path.read_text(encoding="utf-8", errors="replace")
        except FileNotFoundError:
            continue
        sections = len(re.findall(r"(?m)^#{1,3}\s+", text))
        links = len(re.findall(r"\[[^\]]+\]\([^\)]+\)|https?://", text))
        if sections <= 1:
            section_buckets["0-1_sections"] += 1
        elif sections <= 5:
            section_buckets["2-5_sections"] += 1
        else:
            section_buckets["6+_sections"] += 1
        if links == 0:
            link_buckets["0_links"] += 1
        elif links <= 10:
            link_buckets["1-10_links"] += 1
        else:
            link_buckets["11+_links"] += 1

    summary = {
        "generated_at_utc": datetime.now(timezone.utc).isoformat(),
        "results_csv": str(results_path),
        "sample_size": total,
        "llms_txt_http_200": len(llms_http_200),
        "llms_full_txt_http_200": len(full_http_200),
        "llms_txt_hits": len(llms_hits),
        "llms_txt_adoption_pct": round(len(llms_hits) / total * 100, 4) if total else 0,
        "llms_full_txt_hits": len(full_hits),
        "llms_full_txt_adoption_pct": round(len(full_hits) / total * 100, 4) if total else 0,
        "both_files_hits": len(both_hits),
        "both_files_adoption_pct": round(len(both_hits) / total * 100, 4) if total else 0,
        "rank_buckets": buckets,
        "content_size_buckets": dict(size_buckets),
        "markdown_section_buckets": dict(section_buckets),
        "markdown_link_buckets": dict(link_buckets),
        "top_llms_txt_adopters": [
            {
                "rank": int(r["rank"]),
                "domain": r["domain"],
                "bytes": int(r["llms_txt_bytes"] or 0),
                "url": r["llms_txt_url"],
            }
            for r in llms_hits[:20]
        ],
        "top_dual_file_adopters": [
            {
                "rank": int(r["rank"]),
                "domain": r["domain"],
                "llms_txt_bytes": int(r["llms_txt_bytes"] or 0),
                "llms_full_txt_bytes": int(r["llms_full_txt_bytes"] or 0),
            }
            for r in both_hits[:20]
        ],
    }
    out_json.write_text(json.dumps(summary, indent=2), encoding="utf-8")
    print(json.dumps(summary, indent=2), flush=True)


def group_by_bucket(rows: list[dict[str, str]]) -> dict[str, list[dict[str, str]]]:
    out: dict[str, list[dict[str, str]]] = {
        "Top 100": [],
        "Top 101-1,000": [],
        "Top 1,001-10,000": [],
    }
    for row in rows:
        out[bucket_for_rank(int(row["rank"]))].append(row)
    return out


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--limit", type=int, default=10000)
    parser.add_argument("--output-dir", default="llms_txt_adoption_research")
    parser.add_argument("--rate", type=float, default=10.0)
    parser.add_argument("--concurrency", type=int, default=20)
    parser.add_argument("--batch-size", type=int, default=100)
    parser.add_argument("--timeout", type=float, default=8.0)
    parser.add_argument("--refresh-tranco", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    asyncio.run(crawl(parse_args()))
