#!/usr/bin/env python3
"""
XSA Ghost CVE Audit
Crawls every XSA advisory, extracts CVEs and signer, then checks each CVE
against VulnMCP (local instance on serena) to see if it's published.

Rate limited to < 20 VulnMCP lookups per minute.
"""

import json
import re
import sys
import time
import urllib.request

XSA_LIST_URL = "https://xenbits.xen.org/xsa/"
XSA_TXT_URL = "https://xenbits.xen.org/xsa/advisory-{}.txt"
VULNMCP_URL = "http://10.20.20.245:9200"

def get_xsa_numbers():
    """Scrape the XSA index page for all advisory numbers."""
    req = urllib.request.Request(XSA_LIST_URL)
    with urllib.request.urlopen(req, timeout=30) as resp:
        html = resp.read().decode()
    return sorted(set(int(m) for m in re.findall(r'XSA-(\d+)', html)), reverse=True)

def get_xsa_details(num):
    """Fetch advisory text, extract CVEs and signer."""
    url = XSA_TXT_URL.format(num)
    try:
        req = urllib.request.Request(url)
        with urllib.request.urlopen(req, timeout=15) as resp:
            text = resp.read().decode(errors='replace')
    except Exception as e:
        return {"error": str(e), "cves": [], "signer": "unknown", "title": ""}

    cves = sorted(set(re.findall(r'CVE-\d{4}-\d+', text)))

    title = ""
    for line in text.split('\n'):
        line = line.strip()
        if line and not line.startswith(('-', 'Xen', 'Hash', 'CVE', 'Advisory', 'Public', 'Updated', 'Version', 'ISSUE', 'UPDATES')):
            if len(line) > 10 and len(line) < 200 and 'BEGIN PGP' not in line:
                title = line
                break

    # Extract signer from the advisory text
    signer = "unknown"
    # Look for patterns like "reported by", "discovered by", or the contact/author
    # The PGP sig is from pgs@xen.org (Xen Project Security team) on all of them
    # But the advisory text sometimes names who handled it
    for pattern in [r'(?i)reported[- ]by[:\s]+([^\n,]+)',
                    r'(?i)discovered[- ]by[:\s]+([^\n,]+)',
                    r'(?i)found[- ]by[:\s]+([^\n,]+)']:
        m = re.search(pattern, text)
        if m:
            signer = m.group(1).strip()
            break

    # Check for "NOTE REGARDING LACK OF EMBARGO" or other editorial notes
    has_embargo_note = "LACK OF EMBARGO" in text or "NOT CREDITED" in text

    # Check for RESOLUTION section to see if patches exist
    has_patches = "RESOLUTION" in text or "PATCHES" in text or "patch" in text.lower()

    return {
        "cves": cves,
        "signer": signer,
        "title": title,
        "has_embargo_note": has_embargo_note,
        "has_patches": has_patches,
    }

def check_cve(cve_id, max_retries=3):
    """Check if a CVE is published via CIRCL API through serena (has API key)."""
    import subprocess
    for attempt in range(max_retries):
        try:
            result = subprocess.run(
                ["ssh", "serena",
                 f'source ~/.vulnmcp.env && curl -s -H "X-API-KEY: $VULNMCP_API_KEY" '
                 f'"https://vulnerability.circl.lu/api/vulnerability/{cve_id}"'],
                capture_output=True, text=True, timeout=15
            )
            body = result.stdout
            if "20 per 1 minute" in body or "Too Many" in body:
                wait = 65 * (attempt + 1)
                print(f"    rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
                time.sleep(wait)
                continue
            data = json.loads(body)
            if data.get("cveMetadata", {}).get("cveId") == cve_id:
                return "PUBLISHED"
            if data.get("id") and data["id"] != "Unknown":
                return "PUBLISHED"
            return "GHOST"
        except (json.JSONDecodeError, subprocess.TimeoutExpired):
            return "ERROR"
        except Exception as e:
            return f"ERROR:{str(e)[:40]}"
    return "RATELIMITED"

def main():
    print("XSA Ghost CVE Audit", file=sys.stderr)
    print("=" * 60, file=sys.stderr)

    xsa_numbers = get_xsa_numbers()
    print(f"Found {len(xsa_numbers)} XSA advisories", file=sys.stderr)

    results = []
    cve_count = 0

    for i, num in enumerate(xsa_numbers):
        print(f"[{i+1}/{len(xsa_numbers)}] XSA-{num}...", file=sys.stderr, end=" ", flush=True)

        details = get_xsa_details(num)
        if "error" in details:
            print(f"ERROR: {details['error']}", file=sys.stderr)
            continue

        cves = details["cves"]
        print(f"{len(cves)} CVEs, signer: {details['signer'][:30]}", file=sys.stderr, flush=True)

        for cve in cves:
            # Rate limit: sleep 4s between checks (15/min, well under 20/min limit)
            if cve_count > 0:
                time.sleep(4)

            status = check_cve(cve)
            cve_count += 1

            result = {
                "xsa": f"XSA-{num}",
                "cve": cve,
                "status": status,
                "signer": details["signer"],
                "title": details["title"],
                "has_embargo_note": details.get("has_embargo_note", False),
                "has_patches": details.get("has_patches", False),
            }
            results.append(result)

            marker = "***GHOST***" if status == "GHOST" else status
            print(f"  {cve}: {marker}", file=sys.stderr, flush=True)

            # Incremental write after each CVE
            with open("XSA/audit-results-incremental.jsonl", "a") as inc:
                inc.write(json.dumps(result) + "\n")

    # Write results
    output = {
        "audit_date": time.strftime("%Y-%m-%d %H:%M UTC", time.gmtime()),
        "total_xsas": len(xsa_numbers),
        "total_cves": len(results),
        "ghosts": len([r for r in results if r["status"] == "GHOST"]),
        "published": len([r for r in results if r["status"] == "PUBLISHED"]),
        "results": results,
    }

    print(json.dumps(output, indent=2))

    print(f"\n{'=' * 60}", file=sys.stderr)
    print(f"Total CVEs: {len(results)}", file=sys.stderr)
    print(f"Published:  {len([r for r in results if r['status'] == 'PUBLISHED'])}", file=sys.stderr)
    print(f"Ghost:      {len([r for r in results if r['status'] == 'GHOST'])}", file=sys.stderr)

if __name__ == "__main__":
    main()
