import os
import sys
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "en_US-l2arctic-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "artict_tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by Sun Pharmaceutical in response to a U S F D A  observation at its Dadra site, where multiple unaddressed deviations related to missing GMP documents were found. The core issue was inadequate documentation practices and limited training, which prompted the development of a comprehensive training module on documentation control. This module will be extended to all departments across all US supply sites, covering topics like document handling, error correction, and logbook maintenance. The training will be delivered through instructor-led sessions and online modules, with a completion timeline of 240 days and oversight by Corporate Quality to ensure compliance and consistency."
)

# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice = PiperVoice.load(model_path)
    cfg = SynthesisConfig(
        speaker_id=speaker_id,
        length_scale=syn_params["length_scale"],
        noise_scale=syn_params["noise_scale"],
        noise_w_scale=syn_params["noise_w_scale"],
    )
    chunks = []
    src_rate = 22050
    for chunk in voice.synthesize(text, syn_config=cfg):
        pcm = np.frombuffer(chunk.audio_int16_bytes, dtype=np.int16).astype(np.float32)
        chunks.append(pcm)
        src_rate = chunk.sample_rate
    audio = np.concatenate(chunks) / 32768.0
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params, speaker_id=None):
    """Fall back to piper binary via subprocess."""
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    tmp.close()
    cmd = [
        piper_exe,
        "--model",        model_path,
        "--length_scale", str(syn_params["length_scale"]),
        "--noise_scale",  str(syn_params["noise_scale"]),
        "--noise_w",      str(syn_params["noise_w_scale"]),
        "--output_file",  tmp.name,
    ]
    if speaker_id is not None:
        cmd += ["--speaker", str(speaker_id)]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, text=True)
    _, err = proc.communicate(text)
    if proc.returncode != 0:
        raise RuntimeError(f"piper failed: {err}")
    with wave.open(tmp.name, "rb") as wf:
        src_rate = wf.getframerate()
        raw = wf.readframes(wf.getnframes())
    os.unlink(tmp.name)
    audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    return audio, src_rate


# Lower noise_scale → crisper consonants; length_scale 1.0 → natural pace
syn_params = dict(length_scale=1.0, noise_scale=0.25, noise_w_scale=0.4)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")

# ── Audio enhancement ─────────────────────────────────────────────────────────
def soft_compress(audio, threshold=0.30, ratio=3.5, makeup=1.35):
    """
    Soft-knee downward compressor.
    Brings up quiet consonants and reduces peak-to-average gap so every
    syllable is equally intelligible — the single biggest clarity booster
    for non-native TTS voices.
    """
    abs_a = np.abs(audio)
    gain  = np.ones_like(audio)
    above = abs_a > threshold
    if above.any():
        excess          = abs_a[above] - threshold
        target          = threshold + excess / ratio
        gain[above]     = target / abs_a[above]
    return audio * gain * makeup


def enhance_audio(audio, src_rate, target_rate):
    try:
        from scipy import signal as sg

        # 1. High-quality polyphase upsample (22 050 → 44 100 Hz)
        audio = sg.resample_poly(audio, target_rate, src_rate)
        nyq   = target_rate / 2.0

        # 2. Remove DC offset
        audio -= audio.mean()

        # 3. High-pass at 80 Hz — removes low-frequency rumble
        sos_hp = sg.butter(6, 80.0 / nyq, btype="high", output="sos")
        audio  = sg.sosfilt(sos_hp, audio)

        # 4. Cut muddiness (300–600 Hz) — removes boxiness that blurs words
        sos_mud = sg.butter(2, [300.0 / nyq, 600.0 / nyq], btype="bandpass", output="sos")
        audio  -= sg.sosfilt(sos_mud, audio) * 0.10

        # 5. Strong presence boost (1–3.5 kHz) — primary speech intelligibility band;
        #    consonants (t, d, s, p, k) live here and define sharpness
        sos_pres = sg.butter(3, [1000.0 / nyq, 3500.0 / nyq], btype="bandpass", output="sos")
        audio   += sg.sosfilt(sos_pres, audio) * 0.30

        # 6. Brilliance boost (5–9 kHz) — adds air, attack and sharpness to sibilants
        sos_brill = sg.butter(2, [5000.0 / nyq, 9000.0 / nyq], btype="bandpass", output="sos")
        audio    += sg.sosfilt(sos_brill, audio) * 0.18

        # 7. Gentle low-pass at 13 kHz — removes aliasing without killing brilliance
        sos_lp = sg.butter(4, 13000.0 / nyq, btype="low", output="sos")
        audio  = sg.sosfilt(sos_lp, audio)

    except ImportError:
        # scipy not available — numpy-only FFT upsample
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # 8. Soft-knee compression — evens out dynamics, raises quiet syllables
    audio = soft_compress(audio, threshold=0.30, ratio=3.5, makeup=1.35)

    # 9. Safety soft-clip (tanh) to handle any post-compression peaks cleanly
    audio = np.tanh(audio * 0.95) / 0.95

    # 10. Peak-normalise to −0.3 dB for maximum loudness without clipping
    peak = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak
    return audio


audio = enhance_audio(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
