import os
import sys
import re
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "hi_IN-priyamvada-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "artict_tts_audio_in.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by Sun Pharmaceutical in response to a USFDA observation at its Dadra site, where multiple unaddressed deviations related to missing GMP documents were found."
)

# ── Text → sentence segments with pause durations ─────────────────────────────
def split_sentences(text):
    """
    Split text into (segment, pause_ms) pairs.
    Sentence-ending punctuation  → 480 ms pause (full breath).
    Comma / semicolon            → 250 ms pause (short breath).
    No trailing punctuation      → 150 ms pause (natural continuation).
    """
    # Split keeping the delimiter
    parts = re.split(r'(?<=[.,;!?])\s+', text.strip())
    result = []
    for i, part in enumerate(parts):
        part = part.strip()
        if not part:
            continue
        last = part[-1]
        if last in '.!?':
            pause = 480
        elif last in ',;':
            pause = 250
        else:
            pause = 150
        result.append((part, pause if i < len(parts) - 1 else 0))
    return result


def _silence(rate, ms):
    """Return a float32 zeros array of given duration."""
    return np.zeros(int(rate * ms / 1000), dtype=np.float32)


# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice    = PiperVoice.load(model_path)
    cfg      = SynthesisConfig(
        speaker_id    = speaker_id,
        length_scale  = syn_params["length_scale"],
        noise_scale   = syn_params["noise_scale"],
        noise_w_scale = syn_params["noise_w_scale"],
    )
    segments  = split_sentences(text)
    all_pcm   = []
    src_rate  = 22050

    for sentence, pause_ms in segments:
        chunks = []
        for chunk in voice.synthesize(sentence, syn_config=cfg):
            pcm = np.frombuffer(chunk.audio_int16_bytes,
                                dtype=np.int16).astype(np.float32)
            chunks.append(pcm)
            src_rate = chunk.sample_rate
        if chunks:
            all_pcm.append(np.concatenate(chunks))
        if pause_ms > 0:
            # silence in raw int16 scale (divided by 32768 later)
            all_pcm.append(_silence(src_rate, pause_ms) * 32768.0)

    audio = np.concatenate(all_pcm) / 32768.0 if all_pcm else np.zeros(1)
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params,
                                speaker_id=None):
    """Fall back to piper binary via subprocess."""
    segments  = split_sentences(text)
    all_audio = []
    src_rate  = 22050

    for sentence, pause_ms in segments:
        tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        tmp.close()
        cmd = [
            piper_exe,
            "--model",        model_path,
            "--length_scale", str(syn_params["length_scale"]),
            "--noise_scale",  str(syn_params["noise_scale"]),
            "--noise_w",      str(syn_params["noise_w_scale"]),
            "--output_file",  tmp.name,
        ]
        if speaker_id is not None:
            cmd += ["--speaker", str(speaker_id)]
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                text=True)
        _, err = proc.communicate(sentence)
        if proc.returncode != 0:
            raise RuntimeError(f"piper failed: {err}")
        with wave.open(tmp.name, "rb") as wf:
            src_rate = wf.getframerate()
            raw      = wf.readframes(wf.getnframes())
        os.unlink(tmp.name)
        seg = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
        all_audio.append(seg)
        if pause_ms > 0:
            all_audio.append(_silence(src_rate, pause_ms))

    audio = np.concatenate(all_audio) if all_audio else np.zeros(1)
    return audio, src_rate


# length_scale=1.35 → 35% slower speech; noise_scale=0.22 → clean, natural phonemes
syn_params = dict(length_scale=1.35, noise_scale=0.22, noise_w_scale=0.4)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")

# ══════════════════════════════════════════════════════════════════════════════
#  BROADCAST-QUALITY AUDIO PROCESSING CHAIN
#  Stage 1 : Upsample          → 22 050 Hz → 44 100 Hz
#  Stage 2 : Pre-processing    → DC remove, high-pass 80 Hz
#  Stage 3 : Noise gate        → suppress inter-word digital artifacts
#  Stage 4 : 5-band EQ         → shape vocal frequency response
#  Stage 5 : Harmonic exciter  → add warmth & natural upper harmonics
#  Stage 6 : De-esser          → tame harsh sibilants (s, sh, ch)
#  Stage 7 : Multi-band comp.  → 3-band musical compression
#  Stage 8 : Bus compression   → final "glue" and punch
#  Stage 9 : LUFS normalise    → −16 LUFS streaming standard
#  Stage 10: True-peak limiter → −0.3 dBTP brick-wall ceiling
# ══════════════════════════════════════════════════════════════════════════════

def _smooth_env(signal_abs, rate, ms):
    """Compute a smoothed RMS envelope over `ms` milliseconds."""
    win = max(1, int(rate * ms / 1000))
    kernel = np.ones(win) / win
    return np.sqrt(np.convolve(signal_abs ** 2, kernel, mode="same"))


def _soft_compress(audio, threshold, ratio, makeup):
    """Sample-by-sample soft-knee downward compressor."""
    abs_a = np.abs(audio)
    gain  = np.ones_like(audio)
    above = abs_a > threshold
    if above.any():
        excess       = abs_a[above] - threshold
        gain[above]  = (threshold + excess / ratio) / abs_a[above]
    return audio * gain * makeup


def noise_gate(audio, rate, threshold=0.018, attack_ms=5, release_ms=80,
               floor_db=-40):
    """
    Suppress TTS digital noise floor between words.
    Uses a smooth RMS envelope + attack/release to avoid clicks.
    """
    env = _smooth_env(audio, rate, ms=10)         # 10 ms RMS window
    floor = 10 ** (floor_db / 20)

    # Build a smooth binary gate with attack / release
    gate   = np.zeros(len(audio))
    a_coef = np.exp(-1.0 / (rate * attack_ms  / 1000))
    r_coef = np.exp(-1.0 / (rate * release_ms / 1000))
    state  = 0.0
    for i in range(len(env)):
        target = 1.0 if env[i] > threshold else floor
        if target > state:
            state = a_coef * state + (1 - a_coef) * target   # attack
        else:
            state = r_coef * state + (1 - r_coef) * target   # release
        gate[i] = state

    return audio * gate


def parametric_eq(audio, rate):
    """
    FFT-based linear-phase parametric EQ.

    IIR (Butterworth) filters have wide transition bands, causing the intended
    cuts/boosts to be diluted.  FFT-based EQ multiplies the spectrum directly,
    guaranteeing the exact gain at every frequency bin.

    EQ curve anchors — calibrated against spectral analysis of file (5):
      Measured dominant band: 150–300 Hz at −15.7 dBFS (hoarseness source).
      Previous anchors only cut −3 dB at 200 Hz; now cutting −16 dB there.
      Air band (9–13 kHz) was −35 dBFS; boosted more aggressively.

      0–80 Hz      → −30 dB  (rumble / DC block)
      110 Hz       →  −2 dB  (enter voice range, gentle)
      150 Hz       → −10 dB  (low-body cut begins — was barely −1 dB before)
      200 Hz       → −16 dB  (deep cut — primary hoarseness peak)
      260 Hz       → −17 dB  (cut peak — was only −8 dB here before)
      350 Hz       → −14 dB  (cut tapering)
      500 Hz       → −12 dB  (mud, slight ease-off)
      700 Hz       →  −8 dB  (mud taper)
      1 200 Hz     →  −2 dB  (approaching flat)
      1 800 Hz     →  +2 dB  (presence rise begins)
      2 500 Hz     →  +5 dB  (presence)
      4 000 Hz     →  +9 dB  (clarity peak — consonants, slightly raised)
      6 000 Hz     →  +7 dB  (upper clarity, raised)
      9 000 Hz     →  +8 dB  (air boost — was +5, now correcting −35 dBFS)
      13 000 Hz    →  +5 dB  (air taper — was +2)
      15 000 Hz    →  +2 dB  (brilliance rolloff — was 0)
      22 050 Hz    → −20 dB  (Nyquist ceiling)
    """
    from scipy.ndimage import gaussian_filter1d

    N     = len(audio)
    freqs = np.fft.rfftfreq(N, 1.0 / rate)
    spec  = np.fft.rfft(audio.astype(np.float64))

    # Anchor points  [Hz, dB]
    anchors = np.array([
        [0,       -30],   # DC / rumble block
        [80,      -30],   # sub-bass block
        [110,      -2],   # enter voice range
        [150,     -10],   # low-body cut — KEY FIX (was barely −1 dB)
        [200,     -16],   # deep cut — hoarseness peak (was −3 dB)
        [260,     -17],   # cut peak (was −8 dB)
        [350,     -14],   # taper begins
        [500,     -12],   # mud
        [700,      -8],   # mid mud taper
        [1200,     -2],   # approaching flat
        [1800,     +2],   # presence rise
        [2500,     +5],   # presence
        [4000,     +9],   # clarity peak
        [6000,     +7],   # upper clarity
        [9000,     +8],   # AIR boost (was +5)
        [13000,    +5],   # air taper (was +2)
        [15000,    +2],   # brilliance rolloff (was 0)
        [22050,   -20],   # Nyquist ceiling
    ], dtype=np.float64)

    gain_db = np.interp(freqs, anchors[:, 0], anchors[:, 1])

    # Smooth over ~50 Hz (tightened from 80 Hz) to better preserve the
    # steep 150–300 Hz cut while still rounding hard corners.
    hz_per_bin = rate / float(N)
    sigma      = max(2, int(50.0 / hz_per_bin))
    gain_db    = gaussian_filter1d(gain_db, sigma=sigma)

    gain_lin = 10.0 ** (gain_db / 20.0)
    result   = np.fft.irfft(spec * gain_lin, N).astype(np.float32)
    return result


def harmonic_exciter(audio, rate, drive=2.2, mix=0.12):
    """
    Subtle harmonic exciter: saturates the 3–7 kHz band to generate
    natural-sounding 2nd/3rd harmonics above 5 kHz — adds warmth and
    presence without boosting noise.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # Extract excitation band
    sos_in  = sg.butter(2, [3000.0 / nyq, 7000.0 / nyq], btype="bandpass", output="sos")
    band    = sg.sosfilt(sos_in, audio)

    # Soft saturation (tanh) → produces 2nd + 3rd harmonics
    saturated       = np.tanh(band * drive)
    new_harmonics   = saturated - band                       # only added content

    # Keep only content above 5 kHz (the actual new harmonics)
    sos_hi          = sg.butter(2, 5000.0 / nyq, btype="high", output="sos")
    new_harmonics   = sg.sosfilt(sos_hi, new_harmonics)

    return audio + new_harmonics * mix


def de_esser(audio, rate, lo=6000, hi=10000, threshold=0.10, ratio=3.5):
    """
    Frequency-selective compressor targeting sibilant frequencies (6–10 kHz).
    Reduces harsh 's', 'sh', 'ch' sounds without affecting the rest of the voice.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_ess  = sg.butter(2, [lo / nyq, hi / nyq], btype="bandpass", output="sos")
    ess_band = sg.sosfilt(sos_ess, audio)

    # 3 ms smoothed envelope of the sibilant band
    env      = _smooth_env(ess_band, rate, ms=3)
    gain     = np.ones_like(env)
    above    = env > threshold
    if above.any():
        gain[above] = (threshold + (env[above] - threshold) / ratio) / env[above]

    # 5 ms gain smoothing to avoid zipper noise
    win  = max(1, int(rate * 0.005))
    gain = np.convolve(gain, np.ones(win) / win, mode="same")
    gain = np.clip(gain, 0.15, 1.0)

    return audio - ess_band + ess_band * gain


def multiband_compress(audio, rate):
    """
    3-band compressor: low / mid / high each with independent settings.
    Splits at 600 Hz and 5 kHz.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_lo_lp = sg.butter(4,  600.0 / nyq, btype="low",  output="sos")
    sos_hi_hp = sg.butter(4, 5000.0 / nyq, btype="high", output="sos")

    band_lo  = sg.sosfilt(sos_lo_lp, audio)
    band_hi  = sg.sosfilt(sos_hi_hp, audio)
    band_mid = audio - band_lo - band_hi

    # Low  band: gentle 3:1 – no makeup so we don't restore EQ-cut low-body energy
    band_lo  = _soft_compress(band_lo,  threshold=0.22, ratio=3.0, makeup=0.95)
    # Mid  band: 4:1 – even out vocal dynamics (most important band)
    band_mid = _soft_compress(band_mid, threshold=0.18, ratio=4.0, makeup=1.25)
    # High band: soft 2:1 – control transient sibilant energy
    band_hi  = _soft_compress(band_hi,  threshold=0.20, ratio=2.0, makeup=1.12)

    return band_lo + band_mid + band_hi


def lufs_normalize(audio, rate, target_lufs=-16.0):
    """
    Simplified EBU R128 LUFS loudness normalisation.
    Applies K-weighting (pre-filter + RLB) then measures integrated loudness
    across 400 ms gated blocks and scales to target.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # K-weighting stage 1: high-shelf pre-filter (+4 dB above 1.5 kHz)
    sos_hs  = sg.butter(1, 1500.0 / nyq, btype="high", output="sos")
    kw      = audio + sg.sosfilt(sos_hs, audio) * 0.26

    # K-weighting stage 2: RLB high-pass at ~38 Hz
    sos_rlb = sg.butter(2, max(38.0 / nyq, 0.001), btype="high", output="sos")
    kw      = sg.sosfilt(sos_rlb, kw)

    # Integrated loudness over 400 ms / 100 ms hop blocks
    block = int(rate * 0.4)
    hop   = int(rate * 0.1)
    ms_list = [np.mean(kw[s:s + block] ** 2)
               for s in range(0, len(kw) - block, hop)
               if np.mean(kw[s:s + block] ** 2) > 0]

    if not ms_list:
        return audio

    ungated_mean = np.mean(ms_list)
    gate_thr     = ungated_mean * (10 ** (-10 / 10))          # −10 LU relative gate
    gated        = [m for m in ms_list if m > gate_thr] or ms_list
    integrated   = np.mean(gated)

    current_lufs = 10 * np.log10(integrated) - 0.691          # EBU offset
    gain_db      = target_lufs - current_lufs
    gain         = min(10 ** (gain_db / 20.0), 5.0)           # cap at +14 dB
    return audio * gain


def broadcast_enhance(audio, src_rate, target_rate):
    """Full broadcast-quality processing chain."""
    try:
        from scipy import signal as sg

        # ── Stage 1: Upsample ─────────────────────────────────────────────────
        audio = sg.resample_poly(audio, target_rate, src_rate)
        audio -= audio.mean()                                  # DC remove

        # ── Stage 2: Pre-processing ───────────────────────────────────────────
        sos_hp = sg.butter(6, 80.0 / (target_rate / 2), btype="high", output="sos")
        audio  = sg.sosfilt(sos_hp, audio)

        # ── Stage 3: Noise gate ───────────────────────────────────────────────
        audio = noise_gate(audio, target_rate,
                           threshold=0.018, attack_ms=5, release_ms=80)

        # ── Stage 4: 5-band parametric EQ ────────────────────────────────────
        audio = parametric_eq(audio, target_rate)

        # ── Stage 5: Harmonic exciter ─────────────────────────────────────────
        # drive=1.6, mix=0.12: adds subtle 2nd/3rd harmonics above 5 kHz for
        # warmth and vocal presence; low-body is now properly tamed by EQ so
        # exciter no longer exacerbates the hoarseness.
        audio = harmonic_exciter(audio, target_rate, drive=1.6, mix=0.12)

        # ── Stage 6: De-esser ─────────────────────────────────────────────────
        audio = de_esser(audio, target_rate,
                         lo=6000, hi=10000, threshold=0.10, ratio=3.5)

        # ── Stage 7: Multi-band compression ──────────────────────────────────
        audio = multiband_compress(audio, target_rate)

        # ── Stage 8: Bus compression (final glue) ────────────────────────────
        audio = _soft_compress(audio, threshold=0.32, ratio=2.5, makeup=1.25)

        # ── Stage 9: LUFS normalisation (−16 LUFS streaming standard) ────────
        audio = lufs_normalize(audio, target_rate, target_lufs=-16.0)

    except ImportError:
        # scipy not available — numpy-only FFT upsample + basic normalise
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # ── Stage 10: True-peak limiter (−0.3 dBTP brick-wall ceiling) ───────────
    audio = np.tanh(audio * 0.92) / 0.92          # soft-saturation pre-limiter
    peak  = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak                      # final −0.3 dB ceiling
    return audio


audio = broadcast_enhance(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
