简单多线程服务器实现
闲来没事,本来是在学习nio框架的,突然发现对最原始的多线程服务器都不是很了解,遂自己写了个简单的例子。
1 package testmutithreadserver.old;
2
3 import java.io.IOException;
4 import java.net.ServerSocket;
5 import java.net.Socket;
6
7 import testmutithreadserver.old.threadpool.ThreadPool;
8
9 /**
10 * 简单阻塞式多线程服务器(线程池处理)
11 *
12 * @author zhangjun
13 *
14 */
15 public class Server {
16
17 private int port;
18
19 private ServerSocket serverSocket;
20
21 private ThreadPool threadPool;
22
23 private PortListenThread listener;
24
25 public Server(int port) {
26 this.port = port;
27 threadPool = new ThreadPool();
28 }
29
30 public void start() {
31 try {
32 serverSocket = new ServerSocket(port);
33 listener = new PortListenThread();
34 listener.start();
35 } catch (IOException e) {
36 e.printStackTrace();
37 }
38 }
39
40 public void shutdown() {
41 threadPool.shutdown();
42 listener.finish();
43 }
44
45 private class PortListenThread extends Thread {
46
47 private Boolean finish = false;
48
49 @Override
50 public void run() {
51 while (!finish) {
52 try {
53 final Socket socket = serverSocket.accept();
54 threadPool.execute(new Runnable() {
55
56 @Override
57 public void run() {
58 new TestMessage(socket).execute();
59 }
60 });
61 } catch (IOException e) {
62 e.printStackTrace();
63 }
64
65 }
66 }
67
68 public void finish() {
69 finish = true;
70 }
71
72 }
73
74 public static void main(String[] args) {
75 int port = 8888;
76 System.out.println("server is listening on port: " + port);
77 new Server(port).start();
78 }
79
80 }
81
2
3 import java.io.IOException;
4 import java.net.ServerSocket;
5 import java.net.Socket;
6
7 import testmutithreadserver.old.threadpool.ThreadPool;
8
9 /**
10 * 简单阻塞式多线程服务器(线程池处理)
11 *
12 * @author zhangjun
13 *
14 */
15 public class Server {
16
17 private int port;
18
19 private ServerSocket serverSocket;
20
21 private ThreadPool threadPool;
22
23 private PortListenThread listener;
24
25 public Server(int port) {
26 this.port = port;
27 threadPool = new ThreadPool();
28 }
29
30 public void start() {
31 try {
32 serverSocket = new ServerSocket(port);
33 listener = new PortListenThread();
34 listener.start();
35 } catch (IOException e) {
36 e.printStackTrace();
37 }
38 }
39
40 public void shutdown() {
41 threadPool.shutdown();
42 listener.finish();
43 }
44
45 private class PortListenThread extends Thread {
46
47 private Boolean finish = false;
48
49 @Override
50 public void run() {
51 while (!finish) {
52 try {
53 final Socket socket = serverSocket.accept();
54 threadPool.execute(new Runnable() {
55
56 @Override
57 public void run() {
58 new TestMessage(socket).execute();
59 }
60 });
61 } catch (IOException e) {
62 e.printStackTrace();
63 }
64
65 }
66 }
67
68 public void finish() {
69 finish = true;
70 }
71
72 }
73
74 public static void main(String[] args) {
75 int port = 8888;
76 System.out.println("server is listening on port: " + port);
77 new Server(port).start();
78 }
79
80 }
81
这个Server调用的是自己实现的一个基于任务队列的简单线程池:
1 package testmutithreadserver.old.threadpool;
2
3 import java.util.LinkedList;
4
5 /**
6 * 简单线程池 (基于工作队列的同步线程池)
7 *
8 * @author zhangjun
9 *
10 */
11 public class ThreadPool extends ThreadGroup {
12 private final static String THREADPOOL = "thread pool";
13 private final static String WORKTHREAD = "work thread ";
14 private final static int DEFAULTSIZE = Runtime.getRuntime()
15 .availableProcessors() + 1;
16 private LinkedList<Runnable> taskQueue;
17 private boolean isPoolClose = false;
18
19 public ThreadPool() {
20 this(DEFAULTSIZE);
21 }
22
23 public ThreadPool(int size) {
24 super(THREADPOOL);
25 setDaemon(true);
26 taskQueue = new LinkedList<Runnable>();
27 initWorkThread(size);
28 }
29
30 private void initWorkThread(int size) {
31 for (int i = 0; i < size; i++) {
32 new WorkThread(WORKTHREAD + i).start();
33 }
34 try {
35 Thread.sleep(100 * size);
36 } catch (InterruptedException e) {
37 }
38 }
39
40 public synchronized void execute(Runnable task) {
41 if (isPoolClose) {
42 throw new IllegalStateException();
43 }
44 if (task != null) {
45 taskQueue.add(task);
46 notify();
47 }
48 }
49
50 private synchronized Runnable getTask() throws InterruptedException {
51 if (taskQueue.size() == 0) {
52 if (isPoolClose) {
53 return null;
54 }
55 wait();
56 }
57 if (taskQueue.size() == 0) {
58 return null;
59 }
60 return taskQueue.removeFirst();
61 }
62
63 public void shutdown() {
64 waitFinish();
65 synchronized (this) {
66 isPoolClose = true;
67 interrupt();
68 taskQueue.clear();
69 }
70 }
71
72 private void waitFinish() {
73 synchronized (this) {
74 isPoolClose = true;
75 notifyAll();
76 }
77 Thread[] threads = new Thread[activeCount()];
78 enumerate(threads);
79 try {
80 for (Thread t : threads) {
81 t.join();
82 }
83 } catch (InterruptedException e) {
84 //swallow this
85 }
86 }
87
88 private class WorkThread extends Thread {
89
90 public WorkThread(String name) {
91 super(ThreadPool.this, name);
92 }
93
94 @Override
95 public void run() {
96 while (!isInterrupted()) {
97 Runnable task = null;
98 try {
99 task = getTask();
100 } catch (InterruptedException e) {
101 //swallow this
102 }
103 if (task == null) {
104 return;
105 }
106 try {
107 task.run();
108 } catch (Throwable e) {
109 e.printStackTrace();
110 }
111 }
112 }
113
114 }
115 }
116
2
3 import java.util.LinkedList;
4
5 /**
6 * 简单线程池 (基于工作队列的同步线程池)
7 *
8 * @author zhangjun
9 *
10 */
11 public class ThreadPool extends ThreadGroup {
12 private final static String THREADPOOL = "thread pool";
13 private final static String WORKTHREAD = "work thread ";
14 private final static int DEFAULTSIZE = Runtime.getRuntime()
15 .availableProcessors() + 1;
16 private LinkedList<Runnable> taskQueue;
17 private boolean isPoolClose = false;
18
19 public ThreadPool() {
20 this(DEFAULTSIZE);
21 }
22
23 public ThreadPool(int size) {
24 super(THREADPOOL);
25 setDaemon(true);
26 taskQueue = new LinkedList<Runnable>();
27 initWorkThread(size);
28 }
29
30 private void initWorkThread(int size) {
31 for (int i = 0; i < size; i++) {
32 new WorkThread(WORKTHREAD + i).start();
33 }
34 try {
35 Thread.sleep(100 * size);
36 } catch (InterruptedException e) {
37 }
38 }
39
40 public synchronized void execute(Runnable task) {
41 if (isPoolClose) {
42 throw new IllegalStateException();
43 }
44 if (task != null) {
45 taskQueue.add(task);
46 notify();
47 }
48 }
49
50 private synchronized Runnable getTask() throws InterruptedException {
51 if (taskQueue.size() == 0) {
52 if (isPoolClose) {
53 return null;
54 }
55 wait();
56 }
57 if (taskQueue.size() == 0) {
58 return null;
59 }
60 return taskQueue.removeFirst();
61 }
62
63 public void shutdown() {
64 waitFinish();
65 synchronized (this) {
66 isPoolClose = true;
67 interrupt();
68 taskQueue.clear();
69 }
70 }
71
72 private void waitFinish() {
73 synchronized (this) {
74 isPoolClose = true;
75 notifyAll();
76 }
77 Thread[] threads = new Thread[activeCount()];
78 enumerate(threads);
79 try {
80 for (Thread t : threads) {
81 t.join();
82 }
83 } catch (InterruptedException e) {
84 //swallow this
85 }
86 }
87
88 private class WorkThread extends Thread {
89
90 public WorkThread(String name) {
91 super(ThreadPool.this, name);
92 }
93
94 @Override
95 public void run() {
96 while (!isInterrupted()) {
97 Runnable task = null;
98 try {
99 task = getTask();
100 } catch (InterruptedException e) {
101 //swallow this
102 }
103 if (task == null) {
104 return;
105 }
106 try {
107 task.run();
108 } catch (Throwable e) {
109 e.printStackTrace();
110 }
111 }
112 }
113
114 }
115 }
116
当然也可以直接使用concurrent的线程池,代码几乎不用改变:
1 package testmutithreadserver.concurrent;
2
3 import java.io.IOException;
4 import java.net.ServerSocket;
5 import java.net.Socket;
6 import java.util.concurrent.ExecutorService;
7 import java.util.concurrent.Executors;
8
9 import testmutithreadserver.old.TestMessage;
10
11 /**
12 * 简单阻塞式多线程服务器(线程池处理)
13 *
14 * @author zhangjun
15 *
16 */
17 public class Server {
18
19 private int port;
20
21 private ServerSocket serverSocket;
22
23 private ExecutorService threadPool;
24
25 private PortListenThread listener;
26
27 public Server(int port) {
28 this.port = port;
29 threadPool = Executors.newFixedThreadPool(3);
30 }
31
32 public void start() {
33 try {
34 serverSocket = new ServerSocket(port);
35 listener = new PortListenThread();
36 listener.start();
37 } catch (IOException e) {
38 e.printStackTrace();
39 }
40 }
41
42 public void shutdown() {
43 threadPool.shutdown();
44 listener.finish();
45 }
46
47 private class PortListenThread extends Thread {
48
49 private Boolean finish = false;
50
51 @Override
52 public void run() {
53 while (!finish) {
54 try {
55 final Socket socket = serverSocket.accept();
56 threadPool.execute(new Runnable() {
57
58 @Override
59 public void run() {
60 new TestMessage(socket).execute();
61 }
62 });
63 } catch (IOException e) {
64 e.printStackTrace();
65 }
66
67 }
68 }
69
70 public void finish() {
71 finish = true;
72 }
73
74 }
75
76 public static void main(String[] args) {
77 int port = 8888;
78 System.out.println("server is listening on port: " + port);
79 new Server(port).start();
80 }
81 }
82
2
3 import java.io.IOException;
4 import java.net.ServerSocket;
5 import java.net.Socket;
6 import java.util.concurrent.ExecutorService;
7 import java.util.concurrent.Executors;
8
9 import testmutithreadserver.old.TestMessage;
10
11 /**
12 * 简单阻塞式多线程服务器(线程池处理)
13 *
14 * @author zhangjun
15 *
16 */
17 public class Server {
18
19 private int port;
20
21 private ServerSocket serverSocket;
22
23 private ExecutorService threadPool;
24
25 private PortListenThread listener;
26
27 public Server(int port) {
28 this.port = port;
29 threadPool = Executors.newFixedThreadPool(3);
30 }
31
32 public void start() {
33 try {
34 serverSocket = new ServerSocket(port);
35 listener = new PortListenThread();
36 listener.start();
37 } catch (IOException e) {
38 e.printStackTrace();
39 }
40 }
41
42 public void shutdown() {
43 threadPool.shutdown();
44 listener.finish();
45 }
46
47 private class PortListenThread extends Thread {
48
49 private Boolean finish = false;
50
51 @Override
52 public void run() {
53 while (!finish) {
54 try {
55 final Socket socket = serverSocket.accept();
56 threadPool.execute(new Runnable() {
57
58 @Override
59 public void run() {
60 new TestMessage(socket).execute();
61 }
62 });
63 } catch (IOException e) {
64 e.printStackTrace();
65 }
66
67 }
68 }
69
70 public void finish() {
71 finish = true;
72 }
73
74 }
75
76 public static void main(String[] args) {
77 int port = 8888;
78 System.out.println("server is listening on port: " + port);
79 new Server(port).start();
80 }
81 }
82
里边我构造了一个Message接口:
1 package testmutithreadserver.old;
2
3 /**
4 * 通用消息接口
5 *
6 * @author zhangjun
7 *
8 */
9 public interface Message {
10
11 void execute();
12
13 }
14
2
3 /**
4 * 通用消息接口
5 *
6 * @author zhangjun
7 *
8 */
9 public interface Message {
10
11 void execute();
12
13 }
14
以及实现了一个测试消息类:
1 package testmutithreadserver.old;
2
3 import java.io.BufferedReader;
4 import java.io.IOException;
5 import java.io.InputStreamReader;
6 import java.io.PrintWriter;
7 import java.net.Socket;
8
9 /**
10 * 测试消息
11 *
12 * @author zhangjun
13 *
14 */
15 public class TestMessage implements Message {
16
17 private Socket socket;
18
19 public TestMessage(Socket socket) {
20 this.socket = socket;
21 }
22
23 @Override
24 public void execute() {
25 try {
26 BufferedReader in = new BufferedReader(new InputStreamReader(socket
27 .getInputStream()));
28 PrintWriter out = new PrintWriter(socket.getOutputStream(), true);
29 String s;
30 while ((s = in.readLine()) != null) {
31 System.out.println("received message:" + s);
32 if (s.equals("quit")) {
33 break;
34 }
35 out.println("hello " + s);
36 }
37 } catch (IOException e) {
38 e.printStackTrace();
39 } finally {
40 try {
41 if (!socket.isClosed()) {
42 socket.close();
43 }
44 } catch (IOException e) {
45 }
46 }
47 }
48
49 }
50
2
3 import java.io.BufferedReader;
4 import java.io.IOException;
5 import java.io.InputStreamReader;
6 import java.io.PrintWriter;
7 import java.net.Socket;
8
9 /**
10 * 测试消息
11 *
12 * @author zhangjun
13 *
14 */
15 public class TestMessage implements Message {
16
17 private Socket socket;
18
19 public TestMessage(Socket socket) {
20 this.socket = socket;
21 }
22
23 @Override
24 public void execute() {
25 try {
26 BufferedReader in = new BufferedReader(new InputStreamReader(socket
27 .getInputStream()));
28 PrintWriter out = new PrintWriter(socket.getOutputStream(), true);
29 String s;
30 while ((s = in.readLine()) != null) {
31 System.out.println("received message:" + s);
32 if (s.equals("quit")) {
33 break;
34 }
35 out.println("hello " + s);
36 }
37 } catch (IOException e) {
38 e.printStackTrace();
39 } finally {
40 try {
41 if (!socket.isClosed()) {
42 socket.close();
43 }
44 } catch (IOException e) {
45 }
46 }
47 }
48
49 }
50
代码很简单,就不用多解释什么了。下一步打算用nio在自己写个非阻塞的服务器。