Datawhale AI夏令营-task2.2-优化思路探索
🎯 成绩提升总览
成绩对比:从Baseline的172分提升到Baseline3的192分,提升20分(11.6%)
核心优化:API增强 + 混合策略 + 稳定性保障
一、主要技术变化
1.1 技术架构升级
|
对比维度
|
Baseline
|
Baseline3
|
改进效果
|
|
主要方法
|
纯机器学习(TF-IDF + SGD)
|
API+ML混合策略
|
语义理解能力大幅提升
|
|
API调用
|
无
|
讯飞星火WebSocket API
|
处理复杂语义、隐喻、反讽
|
|
连接方式
|
-
|
WebSocket长连接
|
更稳定,避免频繁握手
|
|
容错机制
|
基础异常处理
|
完善的降级策略
|
确保100%处理成功率
|
|
进度管理
|
无断点续跑
|
每50条自动保存
|
支持中断恢复,降低风险
|
1.2 核心模块改进(如果表格看不懂再展开)
A. 商品识别升级
- Baseline:TF-IDF特征提取 + SGD分类器(max_features=50)
- Baseline3:API语义理解 + 规则兜底
- 效果:能理解"多语种沟通神器"→翻译器的语义关联
B. 情感分析革新
- Baseline:纯统计特征,无法处理复杂语义
- Baseline3:混合策略 - 简单情况本地处理,复杂情况API分析
- 效果:节省26.5%的API调用,同时提升复杂情感识别准确率
C. 聚类算法智能化
- Baseline:固定K=2聚类
- Baseline3:自适应选择最优聚类数(5-8范围)+ 轮廓系数评估
- 效果:符合赛题要求,聚类质量显著提升
二、WebSocket API技术优势(感兴趣可以展开)
2.1 连接稳定性对比
|
特性
|
HTTP API
|
WebSocket API
|
|
连接方式
|
短连接,每次请求建立连接
|
长连接,一次建立持续使用
|
|
握手开销
|
每次请求都有握手开销
|
仅初始建立时握手一次
|
|
网络稳定性
|
易受网络波动影响
|
连接保持,更稳定
|
|
错误恢复
|
需要重新请求
|
可检测连接状态自动重连
|
2.2 安全认证
- 使用HMAC-SHA256加密签名
- 时间戳防重放攻击
- 完整的三要素认证(APPID + APIKey + APISecret)
三、工程化保障机制
3.1 多层容错策略
1. API调用容错:内容审核错误自动降级到本地处理
2. 网络超时处理:连接失败自动切换到备用方案
3. 内容审核降级:触发内容限制时使用传统ML方法
3.2 断点续跑机制
- 问题:处理6000+条评论需数小时,中断风险高
- 解决:每50条自动保存进度,支持从断点继续
- 效果:避免重复计算和资源浪费
四、性能与成本优化
4.1 实际运行数据
- 总处理:6,476条评论
- API调用:3,479次(53.7%)
- 本地处理:1,256次(26.5%)
- 成本节省:46.3%
4.2 效率对比
|
处理阶段
|
Baseline
|
Baseline3
|
提升效果
|
|
商品识别
|
TF-IDF训练+预测
|
API直接调用
|
提升准确率
|
|
情感分析
|
批量ML预测
|
智能混合策略
|
提升质量
|
|
聚类分析
|
固定参数
|
自适应优化
|
自动最优
|
|
总体时间
|
约10分钟
|
约5小时
|
质量换时间
|
六、改进方向与核心代码
6.1 讯飞API调用核心代码(按需取用)
A. WebSocket API客户端(最有价值,调用api是一大坑)
import websocket import ssl import json import base64 import hashlib import hmac from urllib.parse import urlparse, urlencode from datetime import datetime from time import mktime from wsgiref.handlers import format_date_time # 配置信息 SPARK_CONFIG = { "appid": "your_appid", "api_secret": "your_api_secret", "api_key": "your_api_key", "domain": "lite", "spark_url": "wss://spark-api.xf-yun.com/v1.1/chat" } class SparkApiClient: def __init__(self, appid, api_key, api_secret, spark_url, domain): self.appid = appid self.api_key = api_key self.api_secret = api_secret self.spark_url = spark_url self.domain = domain self.answer = "" def create_url(self): """生成认证URL""" host = urlparse(self.spark_url).netloc path = urlparse(self.spark_url).path now = datetime.now() date = format_date_time(mktime(now.timetuple())) signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() signature_sha_base64 = base64.b64encode(signature_sha).decode('utf-8') authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8') v = { "authorization": authorization, "date": date, "host": host } return self.spark_url + '?' + urlencode(v) def call_api(self, question): """调用星火API""" self.answer = "" try: wsUrl = self.create_url() ws = websocket.WebSocketApp(wsUrl, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open) ws.question = question ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) return self.answer.strip() except Exception as e: print(f"API调用失败: {e}") return "" def on_message(self, ws, message): """处理消息""" try: data = json.loads(message) code = data['header']['code'] if code != 0: print(f'API错误: {code}') ws.close() return choices = data["payload"]["choices"] status = choices["status"] content = choices['text'][0].get('content', '') self.answer += content if status == 2: ws.close() except Exception as e: print(f"处理消息出错: {e}") def on_error(self, ws, error): print(f"WebSocket错误: {error}") def on_close(self, ws, close_status_code, close_msg): pass def on_open(self, ws): def run(): data = json.dumps(self.gen_params(ws.question)) ws.send(data) import _thread as thread thread.start_new_thread(run, ()) def gen_params(self, question): """生成请求参数""" return { "header": {"app_id": self.appid, "uid": "1234"}, "parameter": { "chat": { "domain": self.domain, "temperature": 0.1, "max_tokens": 1024 } }, "payload": { "message": { "text": [{"role": "user", "content": question}] } } } # 使用示例 spark_client = SparkApiClient(**SPARK_CONFIG) result = spark_client.call_api("分析这条评论的情感")
B. 混合策略实现
1 def analyze_sentiment_hybrid(comment_text): 2 """混合策略情感分析""" 3 # 1. 本地规则快速判断 4 local_result = local_sentiment_analysis(comment_text) 5 6 # 2. 有信心则直接使用,否则调用API 7 if local_result.get('confident', False): 8 return local_result 9 else: 10 # API分析 11 prompt = f"""请分析评论的情感和特征,返回JSON格式: 12 {{ 13 "sentiment": 数字(1-正面, 2-负面, 3-正负混合, 4-中性, 5-不相关), 14 "scenario": 数字(0或1, 是否提及使用场景), 15 "question": 数字(0或1, 是否包含疑问), 16 "suggestion": 数字(0或1, 是否包含建议) 17 }} 18 19 评论:{comment_text}""" 20 21 response = spark_client.call_api(prompt) 22 return parse_api_response(response, local_result) 23 24 def local_sentiment_analysis(text): 25 """本地规则分析""" 26 positive_words = ['好', '棒', '不错', '满意', '推荐', '喜欢'] 27 negative_words = ['差', '坏', '糟糕', '不好', '失望', '垃圾'] 28 29 pos_count = sum(1 for word in positive_words if word in text) 30 neg_count = sum(1 for word in negative_words if word in text) 31 32 # 判断置信度 33 confident = (pos_count > 0 or neg_count > 0 or 34 '?' in text or '建议' in text) 35 36 return { 37 'sentiment_category': 1 if pos_count > neg_count else 2 if neg_count > 0 else 4, 38 'user_scenario': 1 if any(w in text for w in ['出差', '工作', '会议']) else 0, 39 'user_question': 1 if '?' in text or '怎么' in text else 0, 40 'user_suggestion': 1 if '建议' in text or '应该' in text else 0, 41 'confident': confident 42 }
C. 断点续跑与进度备份
1 import pandas as pd 2 import os 3 4 def save_progress(df, filename): 5 """保存进度到本地""" 6 os.makedirs("progress", exist_ok=True) 7 filepath = f"progress/{filename}" 8 df.to_csv(filepath, index=False) 9 print(f"进度已保存: {filepath}") 10 11 def load_progress(filename): 12 """加载已保存的进度""" 13 filepath = f"progress/{filename}" 14 if os.path.exists(filepath): 15 print(f"发现进度文件: {filepath}") 16 return pd.read_csv(filepath) 17 return None 18 19 def process_comments_with_backup(): 20 """带备份的评论处理主流程""" 21 BATCH_SIZE = 50 # 每50条保存一次 22 23 # 1. 尝试加载已有进度 24 progress_data = load_progress("comments_progress.csv") 25 if progress_data is not None: 26 print("继续上次的处理进度...") 27 comments_data = progress_data 28 else: 29 print("开始新的处理...") 30 comments_data = pd.read_csv("origin_comments_data.csv") 31 # 初始化待填充的列 32 for col in ['sentiment_category', 'user_scenario', 'user_question', 'user_suggestion']: 33 if col not in comments_data.columns: 34 comments_data[col] = None 35 36 # 2. 批量处理+定期备份 37 api_calls = 0 38 local_calls = 0 39 40 for i, row in comments_data.iterrows(): 41 # 跳过已处理的数据 42 if not pd.isnull(row.get('sentiment_category')): 43 continue 44 45 comment_text = row['comment_text'] 46 47 # 分析评论 48 try: 49 result = analyze_sentiment_hybrid(comment_text) 50 if result.get('from_api', False): 51 api_calls += 1 52 else: 53 local_calls += 1 54 55 # 更新结果 56 for key, value in result.items(): 57 if key not in ['confident', 'from_api']: 58 comments_data.at[i, key] = value 59 60 except Exception as e: 61 print(f"处理第{i}条评论时出错: {e}") 62 # 使用默认值 63 comments_data.at[i, 'sentiment_category'] = 4 64 comments_data.at[i, 'user_scenario'] = 0 65 comments_data.at[i, 'user_question'] = 0 66 comments_data.at[i, 'user_suggestion'] = 0 67 68 # 定期保存进度 69 if (i + 1) % BATCH_SIZE == 0: 70 save_progress(comments_data, "comments_progress.csv") 71 print(f"已处理: {i+1}/{len(comments_data)} (API:{api_calls}, 本地:{local_calls})") 72 73 # 3. 最终保存 74 save_progress(comments_data, "comments_progress.csv") 75 print(f"处理完成! 总计 API:{api_calls}, 本地:{local_calls}") 76 77 return comments_data 78 79 # 容错版本的API调用 80 def safe_api_call(prompt, fallback_result): 81 """带容错的API调用""" 82 try: 83 response = spark_client.call_api(prompt) 84 if response and len(response.strip()) > 0: 85 # 尝试解析JSON 86 import re 87 json_match = re.search(r'\{.*\}', response, re.DOTALL) 88 if json_match: 89 import json 90 result = json.loads(json_match.group()) 91 result['from_api'] = True 92 return result 93 except Exception as e: 94 print(f"API调用失败,使用本地结果: {e}") 95 96 # API失败时返回本地结果 97 fallback_result['from_api'] = False 98 return fallback_result
6.2 改进方向与提升潜力
短期优化(+5-10分)
- 提示词工程:设计更精确的分析指令,包含上下文信息
- 多模型集成:结合多个API结果进行投票决策
- 错误处理:完善API限制和网络异常的处理机制
中期优化(+10-15分)
- 聚类算法:层次聚类 + LDA主题建模的组合方案
- 特征工程:根据商品类型设计专门的特征提取器
- 模型微调:利用星辰MaaS平台训练领域专用模型
长期突破(+15-30分)
- 多轮对话:与API进行深度交互,逐步细化分析结果
- 知识图谱:构建商品-场景-情感的知识关系网络
- Agent系统:自主选择最佳分析路径的智能代理
浙公网安备 33010602011771号