Загрузка данных


from pathlib import Path
import sys
import re
import numpy as np
import torch
import soundfile as sf
from IPython.display import Audio, display

# =========================
# PATHS
# =========================

COSYVOICE_ROOT = Path(r"C:\Users\Geodezik\image_generation\cosyvoice").resolve()
MATCHA_ROOT = COSYVOICE_ROOT / "third_party" / "Matcha-TTS"

sys.path.insert(0, str(MATCHA_ROOT))
sys.path.insert(0, str(COSYVOICE_ROOT))

MODEL_DIR = COSYVOICE_ROOT / "pretrained_models" / "Fun-CosyVoice3-0.5B-2512"

ROOT_DIR = Path(r"C:\Users\Geodezik\image_generation\sounds")
MOOD = "best"

TEST_OUT_DIR = Path(r"C:\Users\Geodezik\image_generation\cosyvoice_diag_outputs")
TEST_OUT_DIR.mkdir(parents=True, exist_ok=True)

# Официальный prompt из репозитория CosyVoice
OFFICIAL_PROMPT_WAV = COSYVOICE_ROOT / "asset" / "zero_shot_prompt.wav"


# =========================
# NO FFMPEG / NO TORCHCODEC LOADER
# =========================

import torchaudio.functional as AF

def load_wav_no_torchcodec(wav, target_sr, min_sr=16000):
    wav = Path(wav)

    audio, sr = sf.read(str(wav), dtype="float32", always_2d=True)

    # soundfile: [samples, channels]
    # torch: [channels, samples]
    speech = torch.from_numpy(audio.T)

    # force mono
    speech = speech.mean(dim=0, keepdim=True)

    print(f"\n[load_wav_no_torchcodec]")
    print(f"  file: {wav}")
    print(f"  exists: {wav.exists()}")
    print(f"  original sr: {sr}")
    print(f"  original shape soundfile: {audio.shape}")
    print(f"  mono shape torch: {tuple(speech.shape)}")
    print(f"  duration original: {audio.shape[0] / sr:.3f} sec")
    print(f"  min/max/std: {speech.min().item():.6f} / {speech.max().item():.6f} / {speech.std().item():.6f}")

    if sr < min_sr:
        raise ValueError(f"sample rate {sr} is lower than min_sr {min_sr}")

    if sr != target_sr:
        speech = AF.resample(speech, sr, target_sr)
        print(f"  resampled to: {target_sr}")
        print(f"  resampled shape: {tuple(speech.shape)}")
        print(f"  duration resampled: {speech.shape[1] / target_sr:.3f} sec")
        print(f"  resampled min/max/std: {speech.min().item():.6f} / {speech.max().item():.6f} / {speech.std().item():.6f}")

    return speech


# =========================
# PATCH COSYVOICE LOAD_WAV
# =========================

import cosyvoice.utils.file_utils as file_utils
file_utils.load_wav = load_wav_no_torchcodec

import cosyvoice.cli.frontend as frontend
frontend.load_wav = load_wav_no_torchcodec


# =========================
# IMPORT COSYVOICE
# =========================

from cosyvoice.cli.cosyvoice import AutoModel

import cosyvoice
import cosyvoice.cli.cosyvoice as cosyvoice_cli

print("=" * 80)
print("IMPORT CHECK")
print("=" * 80)
print("COSYVOICE_ROOT:", COSYVOICE_ROOT)
print("MATCHA_ROOT:", MATCHA_ROOT)
print("cosyvoice package:", cosyvoice.__file__)
print("cosyvoice cli:", cosyvoice_cli.__file__)
print("frontend:", frontend.__file__)
print("file_utils:", file_utils.__file__)


# =========================
# ENV CHECK
# =========================

print("\n" + "=" * 80)
print("ENV CHECK")
print("=" * 80)

print("python:", sys.executable)
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))

try:
    import torchaudio
    print("torchaudio:", torchaudio.__version__)
except Exception as e:
    print("torchaudio import error:", repr(e))

try:
    import onnxruntime as ort
    print("onnxruntime:", ort.__version__)
    print("onnxruntime providers:", ort.get_available_providers())
except Exception as e:
    print("onnxruntime import error:", repr(e))


# =========================
# MODEL DIR CHECK
# =========================

print("\n" + "=" * 80)
print("MODEL DIR CHECK")
print("=" * 80)

print("MODEL_DIR:", MODEL_DIR)
print("exists:", MODEL_DIR.exists())

if not MODEL_DIR.exists():
    raise FileNotFoundError(f"MODEL_DIR does not exist: {MODEL_DIR}")

print("\nTop-level files:")
for p in sorted(MODEL_DIR.iterdir()):
    if p.is_file():
        size_mb = p.stat().st_size / 1024 / 1024
        print(f"  FILE {p.name:45s} {size_mb:10.2f} MB")
    elif p.is_dir():
        print(f"  DIR  {p.name}")

print("\nGit LFS pointer check:")
lfs_found = False
for p in MODEL_DIR.rglob("*"):
    if not p.is_file():
        continue

    # LFS pointer обычно маленький текстовый файл
    if p.stat().st_size > 1024 * 1024:
        continue

    try:
        txt = p.read_text(encoding="utf-8", errors="ignore")
    except Exception:
        continue

    if "version https://git-lfs.github.com/spec" in txt:
        lfs_found = True
        print("  LFS POINTER, NOT REAL FILE:", p)

if not lfs_found:
    print("  no LFS pointers found in small files")


# =========================
# HELPER FUNCTIONS
# =========================

def clean_text(s: str) -> str:
    s = s.replace("\ufeff", "")
    s = s.replace("\u200b", "")
    s = s.replace("\r", " ")
    s = s.replace("\n", " ")
    s = re.sub(r"\s+", " ", s)
    return s.strip()


def inspect_audio_file(path: Path, name="audio"):
    print("\n" + "-" * 80)
    print(f"INSPECT {name}")
    print("-" * 80)
    print("path:", path)
    print("exists:", path.exists())

    if not path.exists():
        return

    audio, sr = sf.read(str(path), dtype="float32", always_2d=True)
    print("sr:", sr)
    print("shape:", audio.shape)
    print("duration:", audio.shape[0] / sr, "sec")
    print("dtype:", audio.dtype)
    print("min/max/std:", float(audio.min()), float(audio.max()), float(audio.std()))

    if audio.shape[1] > 1:
        mono = audio.mean(axis=1)
    else:
        mono = audio[:, 0]

    print("mono min/max/std:", float(mono.min()), float(mono.max()), float(mono.std()))


def save_outs(outs, output_path: Path, sample_rate: int):
    outs = list(outs)

    print("\n[save_outs]")
    print("  chunks:", len(outs))

    if len(outs) == 0:
        raise RuntimeError("No output chunks from inference")

    for i, o in enumerate(outs):
        print(f"  chunk {i} keys:", list(o.keys()))
        if "tts_speech" in o:
            t = o["tts_speech"]
            print(f"    tts_speech shape: {tuple(t.shape)}")
            print(f"    min/max/std: {t.min().item():.6f} / {t.max().item():.6f} / {t.std().item():.6f}")

    wav = torch.cat([o["tts_speech"] for o in outs], dim=1)
    print("  concat shape:", tuple(wav.shape))
    print("  sample_rate:", sample_rate)
    print("  duration:", wav.shape[1] / sample_rate, "sec")
    print("  concat min/max/std:", wav.min().item(), wav.max().item(), wav.std().item())

    audio = wav.detach().cpu().squeeze(0).numpy()
    audio = np.clip(audio, -1.0, 1.0)

    sf.write(str(output_path), audio, sample_rate)
    print("  saved:", output_path)

    return output_path


def run_and_save(name: str, fn):
    print("\n" + "=" * 80)
    print(f"RUN TEST: {name}")
    print("=" * 80)

    output_path = TEST_OUT_DIR / f"{name}.wav"

    try:
        outs = list(fn())
        saved = save_outs(outs, output_path, cosyvoice.sample_rate)
        display(Audio(str(saved)))
        return saved
    except Exception as e:
        print(f"[ERROR in {name}]")
        print(type(e).__name__ + ":", str(e))
        raise


# =========================
# PROMPT FILE CHECKS
# =========================

inspect_audio_file(OFFICIAL_PROMPT_WAV, "official zero_shot_prompt.wav")

user_wavs = sorted((ROOT_DIR / MOOD).glob("*.ogg"))
if not user_wavs:
    user_wavs = sorted((ROOT_DIR / MOOD).glob("*.wav"))

if not user_wavs:
    raise FileNotFoundError(f"No .ogg or .wav files found in {ROOT_DIR / MOOD}")

USER_PROMPT_WAV = user_wavs[0]
USER_PROMPT_TXT = USER_PROMPT_WAV.with_suffix(".txt")

inspect_audio_file(USER_PROMPT_WAV, "user prompt wav/ogg")

print("\nUser prompt selected:")
print("  wav:", USER_PROMPT_WAV)
print("  txt:", USER_PROMPT_TXT)
print("  txt exists:", USER_PROMPT_TXT.exists())

if USER_PROMPT_TXT.exists():
    USER_PROMPT_TRANSCRIPT = clean_text(USER_PROMPT_TXT.read_text(encoding="utf-8", errors="ignore"))
else:
    USER_PROMPT_TRANSCRIPT = ""

print("  transcript repr:", repr(USER_PROMPT_TRANSCRIPT))
print("  transcript len:", len(USER_PROMPT_TRANSCRIPT))


# =========================
# LOAD MODEL
# =========================

print("\n" + "=" * 80)
print("LOAD MODEL")
print("=" * 80)

cosyvoice = AutoModel(model_dir=str(MODEL_DIR))

print("cosyvoice object:", type(cosyvoice))
print("sample_rate:", cosyvoice.sample_rate)

# Дополнительная попытка показать внутренние поля, если они есть
for attr in ["model_dir", "frontend", "model", "llm", "flow", "hift"]:
    if hasattr(cosyvoice, attr):
        try:
            print(f"{attr}:", type(getattr(cosyvoice, attr)))
        except Exception as e:
            print(f"{attr}: <error {repr(e)}>")


# =========================
# TEST 1: OFFICIAL ZH ZERO-SHOT
# =========================

if OFFICIAL_PROMPT_WAV.exists():
    run_and_save(
        "01_official_zh_zero_shot",
        lambda: cosyvoice.inference_zero_shot(
            "八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。",
            "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。",
            str(OFFICIAL_PROMPT_WAV),
            stream=False,
            speed=1.0,
        )
    )
else:
    print("\nSKIP official zh test: official prompt wav not found:", OFFICIAL_PROMPT_WAV)


# =========================
# TEST 2: ENGLISH CROSS-LINGUAL WITH OFFICIAL PROMPT
# =========================

if OFFICIAL_PROMPT_WAV.exists():
    run_and_save(
        "02_official_prompt_en_cross_lingual",
        lambda: cosyvoice.inference_cross_lingual(
            "You are a helpful assistant.<|endofprompt|>This is a short English speech synthesis test.",
            str(OFFICIAL_PROMPT_WAV),
            stream=False,
            speed=1.0,
        )
    )


# =========================
# TEST 3: RUSSIAN CROSS-LINGUAL WITH OFFICIAL PROMPT
# =========================

if OFFICIAL_PROMPT_WAV.exists():
    run_and_save(
        "03_official_prompt_ru_cross_lingual",
        lambda: cosyvoice.inference_cross_lingual(
            "You are a helpful assistant.<|endofprompt|>Это короткий тест синтеза русской речи.",
            str(OFFICIAL_PROMPT_WAV),
            stream=False,
            speed=1.0,
        )
    )


# =========================
# TEST 4: USER PROMPT EN CROSS-LINGUAL
# =========================

run_and_save(
    "04_user_prompt_en_cross_lingual",
    lambda: cosyvoice.inference_cross_lingual(
        "You are a helpful assistant.<|endofprompt|>This is a short English speech synthesis test.",
        str(USER_PROMPT_WAV),
        stream=False,
        speed=1.0,
    )
)


# =========================
# TEST 5: USER PROMPT RU CROSS-LINGUAL
# =========================

run_and_save(
    "05_user_prompt_ru_cross_lingual",
    lambda: cosyvoice.inference_cross_lingual(
        "You are a helpful assistant.<|endofprompt|>Это короткий тест синтеза русской речи.",
        str(USER_PROMPT_WAV),
        stream=False,
        speed=1.0,
    )
)


# =========================
# TEST 6: USER PROMPT RU ZERO-SHOT WITH TXT
# =========================

if USER_PROMPT_TRANSCRIPT:
    USER_PROMPT_TEXT = f"You are a helpful assistant.<|endofprompt|>{USER_PROMPT_TRANSCRIPT}"

    print("\nZero-shot prompt text:")
    print(repr(USER_PROMPT_TEXT))

    run_and_save(
        "06_user_prompt_ru_zero_shot_with_txt",
        lambda: cosyvoice.inference_zero_shot(
            "Это короткий тест синтеза русской речи.",
            USER_PROMPT_TEXT,
            str(USER_PROMPT_WAV),
            stream=False,
            speed=1.0,
        )
    )
else:
    print("\nSKIP zero-shot user txt test: no transcript")


# =========================
# DONE
# =========================

print("\n" + "=" * 80)
print("DIAGNOSTICS DONE")
print("=" * 80)
print("Outputs saved to:", TEST_OUT_DIR)