DocuSnap

Your AI-powered Personal Document Assistant

View project on GitHub

改进建议

基于对 DocuSnap-Backend 代码库的质量评估,本页面提供具体的改进建议,帮助提升代码质量、可维护性和可扩展性。

代码结构改进

1. 模块化重构

当前代码主要集中在 app.py 文件中,建议将代码拆分为多个模块文件:

DocuSnap-Backend/
├── app.py                  # 主应用入口,只包含 Flask 应用配置和路由定义
├── modules/
│   ├── __init__.py
│   ├── ocr.py              # OCR 处理相关功能
│   ├── llm.py              # LLM 处理相关功能
│   ├── security.py         # 安全加密相关功能
│   ├── cache.py            # 缓存和数据持久化相关功能
│   └── tasks.py            # 任务处理相关功能
├── utils/
│   ├── __init__.py
│   ├── db.py               # 数据库操作工具
│   ├── image.py            # 图像处理工具
│   └── validation.py       # 请求验证工具
├── config/
│   ├── __init__.py
│   └── settings.py         # 配置参数
├── prompts/
│   ├── __init__.py
│   ├── document.py         # 文档处理提示
│   ├── form.py             # 表单处理提示
│   └── form_filling.py     # 表单填充提示
└── api/
    ├── __init__.py
    ├── document.py         # 文档处理 API
    ├── form.py             # 表单处理 API
    └── task.py             # 任务状态 API

这种结构将代码按功能模块划分,提高了代码的组织性和可维护性。

2. 引入层次结构

建议引入更清晰的层次结构,将代码分为以下几层:

  1. API 层:处理 HTTP 请求和响应
  2. 服务层:实现业务逻辑
  3. 数据访问层:处理数据存储和检索
  4. 外部服务层:与外部服务(OCR、LLM)交互

示例实现:

# API 层 (api/document.py)
@app.route('/api/process_document', methods=['POST'])
def process_document():
    # 请求解析和验证
    data = request.get_json()
    decrypted_data, aes_key = security_service.decrypt_request(data)
    
    # 调用服务层
    task_id = document_service.process_document(decrypted_data)
    
    # 响应生成
    response = security_service.encrypt_response({'task_id': task_id}, aes_key)
    return jsonify(response), 202

# 服务层 (modules/document_service.py)
def process_document(data):
    # 业务逻辑
    cache_key = generate_cache_key(data)
    cached_result = cache_repository.get_cache(cache_key)
    
    if cached_result:
        return cached_result
    
    task_id = str(uuid.uuid4())
    task = create_task(task_id, 'document', data, cache_key)
    task_queue.put(task)
    
    return task_id

# 数据访问层 (modules/cache_repository.py)
def get_cache(cache_key):
    conn = get_db_connection()
    cursor = conn.cursor()
    
    cursor.execute(
        "SELECT result FROM cache WHERE key = ? AND expires_at > ?",
        (cache_key, int(time.time()))
    )
    
    result = cursor.fetchone()
    conn.close()
    
    if result:
        return json.loads(result[0])
    else:
        return None

# 外部服务层 (modules/ocr_service.py)
def call_ocr_service(image_data):
    try:
        files = {'image': ('image.png', image_data, 'image/png')}
        response = requests.post(OCR_SERVICE_URL, files=files, timeout=OCR_TIMEOUT)
        
        if response.status_code == 200:
            return response.json()
        else:
            raise Exception(f"OCR 服务返回错误: {response.status_code}")
    except requests.RequestException as e:
        raise Exception(f"OCR 服务调用失败: {str(e)}")

这种层次结构提高了代码的组织性和可维护性,同时也便于单元测试和功能扩展。

代码质量改进

1. 引入面向对象编程

当前代码主要使用函数式编程风格,建议引入更多的面向对象编程,使用类和对象组织代码:

# 任务处理器类
class TaskProcessor:
    def __init__(self, db_connection, ocr_service, llm_service):
        self.db_connection = db_connection
        self.ocr_service = ocr_service
        self.llm_service = llm_service
        self.task_queue = Queue()
        self.workers = []
    
    def start_workers(self, num_workers=4):
        for _ in range(num_workers):
            worker = threading.Thread(target=self._worker_loop, daemon=True)
            worker.start()
            self.workers.append(worker)
    
    def add_task(self, task):
        self.task_queue.put(task)
        return task['id']
    
    def _worker_loop(self):
        while True:
            task = self.task_queue.get()
            try:
                self._process_task(task)
            except Exception as e:
                self._update_task_status(task['id'], 'error', str(e))
            finally:
                self.task_queue.task_done()
    
    def _process_task(self, task):
        # 根据任务类型选择处理策略
        if task['type'] == 'document':
            self._process_document_task(task)
        elif task['type'] == 'form':
            self._process_form_task(task)
        elif task['type'] == 'form_filling':
            self._process_form_filling_task(task)
    
    def _process_document_task(self, task):
        # 文档处理逻辑
        pass
    
    def _update_task_status(self, task_id, status, result=None):
        # 更新任务状态
        pass

这种面向对象的设计提高了代码的组织性和可维护性,同时也便于依赖注入和单元测试。

2. 引入依赖注入

当前代码中的依赖关系是硬编码的,建议引入依赖注入,提高代码的灵活性和可测试性:

# 依赖注入容器
class Container:
    def __init__(self):
        self.services = {}
    
    def register(self, name, instance):
        self.services[name] = instance
    
    def get(self, name):
        if name not in self.services:
            raise Exception(f"Service '{name}' not registered")
        return self.services[name]

# 应用初始化
def init_app():
    container = Container()
    
    # 注册服务
    db_connection = get_db_connection()
    container.register('db_connection', db_connection)
    
    ocr_service = OCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
    container.register('ocr_service', ocr_service)
    
    llm_service = LLMService(LLM_API_KEY, LLM_MODEL)
    container.register('llm_service', llm_service)
    
    task_processor = TaskProcessor(
        db_connection=container.get('db_connection'),
        ocr_service=container.get('ocr_service'),
        llm_service=container.get('llm_service')
    )
    container.register('task_processor', task_processor)
    
    return container

# 使用容器
container = init_app()
task_processor = container.get('task_processor')
task_processor.start_workers()

这种依赖注入的设计提高了代码的灵活性和可测试性,同时也便于替换底层实现。

3. 添加单元测试

当前代码缺乏单元测试,建议添加单元测试,提高代码质量和可靠性:

# tests/test_ocr_service.py
import unittest
from unittest.mock import patch, Mock
from modules.ocr_service import OCRService

class TestOCRService(unittest.TestCase):
    def setUp(self):
        self.ocr_service = OCRService('http://example.com/ocr', 10)
    
    @patch('modules.ocr_service.requests.post')
    def test_call_ocr_service_success(self, mock_post):
        # 模拟成功响应
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = {'text': 'Hello World'}
        mock_post.return_value = mock_response
        
        # 调用服务
        result = self.ocr_service.call_ocr_service(b'image_data')
        
        # 验证结果
        self.assertEqual(result, {'text': 'Hello World'})
        mock_post.assert_called_once()
    
    @patch('modules.ocr_service.requests.post')
    def test_call_ocr_service_error(self, mock_post):
        # 模拟错误响应
        mock_response = Mock()
        mock_response.status_code = 500
        mock_post.return_value = mock_response
        
        # 验证异常
        with self.assertRaises(Exception) as context:
            self.ocr_service.call_ocr_service(b'image_data')
        
        self.assertIn('OCR 服务返回错误: 500', str(context.exception))
        mock_post.assert_called_once()

# 运行测试
if __name__ == '__main__':
    unittest.main()

单元测试可以验证代码的正确性,发现潜在的问题,并支持重构和优化。

可扩展性改进

1. 引入抽象接口

当前代码直接使用具体实现,建议引入抽象接口,提高系统的灵活性和可扩展性:

# 抽象接口
class OCRService:
    def process_image(self, image_data):
        """处理图像并返回文本结果"""
        raise NotImplementedError("子类必须实现此方法")

# 具体实现
class CnOCRService(OCRService):
    def __init__(self, service_url, timeout):
        self.service_url = service_url
        self.timeout = timeout
    
    def process_image(self, image_data):
        """使用 CnOCR 服务处理图像"""
        try:
            files = {'image': ('image.png', image_data, 'image/png')}
            response = requests.post(self.service_url, files=files, timeout=self.timeout)
            
            if response.status_code == 200:
                ocr_result = response.json()
                return self._extract_text(ocr_result)
            else:
                raise Exception(f"OCR 服务返回错误: {response.status_code}")
        except requests.RequestException as e:
            raise Exception(f"OCR 服务调用失败: {str(e)}")
    
    def _extract_text(self, ocr_result):
        """从 OCR 结果中提取文本"""
        if 'text' not in ocr_result:
            raise Exception("OCR 结果格式错误")
        
        return ocr_result['text'].strip()

# 使用示例
ocr_service = CnOCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
text = ocr_service.process_image(image_data)

这种抽象接口的设计提高了系统的灵活性和可扩展性,可以轻松替换底层实现,如更换 OCR 服务提供商。

2. 引入插件机制

当前系统缺乏插件机制,建议引入插件机制,支持动态扩展功能:

# 插件管理器
class PluginManager:
    def __init__(self):
        self.plugins = {}
    
    def register_plugin(self, name, plugin):
        """注册插件"""
        self.plugins[name] = plugin
    
    def get_plugin(self, name):
        """获取插件"""
        if name not in self.plugins:
            raise Exception(f"Plugin '{name}' not registered")
        return self.plugins[name]
    
    def list_plugins(self):
        """列出所有插件"""
        return list(self.plugins.keys())

# 插件接口
class ProcessorPlugin:
    def process(self, task):
        """处理任务"""
        raise NotImplementedError("子类必须实现此方法")

# 具体插件实现
class DocumentProcessor(ProcessorPlugin):
    def __init__(self, ocr_service, llm_service):
        self.ocr_service = ocr_service
        self.llm_service = llm_service
    
    def process(self, task):
        """处理文档任务"""
        # 文档处理逻辑
        pass

# 使用示例
plugin_manager = PluginManager()
plugin_manager.register_plugin('document', DocumentProcessor(ocr_service, llm_service))
plugin_manager.register_plugin('form', FormProcessor(ocr_service, llm_service))

# 处理任务
processor = plugin_manager.get_plugin(task['type'])
result = processor.process(task)

这种插件机制的设计提高了系统的可扩展性,可以轻松添加新的功能模块,如新的任务类型。

安全性改进

1. 改进密钥管理

当前系统的密钥管理相对简单,建议改进密钥管理机制:

# 密钥管理器
class KeyManager:
    def __init__(self, key_store_path):
        self.key_store_path = key_store_path
        self.keys = {}
        self.load_keys()
    
    def load_keys(self):
        """加载密钥"""
        try:
            with open(self.key_store_path, 'rb') as f:
                encrypted_keys = f.read()
            
            # 使用主密钥解密密钥存储
            master_key = os.environ.get('MASTER_KEY')
            if not master_key:
                raise Exception("Missing MASTER_KEY environment variable")
            
            decrypted_keys = self._decrypt_with_master_key(encrypted_keys, master_key)
            self.keys = json.loads(decrypted_keys)
        except Exception as e:
            raise Exception(f"Failed to load keys: {str(e)}")
    
    def get_private_key(self):
        """获取 RSA 私钥"""
        if 'private_key' not in self.keys:
            raise Exception("Private key not found")
        
        private_key_data = self.keys['private_key']
        return RSA.import_key(private_key_data)
    
    def rotate_keys(self):
        """轮换密钥"""
        # 生成新的密钥对
        key = RSA.generate(2048)
        private_key = key.export_key()
        public_key = key.publickey().export_key()
        
        # 更新密钥存储
        self.keys['private_key'] = private_key.decode('utf-8')
        self.keys['public_key'] = public_key.decode('utf-8')
        
        # 保存密钥
        self._save_keys()
        
        return {
            'private_key': private_key,
            'public_key': public_key
        }
    
    def _save_keys(self):
        """保存密钥"""
        # 使用主密钥加密密钥存储
        master_key = os.environ.get('MASTER_KEY')
        if not master_key:
            raise Exception("Missing MASTER_KEY environment variable")
        
        keys_json = json.dumps(self.keys)
        encrypted_keys = self._encrypt_with_master_key(keys_json, master_key)
        
        with open(self.key_store_path, 'wb') as f:
            f.write(encrypted_keys)
    
    def _encrypt_with_master_key(self, data, master_key):
        """使用主密钥加密数据"""
        # 加密实现
        pass
    
    def _decrypt_with_master_key(self, encrypted_data, master_key):
        """使用主密钥解密数据"""
        # 解密实现
        pass

这种密钥管理的设计提高了系统的安全性,支持密钥轮换和安全存储。

2. 添加访问控制

当前系统缺乏细粒度的访问控制,建议添加访问控制机制:

# 访问控制装饰器
def require_auth(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 获取认证信息
        auth_header = request.headers.get('Authorization')
        if not auth_header:
            return jsonify({"error": "Missing Authorization header"}), 401
        
        try:
            # 验证认证信息
            token = auth_header.split(' ')[1]
            payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])
            
            # 检查权限
            if 'permissions' not in payload:
                return jsonify({"error": "Invalid token"}), 401
            
            # 获取所需权限
            required_permission = getattr(func, '_required_permission', None)
            if required_permission and required_permission not in payload['permissions']:
                return jsonify({"error": "Insufficient permissions"}), 403
            
            # 将用户信息添加到请求上下文
            g.user = payload
            
            return func(*args, **kwargs)
        except jwt.InvalidTokenError:
            return jsonify({"error": "Invalid token"}), 401
    
    return wrapper

# 权限装饰器
def require_permission(permission):
    def decorator(func):
        func._required_permission = permission
        return func
    return decorator

# 使用示例
@app.route('/api/process_document', methods=['POST'])
@require_auth
@require_permission('document:write')
def process_document():
    # 处理请求
    pass

这种访问控制的设计提高了系统的安全性,支持用户认证和授权。

性能优化

1. 优化数据库操作

当前系统的数据库操作相对简单,建议优化数据库操作:

# 数据库连接池
class DBConnectionPool:
    def __init__(self, db_path, max_connections=10):
        self.db_path = db_path
        self.max_connections = max_connections
        self.connections = Queue(maxsize=max_connections)
        self.active_connections = 0
        
        # 初始化连接池
        for _ in range(max_connections):
            self.connections.put(self._create_connection())
    
    def _create_connection(self):
        """创建数据库连接"""
        conn = sqlite3.connect(self.db_path)
        conn.row_factory = sqlite3.Row
        return conn
    
    def get_connection(self):
        """获取数据库连接"""
        if not self.connections.empty():
            return self.connections.get()
        
        if self.active_connections < self.max_connections:
            self.active_connections += 1
            return self._create_connection()
        
        # 等待可用连接
        return self.connections.get(block=True)
    
    def release_connection(self, conn):
        """释放数据库连接"""
        self.connections.put(conn)
    
    def close_all(self):
        """关闭所有连接"""
        while not self.connections.empty():
            conn = self.connections.get()
            conn.close()

# 数据库操作上下文管理器
class DBConnection:
    def __init__(self, pool):
        self.pool = pool
        self.conn = None
    
    def __enter__(self):
        self.conn = self.pool.get_connection()
        return self.conn
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.pool.release_connection(self.conn)

# 使用示例
db_pool = DBConnectionPool(DATABASE_PATH)

def get_task_result(task_id):
    with DBConnection(db_pool) as conn:
        cursor = conn.cursor()
        cursor.execute(
            "SELECT status, result FROM tasks WHERE id = ?",
            (task_id,)
        )
        return cursor.fetchone()

这种数据库连接池的设计提高了系统的性能,减少了数据库连接的开销。

2. 优化缓存策略

当前系统的缓存策略相对简单,建议优化缓存策略:

# 缓存管理器
class CacheManager:
    def __init__(self, db_connection, max_size=1000, default_ttl=86400):
        self.db_connection = db_connection
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.init_cache_table()
    
    def init_cache_table(self):
        """初始化缓存表"""
        cursor = self.db_connection.cursor()
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS cache (
            key TEXT PRIMARY KEY,
            result TEXT,
            access_count INTEGER DEFAULT 0,
            last_access INTEGER,
            expires_at INTEGER
        )
        ''')
        self.db_connection.commit()
    
    def get(self, key):
        """获取缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute(
            "SELECT result, expires_at FROM cache WHERE key = ?",
            (key,)
        )
        
        result = cursor.fetchone()
        if not result:
            return None
        
        # 检查是否过期
        if result['expires_at'] < int(time.time()):
            self.delete(key)
            return None
        
        # 更新访问统计
        cursor.execute(
            "UPDATE cache SET access_count = access_count + 1, last_access = ? WHERE key = ?",
            (int(time.time()), key)
        )
        self.db_connection.commit()
        
        return json.loads(result['result'])
    
    def set(self, key, value, ttl=None):
        """设置缓存"""
        if ttl is None:
            ttl = self.default_ttl
        
        # 检查缓存大小
        self._check_cache_size()
        
        cursor = self.db_connection.cursor()
        expires_at = int(time.time()) + ttl
        
        cursor.execute(
            "INSERT OR REPLACE INTO cache (key, result, access_count, last_access, expires_at) VALUES (?, ?, ?, ?, ?)",
            (key, json.dumps(value), 0, int(time.time()), expires_at)
        )
        
        self.db_connection.commit()
    
    def delete(self, key):
        """删除缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
        self.db_connection.commit()
    
    def clear_expired(self):
        """清理过期缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute(
            "DELETE FROM cache WHERE expires_at < ?",
            (int(time.time()),)
        )
        self.db_connection.commit()
    
    def _check_cache_size(self):
        """检查缓存大小,如果超过最大大小,删除最少访问的缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute("SELECT COUNT(*) as count FROM cache")
        result = cursor.fetchone()
        
        if result['count'] >= self.max_size:
            # 删除最少访问的缓存
            cursor.execute(
                "DELETE FROM cache WHERE key IN (SELECT key FROM cache ORDER BY access_count, last_access LIMIT ?)",
                (result['count'] - self.max_size + 10,)  # 多删除一些,避免频繁检查
            )
            self.db_connection.commit()

# 使用示例
cache_manager = CacheManager(db_connection)
result = cache_manager.get(cache_key)

if not result:
    # 处理请求
    result = process_request()
    cache_manager.set(cache_key, result)

return result

这种缓存管理的设计提高了系统的性能,支持基于访问频率的缓存策略和自动清理过期缓存。

总结

通过实施以上改进建议,可以显著提高 DocuSnap-Backend 代码库的质量、可维护性和可扩展性。这些改进包括:

  1. 代码结构改进
    • 模块化重构
    • 引入层次结构
  2. 代码质量改进
    • 引入面向对象编程
    • 引入依赖注入
    • 添加单元测试
  3. 可扩展性改进
    • 引入抽象接口
    • 引入插件机制
  4. 安全性改进
    • 改进密钥管理
    • 添加访问控制
  5. 性能优化
    • 优化数据库操作
    • 优化缓存策略

这些改进可以分阶段实施,先解决最紧急的问题,然后逐步实施其他改进。每次改进后,都应该进行充分的测试,确保系统的稳定性和可靠性。