day 29
day 29 作业
写一个基于TCP协议套接字,服务端实现接收客户端的连接并发。
# 基于之前版本,改动了服务端,支持多个用户同时连接客户端,并同时进行各自的操作
# File-服务端
import socket
import json
import struct
import os
DB_PATH = os.path.dirname(__file__)
USER_PATH = os.path.join(DB_PATH, 'USER')
FILE_PATH = os.path.join(DB_PATH, 'FILE')
def register(response_msg, conn):
user_path = os.path.join(USER_PATH, response_msg.get('user_name'))
if os.path.exists(user_path):
register_msg = {
'flag': False,
'msg': '注册失败用户名已存在!'
}
send_msg(register_msg, conn)
return
del response_msg['function']
with open(user_path, 'w', encoding='utf-8')as fw:
json.dump(response_msg, fw)
fw.flush()
register_msg = {
'flag': True,
'msg': '注册成功!'
}
send_msg(register_msg, conn)
def login(response_msg, conn):
user_path = os.path.join(USER_PATH, response_msg.get('user_name'))
if not os.path.exists(user_path):
login_msg = {
'flag': False,
'msg': '用户不存在!'
}
send_msg(login_msg, conn)
return
with open(user_path, 'r', encoding='utf-8')as fr:
user_msg = json.load(fr)
if user_msg.get('user_pwd') == response_msg.get('user_pwd'):
login_msg = {
'flag': True,
'msg': '登录成功!'
}
else:
login_msg = {
'flag': False,
'msg': '密码错误!'
}
send_msg(login_msg, conn)
def accept_file(response_msg, conn):
file_path = os.path.join(FILE_PATH, response_msg.get('file_name'))
total_data = 0
with open(file_path, 'wb')as fw:
while total_data < response_msg.get('file_size'):
recv_data = conn.recv(1024)
fw.write(recv_data)
total_data += len(recv_data)
upload_msg = {
'flag': True,
'msg': '上传成功!'
}
send_msg(upload_msg, conn)
def check_file(_, conn):
if os.listdir(FILE_PATH):
file_list = os.listdir(FILE_PATH)
check_msg = {
'flag': True,
'file_list': file_list
}
else:
check_msg = {
'flag': False,
'msg': '没有文件可以下载!'
}
send_msg(check_msg, conn)
def sent_file(response_msg, conn):
file_path = os.path.join(FILE_PATH, response_msg.get('file_name'))
file_msg = {
'file_name': os.path.basename(file_path),
'file_size': os.path.getsize(file_path),
'flag': True
}
send_msg(file_msg, conn, file_path)
def send_msg(msg, conn, file=None):
request_msg = json.dumps(msg).encode('utf-8')
headers = struct.pack('i', len(request_msg))
conn.send(headers)
conn.send(request_msg)
if file:
with open(file, 'rb')as fr:
for line in fr:
conn.send(line)
fr.close()
func_dict = {
'register': register,
'login': login,
'accept': accept_file,
'sent_file': sent_file,
'check_file': check_file
}
# ========改动的地方========
from multiprocessing import Process
def run(conn, addr):
if not os.path.exists(USER_PATH):
os.mkdir(USER_PATH)
if not os.path.exists(FILE_PATH):
os.mkdir(FILE_PATH)
while True:
try:
headers = conn.recv(4)
data_len = struct.unpack('i', headers)[0]
request_msg = json.loads(conn.recv(data_len).decode('utf-8'))
func_dict.get(request_msg.get('function'))(request_msg, conn)
except Exception as e:
print(f'来自{addr}的用户中断了连接!')
break
if __name__ == '__main__':
server = socket.socket()
server.bind(('127.0.0.1', 9527))
server.listen(5)
while True:
conn, addr = server.accept()
print(f'{addr}已连接上!')
p = Process(target=run, args=(conn, addr))
p.start()
# File-客户端(无改动)
import socket
import json
import struct
import os
def register(client):
user_name = input('请输入用户名>>>').strip()
user_pwd = input('请输入密码>>>').strip()
re_pwd = input('请确认密码>>>').strip()
if user_pwd == re_pwd:
msg = {
'user_name': user_name,
'user_pwd': user_pwd,
'function': 'register'
}
response_msg = send_msg(msg, client)
if response_msg.get('flag'):
print(response_msg.get('msg'))
else:
print(response_msg.get('msg'))
else:
print('两次密码不一致!')
def login(client):
user_name = input('请输入用户名>>>').strip()
user_pwd = input('请输入密码>>>').strip()
msg = {
'user_name': user_name,
'user_pwd': user_pwd,
'function': 'login'
}
response_msg = send_msg(msg, client)
if response_msg.get('flag'):
print(response_msg.get('msg'))
else:
print(response_msg.get('msg'))
def upload(client):
while True:
file_path = input('请复制上传文件的绝对路径>>>').strip()
if not os.path.isfile(file_path):
print('请输入正确的文件路径!')
break
request_upload = {
'function': 'accept_file',
'file_name': os.path.basename(file_path),
'file_size': os.path.getsize(file_path)
}
response_msg = send_msg(request_upload, client, file_path)
if response_msg.get('flag'):
print(response_msg.get('msg'))
break
else:
print(response_msg.get('msg'))
break
def download(client):
while True:
request_msg = {
'function': 'check_file'
}
response_msg = send_msg(request_msg, client)
if not response_msg.get('flag'):
print(response_msg.get('msg'))
break
print('\n')
print('编号 文件名 ')
for (index, name) in enumerate(response_msg.get('file_list'), 1):
print(f'{index} {name}')
print('\n')
file_choice = input('请选择下载的文件编号>>>').strip()
if not file_choice.isdigit():
print('输入错误!')
continue
file_choice = int(file_choice) - 1
if file_choice not in range(len(response_msg)):
print('输入错误!')
continue
request_download = {
'file_name': response_msg.get('file_list')[file_choice],
'function': 'sent_file'
}
download_msg = send_msg(request_download, client)
if download_msg.get('flag'):
save_dir = os.path.join(os.path.dirname(__file__), '下载的文件')
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, download_msg.get('file_name'))
total_data = 0
with open(save_path, 'wb')as fw:
while total_data < download_msg.get('file_size'):
recv_data = client.recv(1024)
fw.write(recv_data)
total_data += len(recv_data)
fw.flush()
print('文件下载完成!')
break
def send_msg(msg, client, file=None):
request_msg = json.dumps(msg).encode('utf-8')
headers = struct.pack('i', len(request_msg))
client.send(headers)
client.send(request_msg)
if file:
with open(file, 'rb')as fr:
for line in fr:
client.send(line)
headers = client.recv(4)
data_len = struct.unpack('i', headers)[0]
response_msg = json.loads(client.recv(data_len).decode('utf-8'))
return response_msg
func_dict = {
'1': login,
'2': register,
'3': upload,
'4': download,
}
if __name__ == '__main__':
client = socket.socket()
client.connect(('127.0.0.1', 9527))
while True:
print('''
1. 登录
2. 注册
3. 上传
4. 下载
q. 退出
''')
func_choice = input('请选择功能>>>').strip()
if func_choice == 'q':
break
if func_choice not in func_dict:
print('输入错误!')
continue
func_dict.get(func_choice)(client)


浙公网安备 33010602011771号