Загрузка данных


from pathlib import Path
import sys
import inspect
import numpy as np
import torch
import soundfile as sf

# ============================================================
# PATHS
# ============================================================

COSYVOICE_ROOT = Path(r"C:\Users\Geodezik\image_generation\CosyVoice_official").resolve()
MATCHA_ROOT = COSYVOICE_ROOT / "third_party" / "Matcha-TTS"
OLD_BAD_ROOT = Path(r"C:\Users\Geodezik\image_generation\cosyvoice").resolve()

MODEL_DIR = COSYVOICE_ROOT / "pretrained_models" / "Fun-CosyVoice3-0.5B"
QWEN_DIR = MODEL_DIR / "CosyVoice-BlankEN"
PROMPT_WAV = COSYVOICE_ROOT / "asset" / "zero_shot_prompt.wav"

OUT_DIR = COSYVOICE_ROOT / "_real_sanity_outputs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================
# CLEAN IMPORTS
# ============================================================

sys.path = [
    p for p in sys.path
    if str(OLD_BAD_ROOT).lower() not in str(Path(p).resolve()).lower()
]

for name in list(sys.modules.keys()):
    if (
        name == "cosyvoice"
        or name.startswith("cosyvoice.")
        or name == "wetext"
        or name.startswith("wetext.")
    ):
        del sys.modules[name]

sys.path.insert(0, str(MATCHA_ROOT))
sys.path.insert(0, str(COSYVOICE_ROOT))

# ============================================================
# BASIC INFO
# ============================================================

print("=" * 100)
print("ENV / PATH")
print("=" * 100)
print("python:", sys.executable)
print("torch:", torch.__version__)
print("cuda:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))
    print("torch cuda:", torch.version.cuda)

print("COSYVOICE_ROOT:", COSYVOICE_ROOT)
print("MODEL_DIR:", MODEL_DIR)
print("QWEN_DIR:", QWEN_DIR)
print("PROMPT_WAV:", PROMPT_WAV)

# ============================================================
# NO FFMPEG LOADER
# ============================================================

def resample_np(audio_1d: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
    audio_1d = audio_1d.astype(np.float32, copy=False)

    if sr == target_sr:
        return audio_1d

    try:
        import soxr
        return soxr.resample(audio_1d, sr, target_sr).astype(np.float32, copy=False)
    except Exception:
        from scipy.signal import resample_poly
        from math import gcd
        g = gcd(sr, target_sr)
        return resample_poly(audio_1d, target_sr // g, sr // g).astype(np.float32, copy=False)


def load_wav_no_torchcodec(wav, target_sr, min_sr=16000):
    audio, sr = sf.read(str(wav), dtype="float32", always_2d=True)
    mono = audio.mean(axis=1).astype(np.float32, copy=False)

    if sr < min_sr:
        raise ValueError(f"sample rate {sr} is lower than min_sr {min_sr}")

    mono = resample_np(mono, sr, target_sr)
    return torch.from_numpy(mono).unsqueeze(0)

# ============================================================
# IMPORT COSYVOICE
# ============================================================

print("\n" + "=" * 100)
print("COSYVOICE IMPORT")
print("=" * 100)

import cosyvoice
import cosyvoice.cli.cosyvoice as cosyvoice_cli
import cosyvoice.cli.frontend as frontend
import cosyvoice.utils.file_utils as file_utils

file_utils.load_wav = load_wav_no_torchcodec
frontend.load_wav = load_wav_no_torchcodec

print("cosyvoice:", cosyvoice.__file__)
print("cosyvoice cli:", cosyvoice_cli.__file__)
print("frontend:", frontend.__file__)
print("file_utils:", file_utils.__file__)

expected = str(COSYVOICE_ROOT / "cosyvoice").lower()
actual = str(Path(cosyvoice.__file__).parent).lower()
if not actual.startswith(expected):
    raise RuntimeError(f"Wrong cosyvoice import: {actual}")

# ============================================================
# CHECK WETEXT
# ============================================================

print("\n" + "=" * 100)
print("WETEXT CHECK")
print("=" * 100)

try:
    import wetext
    from wetext import Normalizer
    print("wetext:", wetext.__file__)
    print("Normalizer:", Normalizer)

    for lang in ["en", "zh"]:
        try:
            n = Normalizer(lang=lang)
        except TypeError:
            n = Normalizer()

        for text in ["Hello, I have 123 apples.", "今天是2026年6月28日。"]:
            try:
                print(f"lang={lang} {text!r} -> {n.normalize(text)!r}")
            except Exception as e:
                print(f"normalize failed lang={lang}:", repr(e))
except Exception as e:
    raise RuntimeError(f"wetext failed: {e}")

# ============================================================
# CHECK RAW HF TOKENIZER — informational only
# ============================================================

print("\n" + "=" * 100)
print("RAW HF TOKENIZER CHECK — INFORMATIONAL ONLY")
print("=" * 100)

from transformers import AutoTokenizer

hf_tok = AutoTokenizer.from_pretrained(
    str(QWEN_DIR),
    trust_remote_code=True,
    local_files_only=True,
)

print("hf tokenizer:", hf_tok.__class__)
print("hf special_tokens_map:", hf_tok.special_tokens_map)
print("hf <|endofprompt|> id:", hf_tok.convert_tokens_to_ids("<|endofprompt|>"))
print("hf encode <|endofprompt|>:", hf_tok("<|endofprompt|>", add_special_tokens=False)["input_ids"])

print(
    "\nNOTE: если здесь <|endofprompt|> не special — это ещё НЕ окончательный диагноз. "
    "Главное — что делает CosyVoice tokenizer ниже."
)

# ============================================================
# CHECK COSYVOICE TOKENIZER DIRECTLY
# ============================================================

print("\n" + "=" * 100)
print("COSYVOICE TOKENIZER CHECK")
print("=" * 100)

from cosyvoice.tokenizer.tokenizer import get_qwen_tokenizer

print("get_qwen_tokenizer:", get_qwen_tokenizer)
print("signature:", inspect.signature(get_qwen_tokenizer))

cv_tok = get_qwen_tokenizer(
    token_path=str(QWEN_DIR),
    skip_special_tokens=True,
    version="cosyvoice3",
)

print("cv_tok:", type(cv_tok))
print("cv_tok attrs:")
for k, v in sorted(vars(cv_tok).items()):
    if "token" in k.lower() or "special" in k.lower() or "skip" in k.lower():
        print(" ", k, "=", type(v), repr(v)[:500])

print("\ncv_tok callable methods:")
for name in dir(cv_tok):
    if name.startswith("_"):
        continue
    obj = getattr(cv_tok, name)
    if callable(obj):
        try:
            sig = inspect.signature(obj)
        except Exception:
            sig = "<no sig>"
        if any(s in name.lower() for s in ["token", "encode", "decode", "text"]):
            print(" ", name, sig)

TEST_TEXTS = [
    "<|endofprompt|>",
    "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。",
    "八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。",
    "Hello. This is a simple English test.",
]

def try_tokenizer_call(tok, text):
    candidates = []

    # common APIs in CosyVoice tokenizer wrappers
    for method_name in ["encode", "tokenize", "text_tokenize", "__call__"]:
        if method_name == "__call__":
            if callable(tok):
                candidates.append(("__call__", tok))
        elif hasattr(tok, method_name):
            candidates.append((method_name, getattr(tok, method_name)))

    for method_name, fn in candidates:
        print(f"\nTrying cv_tok.{method_name}({text!r})")
        try:
            res = fn(text)
            print(" result type:", type(res))
            print(" result repr:", repr(res)[:1000])

            if isinstance(res, torch.Tensor):
                arr = res.detach().cpu().flatten().tolist()
                print(" tensor flat first:", arr[:100])
                print(" contains 151646:", 151646 in arr)
            elif isinstance(res, (list, tuple)):
                flat = []
                def flatten(x):
                    if isinstance(x, torch.Tensor):
                        flat.extend(x.detach().cpu().flatten().tolist())
                    elif isinstance(x, np.ndarray):
                        flat.extend(x.flatten().tolist())
                    elif isinstance(x, (list, tuple)):
                        for y in x:
                            flatten(y)
                    elif isinstance(x, int):
                        flat.append(x)
                flatten(res)
                print(" flattened first:", flat[:100])
                print(" contains 151646:", 151646 in flat)

        except Exception as e:
            print(" failed:", type(e).__name__, str(e))

for text in TEST_TEXTS:
    try_tokenizer_call(cv_tok, text)

# ============================================================
# LOAD MODEL
# ============================================================

print("\n" + "=" * 100)
print("LOAD MODEL")
print("=" * 100)

from cosyvoice.cli.cosyvoice import AutoModel

cosyvoice_model = AutoModel(
    model_dir=str(MODEL_DIR),
    fp16=False,
)

print("model:", type(cosyvoice_model))
print("sample_rate:", cosyvoice_model.sample_rate)
print("frontend:", type(cosyvoice_model.frontend))
print("frontend.text_frontend:", getattr(cosyvoice_model.frontend, "text_frontend", None))
print("frontend.tokenizer:", type(cosyvoice_model.frontend.tokenizer))

if getattr(cosyvoice_model.frontend, "text_frontend", None) != "wetext":
    raise RuntimeError("frontend is not wetext")

# ============================================================
# DIRECT FRONTEND INPUT CHECK
# ============================================================

print("\n" + "=" * 100)
print("FRONTEND_ZERO_SHOT MODEL INPUT CHECK")
print("=" * 100)

TEXT_ZH = "八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。"
PROMPT_TEXT_ZH = "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"

model_input = cosyvoice_model.frontend.frontend_zero_shot(
    TEXT_ZH,
    PROMPT_TEXT_ZH,
    str(PROMPT_WAV),
    cosyvoice_model.sample_rate,
    "",
)

print("model_input keys:", list(model_input.keys()))

for k, v in model_input.items():
    print("\nKEY:", k)
    print(" type:", type(v))

    if isinstance(v, torch.Tensor):
        vv = v.detach().cpu()
        flat = vv.flatten().tolist()
        print(" shape:", tuple(v.shape))
        print(" dtype:", v.dtype)
        print(" device:", v.device)
        print(" first 120:", flat[:120])
        if v.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
            print(" min/max:", int(vv.min()), int(vv.max()))
            print(" contains 151646:", 151646 in flat)
            print(" contains raw split ids [27,91,408,1055,40581,91,29]:",
                  all(x in flat for x in [27, 91, 408, 1055, 40581, 91, 29]))
    else:
        print(" repr:", repr(v)[:1000])

# Hard check: one of text/prompt tensors should contain 151646 if CosyVoice tokenizer is correct.
contains_endofprompt = False
for k, v in model_input.items():
    if isinstance(v, torch.Tensor) and v.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
        flat = v.detach().cpu().flatten().tolist()
        if 151646 in flat:
            contains_endofprompt = True
            print("\nFOUND 151646 in:", k)

if not contains_endofprompt:
    print("\nWARNING: 151646 was NOT found in frontend model_input.")
    print("This means CosyVoice tokenizer path is probably still wrong or code/tokenizer mismatch exists.")
else:
    print("\nOK: frontend model_input contains 151646.")

# ============================================================
# INFERENCE TESTS
# ============================================================

def save_outs(name, outs):
    outs = list(outs)
    if not outs:
        raise RuntimeError("no output chunks")

    wav = torch.cat([o["tts_speech"] for o in outs], dim=1)
    audio = wav.detach().float().cpu().squeeze(0).numpy()
    audio = np.clip(audio, -1.0, 1.0)

    out_path = OUT_DIR / f"{name}.wav"
    sf.write(str(out_path), audio, cosyvoice_model.sample_rate, subtype="PCM_16")

    print("\nSAVED:", out_path)
    print(" shape:", tuple(wav.shape))
    print(" duration:", wav.shape[1] / cosyvoice_model.sample_rate)
    print(" min/max/std:", wav.min().item(), wav.max().item(), wav.std().item())
    return out_path

print("\n" + "=" * 100)
print("TEST 1: OFFICIAL ZH ZERO-SHOT")
print("=" * 100)

out_zh = save_outs(
    "01_official_zh_zero_shot",
    cosyvoice_model.inference_zero_shot(
        TEXT_ZH,
        PROMPT_TEXT_ZH,
        str(PROMPT_WAV),
        stream=False,
        speed=1.0,
    )
)

print("\n" + "=" * 100)
print("TEST 2: EN ZERO-SHOT WITH CHINESE PROMPT TRANSCRIPT")
print("=" * 100)

TEXT_EN = (
    "Hello. This is a simple English test. "
    "I am checking whether the speech is clear, continuous, and understandable."
)

out_en = save_outs(
    "02_en_zero_shot",
    cosyvoice_model.inference_zero_shot(
        TEXT_EN,
        PROMPT_TEXT_ZH,
        str(PROMPT_WAV),
        stream=False,
        speed=1.0,
    )
)

print("\nDONE")
print("Listen:")
print(out_zh)
print(out_en)