kotlin中的netty
一、kotlin中实现echo server
1.新建maven项目

2.最终项目结构

pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.kt</groupId>
<artifactId>kt-netty</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<kotlin.code.style>official</kotlin.code.style>
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
<kotlin.version>1.9.24</kotlin.version>
</properties>
<repositories>
<repository>
<id>mavenCentral</id>
<url>https://repo1.maven.org/maven2/</url>
</repository>
</repositories>
<build>
<sourceDirectory>src/main/kotlin</sourceDirectory>
<testSourceDirectory>src/test/kotlin</testSourceDirectory>
<plugins>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>${kotlin.version}</version>
<executions>
<execution>
<id>compile</id>
<phase>compile</phase>
</execution>
<execution>
<id>test-compile</id>
<phase>test-compile</phase>
</execution>
</executions>
<configuration>
<jvmTarget>1.8</jvmTarget>
</configuration>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.2</version>
</plugin>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<version>2.22.2</version>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>1.6.0</version>
<configuration>
<mainClass>MainKt</mainClass>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-test-junit5</artifactId>
<version>2.0.20</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.10.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.119.Final</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-test</artifactId>
<version>${kotlin.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
3.代码
com.kt.AppMain.kt
package com.kt
import com.kt.server.NettyServer
fun main(args: Array<String>) {
//启动服务器
val server =NettyServer(8081)
server.start()
}
com.kt.server.NettyServer.kt
package com.kt.client
import io.netty.bootstrap.Bootstrap
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.string.StringDecoder
import io.netty.handler.codec.string.StringEncoder
import io.netty.handler.timeout.IdleStateEvent
import io.netty.handler.timeout.IdleStateHandler
import io.netty.util.concurrent.DefaultThreadFactory
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
class NettyClient(
private val host: String,
private val port: Int,
val username: String? = null,
val password: String? = null,
private val clientId: String = UUID.randomUUID().toString()
) {
private val workerGroup = NioEventLoopGroup(
0, // 默认线程数
DefaultThreadFactory("NettyClient")
)
private val bootstrap = Bootstrap()
private var channel: Channel? = null
private val isConnected = AtomicBoolean(false)
private val isAuthenticated = AtomicBoolean(false)
private var authToken: String? = null
private var reconnectAttempts = 0
private val maxReconnectAttempts = 5
private var reconnectDelay = 5000L // 5秒
init {
configureBootstrap()
}
/**
* 配置Bootstrap
*/
private fun configureBootstrap() {
bootstrap.group(workerGroup)
.channel(NioSocketChannel::class.java)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(object : ChannelInitializer<NioSocketChannel>() {
override fun initChannel(ch: NioSocketChannel) {
ch.pipeline().apply {
addLast("decoder", StringDecoder())
addLast("encoder", StringEncoder())
// 添加心跳检测处理器 (30秒没有写操作则触发)
addLast("idleStateHandler", IdleStateHandler(0, 30, 0, TimeUnit.SECONDS))
addLast("clientHandler", ClientHandler())
}
}
})
}
/**
* 连接到服务器并自动认证
*/
fun connect(): Boolean {
return try {
val channelFuture = bootstrap.connect(host, port).sync()
channel = channelFuture.channel()
isConnected.set(channel?.isActive ?: false)
if (isConnected.get()) {
reconnectAttempts = 0 // 重置重连次数
println("Connected to server $host:$port with client ID: $clientId")
// 如果提供了用户名和密码,自动进行认证
if (username != null && password != null) {
return authenticate(username, password)
}
}
isConnected.get()
} catch (e: Exception) {
println("Failed to connect to server: ${e.message}")
handleReconnect()
false
}
}
/**
* 手动认证
*/
fun authenticate(username: String, password: String): Boolean {
if (!isConnected.get()) {
println("Not connected to server")
return false
}
val authCommand = "LOGIN $username $password"
send(authCommand)
// 等待认证结果(实际应用中应该使用回调或Future机制)
var attempts = 0
while (!isAuthenticated.get() && attempts < 10) {
Thread.sleep(500)
attempts++
}
return isAuthenticated.get()
}
/**
* 异步连接到服务器
*/
fun connectAsync(callback: ((Boolean) -> Unit)? = null) {
bootstrap.connect(host, port).addListener { future ->
if (future.isSuccess) {
channel = future.get() as Channel
isConnected.set(true)
reconnectAttempts = 0
println("Connected to server $host:$port with client ID: $clientId")
// 如果提供了用户名和密码,自动进行认证
if (username != null && password != null) {
authenticate(username, password)
}
callback?.invoke(true)
} else {
println("Failed to connect to server: ${future.cause().message}")
callback?.invoke(false)
handleReconnect()
}
}
}
/**
* 发送消息
*/
fun send(message: String): Boolean {
return if (isConnected.get() && channel?.isActive == true) {
channel?.writeAndFlush(message)
true
} else {
println("Not connected to server, cannot send message")
false
}
}
/**
* 断开连接
*/
fun disconnect() {
try {
channel?.close()?.sync()
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
println("Disconnected from server")
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
}
}
/**
* 关闭客户端
*/
fun close() {
try {
channel?.close()?.sync()
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
} finally {
workerGroup.shutdownGracefully(2, 5, TimeUnit.SECONDS).sync()
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
println("Client closed")
}
}
/**
* 检查连接状态
*/
fun isConnected(): Boolean = isConnected.get()
/**
* 检查认证状态
*/
fun isAuthenticated(): Boolean = isAuthenticated.get()
/**
* 获取客户端ID
*/
fun getClientId(): String = clientId
/**
* 获取认证Token
*/
fun getAuthToken(): String? = authToken
/**
* 处理重连逻辑
*/
private fun handleReconnect() {
if (reconnectAttempts < maxReconnectAttempts) {
reconnectAttempts++
println("Attempting to reconnect... ($reconnectAttempts/$maxReconnectAttempts)")
// 延迟重连
Timer().schedule(object : TimerTask() {
override fun run() {
connect()
}
}, reconnectDelay)
} else {
println("Max reconnect attempts reached. Giving up.")
}
}
/**
* 客户端处理器
*/
inner class ClientHandler : ChannelInboundHandlerAdapter() {
override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
val message = msg?.toString() ?: return
println("Received from server: $message")
// 处理认证响应
when {
message.startsWith("AUTH_SUCCESS:") -> {
authToken = message.substringAfter("AUTH_SUCCESS:").trim()
isAuthenticated.set(true)
println("Authentication successful! Token: $authToken")
}
message.startsWith("AUTH_FAILED:") -> {
val errorMsg = message.substringAfter("AUTH_FAILED:").trim()
isAuthenticated.set(false)
println("Authentication failed: $errorMsg")
}
message.startsWith("AUTH_REQUIRED:") -> {
println("Authentication required: $message")
}
message == "pong" || message == "heartbeat" -> {
println("Heartbeat received from server")
}
else -> {
// 处理其他消息
println("Message: $message")
}
}
}
override fun channelActive(ctx: ChannelHandlerContext?) {
println("Connected to server: ${ctx?.channel()?.remoteAddress()}")
}
override fun channelInactive(ctx: ChannelHandlerContext?) {
println("Disconnected from server: ${ctx?.channel()?.remoteAddress()}")
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
}
override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
if (evt is IdleStateEvent) {
when (evt.state()) {
io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
println("Writer idle, sending ping to server...")
ctx?.writeAndFlush("ping")
}
else -> {
// 其他空闲状态处理
}
}
}
super.userEventTriggered(ctx, evt)
}
override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
cause?.printStackTrace()
println("Exception in client: ${cause?.message}")
ctx?.close()
}
}
}
/**
* 客户端使用示例
*/
fun main() {
// 方式1: 创建客户端时提供认证信息
// val client = NettyClient("127.0.0.1", 8081, "user1", "password1", "Client-${System.currentTimeMillis()}")
// 方式2: 手动认证
val client = NettyClient("127.0.0.1", 8081, null, null, "Client-${System.currentTimeMillis()}")
try {
// 连接服务器
if (client.connect()) {
// 等待认证完成(如果使用自动认证)
if (client.username != null && client.password != null) {
var waitCount = 0
while (!client.isAuthenticated() && waitCount < 20) {
Thread.sleep(500)
waitCount++
}
if (!client.isAuthenticated()) {
println("Authentication timeout")
return
}
}
// 创建Scanner读取用户输入
val scanner = Scanner(System.`in`)
println("Enter messages to send (type 'exit' to quit):")
println("Commands:")
println(" exit - Quit the application")
println(" ping - Send ping to server")
println(" sessions - Get session count")
println(" users - Get authenticated user count")
println(" AUTH <username> <password> - Authenticate manually")
println(" SEND <sessionId> <message> - Send private message")
while (client.isConnected()) {
print("Enter command: ")
val input = scanner.nextLine().trim()
when {
input.lowercase() == "exit" -> {
println("Exiting...")
break
}
input.startsWith("AUTH", true) -> {
val parts = input.split(" ")
if (parts.size == 3) {
val username = parts[1]
val password = parts[2]
client.authenticate(username, password)
} else {
println("Usage: AUTH <username> <password>")
}
}
input.isNotEmpty() -> {
if (client.isAuthenticated() || input.startsWith("LOGIN", true)) {
client.send(input)
println("Sent: $input")
} else {
println("Please authenticate first. Use: AUTH <username> <password>")
}
}
}
// 短暂延迟避免过于频繁的输入
Thread.sleep(100)
}
} else {
println("Failed to connect to server")
}
} catch (e: Exception) {
e.printStackTrace()
} finally {
client.close()
}
}
com.kt.server.handler.AuthHandler.kt
package com.kt.server.handler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import java.util.*
import java.util.concurrent.ConcurrentHashMap
/**
* 认证信息类
*/
data class AuthInfo(
val userId: String,
val username: String,
val token: String,
val loginTime: Long = System.currentTimeMillis()
)
/**
* 认证结果类
*/
data class AuthResult(
val success: Boolean,
val message: String,
val token: String?,
val authInfo: AuthInfo?
)
/**
* 认证管理器
*/
class AuthManager {
// 存储已认证的会话 token -> AuthInfo
private val authenticatedSessions = ConcurrentHashMap<String, AuthInfo>()
// 存储临时连接(未认证) sessionID -> connectTime
private val pendingSessions = ConcurrentHashMap<String, Long>()
// 有效的用户凭证 (实际应用中应该从数据库或配置文件中读取)
private val validCredentials = mapOf(
"user1" to "password1",
"user2" to "password2",
"admin" to "admin123"
)
// 有效的token (实际应用中应该使用JWT或其他token机制)
private val validTokens = mutableSetOf<String>()
/**
* 添加待认证会话
*/
fun addPendingSession(sessionId: String) {
pendingSessions[sessionId] = System.currentTimeMillis()
// 清理过期的待认证会话(超过30秒未认证的连接)
cleanupExpiredPendingSessions()
}
/**
* 验证用户凭证
*/
fun authenticate(username: String, password: String): AuthResult {
return if (validCredentials[username] == password) {
val token = generateToken()
val authInfo = AuthInfo(UUID.randomUUID().toString(), username, token)
validTokens.add(token)
AuthResult(true, "Authentication successful", token, authInfo)
} else {
AuthResult(false, "Invalid username or password", null, null)
}
}
/**
* 验证Token
*/
fun validateToken(token: String): Boolean {
return validTokens.contains(token)
}
/**
* 记录认证成功的会话
*/
fun recordAuthenticatedSession(token: String, authInfo: AuthInfo) {
authenticatedSessions[token] = authInfo
}
/**
* 移除认证会话
*/
fun removeAuthenticatedSession(token: String) {
authenticatedSessions.remove(token)
validTokens.remove(token)
}
/**
* 获取认证信息
*/
fun getAuthInfo(token: String): AuthInfo? {
return authenticatedSessions[token]
}
/**
* 移除待认证会话
*/
fun removePendingSession(sessionId: String) {
pendingSessions.remove(sessionId)
}
/**
* 清理过期的待认证会话
*/
private fun cleanupExpiredPendingSessions() {
val currentTime = System.currentTimeMillis()
val expiredSessions = pendingSessions.filter {
currentTime - it.value > 30000 // 30秒超时
}.keys
expiredSessions.forEach { sessionId ->
pendingSessions.remove(sessionId)
println("Removed expired pending session: $sessionId")
}
}
/**
* 生成Token
*/
private fun generateToken(): String {
return UUID.randomUUID().toString().replace("-", "") +
System.currentTimeMillis().toString().takeLast(6)
}
/**
* 获取当前认证用户数
*/
fun getAuthenticatedUserCount(): Int {
return authenticatedSessions.size
}
}
class AuthHandler(private val authManager: AuthManager) : ChannelInboundHandlerAdapter(){
private var sessionId: String? = null
private var isAuthenticated = false
override fun channelActive(ctx: ChannelHandlerContext?) {
// 生成会话ID并添加到待认证列表
sessionId = UUID.randomUUID().toString()
authManager.addPendingSession(sessionId!!)
println("New connection established, session ID: $sessionId. Waiting for authentication...")
// 发送认证提示
ctx?.writeAndFlush("Please authenticate: LOGIN <username> <password>\n")
super.channelActive(ctx)
}
override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
if (!isAuthenticated) {
// 处理认证消息
handleAuthentication(ctx, msg)
} else {
// 已认证,传递给下一个处理器
ctx?.fireChannelRead(msg)
}
}
private fun handleAuthentication(ctx: ChannelHandlerContext?, msg: Any?) {
val message = msg?.toString()?.trim() ?: return
// 解析认证命令
if (message.startsWith("LOGIN", ignoreCase = true)) {
val parts = message.split(" ")
if (parts.size == 3) {
val username = parts[1]
val password = parts[2]
val authResult = authManager.authenticate(username, password)
if (authResult.success && authResult.token != null && authResult.authInfo != null) {
// 认证成功
isAuthenticated = true
authManager.recordAuthenticatedSession(authResult.token, authResult.authInfo)
authManager.removePendingSession(sessionId!!)
ctx?.writeAndFlush("AUTH_SUCCESS: ${authResult.token}\n")
ctx?.writeAndFlush("Welcome ${authResult.authInfo.username}! You are now authenticated.\n")
println("User $username authenticated successfully with session $sessionId")
} else {
// 认证失败
ctx?.writeAndFlush("AUTH_FAILED: ${authResult.message}\n")
println("Authentication failed for session $sessionId: ${authResult.message}")
}
} else {
ctx?.writeAndFlush("INVALID_COMMAND: Usage: LOGIN <username> <password>\n")
}
} else if (message == "QUIT") {
ctx?.writeAndFlush("Goodbye!\n")
ctx?.close()
} else {
// 未认证状态下收到其他消息
ctx?.writeAndFlush("AUTH_REQUIRED: Please authenticate first. Usage: LOGIN <username> <password>\n")
}
}
override fun channelInactive(ctx: ChannelHandlerContext?) {
// 清理资源
sessionId?.let { id ->
if (!isAuthenticated) {
authManager.removePendingSession(id)
}
}
println("Connection closed for session: $sessionId")
super.channelInactive(ctx)
}
override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
cause?.printStackTrace()
sessionId?.let { id ->
if (!isAuthenticated) {
authManager.removePendingSession(id)
}
}
ctx?.close()
}
}
com.kt.server.handler.ServerHandler.kt
package com.kt.server.handler
import io.netty.channel.Channel
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.handler.timeout.IdleStateEvent
import java.net.InetSocketAddress
import java.util.*
import java.util.concurrent.ConcurrentHashMap
class UserSession(val id:String, val channel: Channel) {
val remoteAddress: InetSocketAddress = channel.remoteAddress() as InetSocketAddress
val connectTime: Long = System.currentTimeMillis()
var lastActiveTime: Long = System.currentTimeMillis()
/**
* 发送消息
*/
fun sendMessage(message: String): Boolean {
return if (channel.isActive) {
channel.writeAndFlush(message)
lastActiveTime = System.currentTimeMillis()
true
} else {
false
}
}
/**
* 检查会话是否活跃
*/
fun isActive(): Boolean {
return channel.isActive
}
/**
* 关闭会话
*/
fun close() {
if (channel.isOpen) {
channel.close()
}
}
override fun toString(): String {
return "UserSession(id='$id', remoteAddress=$remoteAddress, connectTime=$connectTime)"
}
}
/**
* 会话管理器
*/
class SessionManger{
private val sessions = ConcurrentHashMap<String, UserSession>()
/**
* 添加会话
*/
fun addSession(session: UserSession) {
sessions[session.id] = session
println("Session added: ${session.id}, total sessions: ${sessions.size}")
}
/**
* 移除会话
*/
fun removeSession(sessionId: String) {
sessions.remove(sessionId)?.also {
println("Session removed: $sessionId, total sessions: ${sessions.size}")
}
}
/**
* 根据ID获取会话
*/
fun getSession(sessionId: String): UserSession? {
return sessions[sessionId]
}
/**
* 获取所有会话
*/
fun getAllSessions(): Collection<UserSession> {
return sessions.values
}
/**
* 获取会话数量
*/
fun getSessionCount(): Int {
return sessions.size
}
/**
* 广播消息给所有会话
*/
fun broadcast(message: String) {
sessions.values.forEach { session ->
session.sendMessage(message)
}
}
/**
* 发送消息给指定会话
*/
fun sendToSession(sessionId: String, message: String): Boolean {
val session = sessions[sessionId]
return if (session != null) {
session.sendMessage(message)
true
} else {
false
}
}
}
class ServerHandler(private val sessionManager: SessionManger) : ChannelInboundHandlerAdapter() {
private var session: UserSession? = null
override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
session?.let {
it.lastActiveTime = System.currentTimeMillis()
println("Received from client[${it.id}]: $msg")
// 处理特殊命令
when (msg) {
"ping" -> {
// 回复心跳
it.sendMessage("pong")
}
"sessions" -> {
// 返回当前会话数
it.sendMessage("Current sessions: ${sessionManager.getSessionCount()}")
}
else -> {
// 回显消息给客户端
it.sendMessage("Echo: $msg")
}
}
}
}
override fun channelActive(ctx: ChannelHandlerContext?) {
// 创建用户会话
val sessionId = UUID.randomUUID().toString()
session = UserSession(sessionId, ctx!!.channel())
sessionManager.addSession(session!!)
println("Client connected: ${session!!.remoteAddress}, session ID: $sessionId")
session?.sendMessage("Welcome! Your session ID is: $sessionId")
super.channelActive(ctx)
}
override fun channelInactive(ctx: ChannelHandlerContext?) {
session?.let {
println("Client disconnected: ${it.remoteAddress}, session ID: ${it.id}")
sessionManager.removeSession(it.id)
}
super.channelInactive(ctx)
}
override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
if (evt is IdleStateEvent) {
session?.let {
println("Idle event triggered for session ${it.id}: ${evt.state()}")
when (evt.state()) {
// 读空闲 - 客户端长时间未发送消息
io.netty.handler.timeout.IdleState.READER_IDLE -> {
println("Client ${it.id} reader idle, sending ping...")
it.sendMessage("ping")
}
// 写空闲 - 服务器长时间未发送消息
io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
println("Client ${it.id} writer idle, sending heartbeat...")
it.sendMessage("heartbeat")
}
// 所有空闲 - 双方长时间无通信
io.netty.handler.timeout.IdleState.ALL_IDLE -> {
println("Client ${it.id} all idle, closing connection...")
it.sendMessage("Connection timeout, closing...")
it.close()
}
}
}
}
super.userEventTriggered(ctx, evt)
}
override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
cause?.printStackTrace()
session?.let {
println("Exception in session ${it.id}: ${cause?.message}")
sessionManager.removeSession(it.id)
}
ctx?.close()
}
}
3.client端
package com.kt.client
import io.netty.bootstrap.Bootstrap
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.string.StringDecoder
import io.netty.handler.codec.string.StringEncoder
import io.netty.handler.timeout.IdleStateEvent
import io.netty.handler.timeout.IdleStateHandler
import io.netty.util.concurrent.DefaultThreadFactory
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
class NettyClient(
private val host: String,
private val port: Int,
val username: String? = null,
val password: String? = null,
private val clientId: String = UUID.randomUUID().toString()
) {
private val workerGroup = NioEventLoopGroup(
0, // 默认线程数
DefaultThreadFactory("NettyClient")
)
private val bootstrap = Bootstrap()
private var channel: Channel? = null
private val isConnected = AtomicBoolean(false)
private val isAuthenticated = AtomicBoolean(false)
private var authToken: String? = null
private var reconnectAttempts = 0
private val maxReconnectAttempts = 5
private var reconnectDelay = 5000L // 5秒
init {
configureBootstrap()
}
/**
* 配置Bootstrap
*/
private fun configureBootstrap() {
bootstrap.group(workerGroup)
.channel(NioSocketChannel::class.java)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(object : ChannelInitializer<NioSocketChannel>() {
override fun initChannel(ch: NioSocketChannel) {
ch.pipeline().apply {
addLast("decoder", StringDecoder())
addLast("encoder", StringEncoder())
// 添加心跳检测处理器 (30秒没有写操作则触发)
addLast("idleStateHandler", IdleStateHandler(0, 30, 0, TimeUnit.SECONDS))
addLast("clientHandler", ClientHandler())
}
}
})
}
/**
* 连接到服务器并自动认证
*/
fun connect(): Boolean {
return try {
val channelFuture = bootstrap.connect(host, port).sync()
channel = channelFuture.channel()
isConnected.set(channel?.isActive ?: false)
if (isConnected.get()) {
reconnectAttempts = 0 // 重置重连次数
println("Connected to server $host:$port with client ID: $clientId")
// 如果提供了用户名和密码,自动进行认证
if (username != null && password != null) {
return authenticate(username, password)
}
}
isConnected.get()
} catch (e: Exception) {
println("Failed to connect to server: ${e.message}")
handleReconnect()
false
}
}
/**
* 手动认证
*/
fun authenticate(username: String, password: String): Boolean {
if (!isConnected.get()) {
println("Not connected to server")
return false
}
val authCommand = "LOGIN $username $password"
send(authCommand)
// 等待认证结果(实际应用中应该使用回调或Future机制)
var attempts = 0
while (!isAuthenticated.get() && attempts < 10) {
Thread.sleep(500)
attempts++
}
return isAuthenticated.get()
}
/**
* 异步连接到服务器
*/
fun connectAsync(callback: ((Boolean) -> Unit)? = null) {
bootstrap.connect(host, port).addListener { future ->
if (future.isSuccess) {
channel = future.get() as Channel
isConnected.set(true)
reconnectAttempts = 0
println("Connected to server $host:$port with client ID: $clientId")
// 如果提供了用户名和密码,自动进行认证
if (username != null && password != null) {
authenticate(username, password)
}
callback?.invoke(true)
} else {
println("Failed to connect to server: ${future.cause().message}")
callback?.invoke(false)
handleReconnect()
}
}
}
/**
* 发送消息
*/
fun send(message: String): Boolean {
return if (isConnected.get() && channel?.isActive == true) {
channel?.writeAndFlush(message)
true
} else {
println("Not connected to server, cannot send message")
false
}
}
/**
* 断开连接
*/
fun disconnect() {
try {
channel?.close()?.sync()
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
println("Disconnected from server")
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
}
}
/**
* 关闭客户端
*/
fun close() {
try {
channel?.close()?.sync()
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
} finally {
workerGroup.shutdownGracefully(2, 5, TimeUnit.SECONDS).sync()
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
println("Client closed")
}
}
/**
* 检查连接状态
*/
fun isConnected(): Boolean = isConnected.get()
/**
* 检查认证状态
*/
fun isAuthenticated(): Boolean = isAuthenticated.get()
/**
* 获取客户端ID
*/
fun getClientId(): String = clientId
/**
* 获取认证Token
*/
fun getAuthToken(): String? = authToken
/**
* 处理重连逻辑
*/
private fun handleReconnect() {
if (reconnectAttempts < maxReconnectAttempts) {
reconnectAttempts++
println("Attempting to reconnect... ($reconnectAttempts/$maxReconnectAttempts)")
// 延迟重连
Timer().schedule(object : TimerTask() {
override fun run() {
connect()
}
}, reconnectDelay)
} else {
println("Max reconnect attempts reached. Giving up.")
}
}
/**
* 客户端处理器
*/
inner class ClientHandler : ChannelInboundHandlerAdapter() {
override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
val message = msg?.toString() ?: return
println("Received from server: $message")
// 处理认证响应
when {
message.startsWith("AUTH_SUCCESS:") -> {
authToken = message.substringAfter("AUTH_SUCCESS:").trim()
isAuthenticated.set(true)
println("Authentication successful! Token: $authToken")
}
message.startsWith("AUTH_FAILED:") -> {
val errorMsg = message.substringAfter("AUTH_FAILED:").trim()
isAuthenticated.set(false)
println("Authentication failed: $errorMsg")
}
message.startsWith("AUTH_REQUIRED:") -> {
println("Authentication required: $message")
}
message == "pong" || message == "heartbeat" -> {
println("Heartbeat received from server")
}
else -> {
// 处理其他消息
println("Message: $message")
}
}
}
override fun channelActive(ctx: ChannelHandlerContext?) {
println("Connected to server: ${ctx?.channel()?.remoteAddress()}")
}
override fun channelInactive(ctx: ChannelHandlerContext?) {
println("Disconnected from server: ${ctx?.channel()?.remoteAddress()}")
isConnected.set(false)
isAuthenticated.set(false)
authToken = null
}
override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
if (evt is IdleStateEvent) {
when (evt.state()) {
io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
println("Writer idle, sending ping to server...")
ctx?.writeAndFlush("ping")
}
else -> {
// 其他空闲状态处理
}
}
}
super.userEventTriggered(ctx, evt)
}
override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
cause?.printStackTrace()
println("Exception in client: ${cause?.message}")
ctx?.close()
}
}
}
/**
* 客户端使用示例
*/
fun main() {
// 方式1: 创建客户端时提供认证信息
// val client = NettyClient("127.0.0.1", 8081, "user1", "password1", "Client-${System.currentTimeMillis()}")
// 方式2: 手动认证
val client = NettyClient("127.0.0.1", 8081, null, null, "Client-${System.currentTimeMillis()}")
try {
// 连接服务器
if (client.connect()) {
// 等待认证完成(如果使用自动认证)
if (client.username != null && client.password != null) {
var waitCount = 0
while (!client.isAuthenticated() && waitCount < 20) {
Thread.sleep(500)
waitCount++
}
if (!client.isAuthenticated()) {
println("Authentication timeout")
return
}
}
// 创建Scanner读取用户输入
val scanner = Scanner(System.`in`)
println("Enter messages to send (type 'exit' to quit):")
println("Commands:")
println(" exit - Quit the application")
println(" ping - Send ping to server")
println(" sessions - Get session count")
println(" users - Get authenticated user count")
println(" AUTH <username> <password> - Authenticate manually")
println(" SEND <sessionId> <message> - Send private message")
while (client.isConnected()) {
print("Enter command: ")
val input = scanner.nextLine().trim()
when {
input.lowercase() == "exit" -> {
println("Exiting...")
break
}
input.startsWith("AUTH", true) -> {
val parts = input.split(" ")
if (parts.size == 3) {
val username = parts[1]
val password = parts[2]
client.authenticate(username, password)
} else {
println("Usage: AUTH <username> <password>")
}
}
input.isNotEmpty() -> {
if (client.isAuthenticated() || input.startsWith("LOGIN", true)) {
client.send(input)
println("Sent: $input")
} else {
println("Please authenticate first. Use: AUTH <username> <password>")
}
}
}
// 短暂延迟避免过于频繁的输入
Thread.sleep(100)
}
} else {
println("Failed to connect to server")
}
} catch (e: Exception) {
e.printStackTrace()
} finally {
client.close()
}
}
本文来自博客园,作者:一个小笨蛋,转载请注明原文链接:https://www.cnblogs.com/paylove/p/19084340

浙公网安备 33010602011771号