Java网络编程:从零基础到实战应用
引言
网络编程确实是许多Java开发者面临的挑战。Socket、TCP/UDP、客户端服务器架构等概念常常让初学者感到困惑。然而,网络编程的核心原理其实很直观:实现两台计算机之间的通信。
这就像打电话一样——你需要对方的号码(IP地址),选择分机(端口),然后建立连接进行交流。本文将通过7个渐进式阶段,帮助系统掌握Java网络编程的核心概念和实践技巧。
一、理论基础
网络通信的本质
网络通信就是两台计算机之间的数据交换。就像两个人打电话一样,需要:
(1) 电话号码(IP地址):标识网络中的一台计算机
(2) 分机号(端口号):标识该计算机上的一个应用程序
// 网络通信的基本要素
String ip = "192.168.1.100"; // 目标计算机的地址
int port = 8080; // 目标应用程序的端口
通信模式:TCP vs UDP
网络通信有两种基本模式:
TCP模式(打电话模式):需要先建立连接,通信可靠但相对较慢
UDP模式(发短信模式):直接发送数据,快速但不保证到达
在本文中,我们主要关注TCP模式,因为它在实际应用中更常见。
Socket:网络编程的"插座"
// Socket是网络编程的核心概念。就像电器需要插座才能工作,程序需要Socket才能进行网络通信。
// 创建Socket连接
Socket socket = new Socket("192.168.1.100", 8080); // 含义:我要和IP为192.168.1.100的计算机上运行在8080端口的程序建立连接
// 获取输入输出流
InputStream input = socket.getInputStream(); // 接收数据的管道
OutputStream output = socket.getOutputStream(); // 发送数据的管道
客户端-服务器模式
这是网络编程中最常见的模式。可以用餐厅点餐来类比:
服务器 = 餐厅:固定位置,等待顾客上门
客户端 = 顾客:主动去餐厅消费
// 服务器端:开餐厅,等待顾客
ServerSocket restaurant = new ServerSocket(8080); // 开餐厅,地址是8080
Socket customer = restaurant.accept(); // 等待顾客上门
// 客户端:去餐厅消费
Socket customer = new Socket("restaurant-ip", 8080); // 去餐厅
二、第一个网络程序
基于前置理论基础,我们编写一个真正的网络程序,实现最基本的单次通信:客户端发送一条消息,服务器接收并回复。
import java.io.*;
import java.net.*;
// 服务器端代码
public class SimpleServer {
public static void main(String[] args) throws IOException {
System.out.println("服务器启动,等待客户端连接...");
// 1. 创建服务器Socket,监听8080端口
ServerSocket serverSocket = new ServerSocket(8080);
// 2. 等待客户端连接
Socket clientSocket = serverSocket.accept();
System.out.println("客户端已连接!");
// 3. 获取输入流,读取客户端发送的数据
BufferedReader reader = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream())
);
String message = reader.readLine();
System.out.println("收到消息:" + message);
// 4. 获取输出流,向客户端发送回复
PrintWriter writer = new PrintWriter(clientSocket.getOutputStream(), true);
writer.println("服务器收到了:" + message);
// 5. 关闭连接
clientSocket.close();
serverSocket.close();
System.out.println("服务器关闭");
}
}
import java.io.*;
import java.net.*;
// 客户端代码
public class SimpleClient {
public static void main(String[] args) throws IOException {
System.out.println("客户端启动,连接服务器...");
// 1. 连接到服务器
Socket socket = new Socket("localhost", 8080);
System.out.println("已连接到服务器!");
// 2. 发送消息
PrintWriter writer = new PrintWriter(socket.getOutputStream(), true);
writer.println("Hello Server!");
// 3. 接收回复
BufferedReader reader = new BufferedReader(
new InputStreamReader(socket.getInputStream())
);
String response = reader.readLine();
System.out.println("服务器回复:" + response);
// 4. 关闭连接
socket.close();
System.out.println("客户端关闭");
}
}
三、持续通信
单次通信虽然简单,但实际应用中我们往往需要多次交互。升级程序,实现持续的双向通信。
支持多轮对话的服务器
import java.io.*;
import java.net.*;
public class LoopServer {
public static void main(String[] args) throws IOException {
ServerSocket serverSocket = new ServerSocket(8080);
System.out.println("服务器启动,等待客户端连接...");
Socket clientSocket = serverSocket.accept();
System.out.println("客户端已连接!");
BufferedReader reader = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream())
);
PrintWriter writer = new PrintWriter(clientSocket.getOutputStream(), true);
String message;
// 循环处理客户端消息
while ((message = reader.readLine()) != null) {
System.out.println("收到消息:" + message);
// 如果客户端发送"bye",结束通信
if ("bye".equalsIgnoreCase(message)) {
writer.println("再见!");
break;
}
writer.println("回复:" + message);
}
clientSocket.close();
serverSocket.close();
System.out.println("服务器关闭");
}
}
通信协议设计
在实际应用中,我们需要定义通信协议——双方约定的"对话规则"。下面是一个简单的协议示例:
public class ProtocolExample {
// 协议定义:客户端发送格式为 "命令:参数"
// 服务器根据命令执行不同操作
public static void handleMessage(String message, PrintWriter writer) {
String[] parts = message.split(":");
if (parts.length < 1) return;
String command = parts[0];
switch (command) {
case "TIME":
writer.println("当前时间:" + new java.util.Date());
break;
case "ECHO":
if (parts.length > 1) {
writer.println("回声:" + parts[1]);
} else {
writer.println("ECHO命令需要参数");
}
break;
case "CALC":
if (parts.length > 1) {
try {
String result = evaluateExpression(parts[1]);
writer.println("计算结果:" + result);
} catch (Exception e) {
writer.println("计算错误:" + e.getMessage());
}
} else {
writer.println("CALC命令需要表达式");
}
break;
default:
writer.println("未知命令:" + command);
}
}
private static String evaluateExpression(String expr) {
// 简单的加法计算实现
if (expr.contains("+")) {
String[] nums = expr.split("\\+");
int result = Integer.parseInt(nums[0].trim()) + Integer.parseInt(nums[1].trim());
return String.valueOf(result);
}
throw new IllegalArgumentException("仅支持加法运算");
}
}
四、并发处理多客户端
多线程服务器架构
import java.io.*;
import java.net.*;
public class MultiClientServer {
public static void main(String[] args) throws IOException {
ServerSocket serverSocket = new ServerSocket(8080);
System.out.println("服务器启动,可以接受多个客户端连接");
// 循环接受客户端连接
while (true) {
Socket clientSocket = serverSocket.accept();
System.out.println("新客户端连接:" + clientSocket.getInetAddress());
// 为每个客户端创建一个线程
Thread clientThread = new Thread(new ClientHandler(clientSocket));
clientThread.start();
}
}
}
// 客户端处理器
class ClientHandler implements Runnable {
private Socket clientSocket;
public ClientHandler(Socket socket) {
this.clientSocket = socket;
}
@Override
public void run() {
try {
System.out.println("开始处理客户端:" + clientSocket.getInetAddress());
BufferedReader reader = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream())
);
PrintWriter writer = new PrintWriter(clientSocket.getOutputStream(), true);
String message;
while ((message = reader.readLine()) != null) {
System.out.println("客户端 " + clientSocket.getInetAddress() + " 说:" + message);
if ("bye".equalsIgnoreCase(message)) {
writer.println("再见!");
break;
}
writer.println("服务器回复:" + message);
}
} catch (IOException e) {
System.err.println("处理客户端时出错:" + e.getMessage());
} finally {
try {
clientSocket.close();
System.out.println("客户端 " + clientSocket.getInetAddress() + " 断开连接");
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
五、增强程序健壮性
生产环境中的网络程序需要处理各种异常情况,包括网络中断、客户端异常退出、超时等
// 健壮的服务器实现
import java.io.*;
import java.net.*;
import java.util.concurrent.*;
import java.util.logging.*;
public class RobustServer {
private static final Logger logger = Logger.getLogger(RobustServer.class.getName());
private ServerSocket serverSocket;
private ExecutorService executor;
private boolean isRunning = false;
public RobustServer() {
// 配置日志
ConsoleHandler handler = new ConsoleHandler();
handler.setLevel(Level.ALL);
logger.addHandler(handler);
logger.setLevel(Level.ALL);
// 创建线程池
executor = Executors.newFixedThreadPool(10);
}
public void start(int port) throws IOException {
serverSocket = new ServerSocket(port);
isRunning = true;
logger.info("服务器启动在端口:" + port);
while (isRunning) {
try {
Socket clientSocket = serverSocket.accept();
logger.info("新客户端连接:" + clientSocket.getInetAddress());
executor.submit(new RobustClientHandler(clientSocket));
} catch (IOException e) {
if (isRunning) {
logger.severe("接受客户端连接时出错:" + e.getMessage());
}
}
}
}
public void stop() throws IOException {
isRunning = false;
if (serverSocket != null) {
serverSocket.close();
}
if (executor != null) {
executor.shutdown();
}
logger.info("服务器已停止");
}
}
class RobustClientHandler implements Runnable {
private static final Logger logger = Logger.getLogger(RobustClientHandler.class.getName());
private Socket clientSocket;
public RobustClientHandler(Socket socket) {
this.clientSocket = socket;
}
@Override
public void run() {
String clientAddress = clientSocket.getInetAddress().toString();
logger.info("开始处理客户端:" + clientAddress);
BufferedReader reader = null;
PrintWriter writer = null;
try {
// 设置超时时间
clientSocket.setSoTimeout(30000); // 30秒超时
reader = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream())
);
writer = new PrintWriter(clientSocket.getOutputStream(), true);
String message;
while ((message = reader.readLine()) != null) {
logger.info("收到消息从 " + clientAddress + ":" + message);
if ("bye".equalsIgnoreCase(message)) {
writer.println("再见!");
break;
}
// 处理消息
processMessage(message, writer);
}
} catch (SocketTimeoutException e) {
logger.warning("客户端超时:" + clientAddress);
} catch (IOException e) {
logger.warning("客户端连接异常 " + clientAddress + ":" + e.getMessage());
} finally {
// 确保资源被正确关闭
closeResources(reader, writer, clientSocket);
logger.info("客户端连接关闭:" + clientAddress);
}
}
private void processMessage(String message, PrintWriter writer) {
try {
// 模拟处理时间
Thread.sleep(100);
writer.println("处理完成:" + message);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
private void closeResources(BufferedReader reader, PrintWriter writer, Socket socket) {
try {
if (reader != null) reader.close();
if (writer != null) writer.close();
if (socket != null && !socket.isClosed()) socket.close();
} catch (IOException e) {
logger.severe("关闭资源时出错:" + e.getMessage());
}
}
}
六、实际应用案例
文件传输服务
import java.io.*;
import java.net.*;
public class FileTransferServer {
public static void main(String[] args) throws IOException {
ServerSocket serverSocket = new ServerSocket(8080);
System.out.println("文件传输服务器启动");
while (true) {
Socket clientSocket = serverSocket.accept();
new Thread(new FileTransferHandler(clientSocket)).start();
}
}
}
class FileTransferHandler implements Runnable {
private Socket clientSocket;
public FileTransferHandler(Socket socket) {
this.clientSocket = socket;
}
@Override
public void run() {
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream())
);
// 读取文件名
String fileName = reader.readLine();
System.out.println("请求文件:" + fileName);
File file = new File(fileName);
if (file.exists() && file.isFile()) {
sendFile(file);
} else {
sendError("文件不存在:" + fileName);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
clientSocket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
private void sendFile(File file) throws IOException {
// 发送文件大小
PrintWriter writer = new PrintWriter(clientSocket.getOutputStream(), true);
writer.println("OK:" + file.length());
// 发送文件内容
try (FileInputStream fis = new FileInputStream(file);
OutputStream os = clientSocket.getOutputStream()) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = fis.read(buffer)) != -1) {
os.write(buffer, 0, bytesRead);
}
os.flush();
System.out.println("文件发送完成:" + file.getName());
}
}
private void sendError(String error) throws IOException {
PrintWriter writer = new PrintWriter(clientSocket.getOutputStream(), true);
writer.println("ERROR:" + error);
}
}
简单聊天室
import java.io.*;
import java.net.*;
import java.util.*;
public class ChatServer {
private static Set<ClientHandler> clients = new HashSet<>();
public static void main(String[] args) throws IOException {
ServerSocket serverSocket = new ServerSocket(8080);
System.out.println("聊天室服务器启动");
while (true) {
Socket clientSocket = serverSocket.accept();
ClientHandler handler = new ClientHandler(clientSocket);
clients.add(handler);
new Thread(handler).start();
}
}
// 广播消息给所有客户端
public static void broadcast(String message, ClientHandler sender) {
synchronized (clients) {
for (ClientHandler client : clients) {
if (client != sender) {
client.sendMessage(message);
}
}
}
}
// 移除客户端
public static void removeClient(ClientHandler client) {
synchronized (clients) {
clients.remove(client);
}
}
static class ClientHandler implements Runnable {
private Socket socket;
private PrintWriter writer;
private String nickname;
public ClientHandler(Socket socket) {
this.socket = socket;
}
public void sendMessage(String message) {
if (writer != null) {
writer.println(message);
}
}
@Override
public void run() {
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(socket.getInputStream())
);
writer = new PrintWriter(socket.getOutputStream(), true);
// 获取用户昵称
writer.println("请输入您的昵称:");
nickname = reader.readLine();
broadcast(nickname + " 加入了聊天室", this);
String message;
while ((message = reader.readLine()) != null) {
if ("bye".equalsIgnoreCase(message)) {
break;
}
// 广播消息
broadcast(nickname + ": " + message, this);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
// 清理资源
removeClient(this);
broadcast(nickname + " 离开了聊天室", this);
try {
socket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
七、性能优化
连接池优化
import java.io.*;
import java.net.*;
import java.util.concurrent.*;
public class ConnectionPool {
private final Queue<Socket> availableConnections = new ConcurrentLinkedQueue<>();
private final String host;
private final int port;
private final int maxSize;
private final Object lock = new Object();
public ConnectionPool(String host, int port, int maxSize) {
this.host = host;
this.port = port;
this.maxSize = maxSize;
}
public Socket getConnection() throws IOException {
synchronized (lock) {
Socket socket = availableConnections.poll();
if (socket == null || socket.isClosed()) {
socket = new Socket(host, port);
System.out.println("创建新连接");
} else {
System.out.println("复用连接");
}
return socket;
}
}
public void returnConnection(Socket socket) {
synchronized (lock) {
if (availableConnections.size() < maxSize && !socket.isClosed()) {
availableConnections.offer(socket);
System.out.println("连接返回池中");
} else {
try {
socket.close();
System.out.println("连接已关闭");
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
缓冲区优化
public class BufferedNetworkIO {
public static void efficientDataTransfer(Socket socket) throws IOException {
// 使用缓冲区提高效率
BufferedInputStream bis = new BufferedInputStream(
socket.getInputStream(), 8192 // 8KB缓冲区
);
BufferedOutputStream bos = new BufferedOutputStream(
socket.getOutputStream(), 8192
);
// 设置Socket缓冲区大小
socket.setReceiveBufferSize(32 * 1024); // 32KB接收缓冲区
socket.setSendBufferSize(32 * 1024); // 32KB发送缓冲区
// 禁用Nagle算法以减少延迟
socket.setTcpNoDelay(true);
// 数据传输
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = bis.read(buffer)) != -1) {
bos.write(buffer, 0, bytesRead);
}
bos.flush();
// 关闭资源
bis.close();
bos.close();
}
}
八、进阶
NIO编程:实现更高效的网络通信
网络协议:从HTTP到自定义协议的演进
HTTP/1.1的性能陷阱
HTTP/1.1虽然支持Keep-Alive,但存在队头阻塞问题。在高并发场景下,这个问题尤为严重。
什么是队头阻塞(Head-of-Line Blocking)
队头阻塞是HTTP/1.1的核心性能问题。想象一下单车道的高速公路:
客户端 → 服务端
请求1: GET /api/user/1 ← 正在处理(耗时5秒)
请求2: GET /api/user/2 ← 等待中
请求3: GET /api/user/3 ← 等待中
问题核心:即使请求2、3可以很快处理完,但必须等待请求1完成后才能开始处理。
HTTP/1.1队头阻塞的具体表现
// 模拟HTTP/1.1队头阻塞问题
public class HTTP1BlockingDemo {
public static void main(String[] args) throws Exception {
// 创建多个请求到同一个服务器
String serverUrl = "http://localhost:8080";
// 请求1:慢查询(模拟5秒)
long start1 = System.currentTimeMillis();
CompletableFuture<String> request1 = CompletableFuture.supplyAsync(() -> {
try {
return httpGet(serverUrl + "/slow-api"); // 耗时5秒
} catch (Exception e) {
return "error";
}
});
// 请求2:快查询(本来只需100ms)
long start2 = System.currentTimeMillis();
CompletableFuture<String> request2 = CompletableFuture.supplyAsync(() -> {
try {
return httpGet(serverUrl + "/fast-api"); // 耗时100ms
} catch (Exception e) {
return "error";
}
});
// 等待所有请求完成
CompletableFuture.allOf(request1, request2).join();
System.out.println("请求1完成时间: " + (System.currentTimeMillis() - start1) + "ms");
System.out.println("请求2完成时间: " + (System.currentTimeMillis() - start2) + "ms");
// 输出结果:
// 请求1完成时间: 5000ms
// 请求2完成时间: 5100ms ← 本应该100ms,但被阻塞了
}
private static String httpGet(String url) throws Exception {
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
connection.setRequestMethod("GET");
// 这里会被前面的请求阻塞
return new String(connection.getInputStream().readAllBytes());
}
}
HTTP/1.1的解决方案与局限性
// HTTP/1.1的多连接解决方案
public class HTTP1ConnectionPool {
private final Queue<HttpURLConnection> connectionPool;
private final int maxConnections = 6; // 浏览器通常限制6个并发连接
private final AtomicInteger activeConnections = new AtomicInteger(0);
public HTTP1ConnectionPool() {
this.connectionPool = new ConcurrentLinkedQueue<>();
}
public CompletableFuture<String> request(String url) {
return CompletableFuture.supplyAsync(() -> {
HttpURLConnection connection = null;
try {
connection = getConnection(url);
if (connection == null) {
// 连接池满了,需要等待
Thread.sleep(100);
return request(url).get(); // 递归重试
}
// 发送请求
return new String(connection.getInputStream().readAllBytes());
} catch (Exception e) {
return "error: " + e.getMessage();
} finally {
if (connection != null) {
releaseConnection(connection);
}
}
});
}
private HttpURLConnection getConnection(String url) throws Exception {
HttpURLConnection connection = connectionPool.poll();
if (connection == null && activeConnections.get() < maxConnections) {
connection = (HttpURLConnection) new URL(url).openConnection();
connection.setRequestProperty("Connection", "keep-alive");
connection.setConnectTimeout(5000);
connection.setReadTimeout(10000);
activeConnections.incrementAndGet();
}
return connection;
}
private void releaseConnection(HttpURLConnection connection) {
if (connection != null) {
connectionPool.offer(connection);
}
}
}
HTTP/2的多路复用实现
HTTP/2通过二进制分帧和多路复用技术,解决了队头阻塞问题。
二进制分帧(Binary Framing)
HTTP/2将HTTP消息分解为独立的帧(frames),这是多路复用的基础。
查看代码
// HTTP/2帧结构模拟
public class HTTP2Frame {
public enum FrameType {
DATA(0x0),
HEADERS(0x1),
PRIORITY(0x2),
RST_STREAM(0x3),
SETTINGS(0x4),
PUSH_PROMISE(0x5),
PING(0x6),
GOAWAY(0x7),
WINDOW_UPDATE(0x8),
CONTINUATION(0x9);
private final int value;
FrameType(int value) { this.value = value; }
public int getValue() { return value; }
}
// HTTP/2帧头格式(9字节)
private final int length; // 24位:帧负载长度
private final FrameType type; // 8位:帧类型
private final byte flags; // 8位:帧标志
private final int streamId; // 31位:流标识符
private final byte[] payload; // 帧负载
public HTTP2Frame(FrameType type, int streamId, byte[] payload) {
this.type = type;
this.streamId = streamId;
this.payload = payload != null ? payload : new byte[0];
this.length = this.payload.length;
this.flags = 0;
}
// 编码帧为字节数组
public byte[] encode() {
ByteBuffer buffer = ByteBuffer.allocate(9 + payload.length);
// 帧头
buffer.put((byte) (length >>> 16)); // 长度高8位
buffer.put((byte) (length >>> 8)); // 长度中8位
buffer.put((byte) length); // 长度低8位
buffer.put((byte) type.getValue()); // 类型
buffer.put(flags); // 标志
buffer.putInt(streamId); // 流ID
// 负载
buffer.put(payload);
return buffer.array();
}
// 从字节数组解码帧
public static HTTP2Frame decode(byte[] data) {
ByteBuffer buffer = ByteBuffer.wrap(data);
// 解析帧头
int length = ((buffer.get() & 0xFF) << 16) |
((buffer.get() & 0xFF) << 8) |
(buffer.get() & 0xFF);
FrameType type = FrameType.values()[buffer.get() & 0xFF];
byte flags = buffer.get();
int streamId = buffer.getInt() & 0x7FFFFFFF; // 清除保留位
// 解析负载
byte[] payload = new byte[length];
buffer.get(payload);
return new HTTP2Frame(type, streamId, payload);
}
}
多路复用(Multiplexing)实现
HTTP/2多路复用连接管理
// HTTP/2多路复用连接管理
public class HTTP2MultiplexConnection {
private final SocketChannel socketChannel;
private final Map<Integer, HTTP2Stream> streams = new ConcurrentHashMap<>();
private final AtomicInteger nextStreamId = new AtomicInteger(1);
private final BlockingQueue<HTTP2Frame> sendQueue = new LinkedBlockingQueue<>();
private final ExecutorService frameProcessor = Executors.newFixedThreadPool(4);
public HTTP2MultiplexConnection(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
startFrameProcessor();
}
// 发送HTTP请求(非阻塞)
public CompletableFuture<HTTP2Response> sendRequest(HTTP2Request request) {
int streamId = nextStreamId.getAndAdd(2); // 客户端使用奇数流ID
HTTP2Stream stream = new HTTP2Stream(streamId);
streams.put(streamId, stream);
// 发送HEADERS帧
HTTP2Frame headersFrame = new HTTP2Frame(
HTTP2Frame.FrameType.HEADERS,
streamId,
encodeHeaders(request.getHeaders())
);
sendQueue.offer(headersFrame);
// 发送DATA帧(如果有请求体)
if (request.getBody() != null) {
HTTP2Frame dataFrame = new HTTP2Frame(
HTTP2Frame.FrameType.DATA,
streamId,
request.getBody()
);
sendQueue.offer(dataFrame);
}
return stream.getResponseFuture();
}
// 帧处理器:并发处理多个流
private void startFrameProcessor() {
// 发送线程
Thread sender = new Thread(() -> {
while (true) {
try {
HTTP2Frame frame = sendQueue.take();
byte[] frameData = frame.encode();
ByteBuffer buffer = ByteBuffer.wrap(frameData);
while (buffer.hasRemaining()) {
socketChannel.write(buffer);
}
} catch (Exception e) {
e.printStackTrace();
}
}
});
sender.setName("HTTP2-Sender");
sender.start();
// 接收线程
Thread receiver = new Thread(() -> {
ByteBuffer buffer = ByteBuffer.allocate(8192);
while (true) {
try {
int bytesRead = socketChannel.read(buffer);
if (bytesRead > 0) {
buffer.flip();
processReceivedData(buffer);
buffer.clear();
}
} catch (Exception e) {
e.printStackTrace();
}
}
});
receiver.setName("HTTP2-Receiver");
receiver.start();
}
private void processReceivedData(ByteBuffer buffer) {
while (buffer.remaining() >= 9) { // 至少要有完整的帧头
// 解析帧
HTTP2Frame frame = HTTP2Frame.decode(buffer.array());
// 异步处理帧,避免阻塞
frameProcessor.submit(() -> handleFrame(frame));
}
}
private void handleFrame(HTTP2Frame frame) {
int streamId = frame.getStreamId();
HTTP2Stream stream = streams.get(streamId);
if (stream == null) {
// 流不存在,可能是服务器推送或错误
return;
}
switch (frame.getType()) {
case HEADERS:
Map<String, String> headers = decodeHeaders(frame.getPayload());
stream.setResponseHeaders(headers);
break;
case DATA:
stream.appendData(frame.getPayload());
break;
case RST_STREAM:
stream.reset();
streams.remove(streamId);
break;
default:
// 处理其他帧类型
break;
}
}
// HTTP/2流管理
private static class HTTP2Stream {
private final int streamId;
private final CompletableFuture<HTTP2Response> responseFuture;
private final ByteArrayOutputStream dataBuffer;
private Map<String, String> responseHeaders;
public HTTP2Stream(int streamId) {
this.streamId = streamId;
this.responseFuture = new CompletableFuture<>();
this.dataBuffer = new ByteArrayOutputStream();
}
public void setResponseHeaders(Map<String, String> headers) {
this.responseHeaders = headers;
}
public void appendData(byte[] data) {
try {
dataBuffer.write(data);
// 检查是否接收完成(简化实现)
if (isComplete()) {
HTTP2Response response = new HTTP2Response();
response.setHeaders(responseHeaders);
response.setBody(dataBuffer.toByteArray());
responseFuture.complete(response);
}
} catch (Exception e) {
responseFuture.completeExceptionally(e);
}
}
public void reset() {
responseFuture.completeExceptionally(new RuntimeException("Stream reset"));
}
public CompletableFuture<HTTP2Response> getResponseFuture() {
return responseFuture;
}
private boolean isComplete() {
// 简化实现:检查Content-Length
String contentLength = responseHeaders.get("content-length");
if (contentLength != null) {
return dataBuffer.size() >= Integer.parseInt(contentLength);
}
return false;
}
}
}
查看代码
public class Http2Client {
private final Http2Connection connection;
private final ConcurrentHashMap<Integer, CompletableFuture<Http2Response>> pendingRequests;
public CompletableFuture<Http2Response> sendRequest(Http2Request request) {
int streamId = generateStreamId();
CompletableFuture<Http2Response> future = new CompletableFuture<>();
pendingRequests.put(streamId, future);
// 发送HEADERS帧
Http2Frame headersFrame = new Http2Frame(
Http2FrameType.HEADERS,
streamId,
encodeHeaders(request.getHeaders())
);
connection.sendFrame(headersFrame);
// 发送DATA帧
if (request.getBody() != null) {
Http2Frame dataFrame = new Http2Frame(
Http2FrameType.DATA,
streamId,
request.getBody()
);
connection.sendFrame(dataFrame);
}
return future;
}
// 处理响应帧
public void handleFrame(Http2Frame frame) {
CompletableFuture<Http2Response> future = pendingRequests.get(frame.getStreamId());
if (future != null) {
if (frame.getType() == Http2FrameType.HEADERS) {
// 处理响应头
Http2Response response = new Http2Response();
response.setHeaders(decodeHeaders(frame.getPayload()));
future.complete(response);
}
}
}
}
高性能自定义协议设计
在微服务内部通信中,自定义协议可以获得更好的性能。
自定义协议的应用场景
- 微服务内部通信:如Dubbo、gRPC
- 游戏服务器:实时性要求高
- 物联网设备:带宽有限
- 金融交易系统:低延迟要求
Spring Cloud中的协议实现
Spring Cloud实际上使用HTTP协议,但我们可以看看类似的自定义协议实现
查看代码
// 类似Dubbo的自定义协议实现
public class MicroserviceProtocol {
// 协议格式:Magic(4) + Version(1) + Type(1) + Status(1) + Reserved(1) + Length(4) + RequestId(8) + Data(N)
private static final int MAGIC = 0xDABB;
private static final byte VERSION = 1;
private static final int HEADER_LENGTH = 20;
public enum MessageType {
REQUEST((byte) 1),
RESPONSE((byte) 2),
HEARTBEAT((byte) 3);
private final byte value;
MessageType(byte value) { this.value = value; }
public byte getValue() { return value; }
}
public static class RpcMessage {
private MessageType type;
private byte status;
private long requestId;
private byte[] data;
// 序列化
public byte[] serialize() {
ByteBuffer buffer = ByteBuffer.allocate(HEADER_LENGTH + (data != null ? data.length : 0));
buffer.putInt(MAGIC);
buffer.put(VERSION);
buffer.put(type.getValue());
buffer.put(status);
buffer.put((byte) 0); // reserved
buffer.putInt(data != null ? data.length : 0);
buffer.putLong(requestId);
if (data != null) {
buffer.put(data);
}
return buffer.array();
}
// 反序列化
public static RpcMessage deserialize(byte[] bytes) {
ByteBuffer buffer = ByteBuffer.wrap(bytes);
int magic = buffer.getInt();
if (magic != MAGIC) {
throw new IllegalArgumentException("Invalid magic number");
}
byte version = buffer.get();
MessageType type = MessageType.values()[buffer.get() - 1];
byte status = buffer.get();
buffer.get(); // skip reserved
int length = buffer.getInt();
long requestId = buffer.getLong();
RpcMessage message = new RpcMessage();
message.type = type;
message.status = status;
message.requestId = requestId;
if (length > 0) {
message.data = new byte[length];
buffer.get(message.data);
}
return message;
}
}
// 协议处理器
public static class ProtocolHandler {
private final Map<Long, CompletableFuture<RpcMessage>> pendingRequests = new ConcurrentHashMap<>();
private final AtomicLong requestIdGenerator = new AtomicLong(0);
public CompletableFuture<RpcMessage> sendRequest(SocketChannel channel, byte[] data) {
long requestId = requestIdGenerator.incrementAndGet();
RpcMessage request = new RpcMessage();
request.type = MessageType.REQUEST;
request.requestId = requestId;
request.data = data;
CompletableFuture<RpcMessage> future = new CompletableFuture<>();
pendingRequests.put(requestId, future);
try {
// 发送请求
byte[] serialized = request.serialize();
ByteBuffer buffer = ByteBuffer.wrap(serialized);
while (buffer.hasRemaining()) {
channel.write(buffer);
}
} catch (Exception e) {
pendingRequests.remove(requestId);
future.completeExceptionally(e);
}
return future;
}
public void handleResponse(RpcMessage response) {
CompletableFuture<RpcMessage> future = pendingRequests.remove(response.requestId);
if (future != null) {
future.complete(response);
}
}
}
}
协议性能优化技巧
查看代码
// 高性能协议优化
public class OptimizedProtocol {
// 使用对象池避免频繁创建ByteBuffer
private static final ObjectPool<ByteBuffer> bufferPool = new ObjectPool<>(
100,
() -> ByteBuffer.allocateDirect(8192),
ByteBuffer::clear
);
// 使用无锁数据结构
private static final Queue<RpcMessage> messageQueue = new MpscLinkedQueue<>();
// 批量处理提高效率
public void processBatch(List<RpcMessage> messages) {
ByteBuffer batchBuffer = bufferPool.acquire();
try {
for (RpcMessage message : messages) {
byte[] serialized = message.serialize();
if (batchBuffer.remaining() < serialized.length) {
// 缓冲区满,发送当前批次
flushBuffer(batchBuffer);
batchBuffer.clear();
}
batchBuffer.put(serialized);
}
// 发送剩余数据
if (batchBuffer.position() > 0) {
flushBuffer(batchBuffer);
}
} finally {
bufferPool.release(batchBuffer);
}
}
private void flushBuffer(ByteBuffer buffer) {
buffer.flip();
// 发送数据的具体实现
// ...
}
}
分布式系统:微服务架构和分布式通信
服务发现的性能优化
传统服务发现每次都要查询注册中心,存在以下问题:
- 网络延迟
- 注册中心压力
- 服务列表变更频繁
优化策略详解
查看代码
// 完整的服务发现优化实现
@Component
public class EnhancedServiceDiscovery {
// 多级缓存策略
@Autowired
private RedisTemplate<String, String> redisTemplate;
// 本地缓存
private final ConcurrentHashMap<String, CachedServiceList> localCache = new ConcurrentHashMap<>();
// 健康检查
private final ConcurrentHashMap<String, ServiceInstance> healthyInstances = new ConcurrentHashMap<>();
// 负载均衡器
private final LoadBalancer loadBalancer = new WeightedRoundRobinLoadBalancer();
// 事件总线
private final EventBus eventBus = new AsyncEventBus(Executors.newFixedThreadPool(2));
@PostConstruct
public void init() {
// 启动定时任务
startPeriodicTasks();
// 注册事件监听器
eventBus.register(new ServiceChangeListener());
}
// 获取服务实例(三级缓存)
public ServiceInstance getServiceInstance(String serviceName, String requestKey) {
// 1. 本地缓存
CachedServiceList cached = localCache.get(serviceName);
if (cached != null && !cached.isExpired()) {
return loadBalancer.select(cached.getInstances(), requestKey);
}
// 2. Redis缓存
List<ServiceInstance> instances = getFromRedisCache(serviceName);
if (instances != null) {
updateLocalCache(serviceName, instances);
return loadBalancer.select(instances, requestKey);
}
// 3. 注册中心
instances = getFromRegistry(serviceName);
if (instances != null) {
updateRedisCache(serviceName, instances);
updateLocalCache(serviceName, instances);
return loadBalancer.select(instances, requestKey);
}
return null;
}
// 从Redis缓存获取
private List<ServiceInstance> getFromRedisCache(String serviceName) {
try {
String key = "service:" + serviceName;
String json = redisTemplate.opsForValue().get(key);
if (json != null) {
return JsonUtils.fromJson(json, new TypeReference<List<ServiceInstance>>() {});
}
} catch (Exception e) {
logger.warn("Redis缓存获取失败", e);
}
return null;
}
// 更新Redis缓存
private void updateRedisCache(String serviceName, List<ServiceInstance> instances) {
try {
String key = "service:" + serviceName;
String json = JsonUtils.toJson(instances);
redisTemplate.opsForValue().set(key, json, Duration.ofMinutes(5));
} catch (Exception e) {
logger.warn("Redis缓存更新失败", e);
}
}
// 更新本地缓存
private void updateLocalCache(String serviceName, List<ServiceInstance> instances) {
CachedServiceList cached = new CachedServiceList(instances, System.currentTimeMillis() + 60000);
localCache.put(serviceName, cached);
// 发布缓存更新事件
eventBus.post(new ServiceListUpdateEvent(serviceName, instances));
}
// 从注册中心获取
private List<ServiceInstance> getFromRegistry(String serviceName) {
try {
return consulClient.getHealthyInstances(serviceName);
} catch (Exception e) {
logger.error("注册中心查询失败", e);
return null;
}
}
// 启动定时任务
private void startPeriodicTasks() {
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(3);
// 定时刷新缓存
scheduler.scheduleAtFixedRate(() -> {
localCache.forEach((serviceName, cached) -> {
if (cached.isExpired()) {
CompletableFuture.runAsync(() -> refreshService(serviceName));
}
});
}, 30, 30, TimeUnit.SECONDS);
// 健康检查
scheduler.scheduleAtFixedRate(this::performHealthCheck, 10, 10, TimeUnit.SECONDS);
// 监听服务变更
scheduler.submit(this::watchServiceChanges);
}
// 异步刷新服务
private void refreshService(String serviceName) {
try {
List<ServiceInstance> instances = getFromRegistry(serviceName);
if (instances != null) {
updateRedisCache(serviceName, instances);
updateLocalCache(serviceName, instances);
}
} catch (Exception e) {
logger.error("服务刷新失败: " + serviceName, e);
}
}
// 健康检查实现
private void performHealthCheck() {
localCache.forEach((serviceName, cached) -> {
cached.getInstances().parallelStream().forEach(instance -> {
CompletableFuture.runAsync(() -> {
boolean healthy = checkInstanceHealth(instance);
if (healthy) {
healthyInstances.put(instance.getInstanceId(), instance);
} else {
healthyInstances.remove(instance.getInstanceId());
// 发布不健康事件
eventBus.post(new InstanceUnhealthyEvent(instance));
}
});
});
});
}
// 检查单个实例健康状态
private boolean checkInstanceHealth(ServiceInstance instance) {
try {
String healthUrl = "http://" + instance.getHost() + ":" + instance.getPort() + "/health";
HttpURLConnection conn = (HttpURLConnection) new URL(healthUrl).openConnection();
conn.setConnectTimeout(3000);
conn.setReadTimeout(3000);
conn.setRequestMethod("GET");
return conn.getResponseCode() == 200;
} catch (Exception e) {
return false;
}
}
// 监听服务变更
private void watchServiceChanges() {
consulClient.watchServices((oldServices, newServices) -> {
newServices.forEach((serviceName, instances) -> {
CachedServiceList cached = localCache.get(serviceName);
if (cached == null || !cached.getInstances().equals(instances)) {
updateLocalCache(serviceName, instances);
logger.info("服务列表更新: {}, 实例数: {}", serviceName, instances.size());
}
});
});
}
// 缓存数据结构
private static class CachedServiceList {
private final List<ServiceInstance> instances;
private final long expireTime;
public CachedServiceList(List<ServiceInstance> instances, long expireTime) {
this.instances = new ArrayList<>(instances);
this.expireTime = expireTime;
}
public boolean isExpired() {
return System.currentTimeMillis() > expireTime;
}
public List<ServiceInstance> getInstances() {
return instances;
}
}
// 事件监听器
private class ServiceChangeListener {
@Subscribe
public void onServiceListUpdate(ServiceListUpdateEvent event) {
logger.info("服务列表更新事件: {}", event.getServiceName());
// 更新负载均衡器
loadBalancer.updateInstances(event.getServiceName(), event.getInstances());
}
分布式限流与熔断
在微服务架构中,限流和熔断是保障系统稳定性的重要手段。
查看代码
@Component
public class DistributedRateLimiter {
private final RedisTemplate<String, String> redisTemplate;
private final String luaScript = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local current = redis.call('GET', key)
if current == false then
redis.call('SET', key, 1)
redis.call('EXPIRE', key, window)
return 1
end
if tonumber(current) < limit then
return redis.call('INCR', key)
else
return 0
end
""";
public boolean tryAcquire(String key, int limit, int windowSeconds) {
try {
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setScriptText(luaScript);
script.setResultType(Long.class);
Long result = redisTemplate.execute(script,
Collections.singletonList(key),
String.valueOf(limit),
String.valueOf(windowSeconds),
String.valueOf(System.currentTimeMillis() / 1000));
return result != null && result > 0;
} catch (Exception e) {
// Redis异常时降级处理
return tryAcquireLocally(key, limit, windowSeconds);
}
}
private boolean tryAcquireLocally(String key, int limit, int windowSeconds) {
// 本地限流实现
// ...
return true;
}
}
// 熔断器实现
public class CircuitBreaker {
private enum State { CLOSED, OPEN, HALF_OPEN }
private volatile State state = State.CLOSED;
private final int failureThreshold;
private final int successThreshold;
private final long timeout;
private final AtomicInteger failureCount = new AtomicInteger(0);
private final AtomicInteger successCount = new AtomicInteger(0);
private volatile long lastFailureTime = 0;
public CircuitBreaker(int failureThreshold, int successThreshold, long timeout) {
this.failureThreshold = failureThreshold;
this.successThreshold = successThreshold;
this.timeout = timeout;
}
public <T> T execute(Supplier<T> operation, Supplier<T> fallback) {
if (state == State.OPEN) {
if (System.currentTimeMillis() - lastFailureTime > timeout) {
state = State.HALF_OPEN;
successCount.set(0);
} else {
return fallback.get();
}
}
try {
T result = operation.get();
onSuccess();
return result;
} catch (Exception e) {
onFailure();
return fallback.get();
}
}
private void onSuccess() {
failureCount.set(0);
if (state == State.HALF_OPEN) {
if (successCount.incrementAndGet() >= successThreshold) {
state = State.CLOSED;
}
}
}
private void onFailure() {
lastFailureTime = System.currentTimeMillis();
if (failureCount.incrementAndGet() >= failureThreshold) {
state = State.OPEN;
}
}
}
安全通信:生产级SSL/TLS实现
双向认证与证书管理
查看代码
// 完整的SSL/TLS服务器实现
public class ProductionSSLServer {
private static final Logger logger = LoggerFactory.getLogger(ProductionSSLServer.class);
private final int port;
private final SSLContext sslContext;
private final ExecutorService threadPool;
private ServerSocket serverSocket;
private volatile boolean running = false;
public ProductionSSLServer(int port) throws Exception {
this.port = port;
this.sslContext = createSSLContext();
this.threadPool = Executors.newFixedThreadPool(50);
}
/**
* 创建SSL上下文 - 配置证书和加密参数
*/
private SSLContext createSSLContext() throws Exception {
// 1. 加载服务器私钥和证书
KeyStore keyStore = KeyStore.getInstance("PKCS12");
try (FileInputStream keyStoreStream = new FileInputStream("server.p12")) {
keyStore.load(keyStoreStream, "server123".toCharArray());
}
// 2. 初始化密钥管理器
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(keyStore, "server123".toCharArray());
// 3. 加载信任库(用于验证客户端证书)
KeyStore trustStore = KeyStore.getInstance("JKS");
try (FileInputStream trustStoreStream = new FileInputStream("truststore.jks")) {
trustStore.load(trustStoreStream, "trust123".toCharArray());
}
// 4. 初始化信任管理器
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(trustStore);
// 5. 创建SSL上下文
SSLContext context = SSLContext.getInstance("TLSv1.3");
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return context;
}
/**
* 启动SSL服务器
*/
public void start() throws Exception {
SSLServerSocketFactory factory = sslContext.getServerSocketFactory();
serverSocket = factory.createServerSocket(port);
// 转换为SSL服务器套接字进行高级配置
if (serverSocket instanceof SSLServerSocket) {
SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocket;
// 配置SSL参数
configureSslSocket(sslServerSocket);
}
running = true;
logger.info("SSL服务器启动,监听端口: {}", port);
// 主循环:接受客户端连接
while (running) {
try {
Socket clientSocket = serverSocket.accept();
// 异步处理客户端连接
threadPool.submit(() -> handleClient(clientSocket));
} catch (Exception e) {
if (running) {
logger.error("接受客户端连接失败", e);
}
}
}
}
/**
* 配置SSL套接字参数
*/
private void configureSslSocket(SSLServerSocket sslServerSocket) {
// 1. 要求客户端证书验证(双向认证)
sslServerSocket.setNeedClientAuth(true);
// 2. 启用的TLS协议版本
sslServerSocket.setEnabledProtocols(new String[]{
"TLSv1.3", "TLSv1.2"
});
// 3. 启用安全的加密套件
sslServerSocket.setEnabledCipherSuites(new String[]{
// TLS 1.3 加密套件
"TLS_AES_256_GCM_SHA384",
"TLS_AES_128_GCM_SHA256",
"TLS_CHACHA20_POLY1305_SHA256",
// TLS 1.2 加密套件
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
});
logger.info("SSL配置完成 - 协议: {}, 加密套件: {}",
Arrays.toString(sslServerSocket.getEnabledProtocols()),
Arrays.toString(sslServerSocket.getEnabledCipherSuites()));
}
/**
* 处理客户端连接
*/
private void handleClient(Socket clientSocket) {
try (SSLSocket sslSocket = (SSLSocket) clientSocket;
BufferedReader reader = new BufferedReader(
new InputStreamReader(sslSocket.getInputStream()));
PrintWriter writer = new PrintWriter(sslSocket.getOutputStream(), true)) {
// 1. 执行SSL握手
sslSocket.startHandshake();
// 2. 验证客户端证书
if (!validateClientCertificate(sslSocket)) {
logger.warn("客户端证书验证失败: {}", sslSocket.getRemoteSocketAddress());
return;
}
// 3. 打印SSL会话信息
logSSLSessionInfo(sslSocket);
// 4. 处理业务逻辑
String request;
while ((request = reader.readLine()) != null) {
logger.info("收到请求: {}", request);
// 简单的回显服务
String response = "Echo: " + request;
writer.println(response);
if ("bye".equalsIgnoreCase(request)) {
break;
}
}
} catch (Exception e) {
logger.error("处理客户端连接失败", e);
}
}
/**
* 验证客户端证书
*/
private boolean validateClientCertificate(SSLSocket sslSocket) {
try {
// 获取SSL会话
SSLSession session = sslSocket.getSession();
// 获取客户端证书链
Certificate[] peerCerts = session.getPeerCertificates();
if (peerCerts.length == 0) {
logger.warn("客户端未提供证书");
return false;
}
// 验证第一个证书(客户端证书)
X509Certificate clientCert = (X509Certificate) peerCerts[0];
// 1. 检查证书有效期
clientCert.checkValidity();
// 2. 检查证书主题
String subject = clientCert.getSubjectDN().getName();
logger.info("客户端证书主题: {}", subject);
// 3. 检查证书用途
boolean[] keyUsage = clientCert.getKeyUsage();
if (keyUsage != null && keyUsage.length > 0) {
// 检查数字签名位
if (!keyUsage[0]) {
logger.warn("客户端证书不支持数字签名");
return false;
}
}
// 4. 自定义验证逻辑(例如:检查证书序列号、发行者等)
return isAuthorizedClient(clientCert);
} catch (Exception e) {
logger.error("验证客户端证书失败", e);
return false;
}
}
/**
* 自定义客户端授权检查
*/
private boolean isAuthorizedClient(X509Certificate certificate) {
// 示例:检查证书发行者
String issuer = certificate.getIssuerDN().getName();
logger.info("证书发行者: {}", issuer);
// 示例:检查证书序列号
BigInteger serialNumber = certificate.getSerialNumber();
logger.info("证书序列号: {}", serialNumber);
// 这里可以添加更复杂的授权逻辑
// 例如:查询数据库、调用外部服务等
return true; // 简化示例:总是返回true
}
/**
* 记录SSL会话信息
*/
private void logSSLSessionInfo(SSLSocket sslSocket) {
SSLSession session = sslSocket.getSession();
logger.info("SSL会话信息:");
logger.info(" 协议版本: {}", session.getProtocol());
logger.info(" 加密套件: {}", session.getCipherSuite());
logger.info(" 会话ID: {}", bytesToHex(session.getId()));
logger.info(" 创建时间: {}", new Date(session.getCreationTime()));
logger.info(" 最后访问时间: {}", new Date(session.getLastAccessedTime()));
// 打印证书信息
try {
Certificate[] localCerts = session.getLocalCertificates();
if (localCerts != null && localCerts.length > 0) {
X509Certificate serverCert = (X509Certificate) localCerts[0];
logger.info(" 服务器证书: {}", serverCert.getSubjectDN().getName());
}
Certificate[] peerCerts = session.getPeerCertificates();
if (peerCerts.length > 0) {
X509Certificate clientCert = (X509Certificate) peerCerts[0];
logger.info(" 客户端证书: {}", clientCert.getSubjectDN().getName());
}
} catch (Exception e) {
logger.warn("获取证书信息失败", e);
}
}
/**
* 字节数组转十六进制字符串
*/
private String bytesToHex(byte[] bytes) {
StringBuilder result = new StringBuilder();
for (byte b : bytes) {
result.append(String.format("%02X", b));
}
return result.toString();
}
/**
* 停止服务器
*/
public void stop() {
running = false;
try {
if (serverSocket != null && !serverSocket.isClosed()) {
serverSocket.close();
}
} catch (Exception e) {
logger.error("关闭服务器套接字失败", e);
}
threadPool.shutdown();
try {
if (!threadPool.awaitTermination(5, TimeUnit.SECONDS)) {
threadPool.shutdownNow();
}
} catch (InterruptedException e) {
threadPool.shutdownNow();
}
}
/**
* 主方法 - 启动服务器
*/
public static void main(String[] args) {
try {
ProductionSSLServer server = new ProductionSSLServer(8443);
// 添加关闭钩子
Runtime.getRuntime().addShutdownHook(new Thread(server::stop));
server.start();
} catch (Exception e) {
logger.error("启动SSL服务器失败", e);
}
}
}
// SSL客户端实现
public class SSLClient {
private static final Logger logger = LoggerFactory.getLogger(SSLClient.class);
private final String host;
private final int port;
private final SSLContext sslContext;
public SSLClient(String host, int port) throws Exception {
this.host = host;
this.port = port;
this.sslContext = createClientSSLContext();
}
/**
* 创建客户端SSL上下文
*/
private SSLContext createClientSSLContext() throws Exception {
// 1. 加载客户端证书和私钥
KeyStore clientKeyStore = KeyStore.getInstance("PKCS12");
try (FileInputStream keyStoreStream = new FileInputStream("client.p12")) {
clientKeyStore.load(keyStoreStream, "client123".toCharArray());
}
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(clientKeyStore, "client123".toCharArray());
// 2. 加载信任库(用于验证服务器证书)
KeyStore trustStore = KeyStore.getInstance("JKS");
try (FileInputStream trustStoreStream = new FileInputStream("truststore.jks")) {
trustStore.load(trustStoreStream, "trust123".toCharArray());
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(trustStore);
// 3. 创建SSL上下文
SSLContext context = SSLContext.getInstance("TLSv1.3");
context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return context;
}
/**
* 连接到SSL服务器
*/
public void connect() throws Exception {
SSLSocketFactory factory = sslContext.getSocketFactory();
try (SSLSocket sslSocket = (SSLSocket) factory.createSocket(host, port);
BufferedReader reader = new BufferedReader(
new InputStreamReader(sslSocket.getInputStream()));
PrintWriter writer = new PrintWriter(sslSocket.getOutputStream(), true);
Scanner scanner = new Scanner(System.in)) {
// 配置SSL套接字
sslSocket.setEnabledProtocols(new String[]{"TLSv1.3", "TLSv1.2"});
// 执行SSL握手
sslSocket.startHandshake();
logger.info("SSL连接建立成功");
logSSLSessionInfo(sslSocket);
// 交互式通信
String input;
while (true) {
System.out.print("请输入消息 (输入 'bye' 退出): ");
input = scanner.nextLine();
// 发送消息
writer.println(input);
// 接收响应
String response = reader.readLine();
System.out.println("服务器响应: " + response);
if ("bye".equalsIgnoreCase(input)) {
break;
}
}
} catch (Exception e) {
logger.error("SSL连接失败", e);
throw e;
}
}
private void logSSLSessionInfo(SSLSocket sslSocket) {
SSLSession session = sslSocket.getSession();
logger.info("SSL会话建立:");
logger.info(" 协议: {}", session.getProtocol());
logger.info(" 加密套件: {}", session.getCipherSuite());
}
public static void main(String[] args) {
try {
SSLClient client = new SSLClient("localhost", 8443);
client.connect();
} catch (Exception e) {
logger.error("客户端连接失败", e);
}
}
}
// 证书生成工具类
public class CertificateGenerator {
/**
* 生成自签名证书的脚本示例
* 实际生产环境中应该使用CA签发的证书
*/
public static void generateCertificates() {
String script = """
#!/bin/bash
# 1. 生成CA私钥
openssl genrsa -out ca-key.pem 2048
# 2. 生成CA证书
openssl req -new -x509 -key ca-key.pem -out ca-cert.pem -days 365 \\
-subj "/C=CN/ST=Beijing/L=Beijing/O=TestCA/CN=TestCA"
# 3. 生成服务器私钥
openssl genrsa -out server-key.pem 2048
# 4. 生成服务器证书请求
openssl req -new -key server-key.pem -out server-req.pem \\
-subj "/C=CN/ST=Beijing/L=Beijing/O=TestServer/CN=localhost"
# 5. 用CA签发服务器证书
openssl x509 -req -in server-req.pem -CA ca-cert.pem -CAkey ca-key.pem \\
-out server-cert.pem -days 365 -CAcreateserial
# 6. 生成客户端私钥
openssl genrsa -out client-key.pem 2048
# 7. 生成客户端证书请求
openssl req -new -key client-key.pem -out client-req.pem \\
-subj "/C=CN/ST=Beijing/L=Beijing/O=TestClient/CN=client"
# 8. 用CA签发客户端证书
openssl x509 -req -in client-req.pem -CA ca-cert.pem -CAkey ca-key.pem \\
-out client-cert.pem -days 365 -CAcreateserial
# 9. 生成服务器PKCS12格式证书
openssl pkcs12 -export -in server-cert.pem -inkey server-key.pem \\
-out server.p12 -name "server" -passout pass:server123
# 10. 生成客户端PKCS12格式证书
openssl pkcs12 -export -in client-cert.pem -inkey client-key.pem \\
-out client.p12 -name "client" -passout pass:client123
# 11. 创建Java truststore
keytool -import -file ca-cert.pem -alias ca -keystore truststore.jks \\
-storepass trust123 -noprompt
echo "证书生成完成!"
""";
System.out.println(script);
}
}
// 使用示例和测试
public class SSLDemo {
public static void main(String[] args) {
System.out.println("=== SSL/TLS 使用步骤 ===");
System.out.println("1. 生成证书文件(使用 CertificateGenerator)");
System.out.println("2. 启动 ProductionSSLServer");
System.out.println("3. 运行 SSLClient 连接服务器");
System.out.println("4. 进行加密通信");
System.out.println();
System.out.println("=== 关键配置说明 ===");
System.out.println("- 服务器证书: server.p12 (密码: server123)");
System.out.println("- 客户端证书: client.p12 (密码: client123)");
System.out.println("- 信任库: truststore.jks (密码: trust123)");
System.out.println("- 支持协议: TLSv1.3, TLSv1.2");
System.out.println("- 认证方式: 双向认证");
System.out.println();
CertificateGenerator.generateCertificates();
}
}
证书热更新机制
在生产环境中,证书需要定期更新而不能影响服务。
查看代码
@Component
public class CertificateManager {
private volatile SSLContext currentSSLContext;
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private final String certPath;
private final String keyPath;
private long lastModified = 0;
@PostConstruct
public void init() throws Exception {
loadCertificate();
// 每小时检查证书更新
scheduler.scheduleAtFixedRate(this::checkCertificateUpdate, 1, 1, TimeUnit.HOURS);
}
private void loadCertificate() throws Exception {
try {
// 加载新证书
SSLContext newContext = createSSLContext(certPath, keyPath);
// 原子更新
this.currentSSLContext = newContext;
this.lastModified = new File(certPath).lastModified();
logger.info("证书加载成功");
} catch (Exception e) {
logger.error("证书加载失败", e);
throw e;
}
}
private void checkCertificateUpdate() {
try {
File certFile = new File(certPath);
if (certFile.lastModified() > lastModified) {
logger.info("检测到证书更新,开始重新加载");
loadCertificate();
}
} catch (Exception e) {
logger.error("证书更新检查失败", e);
}
}
public SSLContext getSSLContext() {
return currentSSLContext;
}
}
高性能网络:Netty实战与优化
基于Netty的RPC框架实现
RPC协议定义
// ======= 1. RPC协议定义 =======
/**
* RPC协议消息格式
* +-------+-------+-------+-------+-------+-------+-------+-------+
* | Magic | Ver | Type | Codec | ReqId | Length| Status| Reserved|
* +-------+-------+-------+-------+-------+-------+-------+-------+
* | Request/Response Body |
* +---------------------------------------------------------------+
*/
public class RPCProtocol {
// 协议常量
public static final int MAGIC_NUMBER = 0x12345678;
public static final byte VERSION = 1;
public static final int HEADER_LENGTH = 16;
// 消息类型
public enum MessageType {
REQUEST((byte) 1),
RESPONSE((byte) 2),
HEARTBEAT((byte) 3);
private final byte value;
MessageType(byte value) { this.value = value; }
public byte getValue() { return value; }
}
// 序列化类型
public enum SerializationType {
JSON((byte) 1),
PROTOBUF((byte) 2),
HESSIAN((byte) 3);
private final byte value;
SerializationType(byte value) { this.value = value; }
public byte getValue() { return value; }
}
// 响应状态
public enum ResponseStatus {
SUCCESS((byte) 0),
TIMEOUT((byte) 1),
SERVICE_NOT_FOUND((byte) 2),
METHOD_NOT_FOUND((byte) 3),
INTERNAL_ERROR((byte) 4);
private final byte value;
ResponseStatus(byte value) { this.value = value; }
public byte getValue() { return value; }
}
}
// ======= 2. RPC消息对象 =======
/**
* RPC请求消息
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class RPCRequest {
private String requestId; // 请求ID
private String serviceName; // 服务名
private String methodName; // 方法名
private Class<?>[] parameterTypes; // 参数类型
private Object[] parameters; // 参数值
private long timestamp; // 时间戳
private Map<String, String> attachments = new HashMap<>(); // 附加信息
}
/**
* RPC响应消息
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class RPCResponse {
private String requestId; // 对应的请求ID
private Object result; // 返回结果
private Throwable error; // 异常信息
private long timestamp; // 时间戳
private RPCProtocol.ResponseStatus status = RPCProtocol.ResponseStatus.SUCCESS;
private Map<String, String> attachments = new HashMap<>(); // 附加信息
}
// ======= 3. 消息编解码器 =======
/**
* RPC消息编码器
*/
public class RPCEncoder extends MessageToByteEncoder<Object> {
private final Serializer serializer;
public RPCEncoder(Serializer serializer) {
this.serializer = serializer;
}
@Override
protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception {
if (msg instanceof RPCRequest) {
encodeRequest(ctx, (RPCRequest) msg, out);
} else if (msg instanceof RPCResponse) {
encodeResponse(ctx, (RPCResponse) msg, out);
} else {
throw new IllegalArgumentException("不支持的消息类型: " + msg.getClass());
}
}
private void encodeRequest(ChannelHandlerContext ctx, RPCRequest request, ByteBuf out) throws Exception {
// 序列化请求体
byte[] body = serializer.serialize(request);
// 写入协议头
out.writeInt(RPCProtocol.MAGIC_NUMBER); // Magic Number
out.writeByte(RPCProtocol.VERSION); // Version
out.writeByte(RPCProtocol.MessageType.REQUEST.getValue()); // Message Type
out.writeByte(serializer.getType().getValue()); // Serialization Type
out.writeInt(request.getRequestId().hashCode()); // Request ID
out.writeInt(body.length); // Body Length
out.writeByte(RPCProtocol.ResponseStatus.SUCCESS.getValue()); // Status (unused for request)
out.writeByte(0); // Reserved
// 写入请求体
out.writeBytes(body);
}
private void encodeResponse(ChannelHandlerContext ctx, RPCResponse response, ByteBuf out) throws Exception {
// 序列化响应体
byte[] body = serializer.serialize(response);
// 写入协议头
out.writeInt(RPCProtocol.MAGIC_NUMBER); // Magic Number
out.writeByte(RPCProtocol.VERSION); // Version
out.writeByte(RPCProtocol.MessageType.RESPONSE.getValue()); // Message Type
out.writeByte(serializer.getType().getValue()); // Serialization Type
out.writeInt(response.getRequestId().hashCode()); // Request ID
out.writeInt(body.length); // Body Length
out.writeByte(response.getStatus().getValue()); // Status
out.writeByte(0); // Reserved
// 写入响应体
out.writeBytes(body);
}
}
/**
* RPC消息解码器
*/
public class RPCDecoder extends ByteToMessageDecoder {
private final Map<Byte, Serializer> serializers;
public RPCDecoder(Map<Byte, Serializer> serializers) {
this.serializers = serializers;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// 检查是否有足够的字节读取协议头
if (in.readableBytes() < RPCProtocol.HEADER_LENGTH) {
return;
}
// 标记当前读取位置
in.markReaderIndex();
// 读取并验证Magic Number
int magicNumber = in.readInt();
if (magicNumber != RPCProtocol.MAGIC_NUMBER) {
in.resetReaderIndex();
throw new IllegalArgumentException("Invalid magic number: " + magicNumber);
}
// 读取协议头
byte version = in.readByte();
byte messageType = in.readByte();
byte serializationType = in.readByte();
int requestId = in.readInt();
int bodyLength = in.readInt();
byte status = in.readByte();
byte reserved = in.readByte();
// 检查版本
if (version != RPCProtocol.VERSION) {
throw new IllegalArgumentException("Unsupported version: " + version);
}
// 检查是否有足够的字节读取消息体
if (in.readableBytes() < bodyLength) {
in.resetReaderIndex();
return;
}
// 读取消息体
byte[] body = new byte[bodyLength];
in.readBytes(body);
// 获取序列化器
Serializer serializer = serializers.get(serializationType);
if (serializer == null) {
throw new IllegalArgumentException("Unsupported serialization type: " + serializationType);
}
// 反序列化消息
Object message = deserializeMessage(messageType, body, serializer);
out.add(message);
}
private Object deserializeMessage(byte messageType, byte[] body, Serializer serializer) throws Exception {
if (messageType == RPCProtocol.MessageType.REQUEST.getValue()) {
return serializer.deserialize(body, RPCRequest.class);
} else if (messageType == RPCProtocol.MessageType.RESPONSE.getValue()) {
return serializer.deserialize(body, RPCResponse.class);
} else {
throw new IllegalArgumentException("Unknown message type: " + messageType);
}
}
}
// ======= 4. 序列化接口 =======
/**
* 序列化接口
*/
public interface Serializer {
byte[] serialize(Object obj) throws Exception;
<T> T deserialize(byte[] data, Class<T> clazz) throws Exception;
RPCProtocol.SerializationType getType();
}
/**
* JSON序列化实现
*/
public class JsonSerializer implements Serializer {
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
public byte[] serialize(Object obj) throws Exception {
return objectMapper.writeValueAsBytes(obj);
}
@Override
public <T> T deserialize(byte[] data, Class<T> clazz) throws Exception {
return objectMapper.readValue(data, clazz);
}
@Override
public RPCProtocol.SerializationType getType() {
return RPCProtocol.SerializationType.JSON;
}
}
// ======= 5. RPC服务器实现 =======
/**
* RPC服务器
*/
public class RPCServer {
private static final Logger logger = LoggerFactory.getLogger(RPCServer.class);
private final int port;
private final Map<String, Object> serviceMap = new ConcurrentHashMap<>();
private final Map<Byte, Serializer> serializers = new HashMap<>();
private final ExecutorService businessExecutor;
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;
private Channel serverChannel;
public RPCServer(int port) {
this.port = port;
this.businessExecutor = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors() * 2
);
// 初始化序列化器
JsonSerializer jsonSerializer = new JsonSerializer();
serializers.put(jsonSerializer.getType().getValue(), jsonSerializer);
}
/**
* 注册服务
*/
public void registerService(String serviceName, Object serviceImpl) {
serviceMap.put(serviceName, serviceImpl);
logger.info("服务注册成功: {}", serviceName);
}
/**
* 启动服务器
*/
public void start() throws Exception {
bossGroup = new NioEventLoopGroup(1);
workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 128)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
// 空闲检测
pipeline.addLast(new IdleStateHandler(60, 0, 0, TimeUnit.SECONDS));
// 编解码器
pipeline.addLast(new RPCDecoder(serializers));
pipeline.addLast(new RPCEncoder(serializers.get(RPCProtocol.SerializationType.JSON.getValue())));
// 业务处理器
pipeline.addLast(new RPCServerHandler(serviceMap, businessExecutor));
}
});
// 绑定端口并启动
ChannelFuture future = bootstrap.bind(port).sync();
serverChannel = future.channel();
logger.info("RPC服务器启动成功,监听端口: {}", port);
// 等待服务器关闭
future.channel().closeFuture().sync();
} finally {
shutdown();
}
}
/**
* 关闭服务器
*/
public void shutdown() {
if (serverChannel != null) {
serverChannel.close();
}
if (bossGroup != null) {
bossGroup.shutdownGracefully();
}
if (workerGroup != null) {
workerGroup.shutdownGracefully();
}
if (businessExecutor != null) {
businessExecutor.shutdown();
}
logger.info("RPC服务器已关闭");
}
}
/**
* RPC服务器处理器
*/
public class RPCServerHandler extends SimpleChannelInboundHandler<RPCRequest> {
private static final Logger logger = LoggerFactory.getLogger(RPCServerHandler.class);
private final Map<String, Object> serviceMap;
private final ExecutorService businessExecutor;
public RPCServerHandler(Map<String, Object> serviceMap, ExecutorService businessExecutor) {
this.serviceMap = serviceMap;
this.businessExecutor = businessExecutor;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, RPCRequest request) throws Exception {
// 在业务线程池中处理请求
businessExecutor.submit(() -> {
RPCResponse response = new RPCResponse();
response.setRequestId(request.getRequestId());
response.setTimestamp(System.currentTimeMillis());
try {
// 查找服务
Object service = serviceMap.get(request.getServiceName());
if (service == null) {
response.setStatus(RPCProtocol.ResponseStatus.SERVICE_NOT_FOUND);
response.setError(new RuntimeException("服务不存在: " + request.getServiceName()));
ctx.writeAndFlush(response);
return;
}
// 反射调用方法
Method method = service.getClass().getMethod(
request.getMethodName(),
request.getParameterTypes()
);
Object result = method.invoke(service, request.getParameters());
response.setResult(result);
response.setStatus(RPCProtocol.ResponseStatus.SUCCESS);
logger.debug("处理请求成功: {}.{}", request.getServiceName(), request.getMethodName());
} catch (NoSuchMethodException e) {
response.setStatus(RPCProtocol.ResponseStatus.METHOD_NOT_FOUND);
response.setError(new RuntimeException("方法不存在: " + request.getMethodName()));
logger.error("方法不存在: {}.{}", request.getServiceName(), request.getMethodName());
} catch (Exception e) {
response.setStatus(RPCProtocol.ResponseStatus.INTERNAL_ERROR);
response.setError(e);
logger.error("处理请求异常: {}.{}", request.getServiceName(), request.getMethodName(), e);
}
// 发送响应
ctx.writeAndFlush(response);
});
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleState.READER_IDLE) {
logger.warn("客户端空闲超时,关闭连接: {}", ctx.channel().remoteAddress());
ctx.close();
}
}
super.userEventTriggered(ctx, evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("处理请求异常: {}", ctx.channel().remoteAddress(), cause);
ctx.close();
}
}
// ======= 6. RPC客户端实现 =======
/**
* RPC客户端
*/
public class RPCClient {
private static final Logger logger = LoggerFactory.getLogger(RPCClient.class);
private final String host;
private final int port;
private final Map<Byte, Serializer> serializers = new HashMap<>();
private final Map<String, CompletableFuture<RPCResponse>> pendingRequests = new ConcurrentHashMap<>();
private EventLoopGroup workerGroup;
private Channel channel;
private boolean connected = false;
public RPCClient(String host, int port) {
this.host = host;
this.port = port;
// 初始化序列化器
JsonSerializer jsonSerializer = new JsonSerializer();
serializers.put(jsonSerializer.getType().getValue(), jsonSerializer);
}
/**
* 连接服务器
*/
public void connect() throws Exception {
workerGroup = new NioEventLoopGroup();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
// 心跳检测
pipeline.addLast(new IdleStateHandler(0, 30, 0, TimeUnit.SECONDS));
// 编解码器
pipeline.addLast(new RPCDecoder(serializers));
pipeline.addLast(new RPCEncoder(serializers.get(RPCProtocol.SerializationType.JSON.getValue())));
// 客户端处理器
pipeline.addLast(new RPCClientHandler(pendingRequests));
}
});
ChannelFuture future = bootstrap.connect(host, port).sync();
channel = future.channel();
connected = true;
logger.info("连接RPC服务器成功: {}:{}", host, port);
}
/**
* 发送RPC请求
*/
public CompletableFuture<RPCResponse> sendRequest(RPCRequest request) {
if (!connected || channel == null || !channel.isActive()) {
CompletableFuture<RPCResponse> future = new CompletableFuture<>();
future.completeExceptionally(new RuntimeException("客户端未连接"));
return future;
}
CompletableFuture<RPCResponse> future = new CompletableFuture<>();
pendingRequests.put(request.getRequestId(), future);
// 设置超时
channel.eventLoop().schedule(() -> {
CompletableFuture<RPCResponse> timeoutFuture = pendingRequests.remove(request.getRequestId());
if (timeoutFuture != null) {
timeoutFuture.completeExceptionally(new RuntimeException("请求超时"));
}
}, 30, TimeUnit.SECONDS);
// 发送请求
channel.writeAndFlush(request).addListener(channelFuture -> {
if (!channelFuture.isSuccess()) {
pendingRequests.remove(request.getRequestId());
future.completeExceptionally(channelFuture.cause());
}
});
return future;
}
/**
* 关闭客户端
*/
public void close() {
if (channel != null) {
channel.close();
}
if (workerGroup != null) {
workerGroup.shutdownGracefully();
}
connected = false;
logger.info("RPC客户端已关闭");
}
}
/**
* RPC客户端处理器
*/
public class RPCClientHandler extends SimpleChannelInboundHandler<RPCResponse> {
private static final Logger logger = LoggerFactory.getLogger(RPCClientHandler.class);
private final Map<String, CompletableFuture<RPCResponse>> pendingRequests;
public RPCClientHandler(Map<String, CompletableFuture<RPCResponse>> pendingRequests) {
this.pendingRequests = pendingRequests;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, RPCResponse response) throws Exception {
String requestId = response.getRequestId();
CompletableFuture<RPCResponse> future = pendingRequests.remove(requestId);
if (future != null) {
future.complete(response);
} else {
logger.warn("收到未知请求的响应: {}", requestId);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleState.WRITER_IDLE) {
// 发送心跳
logger.debug("发送心跳包");
// 这里可以实现心跳逻辑
}
}
super.userEventTriggered(ctx, evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("客户端处理异常", cause);
ctx.close();
}
}
// ======= 7. 动态代理客户端 =======
/**
* RPC代理工厂
*/
public class RPCProxyFactory {
private final RPCClient client;
public RPCProxyFactory(RPCClient client) {
this.client = client;
}
@SuppressWarnings("unchecked")
public <T> T createProxy(Class<T> interfaceClass) {
return (T) Proxy.newProxyInstance(
interfaceClass.getClassLoader(),
new Class[]{interfaceClass},
new RPCInvocationHandler(client, interfaceClass.getName())
);
}
}
/**
* RPC调用处理器
*/
public class RPCInvocationHandler implements InvocationHandler {
private static final Logger logger = LoggerFactory.getLogger(RPCInvocationHandler.class);
private final RPCClient client;
private final String serviceName;
public RPCInvocationHandler(RPCClient client, String serviceName) {
this.client = client;
this.serviceName = serviceName;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// 忽略Object类的方法
if (method.getDeclaringClass() == Object.class) {
return method.invoke(this, args);
}
// 构造RPC请求
RPCRequest request = new RPCRequest();
request.setRequestId(UUID.randomUUID().toString());
request.setServiceName(serviceName);
request.setMethodName(method.getName());
request.setParameterTypes(method.getParameterTypes());
request.setParameters(args);
request.setTimestamp(System.currentTimeMillis());
// 发送请求并等待响应
CompletableFuture<RPCResponse> future = client.sendRequest(request);
RPCResponse response = future.get();
// 处理响应
if (response.getStatus() == RPCProtocol.ResponseStatus.SUCCESS) {
return response.getResult();
} else {
throw new RuntimeException("RPC调用失败: " + response.getError());
}
}
}
// ======= 8. 使用示例 =======
/**
* 示例服务接口
*/
public interface UserService {
String getUserById(String userId);
boolean updateUser(String userId, String userName);
}
/**
* 示例服务实现
*/
public class UserServiceImpl implements UserService {
private static final Logger logger = LoggerFactory.getLogger(UserServiceImpl.class);
@Override
public String getUserById(String userId) {
logger.info("查询用户: {}", userId);
return "User-" + userId;
}
@Override
public boolean updateUser(String userId, String userName) {
logger.info("更新用户: {} -> {}", userId, userName);
return true;
}
}
/**
* 服务器启动示例
*/
public class RPCServerExample {
public static void main(String[] args) throws Exception {
// 创建RPC服务器
RPCServer server = new RPCServer(8080);
// 注册服务
server.registerService(UserService.class.getName(), new UserServiceImpl());
// 启动服务器
server.start();
}
}
/**
* 客户端调用示例
*/
public class RPCClientExample {
public static void main(String[] args) throws Exception {
// 创建RPC客户端
RPCClient client = new RPCClient("localhost", 8080);
client.connect();
// 创建代理
RPCProxyFactory proxyFactory = new RPCProxyFactory(client);
UserService userService = proxyFactory.createProxy(UserService.class);
// 调用服务
String user = userService.getUserById("123");
System.out.println("查询结果: " + user);
boolean result = userService.updateUser("123", "张三");
System.out.println("更新结果: " + result);
// 关闭客户端
client.close();
}
}
查看代码
public class NettyRPCServer {
private final int port;
private final Map<String, Object> serviceMap = new ConcurrentHashMap<>();
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;
public NettyRPCServer(int port) {
this.port = port;
}
public void registerService(String serviceName, Object serviceImpl) {
serviceMap.put(serviceName, serviceImpl);
}
public void start() throws InterruptedException {
bossGroup = new NioEventLoopGroup(1);
workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
// 自定义协议解码器
pipeline.addLast(new RPCDecoder());
pipeline.addLast(new RPCEncoder());
// 业务处理器
pipeline.addLast(new RPCHandler(serviceMap));
}
})
.option(ChannelOption.SO_BACKLOG, 1024)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true);
ChannelFuture future = bootstrap.bind(port).sync();
logger.info("RPC服务器启动,监听端口: {}", port);
future.channel().closeFuture().sync();
} finally {
shutdown();
}
}
public void shutdown() {
if (bossGroup != null) {
bossGroup.shutdownGracefully();
}
if (workerGroup != null) {
workerGroup.shutdownGracefully();
}
}
}
public class RPCHandler extends SimpleChannelInboundHandler<RPCRequest> {
private final Map<String, Object> serviceMap;
private final ExecutorService businessExecutor;
public RPCHandler(Map<String, Object> serviceMap) {
this.serviceMap = serviceMap;
this.businessExecutor = new ThreadPoolExecutor(
20, 200, 60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
new ThreadFactoryBuilder().setNameFormat("rpc-business-%d").build(),
new ThreadPoolExecutor.CallerRunsPolicy()
);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, RPCRequest request) {
// 异步处理业务逻辑,避免阻塞IO线程
businessExecutor.execute(() -> {
RPCResponse response = new RPCResponse();
response.setRequestId(request.getRequestId());
try {
Object result = invokeService(request);
response.setResult(result);
response.setSuccess(true);
} catch (Exception e) {
response.setError(e.getMessage());
response.setSuccess(false);
logger.error("RPC调用失败", e);
}
// 写回响应
ctx.writeAndFlush(response);
});
}
private Object invokeService(RPCRequest request) throws Exception {
String serviceName = request.getServiceName();
String methodName = request.getMethodName();
Class<?>[] parameterTypes = request.getParameterTypes();
Object[] parameters = request.getParameters();
Object serviceBean = serviceMap.get(serviceName);
if (serviceBean == null) {
throw new RuntimeException("Service not found: " + serviceName);
}
Method method = serviceBean.getClass().getMethod(methodName, parameterTypes);
return method.invoke(serviceBean, parameters);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.error("RPC处理异常", cause);
ctx.close();
}
}
// 高性能对象池实现
public class ObjectPool<T> {
private final Queue<T> pool = new ConcurrentLinkedQueue<>();
private final AtomicInteger currentSize = new AtomicInteger(0);
private final int maxSize;
private final Supplier<T> factory;
private final Consumer<T> resetFunction;
public ObjectPool(int maxSize, Supplier<T> factory, Consumer<T> resetFunction) {
this.maxSize = maxSize;
this.factory = factory;
this.resetFunction = resetFunction;
}
public T acquire() {
T object = pool.poll();
if (object == null) {
object = factory.get();
} else {
currentSize.decrementAndGet();
}
return object;
}
public void release(T object) {
if (object != null && currentSize.get() < maxSize) {
resetFunction.accept(object);
pool.offer(object);
currentSize.incrementAndGet();
}
}
}
内存管理与性能调优
查看代码
public class NettyMemoryOptimizer {
// 配置直接内存分配器
public static void configureDirectMemoryAllocator() {
// 使用Netty的池化直接内存分配器
System.setProperty("io.netty.allocator.type", "pooled");
System.setProperty("io.netty.allocator.directMemoryCacheAlignment", "64");
System.setProperty("io.netty.allocator.maxOrder", "9"); // 最大chunk大小为512KB
// 禁用内存泄漏检测在生产环境中
System.setProperty("io.netty.leakDetection.level", "DISABLED");
}
// 自定义ByteBuf分配策略
public static class OptimizedByteBufAllocator {
private final ByteBufAllocator allocator = PooledByteBufAllocator.DEFAULT;
private final ObjectPool<ByteBuf> bufferPool;
public OptimizedByteBufAllocator() {
this.bufferPool = new ObjectPool<>(1000,
() -> allocator.directBuffer(8192),
ByteBuf::clear);
}
public ByteBuf allocate(int capacity) {
if (capacity <= 8192) {
return bufferPool.acquire();
} else {
return allocator.directBuffer(capacity);
}
}
public void release(ByteBuf buffer) {
if (buffer.capacity() == 8192) {
bufferPool.release(buffer);
} else {
buffer.release();
}
}
}
}
实战案例:百万连接的消息推送系统
查看代码
public class MassivePushServer {
private final int port;
private final ConcurrentHashMap<String, Channel> channels = new ConcurrentHashMap<>();
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(4);
private final AtomicLong connectionCount = new AtomicLong(0);
public MassivePushServer(int port) {
this.port = port;
}
public void start() throws InterruptedException {
// 针对大量连接优化的EventLoopGroup配置
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup(Runtime.getRuntime().availableProcessors() * 2);
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
// 心跳检测
pipeline.addLast(new IdleStateHandler(60, 30, 0));
pipeline.addLast(new HeartbeatHandler());
// 消息处理
pipeline.addLast(new MessageDecoder());
pipeline.addLast(new MessageEncoder());
pipeline.addLast(new PushServerHandler(channels, connectionCount));
}
})
// 优化TCP参数
.option(ChannelOption.SO_BACKLOG, 8192)
.option(ChannelOption.SO_REUSEADDR, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_RCVBUF, 32768)
.childOption(ChannelOption.SO_SNDBUF, 32768)
.childOption(ChannelOption.WRITE_BUFFER_WATER_MARK,
new WriteBufferWaterMark(32768, 65536));
ChannelFuture future = bootstrap.bind(port).sync();
logger.info("推送服务器启动,监听端口: {}", port);
// 启动监控任务
startMonitoring();
future.channel().closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
// 批量推送消息
public void pushMessage(String message, Set<String> userIds) {
if (userIds.isEmpty()) return;
ByteBuf messageBuffer = Unpooled.wrappedBuffer(message.getBytes());
// 分批处理,避免单次处理过多连接
int batchSize = 1000;
List<String> userIdList = new ArrayList<>(userIds);
for (int i = 0; i < userIdList.size(); i += batchSize) {
int end = Math.min(i + batchSize, userIdList.size());
List<String> batch = userIdList.subList(i, end);
// 异步处理每个批次
scheduler.execute(() -> {
for (String userId : batch) {
Channel channel = channels.get(userId);
if (channel != null && channel.isActive()) {
// 检查通道可写性
if (channel.isWritable()) {
channel.writeAndFlush(messageBuffer.retainedDuplicate());
} else {
// 通道阻塞,记录日志或采取其他措施
logger.warn("Channel {} is not writable, skipping message", userId);
}
}
}
});
}
// 释放原始消息缓冲区
messageBuffer.release();
}
// 系统监控
private void startMonitoring() {
scheduler.scheduleAtFixedRate(() -> {
long currentConnections = connectionCount.get();
long heapMemory = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
long directMemory = PlatformDependent.usedDirectMemory();
logger.info("连接数: {}, 堆内存: {}MB, 直接内存: {}MB",
currentConnections,
heapMemory / 1024 / 1024,
directMemory / 1024 / 1024);
// 内存使用率过高时的处理
if (directMemory > 1024 * 1024 * 1024) { // 1GB
System.gc(); // 建议GC回收直接内存
}
}, 10, 10, TimeUnit.SECONDS);
}
}
public class PushServerHandler extends SimpleChannelInboundHandler<Message> {
private final ConcurrentHashMap<String, Channel> channels;
private final AtomicLong connectionCount;
public PushServerHandler(ConcurrentHashMap<String, Channel> channels,
AtomicLong connectionCount) {
this.channels = channels;
this.connectionCount = connectionCount;
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
connectionCount.incrementAndGet();
logger.debug("新连接建立: {}", ctx.channel().remoteAddress());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
connectionCount.decrementAndGet();
// 从连接映射中移除
String userId = getUserId(ctx.channel());
if (userId != null) {
channels.remove(userId);
}
logger.debug("连接断开: {}", ctx.channel().remoteAddress());
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Message msg) {
if (msg.getType() == MessageType.AUTH) {
// 用户认证
String userId = authenticate(msg);
if (userId != null) {
channels.put(userId, ctx.channel());
ctx.channel().attr(AttributeKey.valueOf("userId")).set(userId);
// 发送认证成功消息
ctx.writeAndFlush(new Message(MessageType.AUTH_SUCCESS, "认证成功"));
} else {
ctx.writeAndFlush(new Message(MessageType.AUTH_FAILED, "认证失败"));
ctx.close();
}
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof IdleStateEvent) {
IdleStateEvent event = (IdleStateEvent) evt;
if (event.state() == IdleState.READER_IDLE) {
// 读超时,关闭连接
logger.warn("连接读超时,关闭连接: {}", ctx.channel().remoteAddress());
ctx.close();
}
}
}
private String authenticate(Message msg) {
// 实现用户认证逻辑
return "user123"; // 简化实现
}
private String getUserId(Channel channel) {
Attribute<String> attr = channel.attr(AttributeKey.valueOf("userId"));
return attr.get();
}
}
长连接并发测试
对于推送服务等需要维持长连接的场景
查看代码
import java.io.*;
import java.net.*;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* 并发连接数模拟测试工具
* 用于测试服务器在大量并发连接下的性能表现
*/
public class ConnectionSimulator {
private final String serverHost;
private final int serverPort;
private final int targetConnections;
private final AtomicInteger activeConnections = new AtomicInteger(0);
private final AtomicInteger failedConnections = new AtomicInteger(0);
private final AtomicLong totalResponseTime = new AtomicLong(0);
private final AtomicLong totalRequests = new AtomicLong(0);
private final ExecutorService executorService;
private volatile boolean running = true;
public ConnectionSimulator(String serverHost, int serverPort, int targetConnections) {
this.serverHost = serverHost;
this.serverPort = serverPort;
this.targetConnections = targetConnections;
this.executorService = Executors.newFixedThreadPool(
Math.min(targetConnections / 100, 200) // 限制线程数
);
}
/**
* 模拟长连接场景 - 如WebSocket、推送服务等
*/
public void simulateLongConnections() throws InterruptedException {
System.out.println("开始模拟 " + targetConnections + " 个长连接...");
// 分批建立连接,避免瞬间压力过大
int batchSize = 100;
int delay = 10; // 每批间隔10ms
for (int i = 0; i < targetConnections; i += batchSize) {
int currentBatch = Math.min(batchSize, targetConnections - i);
for (int j = 0; j < currentBatch; j++) {
executorService.submit(this::createLongConnection);
}
Thread.sleep(delay);
// 打印进度
if (i % 1000 == 0) {
System.out.println("已发起连接: " + i + ", 活跃连接: " + activeConnections.get());
}
}
// 等待所有连接建立完成
Thread.sleep(5000);
System.out.println("连接建立完成 - 活跃连接: " + activeConnections.get() +
", 失败连接: " + failedConnections.get());
// 保持连接并定期发送心跳
startHeartbeat();
// 运行测试
Thread.sleep(60000); // 运行1分钟
shutdown();
}
private void createLongConnection() {
try {
Socket socket = new Socket();
socket.setKeepAlive(true);
socket.setTcpNoDelay(true);
socket.connect(new InetSocketAddress(serverHost, serverPort), 5000);
activeConnections.incrementAndGet();
// 保持连接活跃
maintainConnection(socket);
} catch (IOException e) {
failedConnections.incrementAndGet();
System.err.println("连接失败: " + e.getMessage());
}
}
private void maintainConnection(Socket socket) {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
PrintWriter writer = new PrintWriter(socket.getOutputStream(), true);
// 发送认证消息
writer.println("AUTH:user" + Thread.currentThread().getId());
// 保持连接直到测试结束
while (running && !socket.isClosed()) {
// 模拟接收服务器推送消息
if (reader.ready()) {
String message = reader.readLine();
if (message != null) {
// 处理接收到的消息
processMessage(message);
}
}
Thread.sleep(100); // 短暂休眠
}
} catch (Exception e) {
// 连接异常,减少活跃连接数
activeConnections.decrementAndGet();
} finally {
try {
socket.close();
} catch (IOException e) {
// 忽略关闭异常
}
}
}
/**
* 模拟短连接高并发场景 - 如HTTP请求
*/
public void simulateShortConnections(int duration) throws InterruptedException {
System.out.println("开始模拟短连接高并发测试,持续时间: " + duration + "秒");
long startTime = System.currentTimeMillis();
long endTime = startTime + duration * 1000L;
// 创建多个工作线程
int threadCount = Math.min(targetConnections / 10, 50);
CountDownLatch latch = new CountDownLatch(threadCount);
for (int i = 0; i < threadCount; i++) {
executorService.submit(() -> {
try {
while (System.currentTimeMillis() < endTime && running) {
performRequest();
Thread.sleep(1); // 控制请求频率
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
latch.countDown();
}
});
}
// 启动统计线程
startStatistics();
// 等待测试完成
latch.await();
// 输出最终统计
printFinalStatistics(duration);
}
private void performRequest() {
long requestStart = System.currentTimeMillis();
try (Socket socket = new Socket(serverHost, serverPort);
PrintWriter writer = new PrintWriter(socket.getOutputStream(), true);
BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()))) {
// 发送请求
writer.println("GET /api/test HTTP/1.1");
writer.println("Host: " + serverHost);
writer.println("Connection: close");
writer.println();
// 读取响应
String response = reader.readLine();
if (response != null) {
long responseTime = System.currentTimeMillis() - requestStart;
totalResponseTime.addAndGet(responseTime);
totalRequests.incrementAndGet();
}
} catch (IOException e) {
failedConnections.incrementAndGet();
}
}
private void startHeartbeat() {
ScheduledExecutorService heartbeatExecutor = Executors.newScheduledThreadPool(1);
heartbeatExecutor.scheduleAtFixedRate(() -> {
// 这里可以实现心跳逻辑
System.out.println("心跳检测 - 活跃连接: " + activeConnections.get());
}, 10, 10, TimeUnit.SECONDS);
}
private void startStatistics() {
ScheduledExecutorService statsExecutor = Executors.newScheduledThreadPool(1);
statsExecutor.scheduleAtFixedRate(() -> {
long currentRequests = totalRequests.get();
long currentResponseTime = totalResponseTime.get();
double avgResponseTime = currentRequests > 0 ?
(double) currentResponseTime / currentRequests : 0;
System.out.println(String.format(
"实时统计 - 请求数: %d, 失败数: %d, 平均响应时间: %.2fms",
currentRequests, failedConnections.get(), avgResponseTime
));
}, 5, 5, TimeUnit.SECONDS);
}
private void processMessage(String message) {
// 模拟处理服务器推送的消息
totalRequests.incrementAndGet();
}
private void printFinalStatistics(int duration) {
long finalRequests = totalRequests.get();
long finalResponseTime = totalResponseTime.get();
int finalFailures = failedConnections.get();
double qps = (double) finalRequests / duration;
double avgResponseTime = finalRequests > 0 ?
(double) finalResponseTime / finalRequests : 0;
double successRate = finalRequests > 0 ?
(double) (finalRequests - finalFailures) / finalRequests * 100 : 0;
System.out.println("\n=== 最终测试结果 ===");
System.out.println("测试时长: " + duration + "秒");
System.out.println("总请求数: " + finalRequests);
System.out.println("失败请求数: " + finalFailures);
System.out.println("QPS: " + String.format("%.2f", qps));
System.out.println("平均响应时间: " + String.format("%.2f", avgResponseTime) + "ms");
System.out.println("成功率: " + String.format("%.2f", successRate) + "%");
System.out.println("峰值并发连接数: " + activeConnections.get());
}
public void shutdown() {
running = false;
executorService.shutdown();
try {
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
}
}
public static void main(String[] args) throws InterruptedException {
if (args.length < 4) {
System.out.println("用法: java ConnectionSimulator <host> <port> <connections> <test_type>");
System.out.println("test_type: long (长连接测试) 或 short (短连接测试)");
return;
}
String host = args[0];
int port = Integer.parseInt(args[1]);
int connections = Integer.parseInt(args[2]);
String testType = args[3];
ConnectionSimulator simulator = new ConnectionSimulator(host, port, connections);
if ("long".equals(testType)) {
simulator.simulateLongConnections();
} else if ("short".equals(testType)) {
simulator.simulateShortConnections(60); // 测试60秒
}
}
}
/**
* 使用NIO实现的高性能连接测试工具
*/
class NIOConnectionSimulator {
private final String serverHost;
private final int serverPort;
private final int targetConnections;
private final AtomicInteger activeConnections = new AtomicInteger(0);
private final AtomicLong totalRequests = new AtomicLong(0);
private volatile boolean running = true;
public NIOConnectionSimulator(String serverHost, int serverPort, int targetConnections) {
this.serverHost = serverHost;
this.serverPort = serverPort;
this.targetConnections = targetConnections;
}
public void startNIOTest() throws IOException, InterruptedException {
Selector selector = Selector.open();
// 创建大量连接
for (int i = 0; i < targetConnections; i++) {
SocketChannel channel = SocketChannel.open();
channel.configureBlocking(false);
channel.connect(new InetSocketAddress(serverHost, serverPort));
channel.register(selector, SelectionKey.OP_CONNECT);
if (i % 1000 == 0) {
System.out.println("已创建连接: " + i);
}
}
// 事件循环
while (running) {
int readyChannels = selector.select(1000);
if (readyChannels == 0) continue;
var selectedKeys = selector.selectedKeys();
var iterator = selectedKeys.iterator();
while (iterator.hasNext()) {
SelectionKey key = iterator.next();
iterator.remove();
if (key.isConnectable()) {
handleConnect(key);
} else if (key.isReadable()) {
handleRead(key);
} else if (key.isWritable()) {
handleWrite(key);
}
}
}
selector.close();
}
private void handleConnect(SelectionKey key) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
if (channel.finishConnect()) {
activeConnections.incrementAndGet();
key.interestOps(SelectionKey.OP_READ);
// 发送测试数据
ByteBuffer buffer = ByteBuffer.wrap("TEST_MESSAGE".getBytes());
channel.write(buffer);
}
}
private void handleRead(SelectionKey key) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
ByteBuffer buffer = ByteBuffer.allocate(1024);
int bytesRead = channel.read(buffer);
if (bytesRead > 0) {
totalRequests.incrementAndGet();
// 继续发送数据测试
key.interestOps(SelectionKey.OP_WRITE);
} else if (bytesRead == -1) {
// 连接关闭
activeConnections.decrementAndGet();
key.cancel();
channel.close();
}
}
private void handleWrite(SelectionKey key) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
ByteBuffer buffer = ByteBuffer.wrap("PING".getBytes());
channel.write(buffer);
key.interestOps(SelectionKey.OP_READ);
}
}
QPS (每秒查询数) 计算方法
查看代码
// 1. 基本计算公式 QPS = 总请求数 / 测试时间(秒)
// 2. 实时QPS计算
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
/**
* QPS实时计算与统计工具
* 提供多种QPS计算方法和实时监控
*/
public class QPSCalculator {
/**
* 基于时间窗口的QPS计算器
*/
public static class SlidingWindowQPS {
private final int windowSize; // 时间窗口大小(秒)
private final AtomicLong[] counters; // 每秒的计数器
private final AtomicLong currentSecond = new AtomicLong(System.currentTimeMillis() / 1000);
private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
public SlidingWindowQPS(int windowSize) {
this.windowSize = windowSize;
this.counters = new AtomicLong[windowSize];
for (int i = 0; i < windowSize; i++) {
counters[i] = new AtomicLong(0);
}
// 启动定时任务,每秒更新一次
startScheduler();
}
private void startScheduler() {
scheduler.scheduleAtFixedRate(() -> {
long newSecond = System.currentTimeMillis() / 1000;
long oldSecond = currentSecond.get();
if (newSecond > oldSecond) {
// 清理过期的计数器
for (long i = oldSecond + 1; i <= newSecond; i++) {
int index = (int) (i % windowSize);
counters[index].set(0);
}
currentSecond.set(newSecond);
}
}, 0, 1, TimeUnit.SECONDS);
}
public void increment() {
long second = System.currentTimeMillis() / 1000;
int index = (int) (second % windowSize);
counters[index].incrementAndGet();
}
public long getCurrentQPS() {
long total = 0;
for (AtomicLong counter : counters) {
total += counter.get();
}
return total / windowSize;
}
public void shutdown() {
scheduler.shutdown();
}
}
/**
* 基于固定时间间隔的QPS计算器
*/
public static class IntervalQPS {
private volatile long lastTimestamp = System.currentTimeMillis();
private volatile long lastCount = 0;
private final LongAdder totalCount = new LongAdder();
private final int intervalMs; // 统计间隔(毫秒)
public IntervalQPS(int intervalMs) {
this.intervalMs = intervalMs;
}
public void increment() {
totalCount.increment();
}
public synchronized double getCurrentQPS() {
long currentTime = System.currentTimeMillis();
long currentCount = totalCount.sum();
if (currentTime - lastTimestamp >= intervalMs) {
long timeDiff = currentTime - lastTimestamp;
long countDiff = currentCount - lastCount;
double qps = (double) countDiff / timeDiff * 1000; // 转换为每秒
lastTimestamp = currentTime;
lastCount = currentCount;
return qps;
}
return 0; // 间隔时间不足,返回0
}
}
/**
* 综合性能统计器
*/
public static class PerformanceMonitor {
private final LongAdder totalRequests = new LongAdder();
private final LongAdder totalErrors = new LongAdder();
private final LongAdder totalResponseTime = new LongAdder();
private final AtomicLong maxResponseTime = new AtomicLong(0);
private final AtomicLong minResponseTime = new AtomicLong(Long.MAX_VALUE);
private final SlidingWindowQPS qpsCalculator = new SlidingWindowQPS(60); // 60秒窗口
private final ScheduledExecutorService monitor = Executors.newScheduledThreadPool(1);
private volatile long startTime = System.currentTimeMillis();
public void recordRequest(long responseTime, boolean success) {
totalRequests.increment();
qpsCalculator.increment();
if (success) {
totalResponseTime.add(responseTime);
// 更新最大最小响应时间
updateMaxResponseTime(responseTime);
updateMinResponseTime(responseTime);
} else {
totalErrors.increment();
}
}
private void updateMaxResponseTime(long responseTime) {
long currentMax = maxResponseTime.get();
while (responseTime > currentMax) {
if (maxResponseTime.compareAndSet(currentMax, responseTime)) {
break;
}
currentMax = maxResponseTime.get();
}
}
private void updateMinResponseTime(long responseTime) {
long currentMin = minResponseTime.get();
while (responseTime < currentMin) {
if (minResponseTime.compareAndSet(currentMin, responseTime)) {
break;
}
currentMin = minResponseTime.get();
}
}
public void startMonitoring() {
monitor.scheduleAtFixedRate(() -> {
printStatistics();
}, 5, 5, TimeUnit.SECONDS);
}
private void printStatistics() {
long totalReqs = totalRequests.sum();
long totalErrs = totalErrors.sum();
long totalRespTime = totalResponseTime.sum();
double avgResponseTime = totalReqs > 0 ? (double) totalRespTime / totalReqs : 0;
double errorRate = totalReqs > 0 ? (double) totalErrs / totalReqs * 100 : 0;
long currentQPS = qpsCalculator.getCurrentQPS();
long elapsedTime = (System.currentTimeMillis() - startTime) / 1000;
double overallQPS = elapsedTime > 0 ? (double) totalReqs / elapsedTime : 0;
System.out.println("=== 性能统计 ===");
System.out.println("总请求数: " + totalReqs);
System.out.println("错误数: " + totalErrs);
System.out.println("错误率: " + String.format("%.2f%%", errorRate));
System.out.println("当前QPS: " + currentQPS);
System.out.println("平均QPS: " + String.format("%.2f", overallQPS));
System.out.println("平均响应时间: " + String.format("%.2fms", avgResponseTime));
System.out.println("最大响应时间: " + maxResponseTime.get() + "ms");
System.out.println("最小响应时间: " + (minResponseTime.get() == Long.MAX_VALUE ? 0 : minResponseTime.get()) + "ms");
System.out.println("运行时间: " + elapsedTime + "秒");
System.out.println("==================");
}
public void shutdown() {
qpsCalculator.shutdown();
monitor.shutdown();
}
}
/**
* 负载测试工具
*/
public static class LoadTester {
private final String serverHost;
private final int serverPort;
private final int threadCount;
private final int testDuration; // 测试持续时间(秒)
private final PerformanceMonitor monitor = new PerformanceMonitor();
private final ExecutorService executor;
public LoadTester(String serverHost, int serverPort, int threadCount, int testDuration) {
this.serverHost = serverHost;
this.serverPort = serverPort;
this.threadCount = threadCount;
this.testDuration = testDuration;
this.executor = Executors.newFixedThreadPool(threadCount);
}
public void runTest() throws InterruptedException {
System.out.println("开始负载测试...");
System.out.println("目标服务器: " + serverHost + ":" + serverPort);
System.out.println("线程数: " + threadCount);
System.out.println("测试时长: " + testDuration + "秒");
monitor.startMonitoring();
long endTime = System.currentTimeMillis() + testDuration * 1000L;
CountDownLatch latch = new CountDownLatch(threadCount);
// 启动工作线程
for (int i = 0; i < threadCount; i++) {
executor.submit(() -> {
try {
while (System.currentTimeMillis() < endTime) {
performRequest();
// 控制请求频率,避免过度压测
Thread.sleep(1);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
latch.countDown();
}
});
}
// 等待测试完成
latch.await();
// 停止监控
monitor.shutdown();
executor.shutdown();
System.out.println("负载测试完成!");
}
private void performRequest() {
long startTime = System.currentTimeMillis();
boolean success = false;
try {
// 模拟HTTP请求
java.net.Socket socket = new java.net.Socket(serverHost, serverPort);
socket.setSoTimeout(5000); // 5秒超时
java.io.PrintWriter writer = new java.io.PrintWriter(socket.getOutputStream(), true);
java.io.BufferedReader reader = new java.io.BufferedReader(
new java.io.InputStreamReader(socket.getInputStream())
);
// 发送HTTP请求
writer.println("GET /api/test HTTP/1.1");
writer.println("Host: " + serverHost);
writer.println("Connection: close");
writer.println();
// 读取响应
String response = reader.readLine();
success = response != null && response.contains("200");
socket.close();
} catch (Exception e) {
success = false;
} finally {
long responseTime = System.currentTimeMillis() - startTime;
monitor.recordRequest(responseTime, success);
}
}
}
public static void main(String[] args) throws InterruptedException {
if (args.length < 4) {
System.out.println("用法: java QPSCalculator <host> <port> <threads> <duration>");
return;
}
String host = args[0];
int port = Integer.parseInt(args[1]);
int threads = Integer.parseInt(args[2]);
int duration = Integer.parseInt(args[3]);
LoadTester tester = new LoadTester(host, port, threads, duration);
tester.runTest();
}
}

浙公网安备 33010602011771号