from flask import Flask, request, jsonify
from flask_cors import CORS # 允许跨域请求
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import jieba
import requests
import base64
import re
app = Flask(__name__)
CORS(app) # 启用 CORS
# 初始化假新闻检测模型
tokenizer = AutoTokenizer.from_pretrained("hamzab/roberta-fake-news-classification")
model = AutoModelForSequenceClassification.from_pretrained("hamzab/roberta-fake-news-classification")
# 百度 OCR API 配置
API_KEY = ""
SECRET_KEY = ""
def get_access_token():
"""获取百度 OCR API 的 access_token"""
token_url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={API_KEY}&client_secret={SECRET_KEY}"
response = requests.post(token_url)
token_data = response.json()
return token_data.get("access_token")
@app.route("/ocr", methods=["POST"])
def ocr_image():
"""处理图片文字提取"""
if "image" not in request.files:
return jsonify({"error": "请上传图片"}), 400
image_file = request.files["image"]
image_data = image_file.read()
image_base64 = base64.b64encode(image_data).decode()
access_token = get_access_token()
if not access_token:
return jsonify({"error": "获取 access_token 失败"}), 500
ocr_url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={access_token}"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
payload = {"image": image_base64}
response = requests.post(ocr_url, headers=headers, data=payload)
return response.json()
def predict_fake(title, text):
"""假新闻检测"""
input_str = "<title>" + title + "<content>" + text + "<end>"
input_ids = tokenizer.encode_plus(input_str, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
with torch.no_grad():
output = model(input_ids["input_ids"].to(device), attention_mask=input_ids["attention_mask"].to(device))
return dict(zip(["Fake", "Real"], [x.item() for x in list(torch.nn.Softmax()(output.logits)[0])] ))
@app.route('/predict', methods=['POST'])
def predict():
"""处理假新闻检测"""
data = request.json
title = data['title']
content = data['content']
result = predict_fake(title, content)
return jsonify(result)
# 停用词设置
CHINESE_STOPWORDS = {'的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一个', '都', '说', '也', '上', '要'}
ENGLISH_STOPWORDS = {'the', 'and', 'for', 'to', 'of', 'a', 'an', 'in', 'on', 'at', 'by', 'with', 'about', 'as', 'this', 'that','is','we','all','who','these','have','our','I','been','his','was','are'}
def is_chinese(text):
"""判断是否包含中文"""
return bool(re.search('[\u4e00-\u9fff]', text))
def filter_stopwords(words, lang='zh'):
"""过滤停用词"""
stopwords = CHINESE_STOPWORDS if lang == 'zh' else ENGLISH_STOPWORDS
return [word for word in words if word not in stopwords]
def generate_wordcloud_data(text):
"""生成词频数据"""
words = jieba.cut(text) if is_chinese(text) else text.split()
words = filter_stopwords(words, lang='zh' if is_chinese(text) else 'en')
word_freq = {}
for word in words:
if word.strip():
word_freq[word] = word_freq.get(word, 0) + 1
return dict(sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:30])
@app.route('/generate_wordcloud', methods=['POST'])
def generate_wordcloud():
"""生成词云"""
data = request.json
text = data.get('text', '')
if not text:
return jsonify({"error": "No text provided"}), 400
word_freq = generate_wordcloud_data(text)
return jsonify({"word_freq": word_freq})
if __name__ == '__main__':
app.run(debug=True)