kotlin中的netty

一、kotlin中实现echo server

1.新建maven项目

image

2.最终项目结构

image

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.kt</groupId>
    <artifactId>kt-netty</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <kotlin.code.style>official</kotlin.code.style>
        <kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
        <kotlin.version>1.9.24</kotlin.version>
    </properties>

    <repositories>
        <repository>
            <id>mavenCentral</id>
            <url>https://repo1.maven.org/maven2/</url>
        </repository>
    </repositories>

    <build>
        <sourceDirectory>src/main/kotlin</sourceDirectory>
        <testSourceDirectory>src/test/kotlin</testSourceDirectory>
        <plugins>
            <plugin>
                <groupId>org.jetbrains.kotlin</groupId>
                <artifactId>kotlin-maven-plugin</artifactId>
                <version>${kotlin.version}</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>compile</phase>
                    </execution>
                    <execution>
                        <id>test-compile</id>
                        <phase>test-compile</phase>
                    </execution>
                </executions>
                <configuration>
                    <jvmTarget>1.8</jvmTarget>
                </configuration>
            </plugin>
            <plugin>
                <artifactId>maven-surefire-plugin</artifactId>
                <version>2.22.2</version>
            </plugin>
            <plugin>
                <artifactId>maven-failsafe-plugin</artifactId>
                <version>2.22.2</version>
            </plugin>
            <plugin>
                <groupId>org.codehaus.mojo</groupId>
                <artifactId>exec-maven-plugin</artifactId>
                <version>1.6.0</version>
                <configuration>
                    <mainClass>MainKt</mainClass>
                </configuration>
            </plugin>
        </plugins>
    </build>
    <dependencies>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-test-junit5</artifactId>
            <version>2.0.20</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.junit.jupiter</groupId>
            <artifactId>junit-jupiter</artifactId>
            <version>5.10.0</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-stdlib-jdk8</artifactId>
            <version>${kotlin.version}</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.119.Final</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-test</artifactId>
            <version>${kotlin.version}</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

</project>

3.代码

com.kt.AppMain.kt

package com.kt

import com.kt.server.NettyServer
fun main(args: Array<String>) {
    //启动服务器
    val server =NettyServer(8081)
    server.start()
}

com.kt.server.NettyServer.kt

package com.kt.client

import io.netty.bootstrap.Bootstrap
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.string.StringDecoder
import io.netty.handler.codec.string.StringEncoder
import io.netty.handler.timeout.IdleStateEvent
import io.netty.handler.timeout.IdleStateHandler
import io.netty.util.concurrent.DefaultThreadFactory
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

class NettyClient(
    private val host: String,
    private val port: Int,
    val username: String? = null,
    val password: String? = null,
    private val clientId: String = UUID.randomUUID().toString()
) {
    private val workerGroup = NioEventLoopGroup(
        0, // 默认线程数
        DefaultThreadFactory("NettyClient")
    )
    private val bootstrap = Bootstrap()
    private var channel: Channel? = null
    private val isConnected = AtomicBoolean(false)
    private val isAuthenticated = AtomicBoolean(false)
    private var authToken: String? = null
    private var reconnectAttempts = 0
    private val maxReconnectAttempts = 5
    private var reconnectDelay = 5000L // 5秒

    init {
        configureBootstrap()
    }

    /**
     * 配置Bootstrap
     */
    private fun configureBootstrap() {
        bootstrap.group(workerGroup)
            .channel(NioSocketChannel::class.java)
            .option(ChannelOption.SO_KEEPALIVE, true)
            .option(ChannelOption.TCP_NODELAY, true)
            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
            .handler(object : ChannelInitializer<NioSocketChannel>() {
                override fun initChannel(ch: NioSocketChannel) {
                    ch.pipeline().apply {
                        addLast("decoder", StringDecoder())
                        addLast("encoder", StringEncoder())
                        // 添加心跳检测处理器 (30秒没有写操作则触发)
                        addLast("idleStateHandler", IdleStateHandler(0, 30, 0, TimeUnit.SECONDS))
                        addLast("clientHandler", ClientHandler())
                    }
                }
            })
    }

    /**
     * 连接到服务器并自动认证
     */
    fun connect(): Boolean {
        return try {
            val channelFuture = bootstrap.connect(host, port).sync()
            channel = channelFuture.channel()
            isConnected.set(channel?.isActive ?: false)

            if (isConnected.get()) {
                reconnectAttempts = 0 // 重置重连次数
                println("Connected to server $host:$port with client ID: $clientId")

                // 如果提供了用户名和密码,自动进行认证
                if (username != null && password != null) {
                    return authenticate(username, password)
                }
            }

            isConnected.get()
        } catch (e: Exception) {
            println("Failed to connect to server: ${e.message}")
            handleReconnect()
            false
        }
    }

    /**
     * 手动认证
     */
    fun authenticate(username: String, password: String): Boolean {
        if (!isConnected.get()) {
            println("Not connected to server")
            return false
        }

        val authCommand = "LOGIN $username $password"
        send(authCommand)

        // 等待认证结果(实际应用中应该使用回调或Future机制)
        var attempts = 0
        while (!isAuthenticated.get() && attempts < 10) {
            Thread.sleep(500)
            attempts++
        }

        return isAuthenticated.get()
    }

    /**
     * 异步连接到服务器
     */
    fun connectAsync(callback: ((Boolean) -> Unit)? = null) {
        bootstrap.connect(host, port).addListener { future ->
            if (future.isSuccess) {
                channel = future.get() as Channel
                isConnected.set(true)
                reconnectAttempts = 0
                println("Connected to server $host:$port with client ID: $clientId")

                // 如果提供了用户名和密码,自动进行认证
                if (username != null && password != null) {
                    authenticate(username, password)
                }

                callback?.invoke(true)
            } else {
                println("Failed to connect to server: ${future.cause().message}")
                callback?.invoke(false)
                handleReconnect()
            }
        }
    }

    /**
     * 发送消息
     */
    fun send(message: String): Boolean {
        return if (isConnected.get() && channel?.isActive == true) {
            channel?.writeAndFlush(message)
            true
        } else {
            println("Not connected to server, cannot send message")
            false
        }
    }

    /**
     * 断开连接
     */
    fun disconnect() {
        try {
            channel?.close()?.sync()
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
            println("Disconnected from server")
        } catch (e: InterruptedException) {
            Thread.currentThread().interrupt()
        }
    }

    /**
     * 关闭客户端
     */
    fun close() {
        try {
            channel?.close()?.sync()
        } catch (e: InterruptedException) {
            Thread.currentThread().interrupt()
        } finally {
            workerGroup.shutdownGracefully(2, 5, TimeUnit.SECONDS).sync()
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
            println("Client closed")
        }
    }

    /**
     * 检查连接状态
     */
    fun isConnected(): Boolean = isConnected.get()

    /**
     * 检查认证状态
     */
    fun isAuthenticated(): Boolean = isAuthenticated.get()

    /**
     * 获取客户端ID
     */
    fun getClientId(): String = clientId

    /**
     * 获取认证Token
     */
    fun getAuthToken(): String? = authToken

    /**
     * 处理重连逻辑
     */
    private fun handleReconnect() {
        if (reconnectAttempts < maxReconnectAttempts) {
            reconnectAttempts++
            println("Attempting to reconnect... ($reconnectAttempts/$maxReconnectAttempts)")

            // 延迟重连
            Timer().schedule(object : TimerTask() {
                override fun run() {
                    connect()
                }
            }, reconnectDelay)
        } else {
            println("Max reconnect attempts reached. Giving up.")
        }
    }

    /**
     * 客户端处理器
     */
    inner class ClientHandler : ChannelInboundHandlerAdapter() {
        override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
            val message = msg?.toString() ?: return
            println("Received from server: $message")

            // 处理认证响应
            when {
                message.startsWith("AUTH_SUCCESS:") -> {
                    authToken = message.substringAfter("AUTH_SUCCESS:").trim()
                    isAuthenticated.set(true)
                    println("Authentication successful! Token: $authToken")
                }
                message.startsWith("AUTH_FAILED:") -> {
                    val errorMsg = message.substringAfter("AUTH_FAILED:").trim()
                    isAuthenticated.set(false)
                    println("Authentication failed: $errorMsg")
                }
                message.startsWith("AUTH_REQUIRED:") -> {
                    println("Authentication required: $message")
                }
                message == "pong" || message == "heartbeat" -> {
                    println("Heartbeat received from server")
                }
                else -> {
                    // 处理其他消息
                    println("Message: $message")
                }
            }
        }

        override fun channelActive(ctx: ChannelHandlerContext?) {
            println("Connected to server: ${ctx?.channel()?.remoteAddress()}")
        }

        override fun channelInactive(ctx: ChannelHandlerContext?) {
            println("Disconnected from server: ${ctx?.channel()?.remoteAddress()}")
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
        }

        override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
            if (evt is IdleStateEvent) {
                when (evt.state()) {
                    io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
                        println("Writer idle, sending ping to server...")
                        ctx?.writeAndFlush("ping")
                    }
                    else -> {
                        // 其他空闲状态处理
                    }
                }
            }
            super.userEventTriggered(ctx, evt)
        }

        override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
            cause?.printStackTrace()
            println("Exception in client: ${cause?.message}")
            ctx?.close()
        }
    }
}

/**
 * 客户端使用示例
 */
fun main() {
    // 方式1: 创建客户端时提供认证信息
//    val client = NettyClient("127.0.0.1", 8081, "user1", "password1", "Client-${System.currentTimeMillis()}")

    // 方式2: 手动认证
     val client = NettyClient("127.0.0.1", 8081, null, null, "Client-${System.currentTimeMillis()}")

    try {
        // 连接服务器
        if (client.connect()) {
            // 等待认证完成(如果使用自动认证)
            if (client.username != null && client.password != null) {
                var waitCount = 0
                while (!client.isAuthenticated() && waitCount < 20) {
                    Thread.sleep(500)
                    waitCount++
                }

                if (!client.isAuthenticated()) {
                    println("Authentication timeout")
                    return
                }
            }

            // 创建Scanner读取用户输入
            val scanner = Scanner(System.`in`)
            println("Enter messages to send (type 'exit' to quit):")
            println("Commands:")
            println("  exit - Quit the application")
            println("  ping - Send ping to server")
            println("  sessions - Get session count")
            println("  users - Get authenticated user count")
            println("  AUTH <username> <password> - Authenticate manually")
            println("  SEND <sessionId> <message> - Send private message")

            while (client.isConnected()) {
                print("Enter command: ")
                val input = scanner.nextLine().trim()

                when {
                    input.lowercase() == "exit" -> {
                        println("Exiting...")
                        break
                    }
                    input.startsWith("AUTH", true) -> {
                        val parts = input.split(" ")
                        if (parts.size == 3) {
                            val username = parts[1]
                            val password = parts[2]
                            client.authenticate(username, password)
                        } else {
                            println("Usage: AUTH <username> <password>")
                        }
                    }
                    input.isNotEmpty() -> {
                        if (client.isAuthenticated() || input.startsWith("LOGIN", true)) {
                            client.send(input)
                            println("Sent: $input")
                        } else {
                            println("Please authenticate first. Use: AUTH <username> <password>")
                        }
                    }
                }

                // 短暂延迟避免过于频繁的输入
                Thread.sleep(100)
            }
        } else {
            println("Failed to connect to server")
        }
    } catch (e: Exception) {
        e.printStackTrace()
    } finally {
        client.close()
    }
}

com.kt.server.handler.AuthHandler.kt

package com.kt.server.handler

import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import java.util.*
import java.util.concurrent.ConcurrentHashMap

/**
 * 认证信息类
 */
data class AuthInfo(
    val userId: String,
    val username: String,
    val token: String,
    val loginTime: Long = System.currentTimeMillis()
)

/**
 * 认证结果类
 */
data class AuthResult(
    val success: Boolean,
    val message: String,
    val token: String?,
    val authInfo: AuthInfo?
)


/**
 * 认证管理器
 */
class AuthManager {
    // 存储已认证的会话 token -> AuthInfo
    private val authenticatedSessions = ConcurrentHashMap<String, AuthInfo>()

    // 存储临时连接(未认证) sessionID -> connectTime
    private val pendingSessions = ConcurrentHashMap<String, Long>()

    // 有效的用户凭证 (实际应用中应该从数据库或配置文件中读取)
    private val validCredentials = mapOf(
        "user1" to "password1",
        "user2" to "password2",
        "admin" to "admin123"
    )

    // 有效的token (实际应用中应该使用JWT或其他token机制)
    private val validTokens = mutableSetOf<String>()

    /**
     * 添加待认证会话
     */
    fun addPendingSession(sessionId: String) {
        pendingSessions[sessionId] = System.currentTimeMillis()
        // 清理过期的待认证会话(超过30秒未认证的连接)
        cleanupExpiredPendingSessions()
    }

    /**
     * 验证用户凭证
     */
    fun authenticate(username: String, password: String): AuthResult {
        return if (validCredentials[username] == password) {
            val token = generateToken()
            val authInfo = AuthInfo(UUID.randomUUID().toString(), username, token)
            validTokens.add(token)
            AuthResult(true, "Authentication successful", token, authInfo)
        } else {
            AuthResult(false, "Invalid username or password", null, null)
        }
    }

    /**
     * 验证Token
     */
    fun validateToken(token: String): Boolean {
        return validTokens.contains(token)
    }

    /**
     * 记录认证成功的会话
     */
    fun recordAuthenticatedSession(token: String, authInfo: AuthInfo) {
        authenticatedSessions[token] = authInfo
    }

    /**
     * 移除认证会话
     */
    fun removeAuthenticatedSession(token: String) {
        authenticatedSessions.remove(token)
        validTokens.remove(token)
    }

    /**
     * 获取认证信息
     */
    fun getAuthInfo(token: String): AuthInfo? {
        return authenticatedSessions[token]
    }

    /**
     * 移除待认证会话
     */
    fun removePendingSession(sessionId: String) {
        pendingSessions.remove(sessionId)
    }

    /**
     * 清理过期的待认证会话
     */
    private fun cleanupExpiredPendingSessions() {
        val currentTime = System.currentTimeMillis()
        val expiredSessions = pendingSessions.filter {
            currentTime - it.value > 30000 // 30秒超时
        }.keys

        expiredSessions.forEach { sessionId ->
            pendingSessions.remove(sessionId)
            println("Removed expired pending session: $sessionId")
        }
    }

    /**
     * 生成Token
     */
    private fun generateToken(): String {
        return UUID.randomUUID().toString().replace("-", "") +
                System.currentTimeMillis().toString().takeLast(6)
    }

    /**
     * 获取当前认证用户数
     */
    fun getAuthenticatedUserCount(): Int {
        return authenticatedSessions.size
    }
}



class AuthHandler(private val authManager: AuthManager) : ChannelInboundHandlerAdapter(){
    private var sessionId: String? = null
    private var isAuthenticated = false

    override fun channelActive(ctx: ChannelHandlerContext?) {
        // 生成会话ID并添加到待认证列表
        sessionId = UUID.randomUUID().toString()
        authManager.addPendingSession(sessionId!!)
        println("New connection established, session ID: $sessionId. Waiting for authentication...")

        // 发送认证提示
        ctx?.writeAndFlush("Please authenticate: LOGIN <username> <password>\n")

        super.channelActive(ctx)
    }

    override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
        if (!isAuthenticated) {
            // 处理认证消息
            handleAuthentication(ctx, msg)
        } else {
            // 已认证,传递给下一个处理器
            ctx?.fireChannelRead(msg)
        }
    }

    private fun handleAuthentication(ctx: ChannelHandlerContext?, msg: Any?) {
        val message = msg?.toString()?.trim() ?: return

        // 解析认证命令
        if (message.startsWith("LOGIN", ignoreCase = true)) {
            val parts = message.split(" ")
            if (parts.size == 3) {
                val username = parts[1]
                val password = parts[2]

                val authResult = authManager.authenticate(username, password)
                if (authResult.success && authResult.token != null && authResult.authInfo != null) {
                    // 认证成功
                    isAuthenticated = true
                    authManager.recordAuthenticatedSession(authResult.token, authResult.authInfo)
                    authManager.removePendingSession(sessionId!!)

                    ctx?.writeAndFlush("AUTH_SUCCESS: ${authResult.token}\n")
                    ctx?.writeAndFlush("Welcome ${authResult.authInfo.username}! You are now authenticated.\n")
                    println("User $username authenticated successfully with session $sessionId")
                } else {
                    // 认证失败
                    ctx?.writeAndFlush("AUTH_FAILED: ${authResult.message}\n")
                    println("Authentication failed for session $sessionId: ${authResult.message}")
                }
            } else {
                ctx?.writeAndFlush("INVALID_COMMAND: Usage: LOGIN <username> <password>\n")
            }
        } else if (message == "QUIT") {
            ctx?.writeAndFlush("Goodbye!\n")
            ctx?.close()
        } else {
            // 未认证状态下收到其他消息
            ctx?.writeAndFlush("AUTH_REQUIRED: Please authenticate first. Usage: LOGIN <username> <password>\n")
        }
    }

    override fun channelInactive(ctx: ChannelHandlerContext?) {
        // 清理资源
        sessionId?.let { id ->
            if (!isAuthenticated) {
                authManager.removePendingSession(id)
            }
        }
        println("Connection closed for session: $sessionId")
        super.channelInactive(ctx)
    }

    override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
        cause?.printStackTrace()
        sessionId?.let { id ->
            if (!isAuthenticated) {
                authManager.removePendingSession(id)
            }
        }
        ctx?.close()
    }

}

com.kt.server.handler.ServerHandler.kt

package com.kt.server.handler

import io.netty.channel.Channel
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.handler.timeout.IdleStateEvent
import java.net.InetSocketAddress
import java.util.*
import java.util.concurrent.ConcurrentHashMap


class UserSession(val id:String, val channel: Channel) {
    val remoteAddress: InetSocketAddress = channel.remoteAddress() as InetSocketAddress
    val connectTime: Long = System.currentTimeMillis()
    var lastActiveTime: Long = System.currentTimeMillis()
    /**
     * 发送消息
     */
    fun sendMessage(message: String): Boolean {
        return if (channel.isActive) {
            channel.writeAndFlush(message)
            lastActiveTime = System.currentTimeMillis()
            true
        } else {
            false
        }
    }

    /**
     * 检查会话是否活跃
     */
    fun isActive(): Boolean {
        return channel.isActive
    }

    /**
     * 关闭会话
     */
    fun close() {
        if (channel.isOpen) {
            channel.close()
        }
    }

    override fun toString(): String {
        return "UserSession(id='$id', remoteAddress=$remoteAddress, connectTime=$connectTime)"
    }

}


/**
 * 会话管理器
 */
class SessionManger{
    private val sessions = ConcurrentHashMap<String, UserSession>()
    /**
     * 添加会话
     */
    fun addSession(session: UserSession) {
        sessions[session.id] = session
        println("Session added: ${session.id}, total sessions: ${sessions.size}")
    }

    /**
     * 移除会话
     */
    fun removeSession(sessionId: String) {
        sessions.remove(sessionId)?.also {
            println("Session removed: $sessionId, total sessions: ${sessions.size}")
        }
    }

    /**
     * 根据ID获取会话
     */
    fun getSession(sessionId: String): UserSession? {
        return sessions[sessionId]
    }
    /**
     * 获取所有会话
     */
    fun getAllSessions(): Collection<UserSession> {
        return sessions.values
    }
    /**
     * 获取会话数量
     */
    fun getSessionCount(): Int {
        return sessions.size
    }
    /**
     * 广播消息给所有会话
     */
    fun broadcast(message: String) {
        sessions.values.forEach { session ->
            session.sendMessage(message)
        }
    }

    /**
     * 发送消息给指定会话
     */
    fun sendToSession(sessionId: String, message: String): Boolean {
        val session = sessions[sessionId]
        return if (session != null) {
            session.sendMessage(message)
            true
        } else {
            false
        }
    }



}






class ServerHandler(private val sessionManager: SessionManger) : ChannelInboundHandlerAdapter() {

    private var session: UserSession? = null

    override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
        session?.let {
            it.lastActiveTime = System.currentTimeMillis()
            println("Received from client[${it.id}]: $msg")
            // 处理特殊命令
            when (msg) {
                "ping" -> {
                    // 回复心跳
                    it.sendMessage("pong")
                }
                "sessions" -> {
                    // 返回当前会话数
                    it.sendMessage("Current sessions: ${sessionManager.getSessionCount()}")
                }
                else -> {
                    // 回显消息给客户端
                    it.sendMessage("Echo: $msg")
                }
            }
        }
    }
    override fun channelActive(ctx: ChannelHandlerContext?) {
        // 创建用户会话
        val sessionId = UUID.randomUUID().toString()
        session = UserSession(sessionId, ctx!!.channel())
        sessionManager.addSession(session!!)

        println("Client connected: ${session!!.remoteAddress}, session ID: $sessionId")
        session?.sendMessage("Welcome! Your session ID is: $sessionId")

        super.channelActive(ctx)
    }

    override fun channelInactive(ctx: ChannelHandlerContext?) {
        session?.let {
            println("Client disconnected: ${it.remoteAddress}, session ID: ${it.id}")
            sessionManager.removeSession(it.id)
        }
        super.channelInactive(ctx)
    }

    override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
        if (evt is IdleStateEvent) {
            session?.let {
                println("Idle event triggered for session ${it.id}: ${evt.state()}")
                when (evt.state()) {
                    // 读空闲 - 客户端长时间未发送消息
                    io.netty.handler.timeout.IdleState.READER_IDLE -> {
                        println("Client ${it.id} reader idle, sending ping...")
                        it.sendMessage("ping")
                    }
                    // 写空闲 - 服务器长时间未发送消息
                    io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
                        println("Client ${it.id} writer idle, sending heartbeat...")
                        it.sendMessage("heartbeat")
                    }
                    // 所有空闲 - 双方长时间无通信
                    io.netty.handler.timeout.IdleState.ALL_IDLE -> {
                        println("Client ${it.id} all idle, closing connection...")
                        it.sendMessage("Connection timeout, closing...")
                        it.close()
                    }
                }
            }
        }
        super.userEventTriggered(ctx, evt)
    }

    override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
        cause?.printStackTrace()
        session?.let {
            println("Exception in session ${it.id}: ${cause?.message}")
            sessionManager.removeSession(it.id)
        }
        ctx?.close()
    }
}

3.client端

package com.kt.client

import io.netty.bootstrap.Bootstrap
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.string.StringDecoder
import io.netty.handler.codec.string.StringEncoder
import io.netty.handler.timeout.IdleStateEvent
import io.netty.handler.timeout.IdleStateHandler
import io.netty.util.concurrent.DefaultThreadFactory
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

class NettyClient(
    private val host: String,
    private val port: Int,
    val username: String? = null,
    val password: String? = null,
    private val clientId: String = UUID.randomUUID().toString()
) {
    private val workerGroup = NioEventLoopGroup(
        0, // 默认线程数
        DefaultThreadFactory("NettyClient")
    )
    private val bootstrap = Bootstrap()
    private var channel: Channel? = null
    private val isConnected = AtomicBoolean(false)
    private val isAuthenticated = AtomicBoolean(false)
    private var authToken: String? = null
    private var reconnectAttempts = 0
    private val maxReconnectAttempts = 5
    private var reconnectDelay = 5000L // 5秒

    init {
        configureBootstrap()
    }

    /**
     * 配置Bootstrap
     */
    private fun configureBootstrap() {
        bootstrap.group(workerGroup)
            .channel(NioSocketChannel::class.java)
            .option(ChannelOption.SO_KEEPALIVE, true)
            .option(ChannelOption.TCP_NODELAY, true)
            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
            .handler(object : ChannelInitializer<NioSocketChannel>() {
                override fun initChannel(ch: NioSocketChannel) {
                    ch.pipeline().apply {
                        addLast("decoder", StringDecoder())
                        addLast("encoder", StringEncoder())
                        // 添加心跳检测处理器 (30秒没有写操作则触发)
                        addLast("idleStateHandler", IdleStateHandler(0, 30, 0, TimeUnit.SECONDS))
                        addLast("clientHandler", ClientHandler())
                    }
                }
            })
    }

    /**
     * 连接到服务器并自动认证
     */
    fun connect(): Boolean {
        return try {
            val channelFuture = bootstrap.connect(host, port).sync()
            channel = channelFuture.channel()
            isConnected.set(channel?.isActive ?: false)

            if (isConnected.get()) {
                reconnectAttempts = 0 // 重置重连次数
                println("Connected to server $host:$port with client ID: $clientId")

                // 如果提供了用户名和密码,自动进行认证
                if (username != null && password != null) {
                    return authenticate(username, password)
                }
            }

            isConnected.get()
        } catch (e: Exception) {
            println("Failed to connect to server: ${e.message}")
            handleReconnect()
            false
        }
    }

    /**
     * 手动认证
     */
    fun authenticate(username: String, password: String): Boolean {
        if (!isConnected.get()) {
            println("Not connected to server")
            return false
        }

        val authCommand = "LOGIN $username $password"
        send(authCommand)

        // 等待认证结果(实际应用中应该使用回调或Future机制)
        var attempts = 0
        while (!isAuthenticated.get() && attempts < 10) {
            Thread.sleep(500)
            attempts++
        }

        return isAuthenticated.get()
    }

    /**
     * 异步连接到服务器
     */
    fun connectAsync(callback: ((Boolean) -> Unit)? = null) {
        bootstrap.connect(host, port).addListener { future ->
            if (future.isSuccess) {
                channel = future.get() as Channel
                isConnected.set(true)
                reconnectAttempts = 0
                println("Connected to server $host:$port with client ID: $clientId")

                // 如果提供了用户名和密码,自动进行认证
                if (username != null && password != null) {
                    authenticate(username, password)
                }

                callback?.invoke(true)
            } else {
                println("Failed to connect to server: ${future.cause().message}")
                callback?.invoke(false)
                handleReconnect()
            }
        }
    }

    /**
     * 发送消息
     */
    fun send(message: String): Boolean {
        return if (isConnected.get() && channel?.isActive == true) {
            channel?.writeAndFlush(message)
            true
        } else {
            println("Not connected to server, cannot send message")
            false
        }
    }

    /**
     * 断开连接
     */
    fun disconnect() {
        try {
            channel?.close()?.sync()
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
            println("Disconnected from server")
        } catch (e: InterruptedException) {
            Thread.currentThread().interrupt()
        }
    }

    /**
     * 关闭客户端
     */
    fun close() {
        try {
            channel?.close()?.sync()
        } catch (e: InterruptedException) {
            Thread.currentThread().interrupt()
        } finally {
            workerGroup.shutdownGracefully(2, 5, TimeUnit.SECONDS).sync()
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
            println("Client closed")
        }
    }

    /**
     * 检查连接状态
     */
    fun isConnected(): Boolean = isConnected.get()

    /**
     * 检查认证状态
     */
    fun isAuthenticated(): Boolean = isAuthenticated.get()

    /**
     * 获取客户端ID
     */
    fun getClientId(): String = clientId

    /**
     * 获取认证Token
     */
    fun getAuthToken(): String? = authToken

    /**
     * 处理重连逻辑
     */
    private fun handleReconnect() {
        if (reconnectAttempts < maxReconnectAttempts) {
            reconnectAttempts++
            println("Attempting to reconnect... ($reconnectAttempts/$maxReconnectAttempts)")

            // 延迟重连
            Timer().schedule(object : TimerTask() {
                override fun run() {
                    connect()
                }
            }, reconnectDelay)
        } else {
            println("Max reconnect attempts reached. Giving up.")
        }
    }

    /**
     * 客户端处理器
     */
    inner class ClientHandler : ChannelInboundHandlerAdapter() {
        override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
            val message = msg?.toString() ?: return
            println("Received from server: $message")

            // 处理认证响应
            when {
                message.startsWith("AUTH_SUCCESS:") -> {
                    authToken = message.substringAfter("AUTH_SUCCESS:").trim()
                    isAuthenticated.set(true)
                    println("Authentication successful! Token: $authToken")
                }
                message.startsWith("AUTH_FAILED:") -> {
                    val errorMsg = message.substringAfter("AUTH_FAILED:").trim()
                    isAuthenticated.set(false)
                    println("Authentication failed: $errorMsg")
                }
                message.startsWith("AUTH_REQUIRED:") -> {
                    println("Authentication required: $message")
                }
                message == "pong" || message == "heartbeat" -> {
                    println("Heartbeat received from server")
                }
                else -> {
                    // 处理其他消息
                    println("Message: $message")
                }
            }
        }

        override fun channelActive(ctx: ChannelHandlerContext?) {
            println("Connected to server: ${ctx?.channel()?.remoteAddress()}")
        }

        override fun channelInactive(ctx: ChannelHandlerContext?) {
            println("Disconnected from server: ${ctx?.channel()?.remoteAddress()}")
            isConnected.set(false)
            isAuthenticated.set(false)
            authToken = null
        }

        override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) {
            if (evt is IdleStateEvent) {
                when (evt.state()) {
                    io.netty.handler.timeout.IdleState.WRITER_IDLE -> {
                        println("Writer idle, sending ping to server...")
                        ctx?.writeAndFlush("ping")
                    }
                    else -> {
                        // 其他空闲状态处理
                    }
                }
            }
            super.userEventTriggered(ctx, evt)
        }

        override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
            cause?.printStackTrace()
            println("Exception in client: ${cause?.message}")
            ctx?.close()
        }
    }
}

/**
 * 客户端使用示例
 */
fun main() {
    // 方式1: 创建客户端时提供认证信息
//    val client = NettyClient("127.0.0.1", 8081, "user1", "password1", "Client-${System.currentTimeMillis()}")

    // 方式2: 手动认证
     val client = NettyClient("127.0.0.1", 8081, null, null, "Client-${System.currentTimeMillis()}")

    try {
        // 连接服务器
        if (client.connect()) {
            // 等待认证完成(如果使用自动认证)
            if (client.username != null && client.password != null) {
                var waitCount = 0
                while (!client.isAuthenticated() && waitCount < 20) {
                    Thread.sleep(500)
                    waitCount++
                }

                if (!client.isAuthenticated()) {
                    println("Authentication timeout")
                    return
                }
            }

            // 创建Scanner读取用户输入
            val scanner = Scanner(System.`in`)
            println("Enter messages to send (type 'exit' to quit):")
            println("Commands:")
            println("  exit - Quit the application")
            println("  ping - Send ping to server")
            println("  sessions - Get session count")
            println("  users - Get authenticated user count")
            println("  AUTH <username> <password> - Authenticate manually")
            println("  SEND <sessionId> <message> - Send private message")

            while (client.isConnected()) {
                print("Enter command: ")
                val input = scanner.nextLine().trim()

                when {
                    input.lowercase() == "exit" -> {
                        println("Exiting...")
                        break
                    }
                    input.startsWith("AUTH", true) -> {
                        val parts = input.split(" ")
                        if (parts.size == 3) {
                            val username = parts[1]
                            val password = parts[2]
                            client.authenticate(username, password)
                        } else {
                            println("Usage: AUTH <username> <password>")
                        }
                    }
                    input.isNotEmpty() -> {
                        if (client.isAuthenticated() || input.startsWith("LOGIN", true)) {
                            client.send(input)
                            println("Sent: $input")
                        } else {
                            println("Please authenticate first. Use: AUTH <username> <password>")
                        }
                    }
                }

                // 短暂延迟避免过于频繁的输入
                Thread.sleep(100)
            }
        } else {
            println("Failed to connect to server")
        }
    } catch (e: Exception) {
        e.printStackTrace()
    } finally {
        client.close()
    }
}

posted @ 2025-09-10 17:33  一个小笨蛋  阅读(13)  评论(0)    收藏  举报