Skip to content

改进建议

基于对 DocuSnap-Backend 代码库的质量评估,本页面提供具体的改进建议,帮助提升代码质量、可维护性和可扩展性。

代码结构改进

1. 模块化重构

当前代码主要集中在 app.py 文件中,建议将代码拆分为多个模块文件:

DocuSnap-Backend/
├── app.py                  # 主应用入口,只包含 Flask 应用配置和路由定义
├── modules/
│   ├── __init__.py
│   ├── ocr.py              # OCR 处理相关功能
│   ├── llm.py              # LLM 处理相关功能
│   ├── security.py         # 安全加密相关功能
│   ├── cache.py            # 缓存和数据持久化相关功能
│   └── tasks.py            # 任务处理相关功能
├── utils/
│   ├── __init__.py
│   ├── db.py               # 数据库操作工具
│   ├── image.py            # 图像处理工具
│   └── validation.py       # 请求验证工具
├── config/
│   ├── __init__.py
│   └── settings.py         # 配置参数
├── prompts/
│   ├── __init__.py
│   ├── document.py         # 文档处理提示
│   ├── form.py             # 表单处理提示
│   └── form_filling.py     # 表单填充提示
└── api/
    ├── __init__.py
    ├── document.py         # 文档处理 API
    ├── form.py             # 表单处理 API
    └── task.py             # 任务状态 API

这种结构将代码按功能模块划分,提高了代码的组织性和可维护性。

2. 引入层次结构

建议引入更清晰的层次结构,将代码分为以下几层:

  1. API 层:处理 HTTP 请求和响应
  2. 服务层:实现业务逻辑
  3. 数据访问层:处理数据存储和检索
  4. 外部服务层:与外部服务(OCR、LLM)交互

示例实现:

# API 层 (api/document.py)
@app.route('/api/process_document', methods=['POST'])
def process_document():
    # 请求解析和验证
    data = request.get_json()
    decrypted_data, aes_key = security_service.decrypt_request(data)

    # 调用服务层
    task_id = document_service.process_document(decrypted_data)

    # 响应生成
    response = security_service.encrypt_response({'task_id': task_id}, aes_key)
    return jsonify(response), 202

# 服务层 (modules/document_service.py)
def process_document(data):
    # 业务逻辑
    cache_key = generate_cache_key(data)
    cached_result = cache_repository.get_cache(cache_key)

    if cached_result:
        return cached_result

    task_id = str(uuid.uuid4())
    task = create_task(task_id, 'document', data, cache_key)
    task_queue.put(task)

    return task_id

# 数据访问层 (modules/cache_repository.py)
def get_cache(cache_key):
    conn = get_db_connection()
    cursor = conn.cursor()

    cursor.execute(
        "SELECT result FROM cache WHERE key = ? AND expires_at > ?",
        (cache_key, int(time.time()))
    )

    result = cursor.fetchone()
    conn.close()

    if result:
        return json.loads(result[0])
    else:
        return None

# 外部服务层 (modules/ocr_service.py)
def call_ocr_service(image_data):
    try:
        files = {'image': ('image.png', image_data, 'image/png')}
        response = requests.post(OCR_SERVICE_URL, files=files, timeout=OCR_TIMEOUT)

        if response.status_code == 200:
            return response.json()
        else:
            raise Exception(f"OCR 服务返回错误: {response.status_code}")
    except requests.RequestException as e:
        raise Exception(f"OCR 服务调用失败: {str(e)}")

这种层次结构提高了代码的组织性和可维护性,同时也便于单元测试和功能扩展。

代码质量改进

1. 引入面向对象编程

当前代码主要使用函数式编程风格,建议引入更多的面向对象编程,使用类和对象组织代码:

# 任务处理器类
class TaskProcessor:
    def __init__(self, db_connection, ocr_service, llm_service):
        self.db_connection = db_connection
        self.ocr_service = ocr_service
        self.llm_service = llm_service
        self.task_queue = Queue()
        self.workers = []

    def start_workers(self, num_workers=4):
        for _ in range(num_workers):
            worker = threading.Thread(target=self._worker_loop, daemon=True)
            worker.start()
            self.workers.append(worker)

    def add_task(self, task):
        self.task_queue.put(task)
        return task['id']

    def _worker_loop(self):
        while True:
            task = self.task_queue.get()
            try:
                self._process_task(task)
            except Exception as e:
                self._update_task_status(task['id'], 'error', str(e))
            finally:
                self.task_queue.task_done()

    def _process_task(self, task):
        # 根据任务类型选择处理策略
        if task['type'] == 'document':
            self._process_document_task(task)
        elif task['type'] == 'form':
            self._process_form_task(task)
        elif task['type'] == 'form_filling':
            self._process_form_filling_task(task)

    def _process_document_task(self, task):
        # 文档处理逻辑
        pass

    def _update_task_status(self, task_id, status, result=None):
        # 更新任务状态
        pass

这种面向对象的设计提高了代码的组织性和可维护性,同时也便于依赖注入和单元测试。

2. 引入依赖注入

当前代码中的依赖关系是硬编码的,建议引入依赖注入,提高代码的灵活性和可测试性:

# 依赖注入容器
class Container:
    def __init__(self):
        self.services = {}

    def register(self, name, instance):
        self.services[name] = instance

    def get(self, name):
        if name not in self.services:
            raise Exception(f"Service '{name}' not registered")
        return self.services[name]

# 应用初始化
def init_app():
    container = Container()

    # 注册服务
    db_connection = get_db_connection()
    container.register('db_connection', db_connection)

    ocr_service = OCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
    container.register('ocr_service', ocr_service)

    llm_service = LLMService(LLM_API_KEY, LLM_MODEL)
    container.register('llm_service', llm_service)

    task_processor = TaskProcessor(
        db_connection=container.get('db_connection'),
        ocr_service=container.get('ocr_service'),
        llm_service=container.get('llm_service')
    )
    container.register('task_processor', task_processor)

    return container

# 使用容器
container = init_app()
task_processor = container.get('task_processor')
task_processor.start_workers()

这种依赖注入的设计提高了代码的灵活性和可测试性,同时也便于替换底层实现。

3. 添加单元测试

当前代码缺乏单元测试,建议添加单元测试,提高代码质量和可靠性:

# tests/test_ocr_service.py
import unittest
from unittest.mock import patch, Mock
from modules.ocr_service import OCRService

class TestOCRService(unittest.TestCase):
    def setUp(self):
        self.ocr_service = OCRService('http://example.com/ocr', 10)

    @patch('modules.ocr_service.requests.post')
    def test_call_ocr_service_success(self, mock_post):
        # 模拟成功响应
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = {'text': 'Hello World'}
        mock_post.return_value = mock_response

        # 调用服务
        result = self.ocr_service.call_ocr_service(b'image_data')

        # 验证结果
        self.assertEqual(result, {'text': 'Hello World'})
        mock_post.assert_called_once()

    @patch('modules.ocr_service.requests.post')
    def test_call_ocr_service_error(self, mock_post):
        # 模拟错误响应
        mock_response = Mock()
        mock_response.status_code = 500
        mock_post.return_value = mock_response

        # 验证异常
        with self.assertRaises(Exception) as context:
            self.ocr_service.call_ocr_service(b'image_data')

        self.assertIn('OCR 服务返回错误: 500', str(context.exception))
        mock_post.assert_called_once()

# 运行测试
if __name__ == '__main__':
    unittest.main()

单元测试可以验证代码的正确性,发现潜在的问题,并支持重构和优化。

可扩展性改进

1. 引入抽象接口

当前代码直接使用具体实现,建议引入抽象接口,提高系统的灵活性和可扩展性:

# 抽象接口
class OCRService:
    def process_image(self, image_data):
        """处理图像并返回文本结果"""
        raise NotImplementedError("子类必须实现此方法")

# 具体实现
class CnOCRService(OCRService):
    def __init__(self, service_url, timeout):
        self.service_url = service_url
        self.timeout = timeout

    def process_image(self, image_data):
        """使用 CnOCR 服务处理图像"""
        try:
            files = {'image': ('image.png', image_data, 'image/png')}
            response = requests.post(self.service_url, files=files, timeout=self.timeout)

            if response.status_code == 200:
                ocr_result = response.json()
                return self._extract_text(ocr_result)
            else:
                raise Exception(f"OCR 服务返回错误: {response.status_code}")
        except requests.RequestException as e:
            raise Exception(f"OCR 服务调用失败: {str(e)}")

    def _extract_text(self, ocr_result):
        """从 OCR 结果中提取文本"""
        if 'text' not in ocr_result:
            raise Exception("OCR 结果格式错误")

        return ocr_result['text'].strip()

# 使用示例
ocr_service = CnOCRService(OCR_SERVICE_URL, OCR_TIMEOUT)
text = ocr_service.process_image(image_data)

这种抽象接口的设计提高了系统的灵活性和可扩展性,可以轻松替换底层实现,如更换 OCR 服务提供商。

2. 引入插件机制

当前系统缺乏插件机制,建议引入插件机制,支持动态扩展功能:

# 插件管理器
class PluginManager:
    def __init__(self):
        self.plugins = {}

    def register_plugin(self, name, plugin):
        """注册插件"""
        self.plugins[name] = plugin

    def get_plugin(self, name):
        """获取插件"""
        if name not in self.plugins:
            raise Exception(f"Plugin '{name}' not registered")
        return self.plugins[name]

    def list_plugins(self):
        """列出所有插件"""
        return list(self.plugins.keys())

# 插件接口
class ProcessorPlugin:
    def process(self, task):
        """处理任务"""
        raise NotImplementedError("子类必须实现此方法")

# 具体插件实现
class DocumentProcessor(ProcessorPlugin):
    def __init__(self, ocr_service, llm_service):
        self.ocr_service = ocr_service
        self.llm_service = llm_service

    def process(self, task):
        """处理文档任务"""
        # 文档处理逻辑
        pass

# 使用示例
plugin_manager = PluginManager()
plugin_manager.register_plugin('document', DocumentProcessor(ocr_service, llm_service))
plugin_manager.register_plugin('form', FormProcessor(ocr_service, llm_service))

# 处理任务
processor = plugin_manager.get_plugin(task['type'])
result = processor.process(task)

这种插件机制的设计提高了系统的可扩展性,可以轻松添加新的功能模块,如新的任务类型。

安全性改进

1. 改进密钥管理

当前系统的密钥管理相对简单,建议改进密钥管理机制:

# 密钥管理器
class KeyManager:
    def __init__(self, key_store_path):
        self.key_store_path = key_store_path
        self.keys = {}
        self.load_keys()

    def load_keys(self):
        """加载密钥"""
        try:
            with open(self.key_store_path, 'rb') as f:
                encrypted_keys = f.read()

            # 使用主密钥解密密钥存储
            master_key = os.environ.get('MASTER_KEY')
            if not master_key:
                raise Exception("Missing MASTER_KEY environment variable")

            decrypted_keys = self._decrypt_with_master_key(encrypted_keys, master_key)
            self.keys = json.loads(decrypted_keys)
        except Exception as e:
            raise Exception(f"Failed to load keys: {str(e)}")

    def get_private_key(self):
        """获取 RSA 私钥"""
        if 'private_key' not in self.keys:
            raise Exception("Private key not found")

        private_key_data = self.keys['private_key']
        return RSA.import_key(private_key_data)

    def rotate_keys(self):
        """轮换密钥"""
        # 生成新的密钥对
        key = RSA.generate(2048)
        private_key = key.export_key()
        public_key = key.publickey().export_key()

        # 更新密钥存储
        self.keys['private_key'] = private_key.decode('utf-8')
        self.keys['public_key'] = public_key.decode('utf-8')

        # 保存密钥
        self._save_keys()

        return {
            'private_key': private_key,
            'public_key': public_key
        }

    def _save_keys(self):
        """保存密钥"""
        # 使用主密钥加密密钥存储
        master_key = os.environ.get('MASTER_KEY')
        if not master_key:
            raise Exception("Missing MASTER_KEY environment variable")

        keys_json = json.dumps(self.keys)
        encrypted_keys = self._encrypt_with_master_key(keys_json, master_key)

        with open(self.key_store_path, 'wb') as f:
            f.write(encrypted_keys)

    def _encrypt_with_master_key(self, data, master_key):
        """使用主密钥加密数据"""
        # 加密实现
        pass

    def _decrypt_with_master_key(self, encrypted_data, master_key):
        """使用主密钥解密数据"""
        # 解密实现
        pass

这种密钥管理的设计提高了系统的安全性,支持密钥轮换和安全存储。

2. 添加访问控制

当前系统缺乏细粒度的访问控制,建议添加访问控制机制:

# 访问控制装饰器
def require_auth(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 获取认证信息
        auth_header = request.headers.get('Authorization')
        if not auth_header:
            return jsonify({"error": "Missing Authorization header"}), 401

        try:
            # 验证认证信息
            token = auth_header.split(' ')[1]
            payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])

            # 检查权限
            if 'permissions' not in payload:
                return jsonify({"error": "Invalid token"}), 401

            # 获取所需权限
            required_permission = getattr(func, '_required_permission', None)
            if required_permission and required_permission not in payload['permissions']:
                return jsonify({"error": "Insufficient permissions"}), 403

            # 将用户信息添加到请求上下文
            g.user = payload

            return func(*args, **kwargs)
        except jwt.InvalidTokenError:
            return jsonify({"error": "Invalid token"}), 401

    return wrapper

# 权限装饰器
def require_permission(permission):
    def decorator(func):
        func._required_permission = permission
        return func
    return decorator

# 使用示例
@app.route('/api/process_document', methods=['POST'])
@require_auth
@require_permission('document:write')
def process_document():
    # 处理请求
    pass

这种访问控制的设计提高了系统的安全性,支持用户认证和授权。

性能优化

1. 优化数据库操作

当前系统的数据库操作相对简单,建议优化数据库操作:

# 数据库连接池
class DBConnectionPool:
    def __init__(self, db_path, max_connections=10):
        self.db_path = db_path
        self.max_connections = max_connections
        self.connections = Queue(maxsize=max_connections)
        self.active_connections = 0

        # 初始化连接池
        for _ in range(max_connections):
            self.connections.put(self._create_connection())

    def _create_connection(self):
        """创建数据库连接"""
        conn = sqlite3.connect(self.db_path)
        conn.row_factory = sqlite3.Row
        return conn

    def get_connection(self):
        """获取数据库连接"""
        if not self.connections.empty():
            return self.connections.get()

        if self.active_connections < self.max_connections:
            self.active_connections += 1
            return self._create_connection()

        # 等待可用连接
        return self.connections.get(block=True)

    def release_connection(self, conn):
        """释放数据库连接"""
        self.connections.put(conn)

    def close_all(self):
        """关闭所有连接"""
        while not self.connections.empty():
            conn = self.connections.get()
            conn.close()

# 数据库操作上下文管理器
class DBConnection:
    def __init__(self, pool):
        self.pool = pool
        self.conn = None

    def __enter__(self):
        self.conn = self.pool.get_connection()
        return self.conn

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.pool.release_connection(self.conn)

# 使用示例
db_pool = DBConnectionPool(DATABASE_PATH)

def get_task_result(task_id):
    with DBConnection(db_pool) as conn:
        cursor = conn.cursor()
        cursor.execute(
            "SELECT status, result FROM tasks WHERE id = ?",
            (task_id,)
        )
        return cursor.fetchone()

这种数据库连接池的设计提高了系统的性能,减少了数据库连接的开销。

2. 优化缓存策略

当前系统的缓存策略相对简单,建议优化缓存策略:

# 缓存管理器
class CacheManager:
    def __init__(self, db_connection, max_size=1000, default_ttl=86400):
        self.db_connection = db_connection
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.init_cache_table()

    def init_cache_table(self):
        """初始化缓存表"""
        cursor = self.db_connection.cursor()
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS cache (
            key TEXT PRIMARY KEY,
            result TEXT,
            access_count INTEGER DEFAULT 0,
            last_access INTEGER,
            expires_at INTEGER
        )
        ''')
        self.db_connection.commit()

    def get(self, key):
        """获取缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute(
            "SELECT result, expires_at FROM cache WHERE key = ?",
            (key,)
        )

        result = cursor.fetchone()
        if not result:
            return None

        # 检查是否过期
        if result['expires_at'] < int(time.time()):
            self.delete(key)
            return None

        # 更新访问统计
        cursor.execute(
            "UPDATE cache SET access_count = access_count + 1, last_access = ? WHERE key = ?",
            (int(time.time()), key)
        )
        self.db_connection.commit()

        return json.loads(result['result'])

    def set(self, key, value, ttl=None):
        """设置缓存"""
        if ttl is None:
            ttl = self.default_ttl

        # 检查缓存大小
        self._check_cache_size()

        cursor = self.db_connection.cursor()
        expires_at = int(time.time()) + ttl

        cursor.execute(
            "INSERT OR REPLACE INTO cache (key, result, access_count, last_access, expires_at) VALUES (?, ?, ?, ?, ?)",
            (key, json.dumps(value), 0, int(time.time()), expires_at)
        )

        self.db_connection.commit()

    def delete(self, key):
        """删除缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
        self.db_connection.commit()

    def clear_expired(self):
        """清理过期缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute(
            "DELETE FROM cache WHERE expires_at < ?",
            (int(time.time()),)
        )
        self.db_connection.commit()

    def _check_cache_size(self):
        """检查缓存大小,如果超过最大大小,删除最少访问的缓存"""
        cursor = self.db_connection.cursor()
        cursor.execute("SELECT COUNT(*) as count FROM cache")
        result = cursor.fetchone()

        if result['count'] >= self.max_size:
            # 删除最少访问的缓存
            cursor.execute(
                "DELETE FROM cache WHERE key IN (SELECT key FROM cache ORDER BY access_count, last_access LIMIT ?)",
                (result['count'] - self.max_size + 10,)  # 多删除一些,避免频繁检查
            )
            self.db_connection.commit()

# 使用示例
cache_manager = CacheManager(db_connection)
result = cache_manager.get(cache_key)

if not result:
    # 处理请求
    result = process_request()
    cache_manager.set(cache_key, result)

return result

这种缓存管理的设计提高了系统的性能,支持基于访问频率的缓存策略和自动清理过期缓存。

总结

通过实施以上改进建议,可以显著提高 DocuSnap-Backend 代码库的质量、可维护性和可扩展性。这些改进包括:

  1. 代码结构改进
  2. 模块化重构
  3. 引入层次结构

  4. 代码质量改进

  5. 引入面向对象编程
  6. 引入依赖注入
  7. 添加单元测试

  8. 可扩展性改进

  9. 引入抽象接口
  10. 引入插件机制

  11. 安全性改进

  12. 改进密钥管理
  13. 添加访问控制

  14. 性能优化

  15. 优化数据库操作
  16. 优化缓存策略

这些改进可以分阶段实施,先解决最紧急的问题,然后逐步实施其他改进。每次改进后,都应该进行充分的测试,确保系统的稳定性和可靠性。