改进建议¶
基于对 DocuSnap-Backend 代码库的质量评估,本页面提供具体的改进建议,帮助提升代码质量、可维护性和可扩展性。
代码结构改进¶
1. 模块化重构¶
当前代码主要集中在 app.py
文件中,建议将代码拆分为多个模块文件:
DocuSnap-Backend/
├── app.py # 主应用入口,只包含 Flask 应用配置和路由定义
├── modules/
│ ├── __init__.py
│ ├── ocr.py # OCR 处理相关功能
│ ├── llm.py # LLM 处理相关功能
│ ├── security.py # 安全加密相关功能
│ ├── cache.py # 缓存和数据持久化相关功能
│ └── tasks.py # 任务处理相关功能
├── utils/
│ ├── __init__.py
│ ├── db.py # 数据库操作工具
│ ├── image.py # 图像处理工具
│ └── validation.py # 请求验证工具
├── config/
│ ├── __init__.py
│ └── settings.py # 配置参数
├── prompts/
│ ├── __init__.py
│ ├── document.py # 文档处理提示
│ ├── form.py # 表单处理提示
│ └── form_filling.py # 表单填充提示
└── api/
├── __init__.py
├── document.py # 文档处理 API
├── form.py # 表单处理 API
└── task.py # 任务状态 API
这种结构将代码按功能模块划分,提高了代码的组织性和可维护性。
2. 引入层次结构¶
建议引入更清晰的层次结构,将代码分为以下几层:
- API 层:处理 HTTP 请求和响应
- 服务层:实现业务逻辑
- 数据访问层:处理数据存储和检索
- 外部服务层:与外部服务(OCR、LLM)交互
示例实现:
# API 层 (api/document.py)
@app.route('/api/process_document', methods=['POST'])
def process_document():
# 请求解析和验证
data = request.get_json()
decrypted_data, aes_key = security_service.decrypt_request(data)
# 调用服务层
task_id = document_service.process_document(decrypted_data)
# 响应生成
response = security_service.encrypt_response({'task_id': task_id}, aes_key)
return jsonify(response), 202
# 服务层 (modules/document_service.py)
def process_document(data):
# 业务逻辑
cache_key = generate_cache_key(data)
cached_result = cache_repository.get_cache(cache_key)
if cached_result:
return cached_result
task_id = str(uuid.uuid4())
task = create_task(task_id, 'document', data, cache_key)
task_queue.put(task)
return task_id
# 数据访问层 (modules/cache_repository.py)
def get_cache(cache_key):
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT result FROM cache WHERE key = ? AND expires_at > ?",
(cache_key, int(time.time()))
)
result = cursor.fetchone()
conn.close()
if result:
return json.loads(result[0])
else:
return None
# 外部服务层 (modules/ocr_service.py)
def call_ocr_service(image_data):
try:
files = {'image': ('image.png', image_data, 'image/png')}
response = requests.post(OCR_SERVICE_URL, files=files, timeout=OCR_TIMEOUT)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"OCR 服务返回错误: {response.status_code}")
except requests.RequestException as e:
raise Exception(f"OCR 服务调用失败: {str(e)}")
这种层次结构提高了代码的组织性和可维护性,同时也便于单元测试和功能扩展。
代码质量改进¶
1. 引入面向对象编程¶
当前代码主要使用函数式编程风格,建议引入更多的面向对象编程,使用类和对象组织代码:
# 任务处理器类
class TaskProcessor:
def __init__(self, db_connection, ocr_service, llm_service):
self.db_connection = db_connection
self.ocr_service = ocr_service
self.llm_service = llm_service
self.task_queue = Queue()
self.workers = []
def start_workers(self, num_workers=4):
for _ in range(num_workers):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self.workers.append(worker)
def add_task(self, task):
self.task_queue.put(task)
return task['id']
def _worker_loop(self):
while True:
task = self.task_queue.get()
try:
self._process_task(task)
except Exception as e:
self._update_task_status(task['id'], 'error', str(e))
finally:
self.task_queue.task_done()
def _process_task(self, task):
# 根据任务类型选择处理策略
if task['type'] == 'document':
self._process_document_task(task)
elif task['type'] == 'form':
self._process_form_task(task)
elif task['type'] == 'form_filling':
self._process_form_filling_task(task)
def _process_document_task(self, task):
# 文档处理逻辑
pass
def _update_task_status(self, task_id, status, result=None):
# 更新任务状态
pass
这种面向对象的设计提高了代码的组织性和可维护性,同时也便于依赖注入和单元测试。
2. 引入依赖注入¶
当前代码中的依赖关系是硬编码的,建议引入依赖注入,提高代码的灵活性和可测试性:
# 依赖注入容器
class Container:
def __init__(self):
self.services = {}
def register(self, name, instance):
self.services[name] = instance
def get(self, name):
if name not in self.services:
raise Exception(f"Service '{name}' not registered")
return self.services[name]
# 应用初始化
def init_app():
container = Container()
# 注册服务
db_connection = get_db_connection()
container.register('db_connection', db_connection)
ocr_service = OCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
container.register('ocr_service', ocr_service)
llm_service = LLMService(LLM_API_KEY, LLM_MODEL)
container.register('llm_service', llm_service)
task_processor = TaskProcessor(
db_connection=container.get('db_connection'),
ocr_service=container.get('ocr_service'),
llm_service=container.get('llm_service')
)
container.register('task_processor', task_processor)
return container
# 使用容器
container = init_app()
task_processor = container.get('task_processor')
task_processor.start_workers()
这种依赖注入的设计提高了代码的灵活性和可测试性,同时也便于替换底层实现。
3. 添加单元测试¶
当前代码缺乏单元测试,建议添加单元测试,提高代码质量和可靠性:
# tests/test_ocr_service.py
import unittest
from unittest.mock import patch, Mock
from modules.ocr_service import OCRService
class TestOCRService(unittest.TestCase):
def setUp(self):
self.ocr_service = OCRService('http://example.com/ocr', 10)
@patch('modules.ocr_service.requests.post')
def test_call_ocr_service_success(self, mock_post):
# 模拟成功响应
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {'text': 'Hello World'}
mock_post.return_value = mock_response
# 调用服务
result = self.ocr_service.call_ocr_service(b'image_data')
# 验证结果
self.assertEqual(result, {'text': 'Hello World'})
mock_post.assert_called_once()
@patch('modules.ocr_service.requests.post')
def test_call_ocr_service_error(self, mock_post):
# 模拟错误响应
mock_response = Mock()
mock_response.status_code = 500
mock_post.return_value = mock_response
# 验证异常
with self.assertRaises(Exception) as context:
self.ocr_service.call_ocr_service(b'image_data')
self.assertIn('OCR 服务返回错误: 500', str(context.exception))
mock_post.assert_called_once()
# 运行测试
if __name__ == '__main__':
unittest.main()
单元测试可以验证代码的正确性,发现潜在的问题,并支持重构和优化。
可扩展性改进¶
1. 引入抽象接口¶
当前代码直接使用具体实现,建议引入抽象接口,提高系统的灵活性和可扩展性:
# 抽象接口
class OCRService:
def process_image(self, image_data):
"""处理图像并返回文本结果"""
raise NotImplementedError("子类必须实现此方法")
# 具体实现
class CnOCRService(OCRService):
def __init__(self, service_url, timeout):
self.service_url = service_url
self.timeout = timeout
def process_image(self, image_data):
"""使用 CnOCR 服务处理图像"""
try:
files = {'image': ('image.png', image_data, 'image/png')}
response = requests.post(self.service_url, files=files, timeout=self.timeout)
if response.status_code == 200:
ocr_result = response.json()
return self._extract_text(ocr_result)
else:
raise Exception(f"OCR 服务返回错误: {response.status_code}")
except requests.RequestException as e:
raise Exception(f"OCR 服务调用失败: {str(e)}")
def _extract_text(self, ocr_result):
"""从 OCR 结果中提取文本"""
if 'text' not in ocr_result:
raise Exception("OCR 结果格式错误")
return ocr_result['text'].strip()
# 使用示例
ocr_service = CnOCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
text = ocr_service.process_image(image_data)
这种抽象接口的设计提高了系统的灵活性和可扩展性,可以轻松替换底层实现,如更换 OCR 服务提供商。
2. 引入插件机制¶
当前系统缺乏插件机制,建议引入插件机制,支持动态扩展功能:
# 插件管理器
class PluginManager:
def __init__(self):
self.plugins = {}
def register_plugin(self, name, plugin):
"""注册插件"""
self.plugins[name] = plugin
def get_plugin(self, name):
"""获取插件"""
if name not in self.plugins:
raise Exception(f"Plugin '{name}' not registered")
return self.plugins[name]
def list_plugins(self):
"""列出所有插件"""
return list(self.plugins.keys())
# 插件接口
class ProcessorPlugin:
def process(self, task):
"""处理任务"""
raise NotImplementedError("子类必须实现此方法")
# 具体插件实现
class DocumentProcessor(ProcessorPlugin):
def __init__(self, ocr_service, llm_service):
self.ocr_service = ocr_service
self.llm_service = llm_service
def process(self, task):
"""处理文档任务"""
# 文档处理逻辑
pass
# 使用示例
plugin_manager = PluginManager()
plugin_manager.register_plugin('document', DocumentProcessor(ocr_service, llm_service))
plugin_manager.register_plugin('form', FormProcessor(ocr_service, llm_service))
# 处理任务
processor = plugin_manager.get_plugin(task['type'])
result = processor.process(task)
这种插件机制的设计提高了系统的可扩展性,可以轻松添加新的功能模块,如新的任务类型。
安全性改进¶
1. 改进密钥管理¶
当前系统的密钥管理相对简单,建议改进密钥管理机制:
# 密钥管理器
class KeyManager:
def __init__(self, key_store_path):
self.key_store_path = key_store_path
self.keys = {}
self.load_keys()
def load_keys(self):
"""加载密钥"""
try:
with open(self.key_store_path, 'rb') as f:
encrypted_keys = f.read()
# 使用主密钥解密密钥存储
master_key = os.environ.get('MASTER_KEY')
if not master_key:
raise Exception("Missing MASTER_KEY environment variable")
decrypted_keys = self._decrypt_with_master_key(encrypted_keys, master_key)
self.keys = json.loads(decrypted_keys)
except Exception as e:
raise Exception(f"Failed to load keys: {str(e)}")
def get_private_key(self):
"""获取 RSA 私钥"""
if 'private_key' not in self.keys:
raise Exception("Private key not found")
private_key_data = self.keys['private_key']
return RSA.import_key(private_key_data)
def rotate_keys(self):
"""轮换密钥"""
# 生成新的密钥对
key = RSA.generate(2048)
private_key = key.export_key()
public_key = key.publickey().export_key()
# 更新密钥存储
self.keys['private_key'] = private_key.decode('utf-8')
self.keys['public_key'] = public_key.decode('utf-8')
# 保存密钥
self._save_keys()
return {
'private_key': private_key,
'public_key': public_key
}
def _save_keys(self):
"""保存密钥"""
# 使用主密钥加密密钥存储
master_key = os.environ.get('MASTER_KEY')
if not master_key:
raise Exception("Missing MASTER_KEY environment variable")
keys_json = json.dumps(self.keys)
encrypted_keys = self._encrypt_with_master_key(keys_json, master_key)
with open(self.key_store_path, 'wb') as f:
f.write(encrypted_keys)
def _encrypt_with_master_key(self, data, master_key):
"""使用主密钥加密数据"""
# 加密实现
pass
def _decrypt_with_master_key(self, encrypted_data, master_key):
"""使用主密钥解密数据"""
# 解密实现
pass
这种密钥管理的设计提高了系统的安全性,支持密钥轮换和安全存储。
2. 添加访问控制¶
当前系统缺乏细粒度的访问控制,建议添加访问控制机制:
# 访问控制装饰器
def require_auth(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 获取认证信息
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({"error": "Missing Authorization header"}), 401
try:
# 验证认证信息
token = auth_header.split(' ')[1]
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])
# 检查权限
if 'permissions' not in payload:
return jsonify({"error": "Invalid token"}), 401
# 获取所需权限
required_permission = getattr(func, '_required_permission', None)
if required_permission and required_permission not in payload['permissions']:
return jsonify({"error": "Insufficient permissions"}), 403
# 将用户信息添加到请求上下文
g.user = payload
return func(*args, **kwargs)
except jwt.InvalidTokenError:
return jsonify({"error": "Invalid token"}), 401
return wrapper
# 权限装饰器
def require_permission(permission):
def decorator(func):
func._required_permission = permission
return func
return decorator
# 使用示例
@app.route('/api/process_document', methods=['POST'])
@require_auth
@require_permission('document:write')
def process_document():
# 处理请求
pass
这种访问控制的设计提高了系统的安全性,支持用户认证和授权。
性能优化¶
1. 优化数据库操作¶
当前系统的数据库操作相对简单,建议优化数据库操作:
# 数据库连接池
class DBConnectionPool:
def __init__(self, db_path, max_connections=10):
self.db_path = db_path
self.max_connections = max_connections
self.connections = Queue(maxsize=max_connections)
self.active_connections = 0
# 初始化连接池
for _ in range(max_connections):
self.connections.put(self._create_connection())
def _create_connection(self):
"""创建数据库连接"""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def get_connection(self):
"""获取数据库连接"""
if not self.connections.empty():
return self.connections.get()
if self.active_connections < self.max_connections:
self.active_connections += 1
return self._create_connection()
# 等待可用连接
return self.connections.get(block=True)
def release_connection(self, conn):
"""释放数据库连接"""
self.connections.put(conn)
def close_all(self):
"""关闭所有连接"""
while not self.connections.empty():
conn = self.connections.get()
conn.close()
# 数据库操作上下文管理器
class DBConnection:
def __init__(self, pool):
self.pool = pool
self.conn = None
def __enter__(self):
self.conn = self.pool.get_connection()
return self.conn
def __exit__(self, exc_type, exc_val, exc_tb):
self.pool.release_connection(self.conn)
# 使用示例
db_pool = DBConnectionPool(DATABASE_PATH)
def get_task_result(task_id):
with DBConnection(db_pool) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT status, result FROM tasks WHERE id = ?",
(task_id,)
)
return cursor.fetchone()
这种数据库连接池的设计提高了系统的性能,减少了数据库连接的开销。
2. 优化缓存策略¶
当前系统的缓存策略相对简单,建议优化缓存策略:
# 缓存管理器
class CacheManager:
def __init__(self, db_connection, max_size=1000, default_ttl=86400):
self.db_connection = db_connection
self.max_size = max_size
self.default_ttl = default_ttl
self.init_cache_table()
def init_cache_table(self):
"""初始化缓存表"""
cursor = self.db_connection.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
result TEXT,
access_count INTEGER DEFAULT 0,
last_access INTEGER,
expires_at INTEGER
)
''')
self.db_connection.commit()
def get(self, key):
"""获取缓存"""
cursor = self.db_connection.cursor()
cursor.execute(
"SELECT result, expires_at FROM cache WHERE key = ?",
(key,)
)
result = cursor.fetchone()
if not result:
return None
# 检查是否过期
if result['expires_at'] < int(time.time()):
self.delete(key)
return None
# 更新访问统计
cursor.execute(
"UPDATE cache SET access_count = access_count + 1, last_access = ? WHERE key = ?",
(int(time.time()), key)
)
self.db_connection.commit()
return json.loads(result['result'])
def set(self, key, value, ttl=None):
"""设置缓存"""
if ttl is None:
ttl = self.default_ttl
# 检查缓存大小
self._check_cache_size()
cursor = self.db_connection.cursor()
expires_at = int(time.time()) + ttl
cursor.execute(
"INSERT OR REPLACE INTO cache (key, result, access_count, last_access, expires_at) VALUES (?, ?, ?, ?, ?)",
(key, json.dumps(value), 0, int(time.time()), expires_at)
)
self.db_connection.commit()
def delete(self, key):
"""删除缓存"""
cursor = self.db_connection.cursor()
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
self.db_connection.commit()
def clear_expired(self):
"""清理过期缓存"""
cursor = self.db_connection.cursor()
cursor.execute(
"DELETE FROM cache WHERE expires_at < ?",
(int(time.time()),)
)
self.db_connection.commit()
def _check_cache_size(self):
"""检查缓存大小,如果超过最大大小,删除最少访问的缓存"""
cursor = self.db_connection.cursor()
cursor.execute("SELECT COUNT(*) as count FROM cache")
result = cursor.fetchone()
if result['count'] >= self.max_size:
# 删除最少访问的缓存
cursor.execute(
"DELETE FROM cache WHERE key IN (SELECT key FROM cache ORDER BY access_count, last_access LIMIT ?)",
(result['count'] - self.max_size + 10,) # 多删除一些,避免频繁检查
)
self.db_connection.commit()
# 使用示例
cache_manager = CacheManager(db_connection)
result = cache_manager.get(cache_key)
if not result:
# 处理请求
result = process_request()
cache_manager.set(cache_key, result)
return result
这种缓存管理的设计提高了系统的性能,支持基于访问频率的缓存策略和自动清理过期缓存。
总结¶
通过实施以上改进建议,可以显著提高 DocuSnap-Backend 代码库的质量、可维护性和可扩展性。这些改进包括:
- 代码结构改进:
- 模块化重构
-
引入层次结构
-
代码质量改进:
- 引入面向对象编程
- 引入依赖注入
-
添加单元测试
-
可扩展性改进:
- 引入抽象接口
-
引入插件机制
-
安全性改进:
- 改进密钥管理
-
添加访问控制
-
性能优化:
- 优化数据库操作
- 优化缓存策略
这些改进可以分阶段实施,先解决最紧急的问题,然后逐步实施其他改进。每次改进后,都应该进行充分的测试,确保系统的稳定性和可靠性。