from flask import Flask, request, jsonify
from flask_cors import CORS  # 允许跨域请求
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import jieba
import requests
import base64
import re
app = Flask(__name__)
CORS(app)  # 启用 CORS
# 初始化假新闻检测模型
tokenizer = AutoTokenizer.from_pretrained("hamzab/roberta-fake-news-classification")
model = AutoModelForSequenceClassification.from_pretrained("hamzab/roberta-fake-news-classification")
# 百度 OCR API 配置
API_KEY = ""
SECRET_KEY = ""
def get_access_token():
    """获取百度 OCR API 的 access_token"""
    token_url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={API_KEY}&client_secret={SECRET_KEY}"
    response = requests.post(token_url)
    token_data = response.json()
    return token_data.get("access_token")
@app.route("/ocr", methods=["POST"])
def ocr_image():
    """处理图片文字提取"""
    if "image" not in request.files:
        return jsonify({"error": "请上传图片"}), 400
    image_file = request.files["image"]
    image_data = image_file.read()
    image_base64 = base64.b64encode(image_data).decode()
    access_token = get_access_token()
    if not access_token:
        return jsonify({"error": "获取 access_token 失败"}), 500
    ocr_url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={access_token}"
    headers = {"Content-Type": "application/x-www-form-urlencoded"}
    payload = {"image": image_base64}
    response = requests.post(ocr_url, headers=headers, data=payload)
    return response.json()
def predict_fake(title, text):
    """假新闻检测"""
    input_str = "<title>" + title + "<content>" + text + "<end>"
    input_ids = tokenizer.encode_plus(input_str, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    with torch.no_grad():
        output = model(input_ids["input_ids"].to(device), attention_mask=input_ids["attention_mask"].to(device))
    return dict(zip(["Fake", "Real"], [x.item() for x in list(torch.nn.Softmax()(output.logits)[0])] ))
@app.route('/predict', methods=['POST'])
def predict():
    """处理假新闻检测"""
    data = request.json
    title = data['title']
    content = data['content']
    result = predict_fake(title, content)
    return jsonify(result)
# 停用词设置
CHINESE_STOPWORDS = {'的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一个', '都', '说', '也', '上', '要'}
ENGLISH_STOPWORDS = {'the', 'and', 'for', 'to', 'of', 'a', 'an', 'in', 'on', 'at', 'by', 'with', 'about', 'as', 'this', 'that','is','we','all','who','these','have','our','I','been','his','was','are'}
def is_chinese(text):
    """判断是否包含中文"""
    return bool(re.search('[\u4e00-\u9fff]', text))
def filter_stopwords(words, lang='zh'):
    """过滤停用词"""
    stopwords = CHINESE_STOPWORDS if lang == 'zh' else ENGLISH_STOPWORDS
    return [word for word in words if word not in stopwords]
def generate_wordcloud_data(text):
    """生成词频数据"""
    words = jieba.cut(text) if is_chinese(text) else text.split()
    words = filter_stopwords(words, lang='zh' if is_chinese(text) else 'en')
    word_freq = {}
    for word in words:
        if word.strip():
            word_freq[word] = word_freq.get(word, 0) + 1
    return dict(sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:30])
@app.route('/generate_wordcloud', methods=['POST'])
def generate_wordcloud():
    """生成词云"""
    data = request.json
    text = data.get('text', '')
    if not text:
        return jsonify({"error": "No text provided"}), 400
    word_freq = generate_wordcloud_data(text)
    return jsonify({"word_freq": word_freq})
if __name__ == '__main__':
    app.run(debug=True)