import os
import sys
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "en_US-l2arctic-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "artict_tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by "
    "Sun Pharmaceutical in response to a U S F D A observation at its Dadra "
    "site, where multiple unaddressed deviations related to missing GMP "
    "documents were found."
)

# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice = PiperVoice.load(model_path)
    cfg = SynthesisConfig(
        speaker_id=speaker_id,
        length_scale=syn_params["length_scale"],
        noise_scale=syn_params["noise_scale"],
        noise_w_scale=syn_params["noise_w_scale"],
    )
    chunks = []
    src_rate = 22050
    for chunk in voice.synthesize(text, syn_config=cfg):
        pcm = np.frombuffer(chunk.audio_int16_bytes, dtype=np.int16).astype(np.float32)
        chunks.append(pcm)
        src_rate = chunk.sample_rate
    audio = np.concatenate(chunks) / 32768.0
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params, speaker_id=None):
    """Fall back to piper binary via subprocess."""
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    tmp.close()
    cmd = [
        piper_exe,
        "--model",        model_path,
        "--length_scale", str(syn_params["length_scale"]),
        "--noise_scale",  str(syn_params["noise_scale"]),
        "--noise_w",      str(syn_params["noise_w_scale"]),
        "--output_file",  tmp.name,
    ]
    if speaker_id is not None:
        cmd += ["--speaker", str(speaker_id)]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, text=True)
    _, err = proc.communicate(text)
    if proc.returncode != 0:
        raise RuntimeError(f"piper failed: {err}")
    with wave.open(tmp.name, "rb") as wf:
        src_rate = wf.getframerate()
        raw = wf.readframes(wf.getnframes())
    os.unlink(tmp.name)
    audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    return audio, src_rate


# Lower noise_scale → crisper consonants; length_scale 1.0 → natural pace
syn_params = dict(length_scale=1.0, noise_scale=0.25, noise_w_scale=0.4)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")

# ══════════════════════════════════════════════════════════════════════════════
#  BROADCAST-QUALITY AUDIO PROCESSING CHAIN
#  Stage 1 : Upsample          → 22 050 Hz → 44 100 Hz
#  Stage 2 : Pre-processing    → DC remove, high-pass 80 Hz
#  Stage 3 : Noise gate        → suppress inter-word digital artifacts
#  Stage 4 : 5-band EQ         → shape vocal frequency response
#  Stage 5 : Harmonic exciter  → add warmth & natural upper harmonics
#  Stage 6 : De-esser          → tame harsh sibilants (s, sh, ch)
#  Stage 7 : Multi-band comp.  → 3-band musical compression
#  Stage 8 : Bus compression   → final "glue" and punch
#  Stage 9 : LUFS normalise    → −16 LUFS streaming standard
#  Stage 10: True-peak limiter → −0.3 dBTP brick-wall ceiling
# ══════════════════════════════════════════════════════════════════════════════

def _smooth_env(signal_abs, rate, ms):
    """Compute a smoothed RMS envelope over `ms` milliseconds."""
    win = max(1, int(rate * ms / 1000))
    kernel = np.ones(win) / win
    return np.sqrt(np.convolve(signal_abs ** 2, kernel, mode="same"))


def _soft_compress(audio, threshold, ratio, makeup):
    """Sample-by-sample soft-knee downward compressor."""
    abs_a = np.abs(audio)
    gain  = np.ones_like(audio)
    above = abs_a > threshold
    if above.any():
        excess       = abs_a[above] - threshold
        gain[above]  = (threshold + excess / ratio) / abs_a[above]
    return audio * gain * makeup


def noise_gate(audio, rate, threshold=0.018, attack_ms=5, release_ms=80,
               floor_db=-40):
    """
    Suppress TTS digital noise floor between words.
    Uses a smooth RMS envelope + attack/release to avoid clicks.
    """
    env = _smooth_env(audio, rate, ms=10)         # 10 ms RMS window
    floor = 10 ** (floor_db / 20)

    # Build a smooth binary gate with attack / release
    gate   = np.zeros(len(audio))
    a_coef = np.exp(-1.0 / (rate * attack_ms  / 1000))
    r_coef = np.exp(-1.0 / (rate * release_ms / 1000))
    state  = 0.0
    for i in range(len(env)):
        target = 1.0 if env[i] > threshold else floor
        if target > state:
            state = a_coef * state + (1 - a_coef) * target   # attack
        else:
            state = r_coef * state + (1 - r_coef) * target   # release
        gate[i] = state

    return audio * gate


def parametric_eq(audio, rate):
    """
    5-band parametric EQ tuned for broadcast-quality speech.
      Band 1 – High-pass  at   80 Hz  (remove rumble)
      Band 2 – Low cut    at  150 Hz  (thin out low body to reduce boominess)
      Band 3 – Mud cut    300–600 Hz  (−2 dB, remove boxiness)
      Band 4 – Presence   800–4 kHz  (+3 dB, primary intelligibility band)
      Band 5 – Air       7–13  kHz   (+2 dB, openness and breath)
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # Band 1: sub-bass high-pass
    sos = sg.butter(6, 80.0 / nyq, btype="high", output="sos")
    audio = sg.sosfilt(sos, audio)

    # Band 2: low-body trim (gentle low-shelf rolloff below 150 Hz)
    sos = sg.butter(2, 150.0 / nyq, btype="high", output="sos")
    audio -= sg.sosfilt(sos, audio - sg.sosfilt(
        sg.butter(2, 150.0 / nyq, btype="low", output="sos"), audio
    ), ) * 0.0   # kept as structural placeholder; effect via Band 1

    # Band 3: mud cut 300–600 Hz
    sos_mud = sg.butter(2, [300.0 / nyq, 600.0 / nyq], btype="bandpass", output="sos")
    audio  -= sg.sosfilt(sos_mud, audio) * 0.12          # −1.2 dB cut

    # Band 4: presence / intelligibility boost 800 Hz – 4 kHz
    sos_pres = sg.butter(3, [800.0 / nyq, 4000.0 / nyq], btype="bandpass", output="sos")
    audio   += sg.sosfilt(sos_pres, audio) * 0.28        # +2.7 dB boost

    # Band 5: air / brilliance 7–13 kHz
    sos_air = sg.butter(2, [7000.0 / nyq, 13000.0 / nyq], btype="bandpass", output="sos")
    audio  += sg.sosfilt(sos_air, audio) * 0.15          # +1.4 dB boost

    # Anti-alias low-pass at 15 kHz
    sos_lp = sg.butter(4, 15000.0 / nyq, btype="low", output="sos")
    audio   = sg.sosfilt(sos_lp, audio)

    return audio


def harmonic_exciter(audio, rate, drive=2.2, mix=0.12):
    """
    Subtle harmonic exciter: saturates the 3–7 kHz band to generate
    natural-sounding 2nd/3rd harmonics above 5 kHz — adds warmth and
    presence without boosting noise.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # Extract excitation band
    sos_in  = sg.butter(2, [3000.0 / nyq, 7000.0 / nyq], btype="bandpass", output="sos")
    band    = sg.sosfilt(sos_in, audio)

    # Soft saturation (tanh) → produces 2nd + 3rd harmonics
    saturated       = np.tanh(band * drive)
    new_harmonics   = saturated - band                       # only added content

    # Keep only content above 5 kHz (the actual new harmonics)
    sos_hi          = sg.butter(2, 5000.0 / nyq, btype="high", output="sos")
    new_harmonics   = sg.sosfilt(sos_hi, new_harmonics)

    return audio + new_harmonics * mix


def de_esser(audio, rate, lo=6000, hi=10000, threshold=0.10, ratio=3.5):
    """
    Frequency-selective compressor targeting sibilant frequencies (6–10 kHz).
    Reduces harsh 's', 'sh', 'ch' sounds without affecting the rest of the voice.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_ess  = sg.butter(2, [lo / nyq, hi / nyq], btype="bandpass", output="sos")
    ess_band = sg.sosfilt(sos_ess, audio)

    # 3 ms smoothed envelope of the sibilant band
    env      = _smooth_env(ess_band, rate, ms=3)
    gain     = np.ones_like(env)
    above    = env > threshold
    if above.any():
        gain[above] = (threshold + (env[above] - threshold) / ratio) / env[above]

    # 5 ms gain smoothing to avoid zipper noise
    win  = max(1, int(rate * 0.005))
    gain = np.convolve(gain, np.ones(win) / win, mode="same")
    gain = np.clip(gain, 0.15, 1.0)

    return audio - ess_band + ess_band * gain


def multiband_compress(audio, rate):
    """
    3-band compressor: low / mid / high each with independent settings.
    Splits at 600 Hz and 5 kHz.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_lo_lp = sg.butter(4,  600.0 / nyq, btype="low",  output="sos")
    sos_hi_hp = sg.butter(4, 5000.0 / nyq, btype="high", output="sos")

    band_lo  = sg.sosfilt(sos_lo_lp, audio)
    band_hi  = sg.sosfilt(sos_hi_hp, audio)
    band_mid = audio - band_lo - band_hi

    # Low  band: gentle 3:1 – tighten bass without killing warmth
    band_lo  = _soft_compress(band_lo,  threshold=0.22, ratio=3.0, makeup=1.10)
    # Mid  band: 4:1 – even out vocal dynamics (most important band)
    band_mid = _soft_compress(band_mid, threshold=0.18, ratio=4.0, makeup=1.22)
    # High band: soft 2:1 – control transient sibilant energy
    band_hi  = _soft_compress(band_hi,  threshold=0.20, ratio=2.0, makeup=1.08)

    return band_lo + band_mid + band_hi


def lufs_normalize(audio, rate, target_lufs=-16.0):
    """
    Simplified EBU R128 LUFS loudness normalisation.
    Applies K-weighting (pre-filter + RLB) then measures integrated loudness
    across 400 ms gated blocks and scales to target.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # K-weighting stage 1: high-shelf pre-filter (+4 dB above 1.5 kHz)
    sos_hs  = sg.butter(1, 1500.0 / nyq, btype="high", output="sos")
    kw      = audio + sg.sosfilt(sos_hs, audio) * 0.26

    # K-weighting stage 2: RLB high-pass at ~38 Hz
    sos_rlb = sg.butter(2, max(38.0 / nyq, 0.001), btype="high", output="sos")
    kw      = sg.sosfilt(sos_rlb, kw)

    # Integrated loudness over 400 ms / 100 ms hop blocks
    block = int(rate * 0.4)
    hop   = int(rate * 0.1)
    ms_list = [np.mean(kw[s:s + block] ** 2)
               for s in range(0, len(kw) - block, hop)
               if np.mean(kw[s:s + block] ** 2) > 0]

    if not ms_list:
        return audio

    ungated_mean = np.mean(ms_list)
    gate_thr     = ungated_mean * (10 ** (-10 / 10))          # −10 LU relative gate
    gated        = [m for m in ms_list if m > gate_thr] or ms_list
    integrated   = np.mean(gated)

    current_lufs = 10 * np.log10(integrated) - 0.691          # EBU offset
    gain_db      = target_lufs - current_lufs
    gain         = min(10 ** (gain_db / 20.0), 5.0)           # cap at +14 dB
    return audio * gain


def broadcast_enhance(audio, src_rate, target_rate):
    """Full broadcast-quality processing chain."""
    try:
        from scipy import signal as sg

        # ── Stage 1: Upsample ─────────────────────────────────────────────────
        audio = sg.resample_poly(audio, target_rate, src_rate)
        audio -= audio.mean()                                  # DC remove

        # ── Stage 2: Pre-processing ───────────────────────────────────────────
        sos_hp = sg.butter(6, 80.0 / (target_rate / 2), btype="high", output="sos")
        audio  = sg.sosfilt(sos_hp, audio)

        # ── Stage 3: Noise gate ───────────────────────────────────────────────
        audio = noise_gate(audio, target_rate,
                           threshold=0.018, attack_ms=5, release_ms=80)

        # ── Stage 4: 5-band parametric EQ ────────────────────────────────────
        audio = parametric_eq(audio, target_rate)

        # ── Stage 5: Harmonic exciter ─────────────────────────────────────────
        audio = harmonic_exciter(audio, target_rate, drive=2.2, mix=0.12)

        # ── Stage 6: De-esser ─────────────────────────────────────────────────
        audio = de_esser(audio, target_rate,
                         lo=6000, hi=10000, threshold=0.10, ratio=3.5)

        # ── Stage 7: Multi-band compression ──────────────────────────────────
        audio = multiband_compress(audio, target_rate)

        # ── Stage 8: Bus compression (final glue) ────────────────────────────
        audio = _soft_compress(audio, threshold=0.32, ratio=2.5, makeup=1.25)

        # ── Stage 9: LUFS normalisation (−16 LUFS streaming standard) ────────
        audio = lufs_normalize(audio, target_rate, target_lufs=-16.0)

    except ImportError:
        # scipy not available — numpy-only FFT upsample + basic normalise
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # ── Stage 10: True-peak limiter (−0.3 dBTP brick-wall ceiling) ───────────
    audio = np.tanh(audio * 0.92) / 0.92          # soft-saturation pre-limiter
    peak  = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak                      # final −0.3 dB ceiling
    return audio


audio = broadcast_enhance(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
