import os
import sys
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "/home/nityansco/public_html/nscorm_app/script/en_US-hfc_female-medium.onnx"
#MODEL_PATH  = "/home/nityansco/public_html/nscorm_app/script/en_US-hfc_male-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "female_artict_tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by "
    "Sun Pharmaceutical in response to a U S F D A observation at its Dadra "
    "site, where multiple unaddressed deviations related to missing GMP "
    "documents were found."
)
#text = sys.argv[1]
#OUTPUT_FILE = sys.argv[2]
# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice = PiperVoice.load(model_path)
    cfg = SynthesisConfig(
        speaker_id=speaker_id,
        length_scale=syn_params["length_scale"],
        noise_scale=syn_params["noise_scale"],
        noise_w_scale=syn_params["noise_w_scale"],
    )
    chunks = []
    src_rate = 22050
    for chunk in voice.synthesize(text, syn_config=cfg):
        pcm = np.frombuffer(chunk.audio_int16_bytes, dtype=np.int16).astype(np.float32)
        chunks.append(pcm)
        src_rate = chunk.sample_rate
    audio = np.concatenate(chunks) / 32768.0
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params, speaker_id=None):
    """Fall back to piper binary via subprocess."""
    tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    tmp.close()
    cmd = [
        piper_exe,
        "--model",        model_path,
        "--length_scale", str(syn_params["length_scale"]),
        "--noise_scale",  str(syn_params["noise_scale"]),
        "--noise_w",      str(syn_params["noise_w_scale"]),
        "--output_file",  tmp.name,
    ]
    if speaker_id is not None:
        cmd += ["--speaker", str(speaker_id)]
    proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE, text=True)
    _, err = proc.communicate(text)
    if proc.returncode != 0:
        raise RuntimeError(f"piper failed: {err}")
    with wave.open(tmp.name, "rb") as wf:
        src_rate = wf.getframerate()
        raw = wf.readframes(wf.getnframes())
    os.unlink(tmp.name)
    audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    return audio, src_rate


# Lower noise_scale → crisper consonants; length_scale 1.0 → natural pace
syn_params = dict(length_scale=1.3, noise_scale=0.25, noise_w_scale=0.4)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")

# ── Audio enhancement ─────────────────────────────────────────────────────────
def soft_compress(audio, threshold=0.30, ratio=3.5, makeup=1.35):
    """
    Soft-knee downward compressor.
    Brings up quiet consonants and reduces peak-to-average gap so every
    syllable is equally intelligible — the single biggest clarity booster
    for non-native TTS voices.
    """
    abs_a = np.abs(audio)
    gain  = np.ones_like(audio)
    above = abs_a > threshold
    if above.any():
        excess          = abs_a[above] - threshold
        target          = threshold + excess / ratio
        gain[above]     = target / abs_a[above]
    return audio * gain * makeup


def enhance_audio(audio, src_rate, target_rate):
    try:
        from scipy import signal as sg

        # 1. High-quality polyphase upsample (22 050 → 44 100 Hz)
        audio = sg.resample_poly(audio, target_rate, src_rate)
        nyq   = target_rate / 2.0

        # 2. Remove DC offset
        audio -= audio.mean()

        # 3. High-pass at 80 Hz — removes low-frequency rumble
        sos_hp = sg.butter(6, 80.0 / nyq, btype="high", output="sos")
        audio  = sg.sosfilt(sos_hp, audio)

        # 4. Cut muddiness (300–600 Hz) — removes boxiness that blurs words
        sos_mud = sg.butter(2, [300.0 / nyq, 600.0 / nyq], btype="bandpass", output="sos")
        audio  -= sg.sosfilt(sos_mud, audio) * 0.10

        # 5. Strong presence boost (1–3.5 kHz) — primary speech intelligibility band;
        #    consonants (t, d, s, p, k) live here and define sharpness
        sos_pres = sg.butter(3, [1000.0 / nyq, 3500.0 / nyq], btype="bandpass", output="sos")
        audio   += sg.sosfilt(sos_pres, audio) * 0.30

        # 6. Brilliance boost (5–9 kHz) — adds air, attack and sharpness to sibilants
        sos_brill = sg.butter(2, [5000.0 / nyq, 9000.0 / nyq], btype="bandpass", output="sos")
        audio    += sg.sosfilt(sos_brill, audio) * 0.18

        # 7. Gentle low-pass at 13 kHz — removes aliasing without killing brilliance
        sos_lp = sg.butter(4, 13000.0 / nyq, btype="low", output="sos")
        audio  = sg.sosfilt(sos_lp, audio)

    except ImportError:
        # scipy not available — numpy-only FFT upsample
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # 8. Soft-knee compression — evens out dynamics, raises quiet syllables
    audio = soft_compress(audio, threshold=0.30, ratio=3.5, makeup=1.35)

    # 9. Safety soft-clip (tanh) to handle any post-compression peaks cleanly
    audio = np.tanh(audio * 0.95) / 0.95

    # 10. Peak-normalise to −0.3 dB for maximum loudness without clipping
    peak = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak
    return audio


audio = enhance_audio(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
