Загрузка данных


from pathlib import Path
import sys
import numpy as np
import torch
import soundfile as sf
from IPython.display import Audio, display

sys.path.append("third_party/Matcha-TTS")

from cosyvoice.cli.cosyvoice import AutoModel


MODEL_DIR = "pretrained_models/Fun-CosyVoice3-0.5B"

ROOT_DIR = Path("sounds_assistant")
MOODS = ["slow", "sad", "neutral"]

TEXT = "Я слишком долго ждал этого момента. И теперь ты наконец здесь."

USE_INSTRUCT = True
INSTRUCTION = "Speak in Russian slowly, with long pauses, tired and melancholic, low voice.<|endofprompt|>"

PROMPT_TEXT = "You are a helpful assistant.<|endofprompt|>"

OUTPUT_WAV = "cosyvoice3_output.wav"
TMP_PROMPT_WAV = "cosyvoice3_prompt.wav"

TARGET_SR = 16000
MAX_REFS_PER_MOOD = 8


def load_wav(path):
    audio, sr = sf.read(str(path), dtype="float32", always_2d=True)
    wav = torch.from_numpy(audio.T)

    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)

    if sr != TARGET_SR:
        import torchaudio.functional as F
        wav = F.resample(wav, sr, TARGET_SR)
        sr = TARGET_SR

    return wav, sr


def save_wav(path, wav, sr):
    wav = wav.detach().cpu()

    if wav.ndim == 2:
        wav = wav.squeeze(0)

    wav = wav.numpy()
    wav = np.clip(wav, -1.0, 1.0)

    sf.write(str(path), wav, sr)


REFERENCE_WAVS = []

for mood in MOODS:
    mood_dir = ROOT_DIR / mood
    REFERENCE_WAVS += sorted(mood_dir.glob("*.wav"))[:MAX_REFS_PER_MOOD]

if not REFERENCE_WAVS:
    raise FileNotFoundError("Не найдено .wav файлов")

print("Refs:", len(REFERENCE_WAVS))
for p in REFERENCE_WAVS:
    print(p)

parts = []

for wav_path in REFERENCE_WAVS:
    wav, sr = load_wav(wav_path)
    parts.append(wav)

prompt_wav = torch.cat(parts, dim=1)
save_wav(TMP_PROMPT_WAV, prompt_wav, TARGET_SR)

cosyvoice = AutoModel(model_dir=MODEL_DIR)

if USE_INSTRUCT:
    outs = list(
        cosyvoice.inference_instruct2(
            TEXT,
            INSTRUCTION,
            TMP_PROMPT_WAV,
            stream=False,
        )
    )
else:
    outs = list(
        cosyvoice.inference_zero_shot(
            TEXT,
            PROMPT_TEXT,
            TMP_PROMPT_WAV,
            stream=False,
        )
    )

wav = torch.cat([o["tts_speech"] for o in outs], dim=1)

save_wav(OUTPUT_WAV, wav, cosyvoice.sample_rate)

display(Audio(OUTPUT_WAV))