Fork me on GitHub

day 29

day 29 作业

写一个基于TCP协议套接字,服务端实现接收客户端的连接并发。

# 基于之前版本,改动了服务端,支持多个用户同时连接客户端,并同时进行各自的操作
# File-服务端
import socket
import json
import struct
import os

DB_PATH = os.path.dirname(__file__)
USER_PATH = os.path.join(DB_PATH, 'USER')
FILE_PATH = os.path.join(DB_PATH, 'FILE')


def register(response_msg, conn):
    user_path = os.path.join(USER_PATH, response_msg.get('user_name'))

    if os.path.exists(user_path):
        register_msg = {
            'flag': False,
            'msg': '注册失败用户名已存在!'
        }

        send_msg(register_msg, conn)
        return

    del response_msg['function']
    with open(user_path, 'w', encoding='utf-8')as fw:
        json.dump(response_msg, fw)
        fw.flush()

        register_msg = {
            'flag': True,
            'msg': '注册成功!'
        }

        send_msg(register_msg, conn)


def login(response_msg, conn):
    user_path = os.path.join(USER_PATH, response_msg.get('user_name'))

    if not os.path.exists(user_path):
        login_msg = {
            'flag': False,
            'msg': '用户不存在!'
        }

        send_msg(login_msg, conn)
        return

    with open(user_path, 'r', encoding='utf-8')as fr:
        user_msg = json.load(fr)

        if user_msg.get('user_pwd') == response_msg.get('user_pwd'):
            login_msg = {
                'flag': True,
                'msg': '登录成功!'
            }

        else:
            login_msg = {
                'flag': False,
                'msg': '密码错误!'
            }

        send_msg(login_msg, conn)


def accept_file(response_msg, conn):
    file_path = os.path.join(FILE_PATH, response_msg.get('file_name'))

    total_data = 0
    with open(file_path, 'wb')as fw:
        while total_data < response_msg.get('file_size'):
            recv_data = conn.recv(1024)
            fw.write(recv_data)
            total_data += len(recv_data)

    upload_msg = {
        'flag': True,
        'msg': '上传成功!'
    }

    send_msg(upload_msg, conn)


def check_file(_, conn):
    if os.listdir(FILE_PATH):
        file_list = os.listdir(FILE_PATH)

        check_msg = {
            'flag': True,
            'file_list': file_list
        }

    else:
        check_msg = {
            'flag': False,
            'msg': '没有文件可以下载!'
        }

    send_msg(check_msg, conn)


def sent_file(response_msg, conn):
    file_path = os.path.join(FILE_PATH, response_msg.get('file_name'))

    file_msg = {
        'file_name': os.path.basename(file_path),
        'file_size': os.path.getsize(file_path),
        'flag': True
    }

    send_msg(file_msg, conn, file_path)


def send_msg(msg, conn, file=None):
    request_msg = json.dumps(msg).encode('utf-8')

    headers = struct.pack('i', len(request_msg))

    conn.send(headers)
    conn.send(request_msg)

    if file:
        with open(file, 'rb')as fr:
            for line in fr:
                conn.send(line)
        fr.close()


func_dict = {
    'register': register,
    'login': login,
    'accept': accept_file,
    'sent_file': sent_file,
    'check_file': check_file
}

# ========改动的地方========

from multiprocessing import Process


def run(conn, addr):
    if not os.path.exists(USER_PATH):
        os.mkdir(USER_PATH)

    if not os.path.exists(FILE_PATH):
        os.mkdir(FILE_PATH)

    while True:
        try:

            headers = conn.recv(4)

            data_len = struct.unpack('i', headers)[0]

            request_msg = json.loads(conn.recv(data_len).decode('utf-8'))

            func_dict.get(request_msg.get('function'))(request_msg, conn)

        except Exception as e:
            print(f'来自{addr}的用户中断了连接!')
            break


if __name__ == '__main__':

    server = socket.socket()
    server.bind(('127.0.0.1', 9527))
    server.listen(5)
    while True:
        conn, addr = server.accept()
        print(f'{addr}已连接上!')

        p = Process(target=run, args=(conn, addr))
        p.start()
# File-客户端(无改动)
import socket
import json
import struct
import os


def register(client):
    user_name = input('请输入用户名>>>').strip()
    user_pwd = input('请输入密码>>>').strip()
    re_pwd = input('请确认密码>>>').strip()

    if user_pwd == re_pwd:
        msg = {
            'user_name': user_name,
            'user_pwd': user_pwd,
            'function': 'register'
        }

        response_msg = send_msg(msg, client)

        if response_msg.get('flag'):
            print(response_msg.get('msg'))

        else:
            print(response_msg.get('msg'))

    else:
        print('两次密码不一致!')


def login(client):
    user_name = input('请输入用户名>>>').strip()
    user_pwd = input('请输入密码>>>').strip()

    msg = {
        'user_name': user_name,
        'user_pwd': user_pwd,
        'function': 'login'
    }

    response_msg = send_msg(msg, client)

    if response_msg.get('flag'):
        print(response_msg.get('msg'))

    else:
        print(response_msg.get('msg'))


def upload(client):
    while True:
        file_path = input('请复制上传文件的绝对路径>>>').strip()

        if not os.path.isfile(file_path):
            print('请输入正确的文件路径!')
            break

        request_upload = {
            'function': 'accept_file',
            'file_name': os.path.basename(file_path),
            'file_size': os.path.getsize(file_path)
        }

        response_msg = send_msg(request_upload, client, file_path)

        if response_msg.get('flag'):
            print(response_msg.get('msg'))
            break

        else:
            print(response_msg.get('msg'))
            break


def download(client):
    while True:
        request_msg = {
            'function': 'check_file'
        }

        response_msg = send_msg(request_msg, client)

        if not response_msg.get('flag'):
            print(response_msg.get('msg'))
            break

        print('\n')
        print('编号   文件名 ')
        for (index, name) in enumerate(response_msg.get('file_list'), 1):
            print(f'{index}     {name}')
        print('\n')

        file_choice = input('请选择下载的文件编号>>>').strip()

        if not file_choice.isdigit():
            print('输入错误!')
            continue

        file_choice = int(file_choice) - 1

        if file_choice not in range(len(response_msg)):
            print('输入错误!')
            continue

        request_download = {
            'file_name': response_msg.get('file_list')[file_choice],
            'function': 'sent_file'
        }

        download_msg = send_msg(request_download, client)

        if download_msg.get('flag'):

            save_dir = os.path.join(os.path.dirname(__file__), '下载的文件')

            if not os.path.exists(save_dir):
                os.mkdir(save_dir)

            save_path = os.path.join(save_dir, download_msg.get('file_name'))

            total_data = 0
            with open(save_path, 'wb')as fw:
                while total_data < download_msg.get('file_size'):
                    recv_data = client.recv(1024)
                    fw.write(recv_data)
                    total_data += len(recv_data)

                fw.flush()

                print('文件下载完成!')

                break


def send_msg(msg, client, file=None):
    request_msg = json.dumps(msg).encode('utf-8')

    headers = struct.pack('i', len(request_msg))

    client.send(headers)
    client.send(request_msg)

    if file:
        with open(file, 'rb')as fr:
            for line in fr:
                client.send(line)

    headers = client.recv(4)

    data_len = struct.unpack('i', headers)[0]

    response_msg = json.loads(client.recv(data_len).decode('utf-8'))

    return response_msg


func_dict = {
    '1': login,
    '2': register,
    '3': upload,
    '4': download,
}

if __name__ == '__main__':
    client = socket.socket()
    client.connect(('127.0.0.1', 9527))
    while True:
        print('''
        1. 登录
        2. 注册
        3. 上传
        4. 下载
        q. 退出
        ''')

        func_choice = input('请选择功能>>>').strip()

        if func_choice == 'q':
            break

        if func_choice not in func_dict:
            print('输入错误!')
            continue

        func_dict.get(func_choice)(client)
posted @ 2019-10-22 15:50  Yugaliii  阅读(95)  评论(0)    收藏  举报