import os
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
MODEL_PATH  = "en_US-hfc_female-medium.onnx"
OUTPUT_FILE = "tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by "
    "Sun Pharmaceutical in response to a U S F D A observation at its Dadra "
    "site, where multiple unaddressed deviations related to missing GMP "
    "documents were found."
)

# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice = PiperVoice.load(model_path)
    cfg = SynthesisConfig(
        length_scale=syn_params["length_scale"],
        noise_scale=syn_params["noise_scale"],
        noise_w_scale=syn_params["noise_w_scale"],
    )
    chunks = []
    src_rate = 22050
    for chunk in voice.synthesize(text, syn_config=cfg):
        pcm = np.frombuffer(chunk.audio_int16_bytes, dtype=np.int16).astype(np.float32)
        chunks.append(pcm)
        src_rate = chunk.sample_rate
    audio = np.concatenate(chunks) / 32768.0
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params):
    """Fall back to piper.exe subprocess."""
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    tmp.close()
    cmd = [
        piper_exe,
        "--model", model_path,
        "--length_scale", str(syn_params["length_scale"]),
        "--noise_scale",  str(syn_params["noise_scale"]),
        "--noise_w",      str(syn_params["noise_w_scale"]),
        "--output_file",  tmp.name,
    ]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, text=True)
    _, err = proc.communicate(text)
    if proc.returncode != 0:
        raise RuntimeError(f"piper.exe failed: {err}")
    with wave.open(tmp.name, "rb") as wf:
        src_rate = wf.getframerate()
        raw = wf.readframes(wf.getnframes())
    os.unlink(tmp.name)
    audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    return audio, src_rate


syn_params = dict(length_scale=1.2, noise_scale=0.35, noise_w_scale=0.5)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params)
    print("Synthesis: piper.exe subprocess (piper-tts not found in this Python)")

# ── Audio enhancement ─────────────────────────────────────────────────────────
def enhance_audio(audio, src_rate, target_rate):
    try:
        from scipy import signal as sg

        # 1. High-quality polyphase upsample
        audio = sg.resample_poly(audio, target_rate, src_rate)
        nyq = target_rate / 2.0

        # 2. Remove DC offset
        audio -= audio.mean()

        # 3. High-pass at 80 Hz — removes rumble / DC drift
        sos_hp = sg.butter(6, 80.0 / nyq, btype="high", output="sos")
        audio = sg.sosfilt(sos_hp, audio)

        # 4. Presence boost (2–5 kHz) — improves speech clarity
        sos_bp = sg.butter(2, [2000.0 / nyq, 5000.0 / nyq], btype="bandpass", output="sos")
        audio += sg.sosfilt(sos_bp, audio) * 0.12

        # 5. Gentle low-pass at 9 kHz — smooths harsh TTS artefacts
        sos_lp = sg.butter(4, 9000.0 / nyq, btype="low", output="sos")
        audio = sg.sosfilt(sos_lp, audio)

    except ImportError:
        # scipy not available — numpy-only fallback (FFT upsample + normalise)
        print("scipy not found — using numpy-only processing")
        factor = target_rate // src_rate
        n = len(audio)
        fft = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio = np.fft.irfft(fft_pad, n * factor)
        audio -= audio.mean()

    # 6. Peak-normalise to −0.5 dB
    peak = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.944 / peak
    return audio


audio = enhance_audio(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
