gevent协程、select IO多路复用 改造多用户FTP程序例子

python学习网 2018-01-17 12:55:02

原多线程版FTP程序:http://www.cnblogs.com/linzetong/p/8290378.html

只需要在原来的代码基础上稍作修改:

一、gevent协程版本

1、 导入gevent模块

import gevent
from gevent import monkey

2、python的异步库gevent打猴子补丁,他的用途是让你方便的导入非阻塞的模块,不需要特意的去引入。

monkey.patch_all()

3、 把socket设置为非阻塞

self.sock.setblocking(0)  

4、 修改run函数,

 # gevent 实现单线程多并发
gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr)
其他不用更改

 

二、select IO多路复用版本

1、 导入select模块

import select

2、 把socket设置为非阻塞

self.sock.setblocking(0)  

3、 修改run函数,用select.select()方法接收并监控多个通信socket列表

def run(self):
        while True:  # 链接循环
            # select 单进程实现同时处理请求
            inputs = [self.sock, ]
            outputs = []
            while True:
                readable, writeable, exceptional = select.select(inputs, outputs, inputs)
                for r in readable:
                    if r is self.sock:
                        request, client_address = self.sock.accept()
                        inputs.append(request)
                    else:
                        print('处理request:%s'%id(r))
                        return_code, request = TCPHandler().handle(r, )
                        if not return_code:
                            request.close()
                            print('client[%s] is disconect' % ((request.getpeername()),))
4、完整代码:

server.py

  1 # -*- coding: utf-8 -*-
  2 import socket
  3 import os, json, re, struct, threading, time
  4 import gevent
  5 from gevent import monkey
  6 import select
  7 from lib import commons
  8 from conf import settings
  9 from core import logger
 10 
 11 monkey.patch_all()
 12 
 13 
 14 class Server(object):
 15     def __init__(self):
 16         self.init_dir()
 17         self.sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
 18         self.sock.setblocking(0)  # select实现同时处理请求,需要设置为非阻塞
 19         # self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 20         self.sock.bind((settings.server_bind_ip, settings.server_bind_port))
 21         self.sock.listen(settings.server_listen)
 22         print("\033[42;1mserver started sucessful!\033[0m")
 23         self.run()
 24 
 25     @staticmethod
 26     def init_dir():
 27         if not os.path.exists(os.path.join(settings.base_path, 'logs')): os.mkdir(
 28             os.path.join(settings.base_path, 'logs'))
 29         if not os.path.exists(os.path.join(settings.base_path, 'db')): os.mkdir(os.path.join(settings.base_path, 'db'))
 30         if not os.path.exists(os.path.join(settings.base_path, 'home')): os.mkdir(
 31             os.path.join(settings.base_path, 'home'))
 32 
 33     def run(self):
 34         while True:  # 链接循环
 35             # select 单进程实现同时处理请求
 36             inputs = [self.sock, ]
 37             outputs = []
 38             while True:
 39                 readable, writeable, exceptional = select.select(inputs, outputs, inputs)
 40                 for r in readable:
 41                     if r is self.sock:
 42                         request, client_address = self.sock.accept()
 43                         inputs.append(request)
 44                     else:
 45                         print('处理request:%s'%id(r))
 46                         return_code, request = TCPHandler().handle(r, )
 47                         if not return_code:
 48                             request.close()
 49                             print('client[%s] is disconect' % ((request.getpeername()),))
 50                 # self.request, self.cli_addr = self.sock.accept()
 51                 # self.request.settimeout(300)
 52                 # 多线程处理请求
 53                 # thread = threading.Thread(target=TCPHandler.handle, args=(TCPHandler(), self.request, self.cli_addr))
 54                 # thread.start()
 55                 # gevent 实现单线程多并发
 56                 # gevent.spawn(TCPHandler.handle, TCPHandler(), self.request, self.cli_addr)
 57 
 58 
 59 class TCPHandler(object):
 60     STATUS_CODE = {
 61         200: 'Passed authentication!',
 62         201: 'Wrong username or password!',
 63         202: 'Username does not exist!',
 64         300: 'cmd successful , the target path be returned in returnPath',
 65         301: 'cmd format error!',
 66         302: 'The path or file could not be found!',
 67         303: 'The dir is exist',
 68         304: 'The file has been downloaded or the size of the file is exceptions',
 69         305: 'Free space is not enough',
 70         401: 'File MD5 inspection failed',
 71         400: 'File MD5 inspection success',
 72     }
 73 
 74     def __init__(self):
 75         self.server_logger = logger.logger('server')
 76         self.server_logger.debug("server TCPHandler started successful!")
 77 
 78     def handle(self, request, address=(None, None)):
 79         self.request = request
 80         self.cli_addr = request.getpeername()
 81         self.server_logger.info('client[%s] is conecting' % ((request.getpeername()),))
 82         print('client[%s] is conecting' % ((request.getpeername()),))
 83         # while True:  # 通讯循环
 84         try:
 85             # 1、接收客户端的ftp命令
 86             print("waiting receive client[%s] ftp command.." % ((request.getpeername()),), id(self), self)
 87             header_dic, req_dic = self.recv_request()
 88             if not header_dic: return False, request
 89             if not header_dic['cmd']: return False, request
 90             print('receive client ftp command:%s' % header_dic['cmd'])
 91             # 2、解析ftp命令,获取相应命令参数(文件名)
 92             cmds = header_dic['cmd'].split()  # ['register',]、['get', 'a.txt']
 93             if hasattr(self, cmds[0]):
 94                 self.server_logger.info('interface:[%s], request:{client:[%s:%s] action:[%s]}' % (
 95                     cmds[0], self.cli_addr[0], self.cli_addr[1], header_dic['cmd']))
 96                 getattr(self, cmds[0])(header_dic, req_dic)
 97                 return True, request
 98         except (ConnectionResetError, ConnectionAbortedError):
 99             return False, request
100         except socket.timeout:
101             print('time out %s' % ((request.getpeername()),))
102             return False, request
103         # self.request.close()
104         # self.server_logger.info('client %s is disconect' % ((self.cli_addr,)))
105         # print('client[%s:%s] is disconect' % (self.cli_addr[0], self.cli_addr[1]))
106 
107     def unpack_header(self):
108         try:
109             pack_obj = self.request.recv(4)
110             header_size = struct.unpack('i', pack_obj)[0]
111             header_bytes = self.request.recv(header_size)
112             header_json = header_bytes.decode('utf-8')
113             header_dic = json.loads(header_json)
114             return header_dic
115         except struct.error:  # 避免客户端发送错误格式的header_size
116             return
117 
118     def unpack_info(self, info_size):
119         recv_size = 0
120         info_bytes = b''
121         while recv_size < info_size:
122             res = self.request.recv(1024)
123             info_bytes += res
124             recv_size += len(res)
125         info_json = info_bytes.decode('utf-8')
126         info_dic = json.loads(info_json)  # {'username':ton, 'password':123}
127         info_md5 = commons.getStrsMd5(info_bytes)
128         return info_dic, info_md5
129 
130     def recv_request(self):
131         header_dic = self.unpack_header()  # {'cmd':'register','info_size':0}
132         if not header_dic: return None, None
133         req_dic, info_md5 = self.unpack_info(header_dic['info_size'])
134         if header_dic.get('md5'):
135             # 校检请求内容md5一致性
136             if info_md5 == header_dic['md5']:
137                 pass
138             # print('\033[42;1m请求内容md5校检结果一致\033[0m')
139             else:
140                 pass
141                 # print('\033[31;1m请求内容md5校检结果不一致\033[0m')
142         return header_dic, req_dic
143 
144     def response(self, **kwargs):
145         rsp_info = kwargs
146         rsp_bytes = commons.getDictBytes(rsp_info)
147         md5 = commons.getStrsMd5(rsp_bytes)
148         header_size_pack, header_bytes = commons.make_header(info_size=len(rsp_bytes), md5=md5)
149         self.request.sendall(header_size_pack)
150         self.request.sendall(header_bytes)
151         self.request.sendall(rsp_bytes)
152 
153     def register(self, header_dic, req_dic):  # {'cmd':'register','info_size':0,'resultCode':0,'resultDesc':None}
154         username = req_dic['user_info']['username']
155         # 更新数据库,并制作响应信息字典
156         if not os.path.isfile(os.path.join(settings.db_file, '%s.json' % username)):
157             # 更新数据库
158             user_info = dict()
159             user_info['username'] = username
160             user_info['password'] = req_dic['user_info']['password']
161             user_info['home'] = os.path.join(settings.user_home_dir, username)
162             user_info['quota'] = settings.user_quota * (1024 * 1024)
163             commons.save_to_file(user_info, os.path.join(settings.db_file, '%s.json' % username))
164             resultCode = 0
165             resultDesc = None
166             # 创建家目录
167             if not os.path.exists(os.path.join(settings.user_home_dir, username)):
168                 os.mkdir(os.path.join(settings.user_home_dir, username))
169             self.server_logger.info('client[%s:%s] 注册用户[%s]成功' % (self.cli_addr[0], self.cli_addr[1], username))
170         else:
171             resultCode = 1
172             resultDesc = '该用户已存在,注册失败'
173             self.server_logger.warning('client[%s:%s] 注册用户[%s]失败:%s' % (self.cli_addr[0], self.cli_addr[1],
174                                                                         username, resultDesc))
175         # 响应客户端注册请求
176         self.response(resultCode=resultCode, resultDesc=resultDesc)
177 
178     @staticmethod
179     def auth(req_dic):
180         # print(req_dic['user_info'])
181         user_info = None
182         status_code = 201
183         try:
184             req_username = req_dic['user_info']['username']
185             db_file = os.path.join(settings.db_file, '%s.json' % req_username)
186             # 验证用户名密码,并制作响应信息字典
187             if not os.path.isfile(db_file):
188                 status_code = 202
189             else:
190                 with open(db_file, 'r') as f:
191                     user_info_db = json.load(f)
192                 if user_info_db['password'] == req_dic['user_info']['password']:
193                     status_code = 200
194                     user_info = user_info_db
195             return status_code, user_info
196         # 捕获  客户端鉴权请求时发送一个空字典或错误的字典  的异常
197         except KeyError:
198             return 201, user_info
199 
200     def login(self, header_dic, req_dic):
201         # 鉴权
202         status_code, user_info = self.auth(req_dic)
203         # 响应客户端登陆请求
204         self.response(user_info=user_info, resultCode=status_code)
205 
206     def query_quota(self, header_dic, req_dic):
207         used_quota = None
208         total_quota = None
209         # 鉴权
210         status_code, user_info = self.auth(req_dic)
211         # 查询配额
212         if status_code == 200:
213             used_quota = commons.getFileSize(user_info['home'])
214             total_quota = user_info['quota']
215         # 响应客户端配额查询请求
216         self.response(resultCode=status_code, total_quota=total_quota, used_quota=used_quota)
217 
218     @staticmethod
219     def parse_file_path(req_path, cur_path):
220         req_path = req_path.replace(r'/', '\\')
221         req_path = req_path.replace(r'//', r'/', )
222         req_path = req_path.replace('\\\\', '\\')
223         req_path = req_path.replace('~\\', '', 1)
224         req_path = req_path.replace(r'~', '', 1)
225         req_paths = re.findall(r'[^\\]+', req_path)
226         cur_paths = re.findall(r'[^\\]+', cur_path)
227         cur_paths.extend(req_paths)
228         cur_paths[0] += '\\'
229         while '.' in cur_paths:
230             cur_paths.remove('.')
231         while '..' in cur_paths:
232             for index, item in enumerate(cur_paths):
233                 if item == '..':
234                     cur_paths.pop(index)
235                     cur_paths.pop(index - 1)
236                     break
237         return cur_paths
238 
239     def cd(self, header_dic, req_dic):
240         cmds = header_dic['cmd'].split()
241         # 鉴权
242         status_code, user_info = self.auth(req_dic)
243         home = os.path.join(settings.user_home_dir, user_info['username'])
244         # 先定义响应信息
245         returnPath = req_dic['user_info']['cur_path']
246         if status_code == 200:
247             if len(cmds) != 1:
248                 # 解析cd的真实路径
249                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
250                 cd_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
251                 print('cd解析后的路径:', cd_path)
252                 if os.path.isdir(cd_path):
253                     if home in cd_path:
254                         resultCode = 300
255                         returnPath = cd_path.replace('%s\\' % settings.user_home_dir, '', 1)
256                     else:
257                         resultCode = 302
258                 else:
259                     resultCode = 302
260             else:
261                 resultCode = 301
262         else:
263             resultCode = 201
264         # 响应客户端的cd命令结果
265         print('cd发送给客户端的路径:', returnPath)
266         self.response(resultCode=resultCode, returnPath=returnPath)
267 
268     def ls(self, header_dic, req_dic):
269         cmds = header_dic['cmd'].split()
270         # 鉴权
271         status_code, user_info = self.auth(req_dic)
272         home = os.path.join(settings.user_home_dir, user_info['username'])
273         # 先定义响应信息
274         returnFilenames = None
275         if status_code == 200:
276             if len(cmds) <= 2:
277                 # 解析ls的真实路径
278                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
279                 if len(cmds) == 2:
280                     ls_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
281                 else:
282                     ls_path = cur_path
283                 print('ls解析后的路径:', ls_path)
284                 if os.path.isdir(ls_path):
285                     if home in ls_path:
286                         returnCode, filenames = commons.getFile(ls_path, home)
287                         resultCode = 300
288                         returnFilenames = filenames
289                     else:
290                         resultCode = 302
291                 else:
292                     resultCode = 302
293             else:
294                 resultCode = 301
295         else:
296             resultCode = 201
297         # 响应客户端的ls命令结果
298         time.sleep(5)
299         self.response(resultCode=resultCode, returnFilenames=returnFilenames)
300 
301     def rm(self, header_dic, req_dic):
302         cmds = header_dic['cmd'].split()
303         # 鉴权
304         status_code, user_info = self.auth(req_dic)
305         home = os.path.join(settings.user_home_dir, user_info['username'])
306         # 先定义响应信息
307         if status_code == 200:
308             if len(cmds) == 2:
309                 # 解析rm的真实路径
310                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
311                 rm_path = os.path.join('', *self.parse_file_path(os.path.dirname(cmds[1]), cur_path))
312                 rm_file = os.path.join(rm_path, os.path.basename(cmds[1]))
313                 print('rm解析后的文件或文件夹:', rm_file)
314                 if os.path.exists(rm_file):
315                     if home in rm_file:
316                         commons.rmdirs(rm_file)
317                         resultCode = 300
318                     else:
319                         resultCode = 302
320                 else:
321                     resultCode = 302
322             else:
323                 resultCode = 301
324         else:
325             resultCode = 201
326         # 响应客户端的rm命令结果
327         self.response(resultCode=resultCode)
328 
329     def mkdir(self, header_dic, req_dic):
330         cmds = header_dic['cmd'].split()
331         # 鉴权
332         status_code, user_info = self.auth(req_dic)
333         home = os.path.join(settings.user_home_dir, user_info['username'])
334         # 先定义响应信息
335         if status_code == 200:
336             if len(cmds) == 2:
337                 # 解析rm的真实路径
338                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
339                 mkdir_path = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
340                 print('mkdir解析后的文件夹:', mkdir_path)
341                 if not os.path.isdir(mkdir_path):
342                     if home in mkdir_path:
343                         os.makedirs(mkdir_path)
344                         resultCode = 300
345                     else:
346                         resultCode = 302
347                 else:
348                     resultCode = 303
349             else:
350                 resultCode = 301
351         else:
352             resultCode = 201
353         # 响应客户端的mkdir命令结果
354         self.response(resultCode=resultCode)
355 
356     def get(self, header_dic, req_dic):
357         """客户端下载文件"""
358         cmds = header_dic['cmd'].split()  # ['get', 'a.txt', 'download']
359         get_file = None
360         # 鉴权
361         status_code, user_info = self.auth(req_dic)
362         home = os.path.join(settings.user_home_dir, user_info['username'])
363         # 解析断点续传信息
364         position = 0
365         if req_dic['resume'] and isinstance(req_dic['position'], int):
366             position = req_dic['position']
367         # 先定义响应信息
368         resultCode = 300
369         FileSize = None
370         FileMd5 = None
371         if status_code == 200:
372             if 1 < len(cmds) < 4:
373                 # 解析需要get文件的真实路径
374                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
375                 get_file = os.path.join('', *self.parse_file_path(cmds[1], cur_path))
376                 print('get解析后的路径:', get_file)
377                 if os.path.isfile(get_file):
378                     if home in get_file:
379                         FileSize = commons.getFileSize(get_file)
380                         if position >= FileSize != 0:
381                             resultCode = 304
382                         else:
383                             resultCode = 300
384                             FileSize = FileSize
385                             FileMd5 = commons.getFileMd5(get_file)
386                     else:
387                         resultCode = 302
388                 else:
389                     resultCode = 302
390             else:
391                 resultCode = 301
392         else:
393             resultCode = 201
394         # 响应客户端的get命令结果
395         self.response(resultCode=resultCode, FileSize=FileSize, FileMd5=FileMd5)
396         if resultCode == 300:
397             # 发送文件数据
398             with open(get_file, 'rb') as f:
399                 f.seek(position)
400                 for line in f:
401                     self.request.send(line)
402 
403     def put(self, header_dic, req_dic):
404         cmds = header_dic['cmd'].split()  # ['put', 'download/a.txt', 'video']
405         put_file = None
406         # 鉴权
407         status_code, user_info = self.auth(req_dic)
408         home = os.path.join(settings.user_home_dir, user_info['username'])
409         # 查询配额
410         used_quota = commons.getFileSize(user_info['home'])
411         total_quota = user_info['quota']
412         # 先定义响应信息
413         if status_code == 200:
414             if 1 < len(cmds) < 4:
415                 # 解析需要put文件的真实路径
416                 cur_path = os.path.join(settings.user_home_dir, req_dic['user_info']['cur_path'])
417                 if len(cmds) == 3:
418                     put_file = os.path.join(os.path.join('', *self.parse_file_path(cmds[2], cur_path)),
419                                             os.path.basename(cmds[1]))
420                 else:
421                     put_file = os.path.join(cur_path, os.path.basename(cmds[1]))
422                 print('put解析后的文件:', put_file)
423                 put_path = os.path.dirname(put_file)
424                 if os.path.isdir(put_path):
425                     if home in put_path:
426                         if (req_dic['FileSize'] + used_quota) <= total_quota:
427                             resultCode = 300
428                         else:
429                             resultCode = 305
430                     else:
431                         resultCode = 302
432                 else:
433                     resultCode = 302
434             else:
435                 resultCode = 301
436         else:
437             resultCode = 201
438         # 响应客户端的put命令结果
439         self.response(resultCode=resultCode)
440         if resultCode == 300:
441             # 接收文件数据,写入文件
442             recv_size = 0
443             with open(put_file, 'wb') as f:
444                 while recv_size < req_dic['FileSize']:
445                     file_data = self.request.recv(1024)
446                     f.write(file_data)
447                     recv_size += len(file_data)
448             # 校检文件md5一致性
449             if commons.getFileMd5(put_file) == req_dic['FileMd5']:
450                 resultCode = 400
451                 print('\033[42;1m文件md5校检结果一致\033[0m')
452                 print('\033[42;1m文件上传成功,大小:%d,文件名:%s\033[0m' % (req_dic['FileSize'], put_file))
453             else:
454                 os.remove(put_file)
455                 resultCode = 401
456                 print('\033[31;1m文件md5校检结果不一致\033[0m')
457                 print('\033[42;1m文件上传失败\033[0m')
458             # 返回上传文件是否成功响应
459             self.response(resultCode=resultCode)
server.py

 





阅读(812) 评论(0)