319 lines
12 KiB
Python
319 lines
12 KiB
Python
import os
|
||
import re
|
||
import tempfile
|
||
import concurrent.futures
|
||
import queue
|
||
import threading
|
||
from datetime import datetime
|
||
from tts_service import TTSService
|
||
from edge_tts_service import EdgeTTSService
|
||
from a2f_service import A2FService
|
||
from blend_shape_parser import BlendShapeParser
|
||
|
||
class TextToBlendShapesService:
|
||
DEFAULT_SPLIT_PUNCTUATIONS = '。!?;!?;,,'
|
||
|
||
def __init__(self, lang='zh-CN', a2f_url="192.168.1.39:52000", tts_provider='edge-tts'):
|
||
"""
|
||
初始化服务
|
||
:param lang: 语言
|
||
:param a2f_url: A2F服务地址
|
||
:param tts_provider: TTS提供商 ('pyttsx3' 或 'edge-tts')
|
||
"""
|
||
# 根据选择初始化TTS服务
|
||
if tts_provider == 'edge-tts':
|
||
self.tts = EdgeTTSService(lang=lang)
|
||
else:
|
||
self.tts = TTSService(lang=lang)
|
||
|
||
self.a2f = A2FService(a2f_url=a2f_url)
|
||
self.parser = BlendShapeParser()
|
||
|
||
def text_to_blend_shapes(
|
||
self,
|
||
text: str,
|
||
output_dir: str = None,
|
||
segment: bool = False,
|
||
split_punctuations: str = None,
|
||
max_sentence_length: int = None
|
||
):
|
||
if segment:
|
||
return self._text_to_blend_shapes_segmented(
|
||
text,
|
||
output_dir,
|
||
split_punctuations=split_punctuations,
|
||
max_sentence_length=max_sentence_length
|
||
)
|
||
|
||
output_dir, audio_path = self._prepare_output_paths(output_dir)
|
||
|
||
self.tts.text_to_audio(text, audio_path)
|
||
csv_path = self.a2f.audio_to_csv(audio_path)
|
||
frames = self.parser.csv_to_blend_shapes(csv_path)
|
||
|
||
return {
|
||
'success': True,
|
||
'frames': frames,
|
||
'audio_path': audio_path,
|
||
'csv_path': csv_path
|
||
}
|
||
|
||
def iter_text_to_blend_shapes_stream(
|
||
self,
|
||
text: str,
|
||
output_dir: str = None,
|
||
split_punctuations: str = None,
|
||
max_sentence_length: int = None,
|
||
first_sentence_split_size: int = None
|
||
):
|
||
output_dir = output_dir or tempfile.gettempdir()
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
sentences = self.split_sentences(
|
||
text,
|
||
split_punctuations=split_punctuations,
|
||
max_sentence_length=max_sentence_length,
|
||
first_sentence_split_size=first_sentence_split_size
|
||
)
|
||
if not sentences:
|
||
yield {'type': 'error', 'message': '文本为空'}
|
||
return
|
||
|
||
# 测试:只处理第一句
|
||
sentences = sentences[:1]
|
||
print(f"[测试模式] 只处理第一句: {sentences[0]}")
|
||
|
||
yield {
|
||
'type': 'status',
|
||
'stage': 'split',
|
||
'sentences': len(sentences),
|
||
'sentence_texts': sentences, # 发送句子文本列表
|
||
'message': f'已拆分为 {len(sentences)} 个句子'
|
||
}
|
||
|
||
# 打印句子列表用于调试
|
||
print(f"[调试] 发送给前端的句子列表:")
|
||
for i, s in enumerate(sentences):
|
||
print(f" [{i}] {s}")
|
||
|
||
# 使用队列来收集处理完成的句子
|
||
result_queue = queue.Queue()
|
||
|
||
def process_and_queue(index, sentence):
|
||
"""处理句子并放入队列"""
|
||
try:
|
||
print(f"[工作线程 {index}] 开始处理: {sentence[:30]}...")
|
||
frames, audio_path, csv_path = self._process_sentence(sentence, output_dir, index)
|
||
result_queue.put((index, 'success', frames, None))
|
||
print(f"[工作线程 {index}] 完成!已生成 {len(frames)} 帧并加入队列")
|
||
except Exception as e:
|
||
print(f"[工作线程 {index}] 失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
result_queue.put((index, 'error', None, str(e)))
|
||
|
||
# 提交所有句子到线程池并发处理(增加并发数以加速)
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||
for index, sentence in enumerate(sentences):
|
||
executor.submit(process_and_queue, index, sentence)
|
||
|
||
# 按顺序从队列中取出结果并推送
|
||
completed = {}
|
||
next_index = 0
|
||
total_frames = 0
|
||
cumulative_time = 0.0 # 累计时间,用于连续句子
|
||
|
||
while next_index < len(sentences):
|
||
# 如果下一个句子还没完成,等待队列
|
||
if next_index not in completed:
|
||
yield {
|
||
'type': 'status',
|
||
'stage': 'processing',
|
||
'sentence_index': next_index,
|
||
'sentences': len(sentences),
|
||
'message': f'正在处理 {next_index + 1}/{len(sentences)}'
|
||
}
|
||
|
||
# 从队列中获取结果
|
||
while next_index not in completed:
|
||
try:
|
||
index, status, frames, error = result_queue.get(timeout=1)
|
||
completed[index] = (status, frames, error)
|
||
print(f"[主线程] 收到句子 {index} 的处理结果")
|
||
except queue.Empty:
|
||
continue
|
||
|
||
# 推送下一个句子的帧
|
||
status, frames, error = completed[next_index]
|
||
if status == 'error':
|
||
yield {'type': 'error', 'message': f'句子 {next_index} 处理失败: {error}'}
|
||
return
|
||
|
||
# 如果是连续句子,调整时间码使其无缝衔接
|
||
is_continuation = self.is_continuation[next_index] if next_index < len(self.is_continuation) else False
|
||
|
||
print(f"[主线程] 正在推送句子 {next_index} 的 {len(frames)} 帧 {'(连续)' if is_continuation else ''}")
|
||
print(f"[调试] 句子 {next_index} 对应文本: {sentences[next_index] if next_index < len(sentences) else 'N/A'}")
|
||
|
||
# 如果不是连续句子,重置累计时间
|
||
if not is_continuation and next_index > 0:
|
||
cumulative_time = 0.0
|
||
|
||
for frame in frames:
|
||
# 调整时间码:从累计时间开始
|
||
frame['timeCode'] = cumulative_time + frame['timeCode']
|
||
frame['sentenceIndex'] = next_index
|
||
total_frames += 1
|
||
yield {'type': 'frame', 'frame': frame}
|
||
|
||
# 更新累计时间为当前句子的最后一帧时间
|
||
if frames:
|
||
cumulative_time = frames[-1]['timeCode']
|
||
|
||
next_index += 1
|
||
|
||
print(f"[主线程] 流式传输完成,共 {total_frames} 帧")
|
||
yield {
|
||
'type': 'end',
|
||
'frames': total_frames
|
||
}
|
||
|
||
def _process_sentence(self, sentence, output_dir, index):
|
||
"""处理单个句子: TTS -> A2F -> 解析"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
print(f"[线程 {index}] 开始处理: {sentence[:30]}...")
|
||
print(f"[调试] 线程 {index} 实际处理的完整文本: [{sentence}] (长度: {len(sentence)}字)")
|
||
_, audio_path = self._prepare_output_paths(output_dir, suffix=f's{index:03d}')
|
||
|
||
print(f"[线程 {index}] TTS 开始...")
|
||
tts_start = time.time()
|
||
self.tts.text_to_audio(sentence, audio_path)
|
||
tts_time = time.time() - tts_start
|
||
print(f"[线程 {index}] TTS 完成,耗时 {tts_time:.2f}秒,A2F 开始...")
|
||
|
||
a2f_start = time.time()
|
||
csv_path, temp_dir = self.a2f.audio_to_csv(audio_path) # 接收临时目录路径
|
||
a2f_time = time.time() - a2f_start
|
||
print(f"[线程 {index}] A2F 完成,耗时 {a2f_time:.2f}秒,解析中...")
|
||
|
||
parse_start = time.time()
|
||
frames = list(self.parser.iter_csv_to_blend_shapes(csv_path))
|
||
parse_time = time.time() - parse_start
|
||
|
||
# 解析完成后清理临时目录
|
||
import shutil
|
||
try:
|
||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||
print(f"[线程 {index}] 已清理临时目录: {temp_dir}")
|
||
except Exception as e:
|
||
print(f"[线程 {index}] 清理临时目录失败: {e}")
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"[线程 {index}] 完成!生成了 {len(frames)} 帧 | 总耗时: {total_time:.2f}秒 (TTS: {tts_time:.2f}s, A2F: {a2f_time:.2f}s, 解析: {parse_time:.2f}s)")
|
||
|
||
return frames, audio_path, csv_path
|
||
|
||
def _text_to_blend_shapes_segmented(
|
||
self,
|
||
text: str,
|
||
output_dir: str = None,
|
||
split_punctuations: str = None,
|
||
max_sentence_length: int = None
|
||
):
|
||
frames = []
|
||
audio_paths = []
|
||
csv_paths = []
|
||
|
||
for message in self.iter_text_to_blend_shapes_stream(
|
||
text,
|
||
output_dir,
|
||
split_punctuations=split_punctuations,
|
||
max_sentence_length=max_sentence_length
|
||
):
|
||
if message.get('type') == 'frame':
|
||
frames.append(message['frame'])
|
||
elif message.get('type') == 'error':
|
||
return {
|
||
'success': False,
|
||
'error': message.get('message', 'Unknown error')
|
||
}
|
||
elif message.get('type') == 'end':
|
||
audio_paths = message.get('audio_paths', [])
|
||
csv_paths = message.get('csv_paths', [])
|
||
|
||
return {
|
||
'success': True,
|
||
'frames': frames,
|
||
'audio_paths': audio_paths,
|
||
'csv_paths': csv_paths
|
||
}
|
||
|
||
def split_sentences(self, text: str, split_punctuations: str = None, max_sentence_length: int = None, first_sentence_split_size: int = None):
|
||
"""拆分句子,并对第一句进行特殊处理以加速首帧"""
|
||
if not text:
|
||
return []
|
||
|
||
normalized = re.sub(r'[\r\n]+', '。', text.strip())
|
||
punctuations = split_punctuations or self.DEFAULT_SPLIT_PUNCTUATIONS
|
||
if punctuations:
|
||
escaped = re.escape(punctuations)
|
||
split_re = re.compile(rf'(?<=[{escaped}])')
|
||
chunks = split_re.split(normalized)
|
||
else:
|
||
chunks = [normalized]
|
||
|
||
sentences = [chunk.strip() for chunk in chunks if chunk.strip()]
|
||
|
||
# 记录哪些句子是拆分的(需要连续播放)
|
||
self.is_continuation = [False] * len(sentences)
|
||
|
||
# 可选:拆分第一句以加速首帧(并发处理)
|
||
if first_sentence_split_size and sentences:
|
||
first = sentences[0]
|
||
length = len(first)
|
||
parts = []
|
||
|
||
if length <= 8:
|
||
# 8字以下不拆分
|
||
parts = [first]
|
||
elif length <= 12:
|
||
# 8-12字分两部分
|
||
mid = length // 2
|
||
parts = [first[:mid], first[mid:]]
|
||
else:
|
||
# 12字以上:前6字,再6字,剩下的
|
||
parts = [first[:6], first[6:12], first[12:]]
|
||
|
||
# 替换第一句为多个小句
|
||
sentences = parts + sentences[1:]
|
||
# 标记后续部分为连续播放
|
||
self.is_continuation = [False] + [True] * (len(parts) - 1) + [False] * (len(sentences) - len(parts))
|
||
print(f"[拆分优化] 第一句({length}字)拆分为{len(parts)}部分: {[len(p) for p in parts]} - 连续播放")
|
||
|
||
if not max_sentence_length or max_sentence_length <= 0:
|
||
return sentences
|
||
|
||
limited = []
|
||
for sentence in sentences:
|
||
if len(sentence) <= max_sentence_length:
|
||
limited.append(sentence)
|
||
continue
|
||
|
||
start = 0
|
||
while start < len(sentence):
|
||
limited.append(sentence[start:start + max_sentence_length])
|
||
start += max_sentence_length
|
||
return limited
|
||
|
||
def _prepare_output_paths(self, output_dir: str = None, suffix: str = None):
|
||
if output_dir is None:
|
||
output_dir = tempfile.gettempdir()
|
||
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
|
||
suffix_part = f'_{suffix}' if suffix else ''
|
||
audio_path = os.path.join(output_dir, f'tts_{timestamp}{suffix_part}.wav')
|
||
return output_dir, audio_path
|