"""
encrypt_and_run.py
==================
Workflow:
  1. Compile .py  → .pyc  (bytecode)
  2. Encrypt .pyc → .epyc (AES-256-GCM via Fernet)
  3. Save .epyc + secret.key into the encrypted/ folder
  4. Run any .epyc by decrypting in-memory and exec-ing the code object

Usage
-----
  python encrypt_and_run.py encrypt  <source_dir>    <encrypted_dir>
  python encrypt_and_run.py run      <encrypted_dir> <filename.epyc>
  python encrypt_and_run.py run-all  <encrypted_dir>

Requirements: cryptography  (pip install cryptography)
"""

import sys
import os
import py_compile
import struct
import importlib.util
import marshal
import pathlib
import time

# ── optional dependency guard ────────────────────────────────────────────────
try:
    from cryptography.fernet import Fernet
except ImportError:
    sys.exit("Missing dependency — run:  pip install cryptography")


# ═══════════════════════════════════════════════════════════════════════════════
#  KEY MANAGEMENT
# ═══════════════════════════════════════════════════════════════════════════════

def get_or_create_key(key_path: pathlib.Path) -> bytes:
    """Load existing key or generate a new one and save it."""
    if key_path.exists():
        return key_path.read_bytes()
    key = Fernet.generate_key()
    key_path.write_bytes(key)
    key_path.chmod(0o600)           # owner-read-only
    print(f"[key]  New key saved → {key_path}")
    return key


# ═══════════════════════════════════════════════════════════════════════════════
#  ENCRYPT
# ═══════════════════════════════════════════════════════════════════════════════

def encrypt_folder(source_dir: str, encrypted_dir: str) -> None:
    src  = pathlib.Path(source_dir).resolve()
    dest = pathlib.Path(encrypted_dir).resolve()
    dest.mkdir(parents=True, exist_ok=True)

    key_path = dest / "secret.key"
    key      = get_or_create_key(key_path)
    fernet   = Fernet(key)

    py_files = list(src.rglob("*.py"))
    if not py_files:
        print(f"[warn] No .py files found in {src}")
        return

    print(f"\n{'─'*55}")
    print(f"  Encrypting {len(py_files)} file(s)")
    print(f"  Source    : {src}")
    print(f"  Encrypted : {dest}")
    print(f"{'─'*55}")

    for py_path in py_files:
        _encrypt_one(py_path, src, dest, fernet)

    print(f"\n✓ Done.  Key → {key_path.name}\n")


def _encrypt_one(py_path: pathlib.Path, src_root: pathlib.Path,
                 dest_root: pathlib.Path, fernet: Fernet) -> None:
    """Compile one .py to bytecode, encrypt it, write .epyc."""
    # ── compile to .pyc ──────────────────────────────────────────────────────
    cache_dir = py_path.parent / "__pycache__"
    cache_dir.mkdir(exist_ok=True)

    pyc_path = cache_dir / (
        py_path.stem + "."
        + f"cpython-{sys.version_info.major}{sys.version_info.minor}.pyc"
    )

    py_compile.compile(str(py_path), cfile=str(pyc_path),
                       optimize=2, doraise=True)

    # ── read raw bytecode (skip the 16-byte .pyc header) ────────────────────
    raw_pyc = pyc_path.read_bytes()
    bytecode = raw_pyc[16:]         # magic(4) + flags(4) + mtime(4) + size(4)

    # ── encrypt ──────────────────────────────────────────────────────────────
    encrypted_bytes = fernet.encrypt(bytecode)

    # ── mirror folder structure in dest ──────────────────────────────────────
    rel_path   = py_path.relative_to(src_root)
    epyc_path  = (dest_root / rel_path).with_suffix(".epyc")
    epyc_path.parent.mkdir(parents=True, exist_ok=True)
    epyc_path.write_bytes(encrypted_bytes)

    print(f"  [enc]  {py_path.name:30s}  →  {epyc_path.name}")


# ═══════════════════════════════════════════════════════════════════════════════
#  RUN
# ═══════════════════════════════════════════════════════════════════════════════

def run_epyc(encrypted_dir: str, epyc_filename: str) -> None:
    dest      = pathlib.Path(encrypted_dir).resolve()
    epyc_path = dest / epyc_filename
    key_path  = dest / "secret.key"

    _check_exists(epyc_path, "Encrypted file")
    _check_exists(key_path,  "Key file")

    key    = key_path.read_bytes()
    fernet = Fernet(key)

    _execute_epyc(epyc_path, fernet)


def run_all_epyc(encrypted_dir: str) -> None:
    dest     = pathlib.Path(encrypted_dir).resolve()
    key_path = dest / "secret.key"
    _check_exists(key_path, "Key file")

    key    = key_path.read_bytes()
    fernet = Fernet(key)

    epyc_files = list(dest.rglob("*.epyc"))
    if not epyc_files:
        print(f"[warn] No .epyc files found in {dest}")
        return

    for epyc_path in epyc_files:
        print(f"\n{'═'*55}")
        print(f"  Running: {epyc_path.name}")
        print(f"{'═'*55}")
        _execute_epyc(epyc_path, fernet)


def _execute_epyc(epyc_path: pathlib.Path, fernet: Fernet) -> None:
    """Decrypt an .epyc in memory and execute its code object."""
    encrypted_bytes = epyc_path.read_bytes()

    # ── decrypt → raw bytecode ───────────────────────────────────────────────
    try:
        bytecode = fernet.decrypt(encrypted_bytes)
    except Exception:
        print(f"[error] Decryption failed for {epyc_path.name} — wrong key?")
        return

    # ── deserialise the code object ──────────────────────────────────────────
    code_obj = marshal.loads(bytecode)

    # ── set up a realistic __main__ namespace ────────────────────────────────
    ns = {
        "__name__":    "__main__",
        "__file__":    str(epyc_path),
        "__loader__":  None,
        "__package__": None,
        "__spec__":    None,
        "__builtins__": __builtins__,
    }

    start = time.perf_counter()
    exec(code_obj, ns)                      # ← runs the decrypted bytecode
    elapsed = (time.perf_counter() - start) * 1000
    print(f"\n  ✓ Finished in {elapsed:.1f} ms")


# ═══════════════════════════════════════════════════════════════════════════════
#  HELPERS
# ═══════════════════════════════════════════════════════════════════════════════

def _check_exists(path: pathlib.Path, label: str) -> None:
    if not path.exists():
        sys.exit(f"[error] {label} not found: {path}")


def _usage() -> None:
    print(__doc__)
    sys.exit(1)


# ═══════════════════════════════════════════════════════════════════════════════
#  ENTRY POINT
# ═══════════════════════════════════════════════════════════════════════════════

def main():
    if len(sys.argv) < 2:
        _usage()

    cmd = sys.argv[1].lower()

    if cmd == "encrypt":
        if len(sys.argv) != 4:
            print("Usage: encrypt_and_run.py encrypt <source_dir> <encrypted_dir>")
            sys.exit(1)
        encrypt_folder(sys.argv[2], sys.argv[3])

    elif cmd == "run":
        if len(sys.argv) != 4:
            print("Usage: encrypt_and_run.py run <encrypted_dir> <filename.epyc>")
            sys.exit(1)
        run_epyc(sys.argv[2], sys.argv[3])

    elif cmd == "run-all":
        if len(sys.argv) != 3:
            print("Usage: encrypt_and_run.py run-all <encrypted_dir>")
            sys.exit(1)
        run_all_epyc(sys.argv[2])

    else:
        print(f"[error] Unknown command: {cmd!r}")
        _usage()


if __name__ == "__main__":
    main()