neo4j+deepseek 简单的问答程序

图数据库构建

def init_sample_data():
    """初始化示例医疗数据"""
    try:
        # 清除现有数据
        graph.run("MATCH (n) DETACH DELETE n")

        # 创建疾病节点
        diabetes = Node("Disease", name="糖尿病", symptoms="多饮、多尿、体重下降", treatment="胰岛素控制")
        hypertension = Node("Disease", name="高血压", symptoms="头痛、眩晕、心悸", treatment="降压药")

        # 创建症状节点
        polyuria = Node("Symptom", name="多尿", severity="中度")
        polydipsia = Node("Symptom", name="多饮", severity="中度")
        headache = Node("Symptom", name="头痛", severity="轻度")

        # 创建药物节点
        insulin = Node("Drug", name="胰岛素", dosage="10单位/天", side_effects="低血糖")
        metformin = Node("Drug", name="二甲双胍", dosage="500mg/次", side_effects="胃肠道不适")
        lisinopril = Node("Drug", name="赖诺普利", dosage="10mg/天", side_effects="干咳")

        # 创建患者节点
        patient1 = Node("Patient", id="P001", age=56, gender="男")
        patient2 = Node("Patient", id="P002", age=62, gender="女")

        # 添加到图数据库
        graph.create(diabetes)
        graph.create(hypertension)
        graph.create(polyuria)
        graph.create(polydipsia)
        graph.create(headache)
        graph.create(insulin)
        graph.create(metformin)
        graph.create(lisinopril)
        graph.create(patient1)
        graph.create(patient2)

        # 创建关系
        graph.create(Relationship(diabetes, "HAS_SYMPTOM", polyuria))
        graph.create(Relationship(diabetes, "HAS_SYMPTOM", polydipsia))
        graph.create(Relationship(hypertension, "HAS_SYMPTOM", headache))

        graph.create(Relationship(diabetes, "TREATED_WITH", insulin))
        graph.create(Relationship(diabetes, "TREATED_WITH", metformin))
        graph.create(Relationship(hypertension, "TREATED_WITH", lisinopril))

        graph.create(Relationship(patient1, "DIAGNOSED_WITH", diabetes))
        graph.create(Relationship(patient2, "DIAGNOSED_WITH", hypertension))

        logger.info("示例数据初始化完成")
        return True
    except Exception as e:
        logger.error(f"数据初始化失败: {str(e)}")
        return False


接入deepseek

在deepseek的api调用页面创建自己的api-key,根据其提供的接口文档,尝试调用deepseek

# DeepSeek使用的模型
DEEPSEEK_MODEL = 'deepseek-chat'
client = OpenAI(api_key="sk-644bcce24196461c88107b42f2705580", base_url="https://api.deepseek.com")

def check_deepseek_status():
    """检查DeepSeek API服务状态"""


    # 尝试简单的查询以检查API连通性
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "user", "content": "Hello"}
        ],
        stream=False
    )
    return True


利用deepseek生成cypher查询语言

# 生成Cypher查询语言的提示词模板
CYPHER_GENERATION_PROMPT = """
你是一个医疗知识图谱专家,可以将医学问题转换为Neo4j Cypher查询。
数据库Schema包含以下节点和关系:

节点标签:
- Disease (疾病): 属性包括 name, symptoms, treatment
- Symptom (症状): 属性包括 name, severity
- Drug (药物): 属性包括 name, dosage, side_effects
- Patient (患者): 属性包括 id, age, gender

关系类型:
- HAS_SYMPTOM (疾病 -> 症状)
- TREATED_WITH (疾病 -> 药物)
- DIAGNOSED_WITH (患者 -> 疾病)

转换规则:
1. 当询问症状时,使用 MATCH (d:Disease)-[:HAS_SYMPTOM]->(s:Symptom) return s
2. 当询问治疗药物时,使用 MATCH (d:Disease)-[:TREATED_WITH]->(dr:Drug) return dr
3. 患者查询使用 MATCH (p:Patient)-[:DIAGNOSED_WITH]->(d:Disease)

你需要从用户输入中提取实体,例如糖尿病的症状是什么,你应该生成 MATCH (d:Disease{name:"糖尿病"})-[:HAS_SYMPTOM]->(s:Symptom) return s
请只返回纯Cypher查询语句,不要包含任何解释或额外文本。
"""

def generate_cypher(user_query):
    """使用DeepSeek API生成Cypher查询"""
    full_prompt = f"{CYPHER_GENERATION_PROMPT}\n\n用户问题: {user_query}"

    try:

        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "system", "content": CYPHER_GENERATION_PROMPT},
                {"role": "user", "content": user_query}

            ],
            stream=False
        )

        # 提取纯Cypher语句
        cypher = response.choices[0].message.content.strip()

        # 清理可能的代码块标记
        if "```" in cypher:
            parts = cypher.split("```")
            for part in parts:
                if "match" in part.lower() or "return" in part.lower():
                    cypher = part.replace("cypher", "").replace("Cypher", "").strip()
                    break

        return cypher
    except Exception as e:
        logger.error(f"DeepSeek API调用失败: {str(e)}")
        return None

查询后的结果交给deepseek生成可供用户理解的语言

# 查询结果提示词模板
EXPLANATION_PROMPT = """
你是一个专业的医疗助手,请根据知识图谱查询结果,用自然语言解释给用户。
回答要求:
1. 保持医疗专业性,使用准确医学术语
2. 不要编造任何未出现在结果中的信息
3. 如果结果为空,建议用户调整查询方式
4. 回答简洁明了,不超过150字
"""

def explain_results(user_query, results):
    """使用DeepSeek API解释查询结果"""
    prompt = f"""
    {EXPLANATION_PROMPT}

    用户问题: {user_query}
    查询结果: {str(results)[:1500]}  # 截断避免过长

    请生成解释:
    """

    try:

        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "system", "content": EXPLANATION_PROMPT},
                {"role": "user", "content": f"用户问题: {user_query}\n查询结果: {str(results)[:1500]}\n请生成解释:"}
            ],
            stream=False
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"解释生成失败: {str(e)}"

完整的代码

import os
import openai
from py2neo import Graph, Node, Relationship
from dotenv import load_dotenv
import logging
from tabulate import tabulate
import sys
import time

from openai import OpenAI
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# 初始化Neo4j连接
try:
    graph = Graph("http://localhost:7474", auth=("neo4j", "neo4j"))
    logger.info("成功连接到 Neo4j 数据库")
except Exception as e:
    logger.error(f"Neo4j 连接失败: {str(e)}")
    exit(1)

# 生成Cypher查询语言的提示词模板
CYPHER_GENERATION_PROMPT = """
你是一个医疗知识图谱专家,可以将医学问题转换为Neo4j Cypher查询。
数据库Schema包含以下节点和关系:

节点标签:
- Disease (疾病): 属性包括 name, symptoms, treatment
- Symptom (症状): 属性包括 name, severity
- Drug (药物): 属性包括 name, dosage, side_effects
- Patient (患者): 属性包括 id, age, gender

关系类型:
- HAS_SYMPTOM (疾病 -> 症状)
- TREATED_WITH (疾病 -> 药物)
- DIAGNOSED_WITH (患者 -> 疾病)

转换规则:
1. 当询问症状时,使用 MATCH (d:Disease)-[:HAS_SYMPTOM]->(s:Symptom) return s
2. 当询问治疗药物时,使用 MATCH (d:Disease)-[:TREATED_WITH]->(dr:Drug) return dr
3. 患者查询使用 MATCH (p:Patient)-[:DIAGNOSED_WITH]->(d:Disease)

你需要从用户输入中提取实体,例如糖尿病的症状是什么,你应该生成 MATCH (d:Disease{name:"糖尿病"})-[:HAS_SYMPTOM]->(s:Symptom) return s
请只返回纯Cypher查询语句,不要包含任何解释或额外文本。
"""

# 查询结果提示词模板
EXPLANATION_PROMPT = """
你是一个专业的医疗助手,请根据知识图谱查询结果,用自然语言解释给用户。
回答要求:
1. 保持医疗专业性,使用准确医学术语
2. 不要编造任何未出现在结果中的信息
3. 如果结果为空,建议用户调整查询方式
4. 回答简洁明了,不超过150字
"""

# DeepSeek使用的模型
DEEPSEEK_MODEL = 'deepseek-chat'
client = OpenAI(api_key="sk-644bcce24196461c88107b42f2705580", base_url="https://api.deepseek.com")

def check_deepseek_status():
    """检查DeepSeek API服务状态"""


    # 尝试简单的查询以检查API连通性
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "user", "content": "Hello"}
        ],
        stream=False
    )
    return True


def generate_cypher(user_query):
    """使用DeepSeek API生成Cypher查询"""
    full_prompt = f"{CYPHER_GENERATION_PROMPT}\n\n用户问题: {user_query}"

    try:

        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "system", "content": CYPHER_GENERATION_PROMPT},
                {"role": "user", "content": user_query}

            ],
            stream=False
        )

        # 提取纯Cypher语句
        cypher = response.choices[0].message.content.strip()

        # 清理可能的代码块标记
        if "```" in cypher:
            parts = cypher.split("```")
            for part in parts:
                if "match" in part.lower() or "return" in part.lower():
                    cypher = part.replace("cypher", "").replace("Cypher", "").strip()
                    break

        return cypher
    except Exception as e:
        logger.error(f"DeepSeek API调用失败: {str(e)}")
        return None


def execute_cypher(cypher_query):
    """使用py2neo执行Cypher查询并返回结果"""
    try:
        # 添加安全限制 - 只允许查询操作
        if any(cmd in cypher_query.lower() for cmd in ["create", "delete", "merge", "set", "remove"]):
            return {"error": "只允许执行查询操作,禁止修改数据"}

        # 执行查询
        result = graph.run(cypher_query)
        return [dict(record) for record in result]
    except Exception as e:
        return {"error": str(e)}


def explain_results(user_query, results):
    """使用DeepSeek API解释查询结果"""
    prompt = f"""
    {EXPLANATION_PROMPT}

    用户问题: {user_query}
    查询结果: {str(results)[:1500]}  # 截断避免过长

    请生成解释:
    """

    try:

        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[
                {"role": "system", "content": EXPLANATION_PROMPT},
                {"role": "user", "content": f"用户问题: {user_query}\n查询结果: {str(results)[:1500]}\n请生成解释:"}
            ],
            stream=False
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"解释生成失败: {str(e)}"


def init_sample_data():
    """初始化示例医疗数据"""
    try:
        # 清除现有数据
        graph.run("MATCH (n) DETACH DELETE n")

        # 创建疾病节点
        diabetes = Node("Disease", name="糖尿病", symptoms="多饮、多尿、体重下降", treatment="胰岛素控制")
        hypertension = Node("Disease", name="高血压", symptoms="头痛、眩晕、心悸", treatment="降压药")

        # 创建症状节点
        polyuria = Node("Symptom", name="多尿", severity="中度")
        polydipsia = Node("Symptom", name="多饮", severity="中度")
        headache = Node("Symptom", name="头痛", severity="轻度")

        # 创建药物节点
        insulin = Node("Drug", name="胰岛素", dosage="10单位/天", side_effects="低血糖")
        metformin = Node("Drug", name="二甲双胍", dosage="500mg/次", side_effects="胃肠道不适")
        lisinopril = Node("Drug", name="赖诺普利", dosage="10mg/天", side_effects="干咳")

        # 创建患者节点
        patient1 = Node("Patient", id="P001", age=56, gender="男")
        patient2 = Node("Patient", id="P002", age=62, gender="女")

        # 添加到图数据库
        graph.create(diabetes)
        graph.create(hypertension)
        graph.create(polyuria)
        graph.create(polydipsia)
        graph.create(headache)
        graph.create(insulin)
        graph.create(metformin)
        graph.create(lisinopril)
        graph.create(patient1)
        graph.create(patient2)

        # 创建关系
        graph.create(Relationship(diabetes, "HAS_SYMPTOM", polyuria))
        graph.create(Relationship(diabetes, "HAS_SYMPTOM", polydipsia))
        graph.create(Relationship(hypertension, "HAS_SYMPTOM", headache))

        graph.create(Relationship(diabetes, "TREATED_WITH", insulin))
        graph.create(Relationship(diabetes, "TREATED_WITH", metformin))
        graph.create(Relationship(hypertension, "TREATED_WITH", lisinopril))

        graph.create(Relationship(patient1, "DIAGNOSED_WITH", diabetes))
        graph.create(Relationship(patient2, "DIAGNOSED_WITH", hypertension))

        logger.info("示例数据初始化完成")
        return True
    except Exception as e:
        logger.error(f"数据初始化失败: {str(e)}")
        return False


def print_query_history(history):
    """打印查询历史"""
    if not history:
        print("暂无查询历史")
        return

    print("\n" + "=" * 50)
    print("最近查询记录:")
    for idx, entry in enumerate(history[:3], 1):  # 显示最近3条
        print(f"\n查询 {idx}: {entry['query']}")
        print(f"状态: {'成功' if entry.get('status') == 'completed' else '失败'}")
        if 'cypher' in entry:
            print(f"Cypher 查询: {entry['cypher']}")
        if 'explanation' in entry:
            print(f"解释: {entry['explanation']}")
        if 'error' in entry:
            print(f"错误: {entry['error']}")
    print("=" * 50 + "\n")


def print_table(results):
    """将查询结果格式化为表格输出"""
    if not results:
        return

    # 尝试创建表头
    headers = set()
    for record in results:
        headers.update(record.keys())
    headers = sorted(list(headers))

    # 创建行数据
    rows = []
    for record in results[:20]:  # 最多显示20行
        row = [str(record.get(col, '')) for col in headers]
        rows.append(row)

    # 打印表格
    print(tabulate(rows, headers=headers, tablefmt="grid", maxcolwidths=30))


def clear_screen():
    """清除控制台屏幕"""
    if os.name == 'nt':  # Windows
        os.system('cls')
    else:  # macOS and Linux
        os.system('clear')


def main():
    clear_screen()

    # 打印欢迎信息
    print("=" * 60)
    print("医疗知识图谱问答系统")
    print("=" * 60)

    # 检查DeepSeek状态
    print("\n检查DeepSeek服务状态...")
    deepseek_status = check_deepseek_status()

    if deepseek_status:
        logger.info("DeepSeek服务可用")
    else:
        logger.warning("DeepSeek服务不可用,将无法生成Cypher查询和解释结果")
        logger.info("您仍然可以手动输入Cypher查询或使用内置的示例数据")

    # 初始化示例数据
    print("\n正在初始化示例数据...")
    init_sample_data()
    print("示例数据初始化完成")

    query_history = []

    # 用户交互指令
    print("\n输入医学问题或命令(输入'退出'结束程序)")
    print("示例命令: [初始化数据] [历史] [帮助] [清屏]")
    print("示例问题: '糖尿病有哪些症状?' '胰岛素治疗什么疾病?'")

    while True:
        try:
            # 获取用户输入
            user_input = input("\n> 请输入查询: ").strip()

            if not user_input:
                continue

            # 退出命令
            if user_input.lower() in ['退出', 'exit', 'quit']:
                print("程序结束")
                break


            # 历史命令
            if user_input == '历史':
                print_query_history(query_history)
                continue

            # 初始化数据命令
            if user_input == '初始化数据':
                print("正在初始化示例数据...")
                if init_sample_data():
                    query_history = []
                    print("示例数据初始化完成")
                continue

            # 清屏命令
            if user_input == '清屏':
                clear_screen()
                continue

            # 手动Cypher查询
            if user_input.lower().startswith(("match ", "return ")):
                print("\n直接执行Cypher查询...")
                results = execute_cypher(user_input)

                if "error" in results:
                    print(f"错误: {results['error']}")
                elif not results:
                    print("查询成功但没有返回结果")
                else:
                    print(f"返回了 {len(results)} 条结果")
                    print_table(results)

                query_history.insert(0, {
                    "query": user_input,
                    "cypher": "手动输入",
                    "results": results if not isinstance(results, dict) else None,
                    "status": "completed",
                    "error": results.get("error") if isinstance(results, dict) else None
                })
                continue

            # 添加到查询历史
            current_query = {"query": user_input, "status": "processing"}
            query_history.insert(0, current_query)

            if not deepseek_status:
                print("DeepSeek服务不可用,无法处理自然语言查询")
                print("请输入Cypher查询(以'match'或'return'开头)")
                current_query["error"] = "DeepSeek服务不可用"
                current_query["status"] = "failed"
                continue

            # 步骤1: 生成Cypher
            print("\n[步骤1] 使用DeepSeek模型生成Cypher查询...")
            cypher = generate_cypher(user_input)
            if not cypher:
                print("错误: 未能生成有效的Cypher查询")
                current_query["status"] = "failed"
                current_query["error"] = "无法生成Cypher"
                continue

            print(f"生成的Cypher查询: {cypher}")
            current_query["cypher"] = cypher

            # 步骤2: 执行查询
            print("\n[步骤2] 执行Neo4j查询...")
            results = execute_cypher(cypher)

            if "error" in results:
                print(f"错误: {results['error']}")
                current_query["status"] = "failed"
                current_query["error"] = results["error"]
                continue
            elif not results:
                print("查询成功但没有返回结果")
            else:
                print(f"返回了 {len(results)} 条结果")

            current_query["results"] = results

            # 步骤3: 解释结果
            if results:
                print("\n[步骤3] 使用DeepSeek模型生成解释...")
                explanation = explain_results(user_input, results)
                print("\n知识图谱回答:")
                print("-" * 60)
                print(explanation)
                print("-" * 60)

                # 打印表格结果
                print("\n查询结果表格:")
                print_table(results)

                current_query["explanation"] = explanation

            current_query["status"] = "completed"

            # 保留最近10条历史记录
            if len(query_history) > 10:
                query_history = query_history[:10]

        except KeyboardInterrupt:
            print("\n操作已取消,可以输入新查询")
            continue
        except Exception as e:
            logger.error(f"处理过程中出错: {str(e)}")
            print("错误发生,请重试或检查日志")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n程序已终止")
        sys.exit(0)
posted @ 2025-06-06 15:53  狐狸胡兔  阅读(178)  评论(0)    收藏  举报