Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

第六章:核心组件实现

本章深入各核心组件的具体实现细节。

评估器实现

Semantic Similarity 评估器

基于 Embedding 的语义相似度评估:

import numpy as np
from openai import OpenAI

class SemanticSimilarityEvaluator:
    """
    语义相似度评估器
    使用Embedding计算文本语义相似度
    """
    
    def __init__(self, model: str = "text-embedding-3-small"):
        self.client = OpenAI()
        self.model = model
    
    def get_embedding(self, text: str) -> np.ndarray:
        """获取文本Embedding"""
        response = self.client.embeddings.create(
            input=text,
            model=self.model
        )
        return np.array(response.data[0].embedding)
    
    def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """计算余弦相似度"""
        dot_product = np.dot(vec1, vec2)
        norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
        return dot_product / norm
    
    def evaluate(self, output: str, reference: str) -> float:
        """
        评估输出与参考的语义相似度
        
        Args:
            output: 模型输出文本
            reference: 参考答案文本
            
        Returns:
            相似度得分 (0-1)
        """
        if not output or not reference:
            return 0.0
        
        emb_output = self.get_embedding(output)
        emb_reference = self.get_embedding(reference)
        
        return self.cosine_similarity(emb_output, emb_reference)

优化建议

  1. 缓存Embedding:相同文本不重复计算
  2. 批量处理:一次API调用处理多个文本
  3. 长文本截断:超过模型限制时分段处理

G-Eval 评估器实现

使用 LLM 作为评估器:

import json
from typing import Dict, Optional

class GEvalEvaluator:
    """
    G-Eval: 使用GPT-4进行多维度评估
    基于论文: https://arxiv.org/abs/2303.16634
    """
    
    EVAL_PROMPT = """
你是一个专业的AI输出质量评估专家。

请评估以下AI回答的质量:

【用户问题】
{input}

【AI回答】
{output}

【参考答案】(如有)
{reference}

请从以下维度评分(每项1-10分):

1. **Relevance (相关性)**: 回答是否直接回应了用户问题?
2. **Accuracy (准确性)**: 信息是否正确、准确?
3. **Coherence (连贯性)**: 逻辑是否清晰、结构是否合理?
4. **Completeness (完整性)**: 是否充分回答了问题?
5. **Fluency (流畅性)**: 语言表达是否自然流畅?

请以JSON格式返回评分:
```json
{
  "relevance": <score>,
  "accuracy": <score>,
  "coherence": <score>,
  "completeness": <score>,
  "fluency": <score>,
  "overall_comment": "<简短评价>"
}

"""

def __init__(self, model: str = "gpt-4-turbo-preview"):
    self.client = OpenAI()
    self.model = model

def evaluate(
    self,
    output: str,
    input: str,
    reference: Optional[str] = None
) -> Dict:
    """
    执行G-Eval评估
    
    Returns:
        包含各维度得分和综合得分的字典
    """
    prompt = self.EVAL_PROMPT.format(
        input=input,
        output=output,
        reference=reference or "(未提供参考答案)"
    )
    
    response = self.client.chat.completions.create(
        model=self.model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0,  # 评估需要稳定性
        response_format={"type": "json_object"}
    )
    
    result = json.loads(response.choices[0].message.content)
    
    # 计算加权综合得分
    weights = {
        "relevance": 0.25,
        "accuracy": 0.25,
        "coherence": 0.20,
        "completeness": 0.20,
        "fluency": 0.10
    }
    
    overall = sum(
        result[k] * weights[k] 
        for k in weights.keys()
    ) / 10
    
    result["overall_score"] = round(overall, 3)
    return result

### Task Completion 评估器

用于 Agent 任务完成度评估:

```python
class TaskCompletionEvaluator:
    """
    任务完成度评估器
    用于Agent系统的端到端评估
    """
    
    def __init__(self, task_definition: Dict):
        """
        Args:
            task_definition: 任务定义,包含成功标准
        """
        self.criteria = task_definition.get("success_criteria", [])
        self.required_steps = task_definition.get("required_steps", [])
    
    def evaluate(self, execution_trace: Dict) -> Dict:
        """
        评估任务执行结果
        
        Args:
            execution_trace: 执行轨迹,包含步骤、结果
            
        Returns:
            完成度评估结果
        """
        results = {
            "criteria_met": [],
            "criteria_failed": [],
            "steps_completed": 0,
            "steps_total": len(self.required_steps),
            "score": 0.0
        }
        
        # 评估每个成功标准
        for criterion in self.criteria:
            if self._check_criterion(criterion, execution_trace):
                results["criteria_met"].append(criterion)
            else:
                results["criteria_failed"].append(criterion)
        
        # 评估步骤完成情况
        for step in self.required_steps:
            if self._check_step(step, execution_trace):
                results["steps_completed"] += 1
        
        # 计算综合得分
        criteria_score = len(results["criteria_met"]) / len(self.criteria)
        steps_score = results["steps_completed"] / results["steps_total"]
        results["score"] = (criteria_score * 0.6 + steps_score * 0.4)
        
        return results
    
    def _check_criterion(self, criterion: str, trace: Dict) -> bool:
        """检查成功标准是否满足"""
        # 根据criterion类型进行不同检查
        if "完成" in criterion:
            return trace.get("final_result") is not None
        elif "提供" in criterion:
            return any(criterion in str(step) for step in trace.get("steps", []))
        return False
    
    def _check_step(self, step: Dict, trace: Dict) -> bool:
        """检查步骤是否完成"""
        return any(
            step.get("name") in executed.get("action", "")
            for executed in trace.get("steps", [])
        )

数据管理实现

Golden Set 管理

import json
import hashlib
from pathlib import Path
from datetime import datetime

class GoldenSetManager:
    """
    Golden Set 数据管理器
    支持版本锁定、完整性校验
    """
    
    def __init__(self, data_dir: Path):
        self.data_dir = data_dir
        self.registry_file = data_dir / "registry.json"
        self._load_registry()
    
    def _load_registry(self):
        """加载版本注册表"""
        if self.registry_file.exists():
            self.registry = json.loads(self.registry_file.read_text())
        else:
            self.registry = {"versions": {}}
    
    def create_version(
        self,
        name: str,
        cases: List[Dict],
        metadata: Dict = None
    ) -> str:
        """
        创建新版本
        
        Args:
            name: 数据集名称
            cases: 测试案例列表
            metadata: 元数据
            
        Returns:
            版本ID
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        version_id = f"{name}_v{timestamp}"
        
        # 保存数据
        data_file = self.data_dir / f"{version_id}.json"
        data = {
            "name": name,
            "version_id": version_id,
            "created_at": timestamp,
            "metadata": metadata or {},
            "cases": cases
        }
        data_file.write_text(json.dumps(data, indent=2, ensure_ascii=False))
        
        # 计算checksum
        checksum = hashlib.sha256(
            json.dumps(cases, sort_keys=True).encode()
        ).hexdigest()
        
        # 注册版本
        self.registry["versions"][version_id] = {
            "file": str(data_file),
            "checksum": checksum,
            "locked": False,
            "case_count": len(cases)
        }
        self._save_registry()
        
        return version_id
    
    def lock_version(self, version_id: str):
        """锁定版本,禁止修改"""
        if version_id not in self.registry["versions"]:
            raise ValueError(f"Version {version_id} not found")
        
        self.registry["versions"][version_id]["locked"] = True
        self.registry["versions"][version_id]["locked_at"] = datetime.now().isoformat()
        self._save_registry()
    
    def load_version(self, version_id: str) -> Dict:
        """加载指定版本"""
        entry = self.registry["versions"].get(version_id)
        if not entry:
            raise ValueError(f"Version {version_id} not found")
        
        data_file = Path(entry["file"])
        data = json.loads(data_file.read_text())
        
        # 校验完整性
        current_checksum = hashlib.sha256(
            json.dumps(data["cases"], sort_keys=True).encode()
        ).hexdigest()
        
        if current_checksum != entry["checksum"]:
            raise IntegrityError(
                f"Checksum mismatch for {version_id}. "
                f"Expected: {entry['checksum']}, Got: {current_checksum}"
            )
        
        return data
    
    def _save_registry(self):
        """保存注册表"""
        self.registry_file.write_text(
            json.dumps(self.registry, indent=2)
        )

案例数据结构

# 标准测试案例结构
@dataclass
class TestCase:
    """测试案例数据结构"""
    
    id: str                          # 案例唯一标识
    input: str                       # 输入内容
    reference: Optional[str] = None  # 参考答案(可选)
    category: Optional[str] = None   # 分类标签
    difficulty: str = "normal"       # 难度: easy/normal/hard
    metadata: Dict = field(default_factory=dict)
    
    # 评估期望
    expected_criteria: List[Dict] = field(default_factory=list)
    
    # 示例:
    # {
    #     "id": "qa_001",
    #     "input": "什么是机器学习?",
    #     "reference": "机器学习是人工智能的一个分支...",
    #     "category": "qa_concept",
    #     "difficulty": "normal",
    #     "metadata": {"domain": "AI基础"},
    #     "expected_criteria": [
    #         {"name": "relevance", "threshold": 0.8},
    #         {"name": "accuracy", "threshold": 0.7}
    #     ]
    # }

执行引擎实现

并行执行器

import asyncio
from typing import List, Callable
from concurrent.futures import ThreadPoolExecutor

class ParallelExecutor:
    """
    并行执行引擎
    支持异步批量处理
    """
    
    def __init__(
        self,
        max_concurrent: int = 10,
        retry_count: int = 3,
        timeout_sec: float = 30.0
    ):
        self.max_concurrent = max_concurrent
        self.retry_count = retry_count
        self.timeout_sec = timeout_sec
        self.semaphore = asyncio.Semaphore(max_concurrent)
    
    async def execute_batch(
        self,
        cases: List[TestCase],
        process_func: Callable,
        progress_callback: Callable = None
    ) -> List[Result]:
        """
        并行批量执行
        
        Args:
            cases: 测试案例列表
            process_func: 处理函数
            progress_callback: 进度回调
            
        Returns:
            结果列表
        """
        results = []
        completed = 0
        
        async def process_with_semaphore(case):
            async with self.semaphore:
                result = await self._execute_with_retry(case, process_func)
                completed += 1
                if progress_callback:
                    progress_callback(completed, len(cases))
                return result
        
        # 并行执行所有案例
        tasks = [process_with_semaphore(case) for case in cases]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # 处理异常结果
        return [
            r if not isinstance(r, Exception) 
            else Result(error=str(r))
            for r in results
        ]
    
    async def _execute_with_retry(
        self,
        case: TestCase,
        process_func: Callable
    ) -> Result:
        """带重试的执行"""
        for attempt in range(self.retry_count):
            try:
                result = await asyncio.wait_for(
                    process_func(case),
                    timeout=self.timeout_sec
                )
                return result
            except asyncio.TimeoutError:
                if attempt == self.retry_count - 1:
                    return Result(
                        case_id=case.id,
                        error="Timeout after {} retries".format(self.retry_count)
                    )
                await asyncio.sleep(1 * (attempt + 1))  # 指数退避
            except Exception as e:
                if attempt == self.retry_count - 1:
                    return Result(case_id=case.id, error=str(e))
                await asyncio.sleep(0.5)

执行上下文管理

class ExecutionContext:
    """
    执行上下文管理
    确保可追溯性
    """
    
    def __init__(self):
        self.run_id = self._generate_run_id()
        self.timestamp = datetime.now()
        self.config_snapshot = None
        self.model_version = None
        self.dataset_version = None
    
    def _generate_run_id(self) -> str:
        """生成唯一执行ID"""
        return f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
    
    def record_config(self, config: Dict):
        """记录配置快照"""
        self.config_snapshot = json.dumps(config, sort_keys=True)
    
    def to_dict(self) -> Dict:
        """导出为字典"""
        return {
            "run_id": self.run_id,
            "timestamp": self.timestamp.isoformat(),
            "config_snapshot": self.config_snapshot,
            "model_version": self.model_version,
            "dataset_version": self.dataset_version
        }

结果处理实现

结果聚合器

import statistics
from typing import List, Dict

class ResultAggregator:
    """
    结果聚合分析器
    """
    
    def aggregate(self, results: List[Result]) -> Dict:
        """
        聚合分析结果
        
        Returns:
            聚合统计报告
        """
        scores = [r.scores.get("overall", 0) for r in results if not r.error]
        categories = self._group_by_category(results)
        
        return {
            "total_cases": len(results),
            "successful": len(scores),
            "failed": len(results) - len(scores),
            
            "statistics": {
                "mean": statistics.mean(scores) if scores else 0,
                "median": statistics.median(scores) if scores else 0,
                "std": statistics.stdev(scores) if len(scores) > 1 else 0,
                "min": min(scores) if scores else 0,
                "max": max(scores) if scores else 0,
            },
            
            "pass_rate": self._calculate_pass_rate(scores),
            
            "by_category": categories,
            
            "failed_cases": [
                {"id": r.case_id, "error": r.error}
                for r in results if r.error
            ]
        }
    
    def _calculate_pass_rate(self, scores: List[float], threshold: float = 0.7) -> float:
        """计算通过率"""
        passed = sum(1 for s in scores if s >= threshold)
        return passed / len(scores) if scores else 0
    
    def _group_by_category(self, results: List[Result]) -> Dict:
        """按类别分组统计"""
        categories = {}
        for r in results:
            cat = r.case.category or "default"
            if cat not in categories:
                categories[cat] = {"scores": [], "count": 0}
            if not r.error:
                categories[cat]["scores"].append(r.scores.get("overall", 0))
            categories[cat]["count"] += 1
        
        # 计算每类统计
        for cat, data in categories.items():
            if data["scores"]:
                data["mean"] = statistics.mean(data["scores"])
                data["pass_rate"] = self._calculate_pass_rate(data["scores"])
        
        return categories

版本对比器

class VersionComparator:
    """
    版本对比分析器
    用于回归评估
    """
    
    def compare(
        self,
        baseline: Dict,
        current: Dict
    ) -> Dict:
        """
        对比两个版本评估结果
        
        Args:
            baseline: 基线版本结果
            current: 当前版本结果
            
        Returns:
            对比报告
        """
        comparison = {
            "baseline_version": baseline.get("version_id"),
            "current_version": current.get("version_id"),
            "overall_change": 0.0,
            "category_changes": {},
            "regression_cases": [],
            "improvement_cases": [],
            "status": "unknown"
        }
        
        # 计算整体变化
        baseline_mean = baseline["statistics"]["mean"]
        current_mean = current["statistics"]["mean"]
        comparison["overall_change"] = current_mean - baseline_mean
        
        # 分类对比
        for cat in baseline["by_category"]:
            if cat in current["by_category"]:
                delta = (
                    current["by_category"][cat]["mean"] 
                    - baseline["by_category"][cat]["mean"]
                )
                comparison["category_changes"][cat] = {
                    "baseline": baseline["by_category"][cat]["mean"],
                    "current": current["by_category"][cat]["mean"],
                    "delta": delta,
                    "status": "improved" if delta > 0.02 else 
                              "regressed" if delta < -0.02 else "stable"
                }
        
        # 案例级对比
        for case_id in self._get_common_cases(baseline, current):
            baseline_score = self._get_case_score(baseline, case_id)
            current_score = self._get_case_score(current, case_id)
            
            if current_score < baseline_score - 0.1:
                comparison["regression_cases"].append({
                    "case_id": case_id,
                    "baseline": baseline_score,
                    "current": current_score,
                    "delta": current_score - baseline_score
                })
            elif current_score > baseline_score + 0.1:
                comparison["improvement_cases"].append({
                    "case_id": case_id,
                    "baseline": baseline_score,
                    "current": current_score,
                    "delta": current_score - baseline_score
                })
        
        # 综合状态判断
        if comparison["overall_change"] > 0.05:
            comparison["status"] = "improved"
        elif comparison["overall_change"] < -0.05:
            comparison["status"] = "regressed"
        else:
            comparison["status"] = "stable"
        
        return comparison

小结

核心组件实现要点:

组件实现要点
评估器Semantic/G-Eval/Task三种类型,各有适用场景
数据管理版本锁定、完整性校验、结构化案例
执行引擎并行处理、重试机制、上下文追溯
结果处理聚合统计、版本对比、分类分析

实现 Checklist

✅ Semantic评估器是否支持Embedding缓存? ✅ G-Eval是否设置temperature=0? ✅ Golden Set是否锁定版本并校验完整性? ✅ 执行是否支持并行和重试? ✅ 结果是否完整追溯上下文? ✅ 是否支持版本对比分析?


下一章,我们将通过实战案例将这些组件组装为完整的 LLM 评估 Harness。