Java通过共享内存调用机器学习python代码示例
背景:弱鸡的java不支持丰富的机器学习算法。
需求:python实现了一个bert分类,希望给java代码调用。因此,使用共享内存的方式实现跨进程调用。
在进程间通信(IPC)中,性能是一个重要的考虑因素。以下是几种常见的IPC方式及其性能比较:
1. Socket通信:
- 优点: 跨网络和本地通信都可以使用,灵活性高。
- 缺点: 相对较慢,因为涉及网络协议栈的开销。
- 适用场景: 分布式系统、跨主机通信。
2. 管道(Pipes):
- 优点: 简单、易用,适合父子进程间通信。
- 缺点: 只能用于单向通信,且仅限于同一主机。
- 适用场景: 父子进程间的简单数据传输。
3. 共享内存(Shared Memory):
- 优点: 速度最快,因为数据直接在内存中共享。
- 缺点: 需要同步机制(如信号量)来避免竞争条件,编程复杂度较高。
- 适用场景: 高性能需求的进程间通信。
4. 消息队列(Message Queues):
- 优点: 支持消息的有序传递和优先级。
- 缺点: 相对较慢,适合中小规模数据传输。
- 适用场景: 需要消息排队和优先级的场景。
5. 信号量(Semaphores):
- 优点: 用于进程间的同步和互斥。
- 缺点: 只适用于同步,不适合大数据传输。
- 适用场景: 进程间的同步和资源管理。
性能比较
- 共享内存通常是最快的,因为它避免了数据在内核和用户空间之间的拷贝,直接在内存中共享数据。
- 管道和消息队列的性能次之,因为它们需要在内核和用户空间之间进行数据拷贝。
- Socket通信的性能最慢,尤其是跨网络通信时,因为涉及网络协议栈的开销。
选择建议
- 如果你需要在同一主机上的进程间进行高性能通信,共享内存是最佳选择。
- 如果你需要简单的父子进程间通信,管道是一个不错的选择。
- 如果你需要跨主机通信,Socket是唯一的选择。
Java代码:
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
public class Main {
private final static String FILE_PATH = "D:\\shared_memory.bin3";
private final static int FILE_SIZE = 1024*1024*1024 + 1; // 1MB for text, 1 byte for lock
public static void main(String[] args) {
try (RandomAccessFile memoryFile = new RandomAccessFile(FILE_PATH, "rw");
FileChannel fileChannel = memoryFile.getChannel()) {
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, FILE_SIZE);
String textToClassify = "什么是sql注入?";
// Wait until the buffer is available for writing
while (buffer.get(0) != 0) {
System.out.println("Waiting for buffer to be available for writing...");
Thread.sleep(100);
}
// Write the text to the buffer
buffer.put(0, (byte) 1); // Mark as written
byte[] textBytes = textToClassify.getBytes(StandardCharsets.UTF_8);
buffer.position(1);
buffer.put(textBytes);
// Pad the remaining part with 0x00 (null character in C, '\0' character)
buffer.position(1 + textBytes.length);
buffer.put(new byte[1024 - textBytes.length]);
System.out.println("Written to memory: " + textToClassify);
// Polling for result
byte[] data = new byte[1024];
String classificationResult = "";
while (true) {
if (buffer.get(0) == 0) { // Check if buffer is available for writing
buffer.position(1);
buffer.get(data);
classificationResult = new String(data, StandardCharsets.UTF_8).trim();
System.out.println("Classification result received: " + classificationResult);
break;
}
Thread.sleep(1); // Sleep for a short time before polling again
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
python代码:
# server.py
import json
import mmap
import time
import os
# from bert_classify import predict
# from sec_tool_rag import search_sectool_knowledge_base
def classify_text(text):
return 100
# predicted_label, prediction = predict(text)
# return predicted_label
FILE_PATH = 'D:\\shared_memory.bin3'
FILE_SIZE = 1024*1024*1024 + 1 # 1MB bytes for text, 1 byte for lock
if not os.path.exists(FILE_PATH):
with open(FILE_PATH, 'w+b') as f:
f.write(b'\x00' * FILE_SIZE)
with open(FILE_PATH, 'r+b') as f:
mm = mmap.mmap(f.fileno(), FILE_SIZE)
while True:
mm.seek(0)
lock = mm.read_byte()
if lock == 1: # Data written and ready to be read
mm.seek(1)
message = mm.read(1024).decode('utf-8').rstrip('\x00').strip()
print(f"Received message: {message}")
# tools = search_sectool_knowledge_base(message, topk=3)
# print(f"Sec tools for '{message}': {tools}")
label = classify_text(message)
print(f"Classified label: {label}")
# str2write = json.dumps(tools)
str2write = str(label)
# Write the classification result
mm.seek(1)
mm.write(str2write.encode().ljust(1024, b'\x00'))
mm.seek(0)
mm.write_byte(0) # Mark as available for writing
print("Result written and memory marked as available for writing")
time.sleep(0.001)
启动python程序,再启动java代码:
输出
Received message: 什么是sql注入? Classified label: 100 Result written and memory marked as available for writing ------------------------------------------ Written to memory: 什么是sql注入? Classification result received: 100

浙公网安备 33010602011771号