使用rag解决llm问题

思路:

新建conda环境

# 新建3.11环境
conda create --name rag_geo python=3.11
# 激活conda环境
conda activate rag_geo
# 安装依赖
pip install -r requirements.txt
# 依赖包
aim==3.19.3
aim-ui==3.19.3
aimrecords==0.0.7
aimrocks==0.4.0
aiofiles==23.2.1
aiohttp==3.9.5
aiosignal==1.3.1
alembic==1.13.1
annotated-types==0.6.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
async-timeout==4.0.3
asyncpg==0.29.0
attrs==23.2.0
azure-core==1.30.1
azure-storage-blob==12.19.1
base58==2.0.1
BCEmbedding==0.1.4
beautifulsoup4==4.12.3
blinker==1.7.0
boto3==1.34.89
botocore==1.34.89
cachetools==5.3.3
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
coloredlogs==15.0.1
cryptography==42.0.5
dashscope==1.17.1
dataclasses-json==0.6.4
datasets==2.19.0
Deprecated==1.2.14
dill==0.3.8
dirtyjson==1.0.8
distro==1.9.0
docx2txt==0.8
edge-tts==6.1.11
environs==9.5.0
et-xmlfile==1.1.0
fastapi==0.110.2
filelock==3.13.4
Flask==3.0.3
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.3.1
greenlet==3.0.3
grpcio==1.60.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.22.2
humanfriendly==10.0
idna==3.7
InstructorEmbedding==1.0.1
isodate==0.6.1
itsdangerous==2.2.0
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.4.0
jsonpatch==1.33
jsonpointer==2.4
jwt==1.3.1
langchain==0.2.1
langchain-core==0.2.3
langchain-milvus==0.1.0
langchain-text-splitters==0.2.0
langsmith==0.1.67
lark==1.1.9
llama-index==0.10.40
llama-index-agent-openai==0.2.2
llama-index-callbacks-aim==0.1.2
llama-index-cli==0.1.12
llama-index-core==0.10.40
llama-index-embeddings-huggingface==0.2.0
llama-index-embeddings-instructor==0.1.3
llama-index-embeddings-openai==0.1.8
llama-index-indices-managed-llama-cloud==0.1.5
llama-index-legacy==0.9.48
llama-index-llms-openai==0.1.16
llama-index-multi-modal-llms-openai==0.1.5
llama-index-program-openai==0.1.5
llama-index-question-gen-openai==0.1.3
llama-index-readers-file==0.1.19
llama-index-readers-llama-parse==0.1.4
llama-index-vector-stores-milvus==0.1.10
llama-index-vector-stores-postgres==0.1.5
llama-parse==0.4.1
llamaindex-py-client==0.1.18
Mako==1.3.3
MarkupSafe==2.1.5
marshmallow==3.21.1
milvus-lite==2.4.6
minijinja==1.0.16
minio==7.2.5
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
mysql-connector==2.2.9
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
onnx==1.16.0
onnxruntime==1.16.3
openai==1.23.2
openpyxl==3.1.2
optimum==1.19.0
orjson==3.10.3
packaging==23.2
pandas==2.2.2
pgvector==0.2.5
pillow==10.3.0
protobuf==5.26.1
psutil==5.9.8
psycopg2-binary==2.9.9
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycparser==2.22
pycryptodome==3.20.0
pydantic==2.7.0
pydantic_core==2.18.1
pydub==0.25.1
PyJWT==2.8.0
pymilvus==2.4.3
pypdf==4.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2024.1
PyYAML==6.0.1
regex==2024.4.16
requests==2.31.0
RestrictedPython==7.1
s3transfer==0.10.1
safetensors==0.4.3
scikit-learn==1.4.2
scipy==1.13.0
sentence-transformers==2.7.0
sentencepiece==0.2.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
SQLAlchemy==2.0.29
starlette==0.37.2
striprtf==0.0.26
sympy==1.12
tenacity==8.2.3
threadpoolctl==3.4.0
tiktoken==0.6.0
timm==0.9.16
tokenizers==0.15.2
torch==2.2.2
torchvision==0.17.2
tqdm==4.66.2
transformers==4.36.2
triton==2.2.0
typing-inspect==0.9.0
typing_extensions==4.11.0
tzdata==2024.1
ujson==5.9.0
urllib3==2.2.1
uvicorn==0.29.0
websockets==12.0
Werkzeug==3.0.2
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
zhipuai==2.0.1.20240429

由于版本问题,需要替换文件

/mnt/vdc/anaconda3/envs/gugong_rag/lib/python3.11/site-packages/BCEmbedding/tools/llama_index/bce_rerank.py

'''
@Description: 
@Author: shenlei
@Date: 2024-01-15 14:15:30
@LastEditTime: 2024-03-04 16:00:52
@LastEditors: shenlei
'''
from typing import Any, List, Optional

from pydantic.v1 import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.utils import infer_torch_device


class BCERerank(BaseNodePostprocessor):
    model: str = Field(ddescription="Sentence transformer model name.")
    top_n: int = Field(description="Number of nodes to return sorted by score.")
    _model: Any = PrivateAttr()

    def __init__(
        self,
        top_n: int = 5,
        model: str = "maidalun1020/bce-reranker-base_v1",
        device: Optional[str] = None,
        **kwargs
    ):
        try:
            from BCEmbedding.models import RerankerModel
        except ImportError:
            raise ImportError(
                "Cannot import `BCEmbedding` package,",
                "please `pip install BCEmbedding>=0.1.2`",
            )
        self._model = RerankerModel(model_name_or_path=model, device=device, **kwargs)
        device = infer_torch_device() if device is None else device
        super().__init__(top_n=top_n, model=model, device=device)

    @classmethod
    def class_name(cls) -> str:
        return "BCERerank"

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []
        
        query = query_bundle.query_str
        passages = []
        valid_nodes = []
        invalid_nodes = []
        for node in nodes:
            passage = node.node.get_content(metadata_mode=MetadataMode.EMBED)
            if isinstance(passage, str) and len(passage) > 0:
                passages.append(passage.replace('\n', ' '))
                valid_nodes.append(node)
            else:
                invalid_nodes.append(node)

        with self.callback_manager.event(
                CBEventType.RERANKING,
                payload={
                    EventPayload.NODES: nodes,
                    EventPayload.MODEL_NAME: self.model,
                    EventPayload.QUERY_STR: query_bundle.query_str,
                    EventPayload.TOP_K: self.top_n,
                },
            ) as event:

            rerank_result = self._model.rerank(query, passages)
            new_nodes = []
            for score, nid in zip(rerank_result['rerank_scores'], rerank_result['rerank_ids']):
                node = valid_nodes[nid]
                node.score = score
                new_nodes.append(node)
            for node in invalid_nodes:
                node.score = 0
                new_nodes.append(node)

            assert len(new_nodes) == len(nodes)

            new_nodes = new_nodes[:self.top_n]
            event.on_end(payload={EventPayload.NODES: new_nodes})

        return new_nodes

部署glm4-9b

魔搭社区下载地址

# 新建一个conda环境
conda create --name glm4_9b python=3.11
# 激活conda环境
conda activate glm4_9b
# 安装modelscope 
pip install modelscope 
# 下载模型,指定下载路径
python
>>> from modelscope import snapshot_download
>>> model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat',cache_dir='/mnt/vdc/rm/models/')

# 下载快速部署接口
https://github.com/lm-sys/FastChat
# 进入项目目录
cd FastChat-main/
# 安装依赖
pip install .
# 启动控制器
nohup python3 -m fastchat.serve.controller > log/log_fastchat.serve.controller_2024.5.23.log 2>&1 &

# 启动模型工作线程 vllm 加快推理速度
pip install vllm
pip install accelerate


nohup python3 -m fastchat.serve.vllm_worker  --model-path /mnt/vdc/rm/models/ZhipuAI/glm-4-9b-chat --tensor-parallel-size 4 --trust-remote-code --gpu_memory_utilization 0.9 --dtype float16 --model-names glm-4-9b-chat --num-gpus 4 --max-model-len 4096 >log/log_fastchat.serve.vllm_worker_2024.5.23.log 2>&1 &

# openai服务启动
nohup python3 -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8007 >log/log_fastchat.openai_api_server_2024.5.23.log 2>&1 &

参数详情

nohup python3 -m fastchat.serve.controller > log/log_fastchat.serve.controller_2024.5.23.log 2>&1 &
  • --host参数指定应用程序绑定的主机名或IP地址。默认情况下,应用程序将绑定在本地回环地址(即localhost或127.0.0.1)上。
  • --port参数指定应用程序监听的端口号。默认情况下,应用程序将监听21001端口。
  • --dispatch-method参数指定请求调度算法。lottery表示抽奖式随机分配请求,shortest_queue表示将请求分配给队列最短的服务器。默认情况下,使用抽奖式随机分配请求。
  • --ssl参数指示应用程序是否使用SSL加密协议。如果指定了此参数,则应用程序将使用HTTPS协议。否则,应用程序将使用HTTP协议。
nohup python3 -m fastchat.serve.vllm_worker  --model-path /mnt/vdc/rm/models/ZhipuAI/glm-4-9b-chat --tensor-parallel-size 4 --trust-remote-code --gpu_memory_utilization 0.9 --dtype float16 --model-names glm-4-9b-chat --num-gpus 4 --max-model-len 4096 >log/log_fastchat.serve.vllm_worker_2024.5.23.log 2>&1 &
  • --host HOST:指定该工作节点的主机名或 IP 地址,默认为 localhost。
  • --port PORT:指定该工作节点监听的端口号,默认为 21002。
  • --worker-address WORKER_ADDRESS:指定该工作节点的地址。如果未指定,则自动从网络配置中获取。
  • --controller-address CONTROLLER_ADDRESS:指定控制节点的地址。如果未指定,则自动从环境变量中获取。如果环境变量也未设置,则默认使用 http://localhost:8001
  • --model-path MODEL_PATH:指定模型文件的路径。如果未指定,则默认使用 models/model.ckpt。
  • --model-names MODEL_NAMES:指定要加载的模型名称。该参数只在多模型情况下才需要使用。
  • --limit-worker-concurrency LIMIT_WORKER_CONCURRENCY:指定最大并发工作进程数。默认为 None,表示不限制。
  • --no-register:禁止在控制节点上注册该工作节点。
  • --num-gpus NUM_GPUS:指定使用的 GPU 数量。默认为 1。
  • --conv-template CONV_TEMPLATE:指定对话生成的模板文件路径。如果未指定,则默认使用 conversation_template.json。
  • --trust_remote_code:启用远程代码信任模式。
  • --gpu_memory_utilization GPU_MEMORY_UTILIZATION:指定 GPU 内存使用率,范围为 [0,1]。默认为 1.0,表示占用全部 GPU 内存。
  • --model MODEL:指定要加载的模型类型。默认为 fastchat.serve.vllm_worker.VLLMModel。
  • --tokenizer TOKENIZER:指定要使用的分词器类型。默认为 huggingface。
  • --revision REVISION:指定加载的模型版本号。默认为 None,表示加载最新版本。
  • --tokenizer-revision TOKENIZER_REVISION:指定加载的分词器版本号。默认为 None,表示加载最新版本。
  • --tokenizer-mode {auto,slow}:指定分词器模式。默认为 auto,表示自动选择最佳模式。
  • --download-dir DOWNLOAD_DIR:指定模型下载目录。默认为 downloads/。
  • --load-format {auto,pt,safetensors,npcache,dummy}:指定模型加载格式。默认为 auto,表示自动选择最佳格式。
  • --dtype {auto,half,float16,bfloat16,float,float32}:指定模型数据类型。默认为 auto,表示自动选择最佳类型。
  • --max-model-len MAX_MODEL_LEN:指定模型的最大长度。默认为 None,表示不限制。
  • --worker-use-ray:启用 Ray 分布式训练模式。
  • --pipeline-parallel-size PIPELINE_PARALLEL_SIZE:指定管道并行的大小。默认为 None,表示不使用管道并行。
  • --tensor-parallel-size TENSOR_PARALLEL_SIZE:指定张量并行的大小。默认为 None,表示不使用张量并行。
  • --max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS:指定最大并发加载工作数。默认为 4。
  • --block-size {8,16,32}:指定块大小。默认为 16。
  • --seed SEED:指定随机种子。默认为 None。
  • --swap-space SWAP_SPACE:指定交换空间的大小。默认为 4GB。
  • --max-num-batched-tokens MAX_NUM_BATCHED_TOKENS:指定每个批次的最大令牌数。默认为 2048。
  • --max-num-seqs MAX_NUM_SEQS:指定每个批次的最大序列数。默认为 64。
  • --max-paddings MAX_PADDINGS:指定每个批次的最大填充数。默认为 1024。
  • --disable-log-stats:禁止记录统计信息。
  • --quantization {awq,gptq,squeezellm,None}:指定模型量化类型。默认为 None,表示不进行量化。
  • --enforce-eager:强制启用 Eager Execution 模式。
  • --max-context-len-to-capture MAX_CONTEXT_LEN_TO_CAPTURE:指定要捕获的上下文长度。默认为 1024。
  • --engine-use-ray:在引擎中启用 Ray 分布式训练模式。
  • --disable-log-requests:禁止记录请求信息。
  • --max-log-len MAX_LOG_LEN:指定最大日志长度。默认为 10240。

启动测试

import requests
import json

def call_api(content):
    url = "http://127.0.0.1:8007/v1/chat/completions"
    # API请求的参数
    payload = {
        "model": "glm-4-9b-chat",
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ],
        "stream": False,
        "max_tokens": 100,
        "temperature": 0.8,
        "top_p": 0.8
    }
    # 发送POST请求
    response = requests.post(url, json=payload)
    # 检查请求是否成功
    if response.status_code == 200:
        # 解析并打印响应内容
        response_data = response.json()
        data = json.loads(json.dumps(response_data, indent=4))
        content = data['choices'][0]['message']['content']
        return content
    else:
        print(f"请求失败,状态码:{response.status_code}")
        return "请求失败"

if __name__ == "__main__":
    question = "你是哪个模型?"
    res =call_api(question)
    print(res)
posted @ 2024-06-26 09:43  初夏那片海  阅读(103)  评论(0)    收藏  举报