Ray Serve for langchain app
Ray Serve
https://python.langchain.com/docs/integrations/providers/ray_serve/#example-of-deploying-and-openai-chain-with-custom-prompts
Ray Serve
Ray Serve is a scalable model serving library for building online inference APIs. Serve is particularly well suited for system composition, enabling you to build a complex inference service consisting of multiple chains and business logic all in Python code.
Goal of this notebook
This notebook shows a simple example of how to deploy an OpenAI chain into production. You can extend it to deploy your own self-hosted models where you can easily define amount of hardware resources (GPUs and CPUs) needed to run your model in production efficiently. Read more about available options including autoscaling in the Ray Serve documentation.
Example of deploying and OpenAI chain with custom prompts
Get an OpenAI API key from here. By running the following code, you will be asked to provide your API key.
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIfrom getpass import getpass
OPENAI_API_KEY = getpass()@serve.deployment
class DeployLLM:
def __init__(self):
# We initialize the LLM, template and the chain here
llm = OpenAI(openai_api_key=OPENAI_API_KEY)
template = "Question: {question}\n\nAnswer: Let's think step by step."
prompt = PromptTemplate.from_template(template)
self.chain = LLMChain(llm=llm, prompt=prompt)
def _run_chain(self, text: str):
return self.chain(text)
async def __call__(self, request: Request):
# 1. Parse the request
text = request.query_params["text"]
# 2. Run the chain
resp = self._run_chain(text)
# 3. Return the response
return resp["text"]Now we can bind the deployment.
# Bind the model to deployment
deployment = DeployLLM.bind()We can assign the port number and host when we want to run the deployment.
# Example port number
PORT_NUMBER = 8282
# Run the deployment
serve.api.run(deployment, port=PORT_NUMBER)Now that service is deployed on port
localhost:8282we can send a post request to get the results back.import requests
text = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
response = requests.post(f"http://localhost:{PORT_NUMBER}/?text={text}")
print(response.content.decode())
https://github.com/ionet-official/io-ray-serve-chat-demo/blob/main/chat.py
import torch from typing import Dict from ray.serve import Application from ray import serve from starlette.requests import Request from transformers import AutoTokenizer, AutoModelForSeq2SeqLM @serve.deployment(ray_actor_options={"num_gpus": 0.5}) class Chat: def __init__(self, model: str): # configure stateful elements of our service such as loading a model self._tokenizer = AutoTokenizer.from_pretrained(model) print(f"Loading model: {model}") self._model = AutoModelForSeq2SeqLM.from_pretrained(model, torch_dtype=torch.float16).to(0) self._max_length = self._model.config.max_position_embeddings print(f"Model loaded. Max length for the model: {self._max_length}") async def __call__(self, request: Request) -> Dict: # path to handle HTTP requests data = await request.json() # after decoding the payload, we delegate to get_response for logic return {"response": self.get_response(data["user_input"], data["history"])} def get_response(self, user_input: str, history: list[str]) -> str: # this method receives calls directly (from Python) or from __call__ (from HTTP) # the history is client-side state and will be a list of raw strings # older pair of messages from history is used as long as it fits the model's max length # Trim the history until total input fits max_lenght while True: input_text = "\n".join(history + [user_input]) if len(history) == 0 or len(self._tokenizer.encode(input_text)) <= self._max_length: break history = history[2:] inputs = self._tokenizer([input_text], max_length=self._max_length, truncation=True, return_tensors="pt").to(0) reply_ids = self._model.generate(**inputs, max_length=self._max_length) response = self._tokenizer.batch_decode(reply_ids, skip_special_tokens=True) return response[0].strip() def app_builder(args: Dict[str, str]) -> Application: return Chat.bind(model="facebook/blenderbot-400M-distill") app = app_builder(None)
# This file was generated using the `serve build` command on Ray v2.9.3.
proxy_location: EveryNode
http_options:
host: 0.0.0.0
port: 8778
grpc_options:
port: 9000
grpc_servicer_functions: []
logging_config:
encoding: TEXT
log_level: INFO
logs_dir: null
enable_access_log: true
applications:
- name: app1
route_prefix: /
import_path: chat:app
runtime_env: {}
deployments:
- name: Chat
ray_actor_options:
num_cpus: 1.0
num_gpus: 0.5

浙公网安备 33010602011771号