Files
a2f-service/services/a2f_api/text_to_blendshapes_service.py
2025-12-29 11:22:51 +08:00

319 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import tempfile
import concurrent.futures
import queue
import threading
from datetime import datetime
from tts_service import TTSService
from edge_tts_service import EdgeTTSService
from a2f_service import A2FService
from blend_shape_parser import BlendShapeParser
class TextToBlendShapesService:
DEFAULT_SPLIT_PUNCTUATIONS = '。!?;!?;,'
def __init__(self, lang='zh-CN', a2f_url="192.168.1.39:52000", tts_provider='edge-tts'):
"""
初始化服务
:param lang: 语言
:param a2f_url: A2F服务地址
:param tts_provider: TTS提供商 ('pyttsx3''edge-tts')
"""
# 根据选择初始化TTS服务
if tts_provider == 'edge-tts':
self.tts = EdgeTTSService(lang=lang)
else:
self.tts = TTSService(lang=lang)
self.a2f = A2FService(a2f_url=a2f_url)
self.parser = BlendShapeParser()
def text_to_blend_shapes(
self,
text: str,
output_dir: str = None,
segment: bool = False,
split_punctuations: str = None,
max_sentence_length: int = None
):
if segment:
return self._text_to_blend_shapes_segmented(
text,
output_dir,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length
)
output_dir, audio_path = self._prepare_output_paths(output_dir)
self.tts.text_to_audio(text, audio_path)
csv_path = self.a2f.audio_to_csv(audio_path)
frames = self.parser.csv_to_blend_shapes(csv_path)
return {
'success': True,
'frames': frames,
'audio_path': audio_path,
'csv_path': csv_path
}
def iter_text_to_blend_shapes_stream(
self,
text: str,
output_dir: str = None,
split_punctuations: str = None,
max_sentence_length: int = None,
first_sentence_split_size: int = None
):
output_dir = output_dir or tempfile.gettempdir()
os.makedirs(output_dir, exist_ok=True)
sentences = self.split_sentences(
text,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length,
first_sentence_split_size=first_sentence_split_size
)
if not sentences:
yield {'type': 'error', 'message': '文本为空'}
return
# 测试:只处理第一句
sentences = sentences[:1]
print(f"[测试模式] 只处理第一句: {sentences[0]}")
yield {
'type': 'status',
'stage': 'split',
'sentences': len(sentences),
'sentence_texts': sentences, # 发送句子文本列表
'message': f'已拆分为 {len(sentences)} 个句子'
}
# 打印句子列表用于调试
print(f"[调试] 发送给前端的句子列表:")
for i, s in enumerate(sentences):
print(f" [{i}] {s}")
# 使用队列来收集处理完成的句子
result_queue = queue.Queue()
def process_and_queue(index, sentence):
"""处理句子并放入队列"""
try:
print(f"[工作线程 {index}] 开始处理: {sentence[:30]}...")
frames, audio_path, csv_path = self._process_sentence(sentence, output_dir, index)
result_queue.put((index, 'success', frames, None))
print(f"[工作线程 {index}] 完成!已生成 {len(frames)} 帧并加入队列")
except Exception as e:
print(f"[工作线程 {index}] 失败: {str(e)}")
import traceback
traceback.print_exc()
result_queue.put((index, 'error', None, str(e)))
# 提交所有句子到线程池并发处理(增加并发数以加速)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for index, sentence in enumerate(sentences):
executor.submit(process_and_queue, index, sentence)
# 按顺序从队列中取出结果并推送
completed = {}
next_index = 0
total_frames = 0
cumulative_time = 0.0 # 累计时间,用于连续句子
while next_index < len(sentences):
# 如果下一个句子还没完成,等待队列
if next_index not in completed:
yield {
'type': 'status',
'stage': 'processing',
'sentence_index': next_index,
'sentences': len(sentences),
'message': f'正在处理 {next_index + 1}/{len(sentences)}'
}
# 从队列中获取结果
while next_index not in completed:
try:
index, status, frames, error = result_queue.get(timeout=1)
completed[index] = (status, frames, error)
print(f"[主线程] 收到句子 {index} 的处理结果")
except queue.Empty:
continue
# 推送下一个句子的帧
status, frames, error = completed[next_index]
if status == 'error':
yield {'type': 'error', 'message': f'句子 {next_index} 处理失败: {error}'}
return
# 如果是连续句子,调整时间码使其无缝衔接
is_continuation = self.is_continuation[next_index] if next_index < len(self.is_continuation) else False
print(f"[主线程] 正在推送句子 {next_index}{len(frames)}{'(连续)' if is_continuation else ''}")
print(f"[调试] 句子 {next_index} 对应文本: {sentences[next_index] if next_index < len(sentences) else 'N/A'}")
# 如果不是连续句子,重置累计时间
if not is_continuation and next_index > 0:
cumulative_time = 0.0
for frame in frames:
# 调整时间码:从累计时间开始
frame['timeCode'] = cumulative_time + frame['timeCode']
frame['sentenceIndex'] = next_index
total_frames += 1
yield {'type': 'frame', 'frame': frame}
# 更新累计时间为当前句子的最后一帧时间
if frames:
cumulative_time = frames[-1]['timeCode']
next_index += 1
print(f"[主线程] 流式传输完成,共 {total_frames}")
yield {
'type': 'end',
'frames': total_frames
}
def _process_sentence(self, sentence, output_dir, index):
"""处理单个句子: TTS -> A2F -> 解析"""
import time
start_time = time.time()
print(f"[线程 {index}] 开始处理: {sentence[:30]}...")
print(f"[调试] 线程 {index} 实际处理的完整文本: [{sentence}] (长度: {len(sentence)}字)")
_, audio_path = self._prepare_output_paths(output_dir, suffix=f's{index:03d}')
print(f"[线程 {index}] TTS 开始...")
tts_start = time.time()
self.tts.text_to_audio(sentence, audio_path)
tts_time = time.time() - tts_start
print(f"[线程 {index}] TTS 完成,耗时 {tts_time:.2f}A2F 开始...")
a2f_start = time.time()
csv_path, temp_dir = self.a2f.audio_to_csv(audio_path) # 接收临时目录路径
a2f_time = time.time() - a2f_start
print(f"[线程 {index}] A2F 完成,耗时 {a2f_time:.2f}秒,解析中...")
parse_start = time.time()
frames = list(self.parser.iter_csv_to_blend_shapes(csv_path))
parse_time = time.time() - parse_start
# 解析完成后清理临时目录
import shutil
try:
shutil.rmtree(temp_dir, ignore_errors=True)
print(f"[线程 {index}] 已清理临时目录: {temp_dir}")
except Exception as e:
print(f"[线程 {index}] 清理临时目录失败: {e}")
total_time = time.time() - start_time
print(f"[线程 {index}] 完成!生成了 {len(frames)} 帧 | 总耗时: {total_time:.2f}秒 (TTS: {tts_time:.2f}s, A2F: {a2f_time:.2f}s, 解析: {parse_time:.2f}s)")
return frames, audio_path, csv_path
def _text_to_blend_shapes_segmented(
self,
text: str,
output_dir: str = None,
split_punctuations: str = None,
max_sentence_length: int = None
):
frames = []
audio_paths = []
csv_paths = []
for message in self.iter_text_to_blend_shapes_stream(
text,
output_dir,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length
):
if message.get('type') == 'frame':
frames.append(message['frame'])
elif message.get('type') == 'error':
return {
'success': False,
'error': message.get('message', 'Unknown error')
}
elif message.get('type') == 'end':
audio_paths = message.get('audio_paths', [])
csv_paths = message.get('csv_paths', [])
return {
'success': True,
'frames': frames,
'audio_paths': audio_paths,
'csv_paths': csv_paths
}
def split_sentences(self, text: str, split_punctuations: str = None, max_sentence_length: int = None, first_sentence_split_size: int = None):
"""拆分句子,并对第一句进行特殊处理以加速首帧"""
if not text:
return []
normalized = re.sub(r'[\r\n]+', '', text.strip())
punctuations = split_punctuations or self.DEFAULT_SPLIT_PUNCTUATIONS
if punctuations:
escaped = re.escape(punctuations)
split_re = re.compile(rf'(?<=[{escaped}])')
chunks = split_re.split(normalized)
else:
chunks = [normalized]
sentences = [chunk.strip() for chunk in chunks if chunk.strip()]
# 记录哪些句子是拆分的(需要连续播放)
self.is_continuation = [False] * len(sentences)
# 可选:拆分第一句以加速首帧(并发处理)
if first_sentence_split_size and sentences:
first = sentences[0]
length = len(first)
parts = []
if length <= 8:
# 8字以下不拆分
parts = [first]
elif length <= 12:
# 8-12字分两部分
mid = length // 2
parts = [first[:mid], first[mid:]]
else:
# 12字以上前6字再6字剩下的
parts = [first[:6], first[6:12], first[12:]]
# 替换第一句为多个小句
sentences = parts + sentences[1:]
# 标记后续部分为连续播放
self.is_continuation = [False] + [True] * (len(parts) - 1) + [False] * (len(sentences) - len(parts))
print(f"[拆分优化] 第一句({length}字)拆分为{len(parts)}部分: {[len(p) for p in parts]} - 连续播放")
if not max_sentence_length or max_sentence_length <= 0:
return sentences
limited = []
for sentence in sentences:
if len(sentence) <= max_sentence_length:
limited.append(sentence)
continue
start = 0
while start < len(sentence):
limited.append(sentence[start:start + max_sentence_length])
start += max_sentence_length
return limited
def _prepare_output_paths(self, output_dir: str = None, suffix: str = None):
if output_dir is None:
output_dir = tempfile.gettempdir()
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
suffix_part = f'_{suffix}' if suffix else ''
audio_path = os.path.join(output_dir, f'tts_{timestamp}{suffix_part}.wav')
return output_dir, audio_path