第六章:核心组件实现
本章深入各核心组件的具体实现细节。
评估器实现
Semantic Similarity 评估器
基于 Embedding 的语义相似度评估:
import numpy as np
from openai import OpenAI
class SemanticSimilarityEvaluator:
"""
语义相似度评估器
使用Embedding计算文本语义相似度
"""
def __init__(self, model: str = "text-embedding-3-small"):
self.client = OpenAI()
self.model = model
def get_embedding(self, text: str) -> np.ndarray:
"""获取文本Embedding"""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return np.array(response.data[0].embedding)
def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算余弦相似度"""
dot_product = np.dot(vec1, vec2)
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return dot_product / norm
def evaluate(self, output: str, reference: str) -> float:
"""
评估输出与参考的语义相似度
Args:
output: 模型输出文本
reference: 参考答案文本
Returns:
相似度得分 (0-1)
"""
if not output or not reference:
return 0.0
emb_output = self.get_embedding(output)
emb_reference = self.get_embedding(reference)
return self.cosine_similarity(emb_output, emb_reference)
G-Eval 评估器实现
使用 LLM 作为评估器:
import json
from typing import Dict, Optional
class GEvalEvaluator:
"""
G-Eval: 使用GPT-4进行多维度评估
基于论文: https://arxiv.org/abs/2303.16634
"""
EVAL_PROMPT = """
你是一个专业的AI输出质量评估专家。
请评估以下AI回答的质量:
【用户问题】
{input}
【AI回答】
{output}
【参考答案】(如有)
{reference}
请从以下维度评分(每项1-10分):
1. **Relevance (相关性)**: 回答是否直接回应了用户问题?
2. **Accuracy (准确性)**: 信息是否正确、准确?
3. **Coherence (连贯性)**: 逻辑是否清晰、结构是否合理?
4. **Completeness (完整性)**: 是否充分回答了问题?
5. **Fluency (流畅性)**: 语言表达是否自然流畅?
请以JSON格式返回评分:
```json
{
"relevance": <score>,
"accuracy": <score>,
"coherence": <score>,
"completeness": <score>,
"fluency": <score>,
"overall_comment": "<简短评价>"
}
"""
def __init__(self, model: str = "gpt-4-turbo-preview"):
self.client = OpenAI()
self.model = model
def evaluate(
self,
output: str,
input: str,
reference: Optional[str] = None
) -> Dict:
"""
执行G-Eval评估
Returns:
包含各维度得分和综合得分的字典
"""
prompt = self.EVAL_PROMPT.format(
input=input,
output=output,
reference=reference or "(未提供参考答案)"
)
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0, # 评估需要稳定性
response_format={"type": "json_object"}
)
result = json.loads(response.choices[0].message.content)
# 计算加权综合得分
weights = {
"relevance": 0.25,
"accuracy": 0.25,
"coherence": 0.20,
"completeness": 0.20,
"fluency": 0.10
}
overall = sum(
result[k] * weights[k]
for k in weights.keys()
) / 10
result["overall_score"] = round(overall, 3)
return result
### Task Completion 评估器
用于 Agent 任务完成度评估:
```python
class TaskCompletionEvaluator:
"""
任务完成度评估器
用于Agent系统的端到端评估
"""
def __init__(self, task_definition: Dict):
"""
Args:
task_definition: 任务定义,包含成功标准
"""
self.criteria = task_definition.get("success_criteria", [])
self.required_steps = task_definition.get("required_steps", [])
def evaluate(self, execution_trace: Dict) -> Dict:
"""
评估任务执行结果
Args:
execution_trace: 执行轨迹,包含步骤、结果
Returns:
完成度评估结果
"""
results = {
"criteria_met": [],
"criteria_failed": [],
"steps_completed": 0,
"steps_total": len(self.required_steps),
"score": 0.0
}
# 评估每个成功标准
for criterion in self.criteria:
if self._check_criterion(criterion, execution_trace):
results["criteria_met"].append(criterion)
else:
results["criteria_failed"].append(criterion)
# 评估步骤完成情况
for step in self.required_steps:
if self._check_step(step, execution_trace):
results["steps_completed"] += 1
# 计算综合得分
criteria_score = len(results["criteria_met"]) / len(self.criteria)
steps_score = results["steps_completed"] / results["steps_total"]
results["score"] = (criteria_score * 0.6 + steps_score * 0.4)
return results
def _check_criterion(self, criterion: str, trace: Dict) -> bool:
"""检查成功标准是否满足"""
# 根据criterion类型进行不同检查
if "完成" in criterion:
return trace.get("final_result") is not None
elif "提供" in criterion:
return any(criterion in str(step) for step in trace.get("steps", []))
return False
def _check_step(self, step: Dict, trace: Dict) -> bool:
"""检查步骤是否完成"""
return any(
step.get("name") in executed.get("action", "")
for executed in trace.get("steps", [])
)
数据管理实现
Golden Set 管理
import json
import hashlib
from pathlib import Path
from datetime import datetime
class GoldenSetManager:
"""
Golden Set 数据管理器
支持版本锁定、完整性校验
"""
def __init__(self, data_dir: Path):
self.data_dir = data_dir
self.registry_file = data_dir / "registry.json"
self._load_registry()
def _load_registry(self):
"""加载版本注册表"""
if self.registry_file.exists():
self.registry = json.loads(self.registry_file.read_text())
else:
self.registry = {"versions": {}}
def create_version(
self,
name: str,
cases: List[Dict],
metadata: Dict = None
) -> str:
"""
创建新版本
Args:
name: 数据集名称
cases: 测试案例列表
metadata: 元数据
Returns:
版本ID
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
version_id = f"{name}_v{timestamp}"
# 保存数据
data_file = self.data_dir / f"{version_id}.json"
data = {
"name": name,
"version_id": version_id,
"created_at": timestamp,
"metadata": metadata or {},
"cases": cases
}
data_file.write_text(json.dumps(data, indent=2, ensure_ascii=False))
# 计算checksum
checksum = hashlib.sha256(
json.dumps(cases, sort_keys=True).encode()
).hexdigest()
# 注册版本
self.registry["versions"][version_id] = {
"file": str(data_file),
"checksum": checksum,
"locked": False,
"case_count": len(cases)
}
self._save_registry()
return version_id
def lock_version(self, version_id: str):
"""锁定版本,禁止修改"""
if version_id not in self.registry["versions"]:
raise ValueError(f"Version {version_id} not found")
self.registry["versions"][version_id]["locked"] = True
self.registry["versions"][version_id]["locked_at"] = datetime.now().isoformat()
self._save_registry()
def load_version(self, version_id: str) -> Dict:
"""加载指定版本"""
entry = self.registry["versions"].get(version_id)
if not entry:
raise ValueError(f"Version {version_id} not found")
data_file = Path(entry["file"])
data = json.loads(data_file.read_text())
# 校验完整性
current_checksum = hashlib.sha256(
json.dumps(data["cases"], sort_keys=True).encode()
).hexdigest()
if current_checksum != entry["checksum"]:
raise IntegrityError(
f"Checksum mismatch for {version_id}. "
f"Expected: {entry['checksum']}, Got: {current_checksum}"
)
return data
def _save_registry(self):
"""保存注册表"""
self.registry_file.write_text(
json.dumps(self.registry, indent=2)
)
案例数据结构
# 标准测试案例结构
@dataclass
class TestCase:
"""测试案例数据结构"""
id: str # 案例唯一标识
input: str # 输入内容
reference: Optional[str] = None # 参考答案(可选)
category: Optional[str] = None # 分类标签
difficulty: str = "normal" # 难度: easy/normal/hard
metadata: Dict = field(default_factory=dict)
# 评估期望
expected_criteria: List[Dict] = field(default_factory=list)
# 示例:
# {
# "id": "qa_001",
# "input": "什么是机器学习?",
# "reference": "机器学习是人工智能的一个分支...",
# "category": "qa_concept",
# "difficulty": "normal",
# "metadata": {"domain": "AI基础"},
# "expected_criteria": [
# {"name": "relevance", "threshold": 0.8},
# {"name": "accuracy", "threshold": 0.7}
# ]
# }
执行引擎实现
并行执行器
import asyncio
from typing import List, Callable
from concurrent.futures import ThreadPoolExecutor
class ParallelExecutor:
"""
并行执行引擎
支持异步批量处理
"""
def __init__(
self,
max_concurrent: int = 10,
retry_count: int = 3,
timeout_sec: float = 30.0
):
self.max_concurrent = max_concurrent
self.retry_count = retry_count
self.timeout_sec = timeout_sec
self.semaphore = asyncio.Semaphore(max_concurrent)
async def execute_batch(
self,
cases: List[TestCase],
process_func: Callable,
progress_callback: Callable = None
) -> List[Result]:
"""
并行批量执行
Args:
cases: 测试案例列表
process_func: 处理函数
progress_callback: 进度回调
Returns:
结果列表
"""
results = []
completed = 0
async def process_with_semaphore(case):
async with self.semaphore:
result = await self._execute_with_retry(case, process_func)
completed += 1
if progress_callback:
progress_callback(completed, len(cases))
return result
# 并行执行所有案例
tasks = [process_with_semaphore(case) for case in cases]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
return [
r if not isinstance(r, Exception)
else Result(error=str(r))
for r in results
]
async def _execute_with_retry(
self,
case: TestCase,
process_func: Callable
) -> Result:
"""带重试的执行"""
for attempt in range(self.retry_count):
try:
result = await asyncio.wait_for(
process_func(case),
timeout=self.timeout_sec
)
return result
except asyncio.TimeoutError:
if attempt == self.retry_count - 1:
return Result(
case_id=case.id,
error="Timeout after {} retries".format(self.retry_count)
)
await asyncio.sleep(1 * (attempt + 1)) # 指数退避
except Exception as e:
if attempt == self.retry_count - 1:
return Result(case_id=case.id, error=str(e))
await asyncio.sleep(0.5)
执行上下文管理
class ExecutionContext:
"""
执行上下文管理
确保可追溯性
"""
def __init__(self):
self.run_id = self._generate_run_id()
self.timestamp = datetime.now()
self.config_snapshot = None
self.model_version = None
self.dataset_version = None
def _generate_run_id(self) -> str:
"""生成唯一执行ID"""
return f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
def record_config(self, config: Dict):
"""记录配置快照"""
self.config_snapshot = json.dumps(config, sort_keys=True)
def to_dict(self) -> Dict:
"""导出为字典"""
return {
"run_id": self.run_id,
"timestamp": self.timestamp.isoformat(),
"config_snapshot": self.config_snapshot,
"model_version": self.model_version,
"dataset_version": self.dataset_version
}
结果处理实现
结果聚合器
import statistics
from typing import List, Dict
class ResultAggregator:
"""
结果聚合分析器
"""
def aggregate(self, results: List[Result]) -> Dict:
"""
聚合分析结果
Returns:
聚合统计报告
"""
scores = [r.scores.get("overall", 0) for r in results if not r.error]
categories = self._group_by_category(results)
return {
"total_cases": len(results),
"successful": len(scores),
"failed": len(results) - len(scores),
"statistics": {
"mean": statistics.mean(scores) if scores else 0,
"median": statistics.median(scores) if scores else 0,
"std": statistics.stdev(scores) if len(scores) > 1 else 0,
"min": min(scores) if scores else 0,
"max": max(scores) if scores else 0,
},
"pass_rate": self._calculate_pass_rate(scores),
"by_category": categories,
"failed_cases": [
{"id": r.case_id, "error": r.error}
for r in results if r.error
]
}
def _calculate_pass_rate(self, scores: List[float], threshold: float = 0.7) -> float:
"""计算通过率"""
passed = sum(1 for s in scores if s >= threshold)
return passed / len(scores) if scores else 0
def _group_by_category(self, results: List[Result]) -> Dict:
"""按类别分组统计"""
categories = {}
for r in results:
cat = r.case.category or "default"
if cat not in categories:
categories[cat] = {"scores": [], "count": 0}
if not r.error:
categories[cat]["scores"].append(r.scores.get("overall", 0))
categories[cat]["count"] += 1
# 计算每类统计
for cat, data in categories.items():
if data["scores"]:
data["mean"] = statistics.mean(data["scores"])
data["pass_rate"] = self._calculate_pass_rate(data["scores"])
return categories
版本对比器
class VersionComparator:
"""
版本对比分析器
用于回归评估
"""
def compare(
self,
baseline: Dict,
current: Dict
) -> Dict:
"""
对比两个版本评估结果
Args:
baseline: 基线版本结果
current: 当前版本结果
Returns:
对比报告
"""
comparison = {
"baseline_version": baseline.get("version_id"),
"current_version": current.get("version_id"),
"overall_change": 0.0,
"category_changes": {},
"regression_cases": [],
"improvement_cases": [],
"status": "unknown"
}
# 计算整体变化
baseline_mean = baseline["statistics"]["mean"]
current_mean = current["statistics"]["mean"]
comparison["overall_change"] = current_mean - baseline_mean
# 分类对比
for cat in baseline["by_category"]:
if cat in current["by_category"]:
delta = (
current["by_category"][cat]["mean"]
- baseline["by_category"][cat]["mean"]
)
comparison["category_changes"][cat] = {
"baseline": baseline["by_category"][cat]["mean"],
"current": current["by_category"][cat]["mean"],
"delta": delta,
"status": "improved" if delta > 0.02 else
"regressed" if delta < -0.02 else "stable"
}
# 案例级对比
for case_id in self._get_common_cases(baseline, current):
baseline_score = self._get_case_score(baseline, case_id)
current_score = self._get_case_score(current, case_id)
if current_score < baseline_score - 0.1:
comparison["regression_cases"].append({
"case_id": case_id,
"baseline": baseline_score,
"current": current_score,
"delta": current_score - baseline_score
})
elif current_score > baseline_score + 0.1:
comparison["improvement_cases"].append({
"case_id": case_id,
"baseline": baseline_score,
"current": current_score,
"delta": current_score - baseline_score
})
# 综合状态判断
if comparison["overall_change"] > 0.05:
comparison["status"] = "improved"
elif comparison["overall_change"] < -0.05:
comparison["status"] = "regressed"
else:
comparison["status"] = "stable"
return comparison
小结
核心组件实现要点:
| 组件 | 实现要点 |
|---|---|
| 评估器 | Semantic/G-Eval/Task三种类型,各有适用场景 |
| 数据管理 | 版本锁定、完整性校验、结构化案例 |
| 执行引擎 | 并行处理、重试机制、上下文追溯 |
| 结果处理 | 聚合统计、版本对比、分类分析 |
✅ Semantic评估器是否支持Embedding缓存? ✅ G-Eval是否设置temperature=0? ✅ Golden Set是否锁定版本并校验完整性? ✅ 执行是否支持并行和重试? ✅ 结果是否完整追溯上下文? ✅ 是否支持版本对比分析?
下一章,我们将通过实战案例将这些组件组装为完整的 LLM 评估 Harness。