import os
import sys
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "en_US-l2arctic-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "artict_tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by Sun Pharmaceutical in response to a U S F D A  observation at its Dadra site, where multiple unaddressed deviations related to missing GMP documents were found. The core issue was inadequate documentation practices and limited training, which prompted the development of a comprehensive training module on documentation control. This module will be extended to all departments across all US supply sites, covering topics like document handling, error correction, and logbook maintenance. The training will be delivered through instructor-led sessions and online modules, with a completion timeline of 240 days and oversight by Corporate Quality to ensure compliance and consistency."
)

# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice = PiperVoice.load(model_path)
    cfg = SynthesisConfig(
        speaker_id=speaker_id,
        length_scale=syn_params["length_scale"],
        noise_scale=syn_params["noise_scale"],
        noise_w_scale=syn_params["noise_w_scale"],
    )
    chunks = []
    src_rate = 22050
    for chunk in voice.synthesize(text, syn_config=cfg):
        pcm = np.frombuffer(chunk.audio_int16_bytes, dtype=np.int16).astype(np.float32)
        chunks.append(pcm)
        src_rate = chunk.sample_rate
    audio = np.concatenate(chunks) / 32768.0
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params, speaker_id=None):
    """Fall back to piper binary via subprocess."""
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    tmp.close()
    cmd = [
        piper_exe,
        "--model",        model_path,
        "--length_scale", str(syn_params["length_scale"]),
        "--noise_scale",  str(syn_params["noise_scale"]),
        "--noise_w",      str(syn_params["noise_w_scale"]),
        "--output_file",  tmp.name,
    ]
    if speaker_id is not None:
        cmd += ["--speaker", str(speaker_id)]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, text=True)
    _, err = proc.communicate(text)
    if proc.returncode != 0:
        raise RuntimeError(f"piper failed: {err}")
    with wave.open(tmp.name, "rb") as wf:
        src_rate = wf.getframerate()
        raw = wf.readframes(wf.getnframes())
    os.unlink(tmp.name)
    audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    return audio, src_rate


# length_scale > 1.0 → slower pace  |  higher noise_w_scale → more Indian English rhythm variation
syn_params = dict(length_scale=1.10, noise_scale=0.25, noise_w_scale=0.50)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")


# ── Audio Enhancement Helpers ─────────────────────────────────────────────────

def spectral_noise_reduction(audio, rate, frame_ms=25, hop_ms=10,
                              noise_percentile=15, strength=1.5):
    """
    Frame-based spectral subtraction (Wiener-like gain).
    Estimates the noise floor from the lowest-energy frames and subtracts
    it from every frame's magnitude spectrum.  A 5 % spectral floor prevents
    the 'musical noise' artefact common in naive subtraction.
    """
    frame_len = int(rate * frame_ms / 1000)
    hop_len   = int(rate * hop_ms  / 1000)
    window    = np.hanning(frame_len)

    # Pad so the full signal is covered
    audio_padded = np.pad(audio, (frame_len // 2, frame_len // 2 + frame_len))

    # Build overlapping frames
    positions = list(range(0, len(audio_padded) - frame_len + 1, hop_len))
    frames    = np.array([audio_padded[p:p + frame_len] * window for p in positions])

    spectra    = np.fft.rfft(frames, axis=1)
    magnitudes = np.abs(spectra)
    phases     = np.angle(spectra)

    # Noise profile: mean of the quietest frames
    frame_energy  = np.sum(magnitudes ** 2, axis=1)
    thresh        = np.percentile(frame_energy, noise_percentile)
    noise_frames  = magnitudes[frame_energy <= thresh]
    if len(noise_frames) == 0:
        noise_frames = magnitudes[:max(1, len(magnitudes) // 10)]
    noise_profile = np.mean(noise_frames, axis=0)

    # Wiener-like gain: keep at least 5 % of original energy
    gain = np.maximum(
        magnitudes - noise_profile * strength,
        magnitudes * 0.05
    ) / (magnitudes + 1e-10)
    enhanced_mag = magnitudes * gain

    # Reconstruct frames
    enhanced_frames = np.fft.irfft(enhanced_mag * np.exp(1j * phases),
                                   n=frame_len, axis=1)

    # Overlap-add
    output  = np.zeros(len(audio_padded))
    win_sum = np.zeros(len(audio_padded))
    for i, p in enumerate(positions):
        output [p:p + frame_len] += enhanced_frames[i] * window
        win_sum[p:p + frame_len] += window ** 2
    output /= np.maximum(win_sum, 1e-8)

    start = frame_len // 2
    return output[start:start + len(audio)]


def noise_gate(audio, rate, threshold_db=-48, attack_ms=5, release_ms=80):
    """
    Smooth RMS-based noise gate.
    Silences low-level inter-word noise while leaving speech fully intact.
    Attack/release coefficients prevent clicking on gate open/close.
    """
    threshold = 10 ** (threshold_db / 20.0)
    chunk     = max(1, int(rate * 0.005))          # 5 ms analysis chunks
    atk_coef  = np.exp(-1.0 / max(1, rate * attack_ms  / 1000.0))
    rel_coef  = np.exp(-1.0 / max(1, rate * release_ms / 1000.0))

    # Per-chunk RMS envelope
    rms_env = np.zeros(len(audio))
    for i in range(0, len(audio), chunk):
        seg = audio[i:i + chunk]
        rms_env[i:i + chunk] = np.sqrt(np.mean(seg ** 2) + 1e-12)

    # Smooth gain curve
    gain    = np.ones(len(audio))
    current = 1.0
    for i in range(len(audio)):
        target  = 1.0 if rms_env[i] >= threshold else 0.0
        coef    = atk_coef if target < current else rel_coef
        current = coef * current + (1.0 - coef) * target
        gain[i] = current

    return audio * gain


def de_esser(audio, rate, freq_low=5500, freq_high=9000,
             threshold=0.12, ratio=4.0):
    """
    Dynamic de-esser.
    Isolates the sibilant band (5.5–9 kHz), tracks its envelope, and applies
    gain reduction only when it exceeds the threshold — taming harshness
    while leaving the rest of the spectrum untouched.
    """
    from scipy import signal as sg
    nyq      = rate / 2.0
    sos      = sg.butter(4, [freq_low / nyq, freq_high / nyq],
                         btype="bandpass", output="sos")
    sibilant = sg.sosfiltfilt(sos, audio)

    # Smooth 3 ms peak envelope
    chunk    = max(1, int(rate * 0.003))
    envelope = np.zeros(len(audio))
    for i in range(0, len(audio), chunk):
        envelope[i:i + chunk] = np.max(np.abs(sibilant[i:i + chunk]) + 1e-12)
    smooth   = max(1, int(rate * 0.005))
    envelope = np.convolve(envelope, np.ones(smooth) / smooth, mode='same')

    # Gain reduction on sibilant band only
    gain  = np.ones(len(audio))
    above = envelope > threshold
    if above.any():
        excess       = envelope[above] - threshold
        gain[above]  = (threshold + excess / ratio) / envelope[above]

    return audio - sibilant + sibilant * gain


def smooth_compress(audio, rate, threshold=0.22, ratio=4.5, makeup=1.45,
                    attack_ms=3, release_ms=60):
    """
    Downward compressor with sample-accurate attack/release smoothing.
    Raises quiet consonants, tames loud vowels — the single biggest
    intelligibility improvement for TTS voices.
    """
    atk_coef = np.exp(-1.0 / max(1, rate * attack_ms  / 1000.0))
    rel_coef = np.exp(-1.0 / max(1, rate * release_ms / 1000.0))

    abs_a     = np.abs(audio)
    gain_inst = np.ones(len(audio))
    above     = abs_a > threshold
    if above.any():
        excess           = abs_a[above] - threshold
        gain_inst[above] = (threshold + excess / ratio) / abs_a[above]

    # Smooth instantaneous gain to avoid clicks
    gain_smooth = np.ones(len(audio))
    current     = 1.0
    for i, g in enumerate(gain_inst):
        coef    = atk_coef if g < current else rel_coef
        current = coef * current + (1.0 - coef) * g
        gain_smooth[i] = current

    return audio * gain_smooth * makeup


# ── Main enhancement pipeline ─────────────────────────────────────────────────
def enhance_audio(audio, src_rate, target_rate):
    try:
        from scipy import signal as sg

        # ── 1. High-quality polyphase upsample (22 050 → 44 100 Hz) ─────────
        audio = sg.resample_poly(audio, target_rate, src_rate)
        nyq   = target_rate / 2.0

        # ── 2. Remove DC offset ──────────────────────────────────────────────
        audio -= audio.mean()

        # ── 3. Notch-filter electrical hum: 50 / 60 Hz + 1st harmonics ──────
        for f0 in (50.0, 60.0, 100.0, 120.0):
            b, a  = sg.iirnotch(f0, Q=40, fs=target_rate)
            audio = sg.filtfilt(b, a, audio)

        # ── 4. Steep high-pass @ 90 Hz — removes all sub-bass rumble ────────
        sos_hp = sg.butter(8, 90.0 / nyq, btype="high", output="sos")
        audio  = sg.sosfiltfilt(sos_hp, audio)

        # ── 5. Spectral noise reduction ──────────────────────────────────────
        print("  [1/7] Spectral noise reduction …")
        audio = spectral_noise_reduction(audio, target_rate,
                                         noise_percentile=15, strength=1.5)

        # ── 6. Noise gate — silences background between words ────────────────
        print("  [2/7] Noise gate …")
        audio = noise_gate(audio, target_rate, threshold_db=-48,
                           attack_ms=5, release_ms=80)

        # ── 7. Surgical EQ (zero-phase sosfiltfilt throughout) ───────────────
        print("  [3/7] EQ shaping …")

        # Cut muddiness / boxiness (250–500 Hz)
        sos_mud  = sg.butter(3, [250.0 / nyq, 500.0 / nyq],
                             btype="bandpass", output="sos")
        audio   -= sg.sosfiltfilt(sos_mud, audio) * 0.12

        # Warmth / body (120–250 Hz) — slight lift for natural tone
        sos_warm = sg.butter(2, [120.0 / nyq, 250.0 / nyq],
                             btype="bandpass", output="sos")
        audio   += sg.sosfiltfilt(sos_warm, audio) * 0.08

        # Presence / intelligibility (2–4 kHz) — primary speech clarity band
        sos_pres = sg.butter(4, [2000.0 / nyq, 4000.0 / nyq],
                             btype="bandpass", output="sos")
        audio   += sg.sosfiltfilt(sos_pres, audio) * 0.28

        # Clarity / consonants (4–7 kHz) — t, d, s, p, k definition
        sos_clar = sg.butter(3, [4000.0 / nyq, 7000.0 / nyq],
                             btype="bandpass", output="sos")
        audio   += sg.sosfiltfilt(sos_clar, audio) * 0.15

        # Air / openness (10–14 kHz) — studio breathiness
        sos_air  = sg.butter(2, [10000.0 / nyq, 14000.0 / nyq],
                             btype="bandpass", output="sos")
        audio   += sg.sosfiltfilt(sos_air, audio) * 0.08

        # ── 8. De-esser — tame harsh sibilants without dulling the voice ─────
        print("  [4/7] De-esser …")
        audio = de_esser(audio, target_rate,
                         freq_low=5500, freq_high=9000,
                         threshold=0.12, ratio=4.0)

        # ── 9. Gentle low-pass @ 16 kHz — smooth high end, keep air ─────────
        sos_lp = sg.butter(6, 16000.0 / nyq, btype="low", output="sos")
        audio  = sg.sosfiltfilt(sos_lp, audio)

    except ImportError:
        # scipy unavailable — numpy-only FFT upsample
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # ── 10. Smooth downward compressor ───────────────────────────────────────
    print("  [5/7] Compression …")
    audio = smooth_compress(audio, target_rate,
                            threshold=0.22, ratio=4.5, makeup=1.45,
                            attack_ms=3, release_ms=60)

    # ── 11. Safety tanh soft-limiter — no hard clipping ever ────────────────
    print("  [6/7] Soft limiting …")
    audio = np.tanh(audio * 0.90) / 0.90

    # ── 12. Peak-normalize to −0.3 dBFS ─────────────────────────────────────
    print("  [7/7] Peak normalization …")
    peak = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak      # −0.3 dBFS

    return audio


audio = enhance_audio(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
