流式传输

This commit is contained in:
yinsx
2025-12-25 15:36:35 +08:00
parent e56f47076c
commit 14bfdcbf51
19 changed files with 1191 additions and 65 deletions

View File

@ -55,6 +55,23 @@ python api.py
}
```
### POST /text-to-blendshapes/stream
**说明:** 使用 NDJSON 流式返回,便于边收边播放。
**响应:** 每行一个 JSON 对象,`type` 字段取值如下:
- `status` - 阶段提示
- `frame` - 单帧数据
- `end` - 完成信息
- `error` - 错误信息
**示例:**
```json
{"type":"status","stage":"tts","message":"Generating audio"}
{"type":"frame","frame":{"timeCode":0.0,"blendShapes":{"JawOpen":0.1}}}
{"type":"end","frames":900,"audio_path":"...","csv_path":"..."}
```
## 文件说明
- `tts_service.py` - 文字转音频服务

Binary file not shown.

View File

@ -1,34 +1,73 @@
from flask import Flask, request, jsonify
from flask_cors import CORS
import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from text_to_blendshapes_service import TextToBlendShapesService
app = Flask(__name__)
CORS(app)
app = FastAPI()
@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'ok'})
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.route('/text-to-blendshapes', methods=['POST'])
def text_to_blendshapes():
class TextRequest(BaseModel):
text: str
language: str = 'zh-CN'
segment: bool = False
split_punctuations: str = None
max_sentence_length: int = None
first_sentence_split_size: int = None
@app.get('/health')
async def health():
return {'status': 'ok'}
@app.post('/text-to-blendshapes')
async def text_to_blendshapes(request: TextRequest):
try:
data = request.get_json()
if not data or 'text' not in data:
return jsonify({'success': False, 'error': 'Missing text'}), 400
text = data['text']
language = data.get('language', 'zh-CN')
service = TextToBlendShapesService(lang=language)
result = service.text_to_blend_shapes(text)
return jsonify(result)
service = TextToBlendShapesService(lang=request.language)
result = service.text_to_blend_shapes(
request.text,
segment=request.segment,
split_punctuations=request.split_punctuations,
max_sentence_length=request.max_sentence_length
)
return result
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'success': False, 'error': str(e)}), 500
return {'success': False, 'error': str(e)}
@app.post('/text-to-blendshapes/stream')
async def text_to_blendshapes_stream(request: TextRequest):
async def generate():
service = TextToBlendShapesService(lang=request.language)
try:
for message in service.iter_text_to_blend_shapes_stream(
request.text,
split_punctuations=request.split_punctuations,
max_sentence_length=request.max_sentence_length,
first_sentence_split_size=request.first_sentence_split_size
):
yield json.dumps(message) + "\n"
except Exception as e:
yield json.dumps({'type': 'error', 'message': str(e)}) + "\n"
return StreamingResponse(
generate(),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
)
if __name__ == '__main__':
import uvicorn
print("Text to BlendShapes API: http://localhost:5001")
app.run(host='0.0.0.0', port=5001, debug=True)
uvicorn.run(app, host='0.0.0.0', port=5001)

View File

@ -17,9 +17,11 @@ class BlendShapeParser:
@staticmethod
def csv_to_blend_shapes(csv_path: str):
frames = []
return list(BlendShapeParser.iter_csv_to_blend_shapes(csv_path))
@staticmethod
def iter_csv_to_blend_shapes(csv_path: str):
with open(csv_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
frame = {'timeCode': float(row['timeCode']), 'blendShapes': {}}
@ -27,5 +29,4 @@ class BlendShapeParser:
col_name = f'blendShapes.{key}'
if col_name in row:
frame['blendShapes'][key] = float(row[col_name])
frames.append(frame)
return frames
yield frame

282
services/a2f_api/test.html Normal file
View File

@ -0,0 +1,282 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Text to BlendShapes 测试</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: white;
border-radius: 12px;
padding: 30px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
}
h1 {
color: #333;
margin-bottom: 10px;
font-size: 28px;
}
.subtitle {
color: #666;
margin-bottom: 30px;
font-size: 14px;
}
.input-group {
margin-bottom: 20px;
}
label {
display: block;
margin-bottom: 8px;
color: #555;
font-weight: 500;
}
input, textarea, select {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
transition: border-color 0.3s;
}
input:focus, textarea:focus, select:focus {
outline: none;
border-color: #667eea;
}
textarea {
resize: vertical;
min-height: 100px;
font-family: inherit;
}
button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 6px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
}
button:active {
transform: translateY(0);
}
button:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.loading {
display: none;
text-align: center;
margin: 20px 0;
color: #667eea;
}
.loading.show {
display: block;
}
.result {
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border-radius: 6px;
display: none;
}
.result.show {
display: block;
}
.result h3 {
color: #333;
margin-bottom: 15px;
}
.stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
gap: 15px;
margin-bottom: 20px;
}
.stat-card {
background: white;
padding: 15px;
border-radius: 6px;
text-align: center;
}
.stat-value {
font-size: 24px;
font-weight: bold;
color: #667eea;
}
.stat-label {
font-size: 12px;
color: #666;
margin-top: 5px;
}
.frames-preview {
max-height: 300px;
overflow-y: auto;
background: white;
padding: 15px;
border-radius: 6px;
font-family: monospace;
font-size: 12px;
}
.error {
background: #fee;
color: #c33;
padding: 15px;
border-radius: 6px;
margin-top: 20px;
display: none;
}
.error.show {
display: block;
}
</style>
</head>
<body>
<div class="container">
<h1>Text to BlendShapes</h1>
<p class="subtitle">将文字转换为 52 个 ARKit 形态键</p>
<div class="input-group">
<label for="text">输入文字</label>
<textarea id="text" placeholder="请输入要转换的文字...">你好世界,这是一个测试。</textarea>
</div>
<div class="input-group">
<label for="language">语言</label>
<select id="language">
<option value="zh-CN">中文</option>
<option value="en">English</option>
<option value="ja">日本語</option>
<option value="ko">한국어</option>
</select>
</div>
<div class="input-group">
<label for="apiUrl">API 地址</label>
<input type="text" id="apiUrl" value="http://localhost:5001/text-to-blendshapes">
</div>
<button id="submitBtn" onclick="convert()">转换</button>
<div class="loading" id="loading">
<p>⏳ 处理中,请稍候...</p>
</div>
<div class="error" id="error"></div>
<div class="result" id="result">
<h3>转换结果</h3>
<div class="stats">
<div class="stat-card">
<div class="stat-value" id="frameCount">0</div>
<div class="stat-label">总帧数</div>
</div>
<div class="stat-card">
<div class="stat-value" id="duration">0s</div>
<div class="stat-label">时长</div>
</div>
<div class="stat-card">
<div class="stat-value">52</div>
<div class="stat-label">形态键数量</div>
</div>
</div>
<h4 style="margin-bottom: 10px;">帧数据预览</h4>
<div class="frames-preview" id="framesPreview"></div>
</div>
</div>
<script>
async function convert() {
const text = document.getElementById('text').value.trim();
const language = document.getElementById('language').value;
const apiUrl = document.getElementById('apiUrl').value;
if (!text) {
showError('请输入文字');
return;
}
const submitBtn = document.getElementById('submitBtn');
const loading = document.getElementById('loading');
const result = document.getElementById('result');
const error = document.getElementById('error');
submitBtn.disabled = true;
loading.classList.add('show');
result.classList.remove('show');
error.classList.remove('show');
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
text: text,
language: language
})
});
const data = await response.json();
if (!response.ok || !data.success) {
throw new Error(data.error || '请求失败');
}
displayResult(data);
} catch (err) {
showError(err.message);
} finally {
submitBtn.disabled = false;
loading.classList.remove('show');
}
}
function displayResult(data) {
const result = document.getElementById('result');
const frameCount = document.getElementById('frameCount');
const duration = document.getElementById('duration');
const framesPreview = document.getElementById('framesPreview');
const frames = data.frames || [];
frameCount.textContent = frames.length;
if (frames.length > 0) {
const lastFrame = frames[frames.length - 1];
duration.textContent = lastFrame.timeCode.toFixed(2) + 's';
}
framesPreview.textContent = JSON.stringify(frames.slice(0, 3), null, 2);
if (frames.length > 3) {
framesPreview.textContent += '\n\n... 共 ' + frames.length + ' 帧';
}
result.classList.add('show');
}
function showError(message) {
const error = document.getElementById('error');
error.textContent = '错误: ' + message;
error.classList.add('show');
}
</script>
</body>
</html>

View File

@ -1,23 +1,39 @@
import os
import re
import tempfile
import concurrent.futures
import queue
import threading
from datetime import datetime
from tts_service import TTSService
from a2f_service import A2FService
from blend_shape_parser import BlendShapeParser
class TextToBlendShapesService:
DEFAULT_SPLIT_PUNCTUATIONS = '。!?;!?;,'
def __init__(self, lang='zh-CN', a2f_url="192.168.1.39:52000"):
self.tts = TTSService(lang=lang)
self.a2f = A2FService(a2f_url=a2f_url)
self.parser = BlendShapeParser()
def text_to_blend_shapes(self, text: str, output_dir: str = None):
if output_dir is None:
output_dir = tempfile.gettempdir()
def text_to_blend_shapes(
self,
text: str,
output_dir: str = None,
segment: bool = False,
split_punctuations: str = None,
max_sentence_length: int = None
):
if segment:
return self._text_to_blend_shapes_segmented(
text,
output_dir,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length
)
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
audio_path = os.path.join(output_dir, f'tts_{timestamp}.wav')
output_dir, audio_path = self._prepare_output_paths(output_dir)
self.tts.text_to_audio(text, audio_path)
csv_path = self.a2f.audio_to_csv(audio_path)
@ -29,3 +45,235 @@ class TextToBlendShapesService:
'audio_path': audio_path,
'csv_path': csv_path
}
def iter_text_to_blend_shapes_stream(
self,
text: str,
output_dir: str = None,
split_punctuations: str = None,
max_sentence_length: int = None,
first_sentence_split_size: int = None
):
output_dir = output_dir or tempfile.gettempdir()
os.makedirs(output_dir, exist_ok=True)
sentences = self.split_sentences(
text,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length,
first_sentence_split_size=first_sentence_split_size
)
if not sentences:
yield {'type': 'error', 'message': '文本为空'}
return
yield {'type': 'status', 'stage': 'split', 'sentences': len(sentences), 'message': f'已拆分为 {len(sentences)} 个句子'}
# 使用队列来收集处理完成的句子
result_queue = queue.Queue()
def process_and_queue(index, sentence):
"""处理句子并放入队列"""
try:
print(f"[工作线程 {index}] 开始处理: {sentence[:30]}...")
frames, audio_path, csv_path = self._process_sentence(sentence, output_dir, index)
result_queue.put((index, 'success', frames, None))
print(f"[工作线程 {index}] 完成!已生成 {len(frames)} 帧并加入队列")
except Exception as e:
print(f"[工作线程 {index}] 失败: {str(e)}")
import traceback
traceback.print_exc()
result_queue.put((index, 'error', None, str(e)))
# 提交所有句子到线程池并发处理(增加并发数以加速)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for index, sentence in enumerate(sentences):
executor.submit(process_and_queue, index, sentence)
# 按顺序从队列中取出结果并推送
completed = {}
next_index = 0
total_frames = 0
cumulative_time = 0.0 # 累计时间,用于连续句子
while next_index < len(sentences):
# 如果下一个句子还没完成,等待队列
if next_index not in completed:
yield {
'type': 'status',
'stage': 'processing',
'sentence_index': next_index,
'sentences': len(sentences),
'message': f'正在处理 {next_index + 1}/{len(sentences)}'
}
# 从队列中获取结果
while next_index not in completed:
try:
index, status, frames, error = result_queue.get(timeout=1)
completed[index] = (status, frames, error)
print(f"[主线程] 收到句子 {index} 的处理结果")
except queue.Empty:
continue
# 推送下一个句子的帧
status, frames, error = completed[next_index]
if status == 'error':
yield {'type': 'error', 'message': f'句子 {next_index} 处理失败: {error}'}
return
# 如果是连续句子,调整时间码使其无缝衔接
is_continuation = self.is_continuation[next_index] if next_index < len(self.is_continuation) else False
print(f"[主线程] 正在推送句子 {next_index}{len(frames)}{'(连续)' if is_continuation else ''}")
# 如果不是连续句子,重置累计时间
if not is_continuation and next_index > 0:
cumulative_time = 0.0
for frame in frames:
# 调整时间码:从累计时间开始
frame['timeCode'] = cumulative_time + frame['timeCode']
frame['sentenceIndex'] = next_index
frame['isContinuation'] = is_continuation
total_frames += 1
yield {'type': 'frame', 'frame': frame}
# 更新累计时间为当前句子的最后一帧时间
if frames:
cumulative_time = frames[-1]['timeCode']
next_index += 1
print(f"[主线程] 流式传输完成,共 {total_frames}")
yield {
'type': 'end',
'frames': total_frames
}
def _process_sentence(self, sentence, output_dir, index):
"""处理单个句子: TTS -> A2F -> 解析"""
import time
start_time = time.time()
print(f"[线程 {index}] 开始处理: {sentence[:30]}...")
_, audio_path = self._prepare_output_paths(output_dir, suffix=f's{index:03d}')
print(f"[线程 {index}] TTS 开始...")
tts_start = time.time()
self.tts.text_to_audio(sentence, audio_path)
tts_time = time.time() - tts_start
print(f"[线程 {index}] TTS 完成,耗时 {tts_time:.2f}A2F 开始...")
a2f_start = time.time()
csv_path = self.a2f.audio_to_csv(audio_path)
a2f_time = time.time() - a2f_start
print(f"[线程 {index}] A2F 完成,耗时 {a2f_time:.2f}秒,解析中...")
parse_start = time.time()
frames = list(self.parser.iter_csv_to_blend_shapes(csv_path))
parse_time = time.time() - parse_start
total_time = time.time() - start_time
print(f"[线程 {index}] 完成!生成了 {len(frames)} 帧 | 总耗时: {total_time:.2f}秒 (TTS: {tts_time:.2f}s, A2F: {a2f_time:.2f}s, 解析: {parse_time:.2f}s)")
return frames, audio_path, csv_path
def _text_to_blend_shapes_segmented(
self,
text: str,
output_dir: str = None,
split_punctuations: str = None,
max_sentence_length: int = None
):
frames = []
audio_paths = []
csv_paths = []
for message in self.iter_text_to_blend_shapes_stream(
text,
output_dir,
split_punctuations=split_punctuations,
max_sentence_length=max_sentence_length
):
if message.get('type') == 'frame':
frames.append(message['frame'])
elif message.get('type') == 'error':
return {
'success': False,
'error': message.get('message', 'Unknown error')
}
elif message.get('type') == 'end':
audio_paths = message.get('audio_paths', [])
csv_paths = message.get('csv_paths', [])
return {
'success': True,
'frames': frames,
'audio_paths': audio_paths,
'csv_paths': csv_paths
}
def split_sentences(self, text: str, split_punctuations: str = None, max_sentence_length: int = None, first_sentence_split_size: int = None):
"""拆分句子,并对第一句进行特殊处理以加速首帧"""
if not text:
return []
normalized = re.sub(r'[\r\n]+', '', text.strip())
punctuations = split_punctuations or self.DEFAULT_SPLIT_PUNCTUATIONS
if punctuations:
escaped = re.escape(punctuations)
split_re = re.compile(rf'(?<=[{escaped}])')
chunks = split_re.split(normalized)
else:
chunks = [normalized]
sentences = [chunk.strip() for chunk in chunks if chunk.strip()]
# 记录哪些句子是拆分的(需要连续播放)
self.is_continuation = [False] * len(sentences)
# 可选:拆分第一句以加速首帧(并发处理)
if first_sentence_split_size and sentences:
first = sentences[0]
length = len(first)
parts = []
if length <= 12:
# 12字以内分两部分
mid = length // 2
parts = [first[:mid], first[mid:]]
else:
# 12字之后前6字再6字剩下的
parts = [first[:6], first[6:12], first[12:]]
# 替换第一句为多个小句
sentences = parts + sentences[1:]
# 标记后续部分为连续播放
self.is_continuation = [False] + [True] * (len(parts) - 1) + [False] * (len(sentences) - len(parts))
print(f"[拆分优化] 第一句({length}字)拆分为{len(parts)}部分: {[len(p) for p in parts]} - 连续播放")
if not max_sentence_length or max_sentence_length <= 0:
return sentences
limited = []
for sentence in sentences:
if len(sentence) <= max_sentence_length:
limited.append(sentence)
continue
start = 0
while start < len(sentence):
limited.append(sentence[start:start + max_sentence_length])
start += max_sentence_length
return limited
def _prepare_output_paths(self, output_dir: str = None, suffix: str = None):
if output_dir is None:
output_dir = tempfile.gettempdir()
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
suffix_part = f'_{suffix}' if suffix else ''
audio_path = os.path.join(output_dir, f'tts_{timestamp}{suffix_part}.wav')
return output_dir, audio_path

View File

@ -1,20 +1,35 @@
import os
import threading
import pyttsx3
class TTSService:
_lock = threading.Lock()
def __init__(self, lang='zh-CN'):
self.lang = lang
self.engine = pyttsx3.init()
if lang == 'zh-CN':
voices = self.engine.getProperty('voices')
for voice in voices:
if 'chinese' in voice.name.lower() or 'zh' in voice.id.lower():
self.engine.setProperty('voice', voice.id)
break
def text_to_audio(self, text: str, output_path: str) -> str:
"""将文本转换为WAV音频文件使用pyttsx3"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.engine.save_to_file(text, output_path)
self.engine.runAndWait()
return output_path
with self._lock:
engine = pyttsx3.init()
try:
# 设置中文语音
voices = engine.getProperty('voices')
for voice in voices:
if 'chinese' in voice.name.lower() or 'zh' in voice.id.lower():
engine.setProperty('voice', voice.id)
break
# 设置语速
engine.setProperty('rate', 150)
# 保存为WAV
engine.save_to_file(text, output_path)
engine.runAndWait()
return output_path
finally:
engine.stop()
del engine