Datawhale AI夏令营-task2.2-优化思路探索

 

🎯 成绩提升总览

 
成绩对比:从Baseline的172分提升到Baseline3的192分,提升20分(11.6%)
 
核心优化:API增强 + 混合策略 + 稳定性保障
 

 

一、主要技术变化

 

1.1 技术架构升级

 
对比维度
Baseline
Baseline3
改进效果
主要方法
纯机器学习(TF-IDF + SGD)
API+ML混合策略
语义理解能力大幅提升
API调用
讯飞星火WebSocket API
处理复杂语义、隐喻、反讽
连接方式
-
WebSocket长连接
更稳定,避免频繁握手
容错机制
基础异常处理
完善的降级策略
确保100%处理成功率
进度管理
无断点续跑
每50条自动保存
支持中断恢复,降低风险
 

1.2 核心模块改进(如果表格看不懂再展开)

 

A. 商品识别升级

- Baseline:TF-IDF特征提取 + SGD分类器(max_features=50)
- Baseline3:API语义理解 + 规则兜底
- 效果:能理解"多语种沟通神器"→翻译器的语义关联
 

B. 情感分析革新

- Baseline:纯统计特征,无法处理复杂语义
- Baseline3:混合策略 - 简单情况本地处理,复杂情况API分析
- 效果:节省26.5%的API调用,同时提升复杂情感识别准确率
 

C. 聚类算法智能化

- Baseline:固定K=2聚类
- Baseline3:自适应选择最优聚类数(5-8范围)+ 轮廓系数评估
- 效果:符合赛题要求,聚类质量显著提升
 

 
 

二、WebSocket API技术优势(感兴趣可以展开)

 

2.1 连接稳定性对比

 
特性
HTTP API
WebSocket API
连接方式
短连接,每次请求建立连接
长连接,一次建立持续使用
握手开销
每次请求都有握手开销
仅初始建立时握手一次
网络稳定性
易受网络波动影响
连接保持,更稳定
错误恢复
需要重新请求
可检测连接状态自动重连
 

2.2 安全认证

  • 使用HMAC-SHA256加密签名
  • 时间戳防重放攻击
  • 完整的三要素认证(APPID + APIKey + APISecret)

 

三、工程化保障机制

 

3.1 多层容错策略

1. API调用容错:内容审核错误自动降级到本地处理
2. 网络超时处理:连接失败自动切换到备用方案
3. 内容审核降级:触发内容限制时使用传统ML方法
 

3.2 断点续跑机制

- 问题:处理6000+条评论需数小时,中断风险高
- 解决:每50条自动保存进度,支持从断点继续
- 效果:避免重复计算和资源浪费
 

 

四、性能与成本优化

 

4.1 实际运行数据

- 总处理:6,476条评论
- API调用:3,479次(53.7%)
- 本地处理:1,256次(26.5%)
- 成本节省:46.3%
 

4.2 效率对比

处理阶段
Baseline
Baseline3
提升效果
商品识别
TF-IDF训练+预测
API直接调用
提升准确率
情感分析
批量ML预测
智能混合策略
提升质量
聚类分析
固定参数
自适应优化
自动最优
总体时间
约10分钟
约5小时
质量换时间
 

 

六、改进方向与核心代码

 

6.1 讯飞API调用核心代码(按需取用)

 

A. WebSocket API客户端(最有价值,调用api是一大坑)

import websocket
import ssl
import json
import base64
import hashlib
import hmac
from urllib.parse import urlparse, urlencode
from datetime import datetime
from time import mktime
from wsgiref.handlers import format_date_time

# 配置信息
SPARK_CONFIG = {
    "appid": "your_appid",
    "api_secret": "your_api_secret", 
    "api_key": "your_api_key",
    "domain": "lite",
    "spark_url": "wss://spark-api.xf-yun.com/v1.1/chat"
}

class SparkApiClient:
    def __init__(self, appid, api_key, api_secret, spark_url, domain):
        self.appid = appid
        self.api_key = api_key
        self.api_secret = api_secret
        self.spark_url = spark_url
        self.domain = domain
        self.answer = ""
        
    def create_url(self):
        """生成认证URL"""
        host = urlparse(self.spark_url).netloc
        path = urlparse(self.spark_url).path
        
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
        
        signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
        signature_sha = hmac.new(self.api_secret.encode('utf-8'), 
                                signature_origin.encode('utf-8'),
                                digestmod=hashlib.sha256).digest()
        signature_sha_base64 = base64.b64encode(signature_sha).decode('utf-8')
        
        authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8')
        
        v = {
            "authorization": authorization,
            "date": date,
            "host": host
        }
        
        return self.spark_url + '?' + urlencode(v)
        
    def call_api(self, question):
        """调用星火API"""
        self.answer = ""
        try:
            wsUrl = self.create_url()
            ws = websocket.WebSocketApp(wsUrl, 
                                      on_message=self.on_message, 
                                      on_error=self.on_error, 
                                      on_close=self.on_close, 
                                      on_open=self.on_open)
            ws.question = question
            ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
            return self.answer.strip()
        except Exception as e:
            print(f"API调用失败: {e}")
            return ""
    
    def on_message(self, ws, message):
        """处理消息"""
        try:
            data = json.loads(message)
            code = data['header']['code']
            if code != 0:
                print(f'API错误: {code}')
                ws.close()
                return
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices['text'][0].get('content', '')
            self.answer += content
            if status == 2:
                ws.close()
        except Exception as e:
            print(f"处理消息出错: {e}")
            
    def on_error(self, ws, error):
        print(f"WebSocket错误: {error}")
        
    def on_close(self, ws, close_status_code, close_msg):
        pass
        
    def on_open(self, ws):
        def run():
            data = json.dumps(self.gen_params(ws.question))
            ws.send(data)
        import _thread as thread
        thread.start_new_thread(run, ())
        
    def gen_params(self, question):
        """生成请求参数"""
        return {
            "header": {"app_id": self.appid, "uid": "1234"},
            "parameter": {
                "chat": {
                    "domain": self.domain,
                    "temperature": 0.1,
                    "max_tokens": 1024
                }
            },
            "payload": {
                "message": {
                    "text": [{"role": "user", "content": question}]
                }
            }
        }

# 使用示例
spark_client = SparkApiClient(**SPARK_CONFIG)
result = spark_client.call_api("分析这条评论的情感")
View Code

 

 

B. 混合策略实现

 1 def analyze_sentiment_hybrid(comment_text):
 2     """混合策略情感分析"""
 3     # 1. 本地规则快速判断
 4     local_result = local_sentiment_analysis(comment_text)
 5     
 6     # 2. 有信心则直接使用,否则调用API
 7     if local_result.get('confident', False):
 8         return local_result
 9     else:
10         # API分析
11         prompt = f"""请分析评论的情感和特征,返回JSON格式:
12 {{
13   "sentiment": 数字(1-正面, 2-负面, 3-正负混合, 4-中性, 5-不相关),
14   "scenario": 数字(0或1, 是否提及使用场景),
15   "question": 数字(0或1, 是否包含疑问),
16   "suggestion": 数字(0或1, 是否包含建议)
17 }}
18 
19 评论:{comment_text}"""
20         
21         response = spark_client.call_api(prompt)
22         return parse_api_response(response, local_result)
23 
24 def local_sentiment_analysis(text):
25     """本地规则分析"""
26     positive_words = ['', '', '不错', '满意', '推荐', '喜欢']
27     negative_words = ['', '', '糟糕', '不好', '失望', '垃圾']
28     
29     pos_count = sum(1 for word in positive_words if word in text)
30     neg_count = sum(1 for word in negative_words if word in text)
31     
32     # 判断置信度
33     confident = (pos_count > 0 or neg_count > 0 or 
34                 '' in text or '建议' in text)
35     
36     return {
37         'sentiment_category': 1 if pos_count > neg_count else 2 if neg_count > 0 else 4,
38         'user_scenario': 1 if any(w in text for w in ['出差', '工作', '会议']) else 0,
39         'user_question': 1 if '' in text or '怎么' in text else 0,
40         'user_suggestion': 1 if '建议' in text or '应该' in text else 0,
41         'confident': confident
42     }

 

 

C. 断点续跑与进度备份

 1 import pandas as pd
 2 import os
 3 
 4 def save_progress(df, filename):
 5     """保存进度到本地"""
 6     os.makedirs("progress", exist_ok=True)
 7     filepath = f"progress/{filename}"
 8     df.to_csv(filepath, index=False)
 9     print(f"进度已保存: {filepath}")
10 
11 def load_progress(filename):
12     """加载已保存的进度"""
13     filepath = f"progress/{filename}"
14     if os.path.exists(filepath):
15         print(f"发现进度文件: {filepath}")
16         return pd.read_csv(filepath)
17     return None
18 
19 def process_comments_with_backup():
20     """带备份的评论处理主流程"""
21     BATCH_SIZE = 50  # 每50条保存一次
22     
23     # 1. 尝试加载已有进度
24     progress_data = load_progress("comments_progress.csv")
25     if progress_data is not None:
26         print("继续上次的处理进度...")
27         comments_data = progress_data
28     else:
29         print("开始新的处理...")
30         comments_data = pd.read_csv("origin_comments_data.csv")
31         # 初始化待填充的列
32         for col in ['sentiment_category', 'user_scenario', 'user_question', 'user_suggestion']:
33             if col not in comments_data.columns:
34                 comments_data[col] = None
35     
36     # 2. 批量处理+定期备份
37     api_calls = 0
38     local_calls = 0
39     
40     for i, row in comments_data.iterrows():
41         # 跳过已处理的数据
42         if not pd.isnull(row.get('sentiment_category')):
43             continue
44             
45         comment_text = row['comment_text']
46         
47         # 分析评论
48         try:
49             result = analyze_sentiment_hybrid(comment_text)
50             if result.get('from_api', False):
51                 api_calls += 1
52             else:
53                 local_calls += 1
54             
55             # 更新结果
56             for key, value in result.items():
57                 if key not in ['confident', 'from_api']:
58                     comments_data.at[i, key] = value
59                     
60         except Exception as e:
61             print(f"处理第{i}条评论时出错: {e}")
62             # 使用默认值
63             comments_data.at[i, 'sentiment_category'] = 4
64             comments_data.at[i, 'user_scenario'] = 0
65             comments_data.at[i, 'user_question'] = 0
66             comments_data.at[i, 'user_suggestion'] = 0
67         
68         # 定期保存进度
69         if (i + 1) % BATCH_SIZE == 0:
70             save_progress(comments_data, "comments_progress.csv")
71             print(f"已处理: {i+1}/{len(comments_data)} (API:{api_calls}, 本地:{local_calls})")
72     
73     # 3. 最终保存
74     save_progress(comments_data, "comments_progress.csv")
75     print(f"处理完成! 总计 API:{api_calls}, 本地:{local_calls}")
76     
77     return comments_data
78 
79 # 容错版本的API调用
80 def safe_api_call(prompt, fallback_result):
81     """带容错的API调用"""
82     try:
83         response = spark_client.call_api(prompt)
84         if response and len(response.strip()) > 0:
85             # 尝试解析JSON
86             import re
87             json_match = re.search(r'\{.*\}', response, re.DOTALL)
88             if json_match:
89                 import json
90                 result = json.loads(json_match.group())
91                 result['from_api'] = True
92                 return result
93     except Exception as e:
94         print(f"API调用失败,使用本地结果: {e}")
95     
96     # API失败时返回本地结果
97     fallback_result['from_api'] = False
98     return fallback_result
View Code

 

 

6.2 改进方向与提升潜力

 

短期优化(+5-10分)

- 提示词工程:设计更精确的分析指令,包含上下文信息
- 多模型集成:结合多个API结果进行投票决策
- 错误处理:完善API限制和网络异常的处理机制
 

中期优化(+10-15分)

- 聚类算法:层次聚类 + LDA主题建模的组合方案
- 特征工程:根据商品类型设计专门的特征提取器
- 模型微调:利用星辰MaaS平台训练领域专用模型
 

长期突破(+15-30分)

- 多轮对话:与API进行深度交互,逐步细化分析结果
- 知识图谱:构建商品-场景-情感的知识关系网络
- Agent系统:自主选择最佳分析路径的智能代理
 
posted @ 2025-07-10 08:54  windiest  阅读(18)  评论(0)    收藏  举报