import os
import sys
import re
import wave
import subprocess
import tempfile
import numpy as np

# Required: point to espeak-ng data bundled with Piper
os.environ.setdefault("ESPEAK_DATA_PATH", r"espeak-ng")

# ── Config ────────────────────────────────────────────────────────────────────
# Indian female voices from the L2-ARCTIC corpus (Hindi L1 speakers):
#   SVBI → speaker_id=2  (Indian female)
#   TNI  → speaker_id=9  (Indian female)
MODEL_PATH  = "hi_IN-rohan-medium.onnx"
SPEAKER_ID  = 9        # 2 = SVBI (Indian female)  |  9 = TNI (Indian female)
OUTPUT_FILE = "rohan_artict_tts_audio.wav"
PIPER_EXE   = r"piper"
TARGET_RATE = 44100   # upsample output to 44.1 kHz

text = (
    "The presentation outlines a global corrective action plan initiated by Sun Pharmaceutical in response to a U S F D A  observation at its Dadra site, where multiple unaddressed deviations related to missing GMP documents were found. The core issue was inadequate documentation practices and limited training, which prompted the development of a comprehensive training module on documentation control. This module will be extended to all departments across all US supply sites, covering topics like document handling, error correction, and logbook maintenance. The training will be delivered through instructor-led sessions and online modules, with a completion timeline of 240 days and oversight by Corporate Quality to ensure compliance and consistency."
)

# ── Text → sentence segments with pause durations ─────────────────────────────
def split_sentences(text):
    """
    Split text into (segment, pause_ms) pairs.
    Sentence-ending punctuation  → 480 ms pause (full breath).
    Comma / semicolon            → 250 ms pause (short breath).
    No trailing punctuation      → 150 ms pause (natural continuation).
    """
    # Split keeping the delimiter
    parts = re.split(r'(?<=[.,;!?])\s+', text.strip())
    result = []
    for i, part in enumerate(parts):
        part = part.strip()
        if not part:
            continue
        last = part[-1]
        if last in '.!?':
            pause = 480
        elif last in ',;':
            pause = 250
        else:
            pause = 150
        result.append((part, pause if i < len(parts) - 1 else 0))
    return result


def _silence(rate, ms):
    """Return a float32 zeros array of given duration."""
    return np.zeros(int(rate * ms / 1000), dtype=np.float32)


# ── Synthesis ─────────────────────────────────────────────────────────────────
def synthesize_with_python_api(text, model_path, syn_params, speaker_id=None):
    """Use piper-tts Python package directly (preferred)."""
    from piper.voice import PiperVoice, SynthesisConfig
    voice    = PiperVoice.load(model_path)
    cfg      = SynthesisConfig(
        speaker_id    = speaker_id,
        length_scale  = syn_params["length_scale"],
        noise_scale   = syn_params["noise_scale"],
        noise_w_scale = syn_params["noise_w_scale"],
    )
    segments  = split_sentences(text)
    all_pcm   = []
    src_rate  = 22050

    for sentence, pause_ms in segments:
        chunks = []
        for chunk in voice.synthesize(sentence, syn_config=cfg):
            pcm = np.frombuffer(chunk.audio_int16_bytes,
                                dtype=np.int16).astype(np.float32)
            chunks.append(pcm)
            src_rate = chunk.sample_rate
        if chunks:
            all_pcm.append(np.concatenate(chunks))
        if pause_ms > 0:
            # silence in raw int16 scale (divided by 32768 later)
            all_pcm.append(_silence(src_rate, pause_ms) * 32768.0)

    audio = np.concatenate(all_pcm) / 32768.0 if all_pcm else np.zeros(1)
    return audio, src_rate


def synthesize_with_subprocess(text, model_path, piper_exe, syn_params,
                                speaker_id=None):
    """Fall back to piper binary via subprocess."""
    segments  = split_sentences(text)
    all_audio = []
    src_rate  = 22050

    for sentence, pause_ms in segments:
        tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        tmp.close()
        cmd = [
            piper_exe,
            "--model",        model_path,
            "--length_scale", str(syn_params["length_scale"]),
            "--noise_scale",  str(syn_params["noise_scale"]),
            "--noise_w",      str(syn_params["noise_w_scale"]),
            "--output_file",  tmp.name,
        ]
        if speaker_id is not None:
            cmd += ["--speaker", str(speaker_id)]
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                text=True)
        _, err = proc.communicate(sentence)
        if proc.returncode != 0:
            raise RuntimeError(f"piper failed: {err}")
        with wave.open(tmp.name, "rb") as wf:
            src_rate = wf.getframerate()
            raw      = wf.readframes(wf.getnframes())
        os.unlink(tmp.name)
        seg = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
        all_audio.append(seg)
        if pause_ms > 0:
            all_audio.append(_silence(src_rate, pause_ms))

    audio = np.concatenate(all_audio) if all_audio else np.zeros(1)
    return audio, src_rate


# length_scale=1.35 → 35% slower speech; noise_scale=0.22 → clean, natural phonemes
syn_params = dict(length_scale=1.35, noise_scale=0.22, noise_w_scale=0.4)

try:
    audio, src_rate = synthesize_with_python_api(text, MODEL_PATH, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper-tts Python API")
except ImportError:
    audio, src_rate = synthesize_with_subprocess(text, MODEL_PATH, PIPER_EXE, syn_params, speaker_id=SPEAKER_ID)
    print("Synthesis: piper subprocess (piper-tts not found in this Python)")

# ══════════════════════════════════════════════════════════════════════════════
#  BROADCAST-QUALITY AUDIO PROCESSING CHAIN
#  Stage 1 : Upsample          → 22 050 Hz → 44 100 Hz
#  Stage 2 : Pre-processing    → DC remove, high-pass 80 Hz
#  Stage 3 : Noise gate        → suppress inter-word digital artifacts
#  Stage 4 : 5-band EQ         → shape vocal frequency response
#  Stage 5 : Harmonic exciter  → add warmth & natural upper harmonics
#  Stage 6 : De-esser          → tame harsh sibilants (s, sh, ch)
#  Stage 7 : Multi-band comp.  → 3-band musical compression
#  Stage 8 : Bus compression   → final "glue" and punch
#  Stage 9 : LUFS normalise    → −16 LUFS streaming standard
#  Stage 10: True-peak limiter → −0.3 dBTP brick-wall ceiling
# ══════════════════════════════════════════════════════════════════════════════

def _smooth_env(signal_abs, rate, ms):
    """Compute a smoothed RMS envelope over `ms` milliseconds."""
    win = max(1, int(rate * ms / 1000))
    kernel = np.ones(win) / win
    return np.sqrt(np.convolve(signal_abs ** 2, kernel, mode="same"))


def _soft_compress(audio, threshold, ratio, makeup):
    """Sample-by-sample soft-knee downward compressor."""
    abs_a = np.abs(audio)
    gain  = np.ones_like(audio)
    above = abs_a > threshold
    if above.any():
        excess       = abs_a[above] - threshold
        gain[above]  = (threshold + excess / ratio) / abs_a[above]
    return audio * gain * makeup


def noise_gate(audio, rate, threshold=0.018, attack_ms=5, release_ms=80,
               floor_db=-40):
    """
    Suppress TTS digital noise floor between words.
    Uses a smooth RMS envelope + attack/release to avoid clicks.
    """
    env = _smooth_env(audio, rate, ms=10)         # 10 ms RMS window
    floor = 10 ** (floor_db / 20)

    # Build a smooth binary gate with attack / release
    gate   = np.zeros(len(audio))
    a_coef = np.exp(-1.0 / (rate * attack_ms  / 1000))
    r_coef = np.exp(-1.0 / (rate * release_ms / 1000))
    state  = 0.0
    for i in range(len(env)):
        target = 1.0 if env[i] > threshold else floor
        if target > state:
            state = a_coef * state + (1 - a_coef) * target   # attack
        else:
            state = r_coef * state + (1 - r_coef) * target   # release
        gate[i] = state

    return audio * gate


def parametric_eq(audio, rate):
    """
    FFT-based linear-phase parametric EQ.

    IIR (Butterworth) filters have wide transition bands, causing the intended
    cuts/boosts to be diluted.  FFT-based EQ multiplies the spectrum directly,
    guaranteeing the exact gain at every frequency bin.

    EQ curve anchors — calibrated against spectral analysis of file (5):
      Measured dominant band: 150–300 Hz at −15.7 dBFS (hoarseness source).
      Previous anchors only cut −3 dB at 200 Hz; now cutting −16 dB there.
      Air band (9–13 kHz) was −35 dBFS; boosted more aggressively.

      0–80 Hz      → −30 dB  (rumble / DC block)
      110 Hz       →  −2 dB  (enter voice range, gentle)
      150 Hz       → −10 dB  (low-body cut begins — was barely −1 dB before)
      200 Hz       → −16 dB  (deep cut — primary hoarseness peak)
      260 Hz       → −17 dB  (cut peak — was only −8 dB here before)
      350 Hz       → −14 dB  (cut tapering)
      500 Hz       → −12 dB  (mud, slight ease-off)
      700 Hz       →  −8 dB  (mud taper)
      1 200 Hz     →  −2 dB  (approaching flat)
      1 800 Hz     →  +2 dB  (presence rise begins)
      2 500 Hz     →  +5 dB  (presence)
      4 000 Hz     →  +9 dB  (clarity peak — consonants, slightly raised)
      6 000 Hz     →  +7 dB  (upper clarity, raised)
      9 000 Hz     →  +8 dB  (air boost — was +5, now correcting −35 dBFS)
      13 000 Hz    →  +5 dB  (air taper — was +2)
      15 000 Hz    →  +2 dB  (brilliance rolloff — was 0)
      22 050 Hz    → −20 dB  (Nyquist ceiling)
    """
    from scipy.ndimage import gaussian_filter1d

    N     = len(audio)
    freqs = np.fft.rfftfreq(N, 1.0 / rate)
    spec  = np.fft.rfft(audio.astype(np.float64))

    # Anchor points  [Hz, dB]
    anchors = np.array([
        [0,       -30],   # DC / rumble block
        [80,      -30],   # sub-bass block
        [110,      -2],   # enter voice range
        [150,     -10],   # low-body cut — KEY FIX (was barely −1 dB)
        [200,     -16],   # deep cut — hoarseness peak (was −3 dB)
        [260,     -17],   # cut peak (was −8 dB)
        [350,     -14],   # taper begins
        [500,     -12],   # mud
        [700,      -8],   # mid mud taper
        [1200,     -2],   # approaching flat
        [1800,     +2],   # presence rise
        [2500,     +5],   # presence
        [4000,     +9],   # clarity peak
        [6000,     +7],   # upper clarity
        [9000,     +8],   # AIR boost (was +5)
        [13000,    +5],   # air taper (was +2)
        [15000,    +2],   # brilliance rolloff (was 0)
        [22050,   -20],   # Nyquist ceiling
    ], dtype=np.float64)

    gain_db = np.interp(freqs, anchors[:, 0], anchors[:, 1])

    # Smooth over ~50 Hz (tightened from 80 Hz) to better preserve the
    # steep 150–300 Hz cut while still rounding hard corners.
    hz_per_bin = rate / float(N)
    sigma      = max(2, int(50.0 / hz_per_bin))
    gain_db    = gaussian_filter1d(gain_db, sigma=sigma)

    gain_lin = 10.0 ** (gain_db / 20.0)
    result   = np.fft.irfft(spec * gain_lin, N).astype(np.float32)
    return result


def harmonic_exciter(audio, rate, drive=2.2, mix=0.12):
    """
    Subtle harmonic exciter: saturates the 3–7 kHz band to generate
    natural-sounding 2nd/3rd harmonics above 5 kHz — adds warmth and
    presence without boosting noise.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # Extract excitation band
    sos_in  = sg.butter(2, [3000.0 / nyq, 7000.0 / nyq], btype="bandpass", output="sos")
    band    = sg.sosfilt(sos_in, audio)

    # Soft saturation (tanh) → produces 2nd + 3rd harmonics
    saturated       = np.tanh(band * drive)
    new_harmonics   = saturated - band                       # only added content

    # Keep only content above 5 kHz (the actual new harmonics)
    sos_hi          = sg.butter(2, 5000.0 / nyq, btype="high", output="sos")
    new_harmonics   = sg.sosfilt(sos_hi, new_harmonics)

    return audio + new_harmonics * mix


def de_esser(audio, rate, lo=6000, hi=10000, threshold=0.10, ratio=3.5):
    """
    Frequency-selective compressor targeting sibilant frequencies (6–10 kHz).
    Reduces harsh 's', 'sh', 'ch' sounds without affecting the rest of the voice.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_ess  = sg.butter(2, [lo / nyq, hi / nyq], btype="bandpass", output="sos")
    ess_band = sg.sosfilt(sos_ess, audio)

    # 3 ms smoothed envelope of the sibilant band
    env      = _smooth_env(ess_band, rate, ms=3)
    gain     = np.ones_like(env)
    above    = env > threshold
    if above.any():
        gain[above] = (threshold + (env[above] - threshold) / ratio) / env[above]

    # 5 ms gain smoothing to avoid zipper noise
    win  = max(1, int(rate * 0.005))
    gain = np.convolve(gain, np.ones(win) / win, mode="same")
    gain = np.clip(gain, 0.15, 1.0)

    return audio - ess_band + ess_band * gain


def multiband_compress(audio, rate):
    """
    3-band compressor: low / mid / high each with independent settings.
    Splits at 600 Hz and 5 kHz.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    sos_lo_lp = sg.butter(4,  600.0 / nyq, btype="low",  output="sos")
    sos_hi_hp = sg.butter(4, 5000.0 / nyq, btype="high", output="sos")

    band_lo  = sg.sosfilt(sos_lo_lp, audio)
    band_hi  = sg.sosfilt(sos_hi_hp, audio)
    band_mid = audio - band_lo - band_hi

    # Low  band: gentle 3:1 – no makeup so we don't restore EQ-cut low-body energy
    band_lo  = _soft_compress(band_lo,  threshold=0.22, ratio=3.0, makeup=0.95)
    # Mid  band: 4:1 – even out vocal dynamics (most important band)
    band_mid = _soft_compress(band_mid, threshold=0.18, ratio=4.0, makeup=1.25)
    # High band: soft 2:1 – control transient sibilant energy
    band_hi  = _soft_compress(band_hi,  threshold=0.20, ratio=2.0, makeup=1.12)

    return band_lo + band_mid + band_hi


def lufs_normalize(audio, rate, target_lufs=-16.0):
    """
    Simplified EBU R128 LUFS loudness normalisation.
    Applies K-weighting (pre-filter + RLB) then measures integrated loudness
    across 400 ms gated blocks and scales to target.
    """
    from scipy import signal as sg
    nyq = rate / 2.0

    # K-weighting stage 1: high-shelf pre-filter (+4 dB above 1.5 kHz)
    sos_hs  = sg.butter(1, 1500.0 / nyq, btype="high", output="sos")
    kw      = audio + sg.sosfilt(sos_hs, audio) * 0.26

    # K-weighting stage 2: RLB high-pass at ~38 Hz
    sos_rlb = sg.butter(2, max(38.0 / nyq, 0.001), btype="high", output="sos")
    kw      = sg.sosfilt(sos_rlb, kw)

    # Integrated loudness over 400 ms / 100 ms hop blocks
    block = int(rate * 0.4)
    hop   = int(rate * 0.1)
    ms_list = [np.mean(kw[s:s + block] ** 2)
               for s in range(0, len(kw) - block, hop)
               if np.mean(kw[s:s + block] ** 2) > 0]

    if not ms_list:
        return audio

    ungated_mean = np.mean(ms_list)
    gate_thr     = ungated_mean * (10 ** (-10 / 10))          # −10 LU relative gate
    gated        = [m for m in ms_list if m > gate_thr] or ms_list
    integrated   = np.mean(gated)

    current_lufs = 10 * np.log10(integrated) - 0.691          # EBU offset
    gain_db      = target_lufs - current_lufs
    gain         = min(10 ** (gain_db / 20.0), 5.0)           # cap at +14 dB
    return audio * gain


def broadcast_enhance(audio, src_rate, target_rate):
    """Full broadcast-quality processing chain."""
    try:
        from scipy import signal as sg

        # ── Stage 1: Upsample ─────────────────────────────────────────────────
        audio = sg.resample_poly(audio, target_rate, src_rate)
        audio -= audio.mean()                                  # DC remove

        # ── Stage 2: Pre-processing ───────────────────────────────────────────
        sos_hp = sg.butter(6, 80.0 / (target_rate / 2), btype="high", output="sos")
        audio  = sg.sosfilt(sos_hp, audio)

        # ── Stage 3: Noise gate ───────────────────────────────────────────────
        audio = noise_gate(audio, target_rate,
                           threshold=0.018, attack_ms=5, release_ms=80)

        # ── Stage 4: 5-band parametric EQ ────────────────────────────────────
        audio = parametric_eq(audio, target_rate)

        # ── Stage 5: Harmonic exciter ─────────────────────────────────────────
        # drive=1.6, mix=0.12: adds subtle 2nd/3rd harmonics above 5 kHz for
        # warmth and vocal presence; low-body is now properly tamed by EQ so
        # exciter no longer exacerbates the hoarseness.
        audio = harmonic_exciter(audio, target_rate, drive=1.6, mix=0.12)

        # ── Stage 6: De-esser ─────────────────────────────────────────────────
        audio = de_esser(audio, target_rate,
                         lo=6000, hi=10000, threshold=0.10, ratio=3.5)

        # ── Stage 7: Multi-band compression ──────────────────────────────────
        audio = multiband_compress(audio, target_rate)

        # ── Stage 8: Bus compression (final glue) ────────────────────────────
        audio = _soft_compress(audio, threshold=0.32, ratio=2.5, makeup=1.25)

        # ── Stage 9: LUFS normalisation (−16 LUFS streaming standard) ────────
        audio = lufs_normalize(audio, target_rate, target_lufs=-16.0)

    except ImportError:
        # scipy not available — numpy-only FFT upsample + basic normalise
        print("scipy not found — using numpy-only processing")
        factor  = target_rate // src_rate
        n       = len(audio)
        fft     = np.fft.rfft(audio)
        fft_pad = np.zeros(n * factor // 2 + 1, dtype=complex)
        fft_pad[:len(fft)] = fft * factor
        audio   = np.fft.irfft(fft_pad, n * factor)
        audio  -= audio.mean()

    # ── Stage 10: True-peak limiter (−0.3 dBTP brick-wall ceiling) ───────────
    audio = np.tanh(audio * 0.92) / 0.92          # soft-saturation pre-limiter
    peak  = np.max(np.abs(audio))
    if peak > 0:
        audio *= 0.966 / peak                      # final −0.3 dB ceiling
    return audio


audio = broadcast_enhance(audio, src_rate, TARGET_RATE)

# ── Save ──────────────────────────────────────────────────────────────────────
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(OUTPUT_FILE, "wb") as wf:
    wf.setnchannels(1)
    wf.setsampwidth(2)
    wf.setframerate(TARGET_RATE)
    wf.writeframes(audio_int16.tobytes())

duration = len(audio_int16) / TARGET_RATE
print(f"Done: {OUTPUT_FILE}  ({duration:.1f}s  @{TARGET_RATE} Hz)")
