Python网络编程之基于socketserver实现并发

socketserver是一个支持IO多路复用、多线程及多进程的模块。

基于tcp的套接字,关键就是两个循环,一个链接循环,一个通信循环。socketserver模块中分两大类。

server类(解决链接问题):

链接相关类:BaseServer,TCPServer,UDPServer,UnixStreamServer,UnixDatagramServer

基于多线程实现并发相关类:ThreadingMixIn,ThreadingTCPServer,ThreadingUDPServer

基于多进程实现并发相关类:ForkingMixIn,ForkingTCPServer,ForkingUDPServer

request类(解决通信问题):

通信相关类:BaseRequestHandler,StreamRequestHandler,DatagramRequestHandler

它们之间的继承关系如下:

源码分析总结:

一般在socketserver服务店中都会这样一句:server = socketserver.ThreadingTCPServer(settings.IP_PORT, MyServer),

ThreadingTCPServer这个类是一个支持多线程和TCP协议的socketserver,它的继承关系是这样的:

class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass

右边的TCPServer实际上是它主要的功能父类,而左边的ThreadingMixIn则是实现了多线程的类,它自己本身则没有任何代码。
MixIn在python的类命名中,很常见,一般被称为“混入”,戏称“乱入”,通常为了某种重要功能被子类继承。

class ThreadingMixIn:
    """Mix-in class to handle each request in a new thread."""

    # Decides how threads will act upon termination of the
    # main process
    daemon_threads = False

    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.

        In addition, exception handling is done here.

        """
        try:
            self.finish_request(request, client_address)
            self.shutdown_request(request)
        except:
            self.handle_error(request, client_address)
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads
        t.start()
ThreadingMixIn

在ThreadingMixIn类中,其实就定义了一个属性,两个方法。在process_request方法中实际调用的正是python内置的多线程模块threading。

这个模块是python中所有多线程的基础,socketserver本质上也是利用了这个模块。

基于tcp的socketserver我们自己定义的类中的

1、self.server即套接字对象:<socketserver.ThreadingTCPServer object at 0x0000017A8FD2DB00>

2、self.request即一个链接:<socket.socket fd=324, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('127.0.0.1', 8080), raddr=('127.0.0.1', 55216)

3、self.client_address即客户端地址:('127.0.0.1', 55216)

4、self 新建类对象:<__main__.FTPserver object at 0x0000017A8FF1AE80>

基于udp的socketserver我们自己定义的类中的

1、self.request是一个元组(第一个元素是客户端发来的数据,第二部分是服务端的udp套接字对象),如(b'adsf', <socket.socket fd=200, family=AddressFamily.AF_INET, type=SocketKind.SOCK_DGRAM, proto=0, laddr=('127.0.0.1', 8080)>)

2、self.client_address即客户端地址

#! /usr/bin/python
# -*- coding:utf-8 -*-

import socketserver,os,struct,json,hashlib,sys,datetime

bash_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(bash_dir)

import conf.client_set as args_set

class MyFtpServer(socketserver.BaseRequestHandler):
    """主要实现用户登录信息接受比对、用户注册、上传文件、下载文件及查询家目录文件等功能"""
    buf_size = args_set.buf_size
    coding = args_set.coding
    bash_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    connect_list = []

    def handle(self):
        while True:
            try:
                data=self.request.recv(4)                    #接受客户端报头信息并进行处理
                data_len=struct.unpack('i',data)[0]
                head_json=self.request.recv(data_len).decode(self.coding)
                head_dic=json.loads(head_json)
                cmd=head_dic['cmd']
                if hasattr(self,cmd):
                    func=getattr(self,cmd)
                    func(head_dic)
            except Exception:
                break

    def loggin(self,args):
        """用户登录"""
        loggin_status = False                               #初始化登陆状态
        user = args["user"]
        password = args["password"]
        user_file_path = os.path.normpath(os.path.join(self.bash_dir, "db", "user_file"))
        with open(user_file_path, encoding=self.coding) as f_read:
            for user_dict in f_read:
                user_dict = json.loads(user_dict)
                if user_dict["user"] == user and user_dict["password"] == password:
                    self.request.send("can".encode(self.coding))        #向客户端发送登陆成功消息
                    loggin_status = True
                    break
            if loggin_status == False:
                self.request.send("not".encode(self.coding))            #向客户端发送登陆失败消息

    def regist(self,args):
        """用户注册"""
        user = args["user"]
        password = args["password"]
        #不同用户等级,磁盘空间大小不同
        user_level_dict = {
                "1": 1024*1024*50,
                "2": 1024*1024*500,
                "3": 1024*1024*1024*2,
                "4": 1024*1024*1024*5,}
        disk_space = user_level_dict[args["user_level"]]
        user_file_path = os.path.normpath(os.path.join(self.bash_dir,"db","user_file"))
        home_dir = os.path.normpath(os.path.join(self.bash_dir,"db",user))      #指定用户家目录
        if os.path.exists(home_dir):                             #如果存在该用户,则说明该账号已注册
            self.request.send("not".encode(self.coding))
        else:
            os.mkdir(home_dir)                                   #建立家目录
            self.request.send("can".encode(self.coding))
            with open(user_file_path, "a") as f_write:  # 将客户注册信息写到用户文件中
                user_dict = {"user": user, "password": password, "disk_space": disk_space}
                user_json = json.dumps(user_dict)
                f_write.write(user_json)
                f_write.write("\n")

    def upload(self,args):
        """上传文件。主要完成空间大小比对、哈希值校验及断点续传功能"""
        file_path = os.path.normpath(os.path.join(self.bash_dir,"db",args["user"],args['filename']))  #文件上传目标位置
        file_size = args['filesize']                           #文件大小
        if os.path.exists(file_path):                          #判断是否存在上传文件
            exist_file_size = os.path.getsize(file_path)       #已存在文件大小,相等则说明已存在该文件,不等则说明上传一半,与续传有关
            if file_size == exist_file_size:
                self.request.send("have".encode(self.coding))
            else:
                self.request.send("half".encode(self.coding))
        else:
            self.request.send("noth".encode(self.coding))
        recv_data = self.request.recv(1).decode(self.coding)   #接收用户指令,进行下一步操作
        if recv_data == "n" or recv_data == "_":
            recv_size = 0
            continue_flag = False
        elif recv_data == "y":
            recv_size = exist_file_size
            continue_flag = True
        recv_size_struct = struct.pack("i", recv_size)         #将已存文件大小发送到客户端
        self.request.send(recv_size_struct)
        current_dir_size = os.path.getsize(os.path.dirname(file_path))         #获取客户家目录当前空间大小
        user_file_path = os.path.normpath(os.path.join(self.bash_dir, "db", "user_file"))   #用户信息文件
        with open(user_file_path, encoding=self.coding) as f_read:    # 打开用户信息文件,获取用户磁盘配额
            for user_dict in f_read:
                user_dict = json.loads(user_dict)
                if user_dict["user"] == args["user"]:
                    disk_space = user_dict["disk_space"]
                    if int(disk_space)-int(current_dir_size) > int(file_size)-int(recv_size): #判断剩余磁盘空间是否大于文件大小
                        self.request.send("can".encode(self.coding))
                        print(datetime.datetime.now(), self.client_address,"文件上传中...")
                        if continue_flag:
                            with open(file_path, 'ab') as f_write:               #续传模式
                                while recv_size < file_size:
                                    recv_data1 = self.request.recv(self.buf_size)
                                    f_write.write(recv_data1)
                                    recv_size += len(recv_data1)
                        else:
                            with open(file_path, 'wb') as f_write:               #从零开始上传模式
                                while recv_size < file_size:
                                    recv_data1 = self.request.recv(self.buf_size)
                                    f_write.write(recv_data1)
                                    recv_size += len(recv_data1)
                        with open(file_path, 'rb') as f_read:        #获取接收文件的哈希值,并与原始值比较,判断是否上传成功
                            md5obj = hashlib.md5()
                            md5obj.update(f_read.read())
                            hash_value = md5obj.hexdigest()
                            if hash_value == args["hash_value"]:     #判断哈希值是否相等
                                self.request.send("can".encode(self.coding))
                                print(datetime.datetime.now(),self.client_address,"上传成功!!!")
                                break
                            else:
                                self.request.send("not".encode(self.coding))
                                print(datetime.datetime.now(),self.client_address,"上传失败!!!")
                                break
                    else:
                        self.request.send("not".encode(self.coding))
                        break

    def download(self,args):
        """下载文件,断点续传等"""
        home_path = os.path.normpath(os.path.join(self.bash_dir, "db", args["user"]))   #用户家目录
        file_path = os.path.normpath(os.path.join(home_path,args["choice_file"]))       #文件路径
        file_size = os.path.getsize(file_path)                                          #获取文件大小
        head_dic = {"filesize": file_size,}                            #向客户端发送文件大小
        self.send_msg(head_dic)
        recv_data = self.request.recv(4)                               #接收已存文件大小
        start_seek = struct.unpack("i", recv_data)[0]
        recv_data1 = self.request.recv(1).decode(self.coding)          #接收用户指令,执行下一步操作
        if recv_data1 == "n" or recv_data1 == "_":
            send_size = 0
        elif recv_data1 == "y":
            send_size = start_seek
        elif recv_data1 == "r":
            return
        print(datetime.datetime.now(),self.client_address,"文件下载中...")
        with open(file_path, 'rb') as f_read:      # 服务端下载文件
            f_read.seek(send_size)
            for line in f_read:
                self.request.sendall(line)
        recv_hash_value = self.request.recv(32).decode(self.coding)     #接收客户端已接收文件的哈希值
        with open(file_path, "rb") as f_read:                           #获取上传文件的哈希值
            md5obj = hashlib.md5()
            md5obj.update(f_read.read())
            hash_value = md5obj.hexdigest()
        if recv_hash_value == hash_value:                    #判断原文件和下载文件的哈希值是否一致,如一致则发送成功
            self.request.send("can".encode(self.coding))
            print(datetime.datetime.now(),self.client_address,"下载成功!!!")
        else:
            print(recv_hash_value, hash_value)
            print(datetime.datetime.now(),self.client_address,"下载失败!!!")

    def query_file(self,args):
        """查询家目录文件"""
        file_path = os.path.normpath(os.path.join(self.bash_dir, "db", args["user"],))
        file_list = []
        for name in os.listdir(file_path):
            file_list.append(name)
        head_dic = {"file_list": file_list}
        self.send_msg(head_dic)

    def send_msg(self,head_dic):
        """格式化发送信息"""
        head_json = json.dumps(head_dic)
        head_json_bytes = head_json.encode(self.coding)
        head_struct = struct.pack("i", len(head_json_bytes))
        self.request.send(head_struct)
        self.request.send(head_json_bytes)

if __name__ == '__main__':
    ip_port = ("127.0.0.1", 8888)
    ftp_server = socketserver.ThreadingTCPServer(ip_port,MyFtpServer)
    ftp_server.serve_forever()
简易FTP实现-服务端
#!/usr/bin/python
#-*- coding:utf-8 -*-

import socket,struct,json,os,sys,hashlib

bash_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(bash_dir)

import conf.client_set as args_set

class MyFtpClient:
    """主要实现用户登录、注册、上传文件、下载文件及查询家目录文件等功能"""
    address_family = args_set.address_family
    socket_type = args_set.socket_type
    buf_size = args_set.buf_size
    coding = args_set.coding

    def __init__(self, server_address, connect=True,):
        self.server_address=server_address
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)
        self.loggin_flag=False                                     #初始化登录状态
        if connect:
            try:
                self.client_connect()
            except:
                self.client_close()
                raise

    def client_connect(self):
        """连接服务器"""
        self.socket.connect(self.server_address)

    def client_close(self):
        """关闭服务器"""
        self.socket.close()

    def run(self):
        """客户端程序的开始"""
        judge_msg_dict = {
                    "1": "upload",
                    "2": "download",
                    "3": "query_file",
                    "4": exit,
                    "5": "loggin",
                    "6": "regist",
                            }
        while True:
            if not self.loggin_flag:
                judge_msg = ("""
                ******欢迎您来到FTP系统******
                    你可以进行的操作如下:
                       1、上传文件
                       2、下载文件
                       3、查看家目录文件信息
                       4、退出
                       5、登录
                       6、注册
                            """)
            elif self.loggin_flag:
                user = loggin_user
                judge_msg = ("""
                    你可以进行的操作如下:           登陆账号:%s
                       1、上传文件
                       2、下载文件
                       3、查看家目录文件信息
                       4、退出
                            """ % user)
                judge_msg_dict = {
                        "1": "upload",
                        "2": "download",
                        "3": "query_file",
                        "4": exit,}
            print(judge_msg)
            judge_chioce = input("请输入您需要进行的操作序号:>>").strip()
            if judge_chioce == "4":
                print("你已退出系统,欢迎下次再见!!!")
                break
            elif judge_chioce in judge_msg_dict:
                cmd = judge_msg_dict[judge_chioce]
                if hasattr(self,cmd):
                    func = getattr(self,cmd)
                    func(cmd)
            else:
                print("您的输入有误,请重新输入!!!")
                continue

    def loggin(self,cmd):
        """用户登录"""
        while True:
            user = input("请输入您的用户名:>>").strip()
            pwd = input("请输入你的账户密码:>>").strip()
            head_dic={"cmd":cmd,"user":user,"password":pwd}
            self.send_msg(head_dic)
            recv_data = self.socket.recv(3)
            if recv_data.decode(self.coding) == "can":
                print("登录成功,欢迎您,%s" %user)
                self.loggin_flag = True
                global loggin_user
                loggin_user = user
                break
            if recv_data.decode(self.coding) == "not":
                choice = input("用户名或密码错误,请按任意键重新登录、确定是否注册或输入“q”返回主界面:>>")
                if choice == "q" or choice == "Q":
                    break
                else:
                    continue

    def regist(self,cmd):
        """用户注册"""
        regist_flag = True
        while regist_flag:
            user = input("请输入您的用户名:>>").strip()
            pwd1 = input("请输入你的账户密码:>>").strip()
            pwd2 = input("请重复输入你的账户密码:>>").strip()
            while True:
                #模拟实现不同用户磁盘空间大小不同
                user_level_msg = """
                可供选择的会员类型如下:
                   1、普通会员(磁盘空间:50M,无费用)
                   2、中级会员(磁盘空间:500M,2元/月)
                   3、高级会员(磁盘空间:2G,6元/月)
                   4、超级会员(磁盘空间:5G,10元/月)"""
                print(user_level_msg)
                user_level = input("请选择您需要的等级:>>")
                if user_level.isdigit():
                    if int(user_level) in range(1,5):
                        if pwd1 == pwd2:
                            head_dic={"cmd":cmd,"user":user,"password":pwd1,"user_level":user_level}
                            self.send_msg(head_dic)
                            recv_data = self.socket.recv(3)
                            if recv_data.decode(self.coding) == "can":
                                print("注册成功!!!")
                            else:
                                print("该账号已被注册,请更换其他账号!!!")
                            regist_flag = False
                            break
                        else:
                            choice = input("两次输入密码不一致,请按任意键重新注册或输入“q”返回主界面!!!")
                            if choice == "q" or choice == "Q":
                                regist_flag = False
                                break
                            else:
                                break
                    else:
                        print("选择超出范围!!!")
                else:
                    print("会员类型选择错误,请重新选择!!!")

    def upload(self,cmd):
        """上传文件"""
        if self.loggin_flag:                                              #判断是否进入登陆状态
            user = loggin_user
            file_dir=input("请输入需要上传的本地文件路径或输入“q”返回主菜单:>>")
            if file_dir == "q" or file_dir == "Q":
                return
            elif not os.path.isfile(file_dir):                            #判断需要上传的文件是否存在
                print("%s 文件路径不存在!!!" %file_dir)
                return
            else:
                with open(file_dir, "rb") as f_read:                      #获取上传文件的哈希值
                    md5obj = hashlib.md5()
                    md5obj.update(f_read.read())
                    hash_value = md5obj.hexdigest()
                filesize=os.path.getsize(file_dir)
            head_dic={"cmd":cmd,"filename":os.path.basename(file_dir),"filesize":filesize,
                      "user":user,"hash_value":hash_value}                #制定报头字典
            self.send_msg(head_dic)
            recv_data = self.socket.recv(4).decode(self.coding)           #接收数据,断点续传会用到
            if recv_data == "have":
                print("该文件已存在!!!")
                return
            elif recv_data == "half":
                while True:
                    upload_choice = input("发现该文件未上传完成,是否继续(y/n):>>")
                    if upload_choice == "y" or upload_choice == "Y":
                        self.socket.send("y".encode(self.coding))
                        break
                    elif upload_choice == "n" or upload_choice == "N":
                        self.socket.send("n".encode(self.coding))
                        break
                    else:
                        print("输入错误,请重新输入!!!")
                        continue
            elif recv_data == "noth":
                self.socket.send("_".encode(self.coding))
            recv_data1 = self.socket.recv(4)
            start_seek = struct.unpack("i",recv_data1)[0]                #获取文件发送起始位置
            recv_data2 = self.socket.recv(3).decode(self.coding)         #获取是否可以进行上传信号
            if recv_data2 == "can":                                      #接收磁盘空间大小是否OK
                send_size=start_seek
                with open(file_dir,"rb") as f_read:                      #发送文件
                    f_read.seek(start_seek)
                    for line in f_read:
                        self.socket.sendall(line)
                        send_size += len(line)
                        self.progress_bar(send_size, filesize)
                recv_data3 = self.socket.recv(3).decode(self.coding)    #获取哈希校验是否OK
                if recv_data3 == "can":
                    print("\n上传成功!!!")
                elif recv_data3 == "not":
                    print("\n上传失败,请重新进行操作!!!")
            else:
                print("您的磁盘空间不足!!!")
        else:
            print("请进行登录操作!!!")
            self.loggin("loggin")

    def download(self,cmd):
        """下载文件"""
        if self.loggin_flag:
            user = loggin_user
            file_list = self.query_file("query_file")                      #获取家目录文件清单
            if not file_list:
                print("家目录无可供选择文件!!!")
                return
            choice_file = input("请输入您需要下载的文件(文件名.扩展名)或输入“q”返回主菜单:>>")
            if choice_file == "q" or choice_file == "Q":
                return
            elif choice_file in file_list:                                   #判断输入文件的正确性
                while True:
                    file_dir = input("请输入本地目标路径(例如:D:\\test):>>")
                    if os.path.isdir(file_dir):                            #判断本地路径是否存在
                        break
                    else:
                        print("%s 文件路径不存在!!!" % file_dir)
                        choice = input("是否重新输入(y/n):>>")
                        if choice == "y" or choice == "Y":
                            continue
                        elif choice == "n" or choice == "N":
                            return
                        else:
                            print("输入错误!!!")
                            continue
                head_dic = {"cmd": cmd, "user": user,"choice_file":choice_file,}
                self.send_msg(head_dic)
                data = self.socket.recv(4)                                #接收服务端报头文件
                data_len = struct.unpack('i', data)[0]
                head_json = self.socket.recv(data_len).decode(self.coding)
                head_dic = json.loads(head_json)
                file_size = head_dic["filesize"]                          #得到文件大小
                file_path = os.path.normpath(os.path.join(file_dir,choice_file))  # 文件下载目标位置
                if os.path.exists(file_path):                     # 判断是否存在下载文件
                    exist_file_size = os.path.getsize(file_path)  # 已存在文件大小,相等则说明已存在该文件,不等则说明上传一半
                    recv_size_struct = struct.pack("i", exist_file_size)  # 将已存文件大小发送到客户端
                    self.socket.send(recv_size_struct)
                    if file_size == exist_file_size:
                        print("该文件已存在!!!")
                        self.socket.send("r".encode(self.coding))
                        return
                    else:
                        download_choice = input("发现该文件未下载完成,是否继续(y/n):>>")
                        if download_choice == "y" or download_choice == "Y":
                            self.socket.send("y".encode(self.coding))
                            self.file_download(file_path, file_size, 'ab', exist_file_size)
                        elif download_choice == "n" or download_choice == "N":
                            self.socket.send("n".encode(self.coding))
                            self.file_download(file_path, file_size, 'wb', 0)
                else:
                    exist_file_size = 0
                    recv_size_struct = struct.pack("i", exist_file_size)  # 将已存文件大小发送到客户端
                    self.socket.send(recv_size_struct)
                    self.socket.send("_".encode(self.coding))
                    self.file_download(file_path, file_size, 'wb', 0)
                with open(file_path, 'rb') as f_read:
                    md5obj = hashlib.md5()
                    md5obj.update(f_read.read())
                    hash_value = md5obj.hexdigest()
                self.socket.send(hash_value.encode(self.coding))
                recv_data1 = self.socket.recv(3).decode(self.coding)
                if recv_data1 == "can":
                    print("\n下载成功!!!")
                elif recv_data1 == "not":
                    print("\n下载失败,请重新进行操作!!!")
            else:
                print("文件输入错误!!!")
        else:
            print("请进行登录操作!!!")
            self.loggin("loggin")

    def file_download(self,file_path, file_size, mode, recv_size):
        """该方法用于文件从不同位置下载的模式不同的情况"""
        print("文件下载中...")
        with open(file_path, mode) as f_write:
            while recv_size < file_size:
                recv_data = self.socket.recv(self.buf_size)
                f_write.write(recv_data)
                recv_size += len(recv_data)
                self.progress_bar(recv_size, file_size)

    def query_file(self,cmd):
        """查询家目录文件信息"""
        if self.loggin_flag:
            user = loggin_user
            head_dic = {"cmd": cmd, "user": user}
            self.send_msg(head_dic)
            data = self.socket.recv(4)
            data_len = struct.unpack('i', data)[0]
            head_json = self.socket.recv(data_len).decode(self.coding)
            head_dic = json.loads(head_json)
            print("您的家目录文件清单如下:")
            if head_dic["file_list"]:
                for file in head_dic["file_list"]:
                    print(file)
            else:
                print("当前无文件!!!")
            return head_dic["file_list"]
        else:
            print("请进行登录操作!!!")
            self.loggin("loggin")

    def send_msg(self,head_dic):
        """格式化发送信息"""
        head_json = json.dumps(head_dic)
        head_json_bytes = head_json.encode(self.coding)
        head_struct = struct.pack("i", len(head_json_bytes))
        self.socket.send(head_struct)
        self.socket.send(head_json_bytes)

    def progress_bar(self,num,total):
        """动态进度条显示"""
        rate = num / total
        rate_num = int(rate * 100)
        r = '\r[%s%s]%d%%' % ("=" * rate_num, "" * (100 - rate_num), rate_num,)
        sys.stdout.write(r)
        sys.stdout.flush()

if __name__ == '__main__':
    ip_port = args_set.ip_port
    ftp_client=MyFtpClient(ip_port)
    ftp_client.run()
简易FTP实现-客户端

 

参考:http://www.cnblogs.com/linhaifeng/articles/6129246.html#_label1

posted @ 2018-08-15 18:01  Joe1991  阅读(200)  评论(0)    收藏  举报