前言:

                                                   

 

 

每个公司的网络环境大都划分 办公网络、线上网络,之所以划分的主要原因是为了保证线上操作安全;

对于外部用户而言也只能访问线上网络的特定开放端口,那么是什么控制了用户访问线上网络的呢?

防火墙过滤......!

 

对于内部员工而言对线上系统日常运维、代码部署如何安全访问线上业务系统呢?如何监控、记录技术人员的操作记录?

 

堡垒机策略:

1.回收所有远程登录Linux主机的用户名、密码;

2.中间设置堡垒机(保存所有线上Linux主机的用户名、密码);

3.所有技术人员都要通过堡垒机去获取用户名、密码,然后在再去连接 线上系统,并记录操作日志;

 

堡垒机策略优点:

1.记录用户操作;

2.实现远程操作权限集中管理;

 

一、堡垒机表结构设计

 

from django.db import models
from django.contrib.auth.models import  User
# Create your models here.


class IDC(models.Model):
    name = models.CharField(max_length=64,unique=True)
    def __str__(self):
        return self.name

class Host(models.Model):
    """存储所有主机信息"""
    hostname = models.CharField(max_length=64,unique=True)
    ip_addr = models.GenericIPAddressField(unique=True)
    port = models.IntegerField(default=22)
    idc = models.ForeignKey("IDC")
    #host_groups = models.ManyToManyField("HostGroup")
    #host_users = models.ManyToManyField("HostUser")
    enabled = models.BooleanField(default=True)

    def __str__(self):
        return "%s-%s" %(self.hostname,self.ip_addr)

class HostGroup(models.Model):
    """主机组"""
    name = models.CharField(max_length=64,unique=True)
    host_user_binds  = models.ManyToManyField("HostUserBind")
    def __str__(self):
        return self.name


class HostUser(models.Model):
    """存储远程主机的用户信息
    root 123
    root abc
    root sfsfs
    """
    auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
    auth_type = models.SmallIntegerField(choices=auth_type_choices)
    username = models.CharField(max_length=32)
    password = models.CharField(blank=True,null=True,max_length=128)

    def __str__(self):
        return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)

    class Meta:
        unique_together = ('username','password')


class HostUserBind(models.Model):
    """绑定主机和用户"""
    host = models.ForeignKey("Host")
    host_user = models.ForeignKey("HostUser")

    def __str__(self):
        return "%s-%s" %(self.host,self.host_user)

    class Meta:
        unique_together = ('host','host_user')


class SessionLog(models.Model):
    ''' 记录每个用户登录操作,ID传给 shell生成文件命名 '''
    account=models.ForeignKey('Account')
    host_user_bind=models.ForeignKey('HostUserBind')
    start_date=models.DateField(auto_now_add=True)
    end_date=models.DateField(blank=True,null=True)

    def __str__(self):
        return '%s-%s'%(self.account,self.host_user_bind)

class AuditLog(models.Model):
    """审计日志"""


class Account(models.Model):
    """堡垒机账户
    1. 扩展
    2. 继承
    user.account.host_user_bind
    """

    user = models.OneToOneField(User)
    name = models.CharField(max_length=64)

    host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
    host_groups = models.ManyToManyField("HostGroup",blank=True)
models.py

 

 

 二、通过堡垒机远程登录Linux主机

2种堡垒机登录方式:

 

命令行登录堡垒机方式:

方式1:通过 修改open_shh源码扩展-Z option生成唯一 ssh进程,使用Linux的strace 命令对唯一 ssh进程进行检测生成日志文件;

 

0.用户执行audit_shell出现交互界面,提示用户输入机组和主机;

import sys,os,django
os.environ.setdefault("DJANGO_SETTINGS_MODULE","zhanggen_audit.settings")
django.setup() #在Django视图之外,调用Django功能设置环境变量!
from audit.backend import user_interactive


if __name__ == '__main__':
    shell_obj=user_interactive.UserShell(sys.argv)
    shell_obj.start()
audit_shell.py
from django.contrib.auth import authenticate

class UserShell(object):
    '''用户登录堡垒机,启动自定制shell  '''
    def __init__(self,sys_argv):
        self.sys_argv=sys_argv
        self.user=None

    def auth(self):
        count=0
        while count < 3:
            username=input('username:').strip()
            password=input('password:').strip()
            user=authenticate(username=username,password=password)
            #none 代表认证失败,返回用户对象认证成功!
            if not user:
                count+=1
                print('无效的用户名或者,密码!')
            else:
                self.user=user
                return True
        else:
            print('输入次数超过3次!')


    def start(self):
        """启动交互程序"""

        if self.auth():
            # print(self.user.account.host_user_binds.all()) #select_related()
            while True:
                host_groups = self.user.account.host_groups.all()
                for index, group in enumerate(host_groups):
                    print("%s.\t%s[%s]" % (index, group, group.host_user_binds.count()))
                print("%s.\t未分组机器[%s]" % (len(host_groups), self.user.account.host_user_binds.count()))

                choice = input("select group>:").strip()
                if choice.isdigit():
                    choice = int(choice)
                    host_bind_list = None
                    if choice >= 0 and choice < len(host_groups):
                        selected_group = host_groups[choice]
                        host_bind_list = selected_group.host_user_binds.all()
                    elif choice == len(host_groups):  # 选择的未分组机器
                        # selected_group = self.user.account.host_user_binds.all()
                        host_bind_list = self.user.account.host_user_binds.all()

                    if host_bind_list:
                        while True:
                            for index, host in enumerate(host_bind_list):
                                print("%s.\t%s" % (index, host,))
                            choice2 = input("select host>:").strip()
                            if choice2.isdigit():
                                choice2 = int(choice2)
                                if choice2 >= 0 and choice2 < len(host_bind_list):
                                    selected_host = host_bind_list[choice2]
                                    print("selected host", selected_host)
                            elif choice2 == 'b':
                                break
user_interactive.py

 

知识点:

在Django视图之外,调用Django功能设置环境变量!(切记放在和Django manage.py 同级目录); 

import sys,os,django
os.environ.setdefault("DJANGO_SETTINGS_MODULE","Sensors_Data.settings")
django.setup()  # 在Django视图之外,调用Django功能设置环境变量!
from app01 import models
objs=models.AlarmInfo.objects.all()
for row in objs:
    print(row.comment)

 

 

注意:在Django启动时会自动加载一些 文件,比如每个app中admin.py,不能在这些文件里面设置加载环境变量,因为已经加载完了,如果违反这个规则会导致Django程序启动失败;

 

 

 

1.实现ssh用户指令检测

1.0  修改open_shh源码,扩展 ssh -Z  唯一标识符;(这样每次ssh远程登录,都可以利用唯一标识符,分辨出 每个ssh会话进程;)

修改OpenSsh下的ssh.c文件的608和609行、935行增加;
    while ((opt = getopt(ac, av, "1246ab:c:e:fgi:kl:m:no:p:qstvxz:"
        "ACD:E:F:GI:J:KL:MNO:PQ:R:S:TVw:W:XYyZ:")) != -1) {

        case 'Z':
            break;
ssh.c 

 

知识点:

OpenSSH 是 SSH (Secure SHell) 协议的免费开源实现项目。

 

 

1.1  修改openssh之后,编译、安装

chmod 755 configure
./configure --prefix=/usr/local/openssh
make
chmod 755 mkinstalldirs
make install
sshpass -p xxxxxx123 /usr/local/openssh/bin/ssh root@172.17.10.112 -Z s1123ssssd212

 

 

1.2  每个ssh会话进程可以唯一标识之后,在堡垒机使用会话脚本shell脚本检测 ssh会话进程;(strace命令进行监控,并生产 log日志文件);

#!/usr/bin/bash

for i in $(seq 1 30);do
    echo $i $1
    process_id=`ps -ef | grep $1 | grep -v 'ession_check.sh' | grep -v grep | grep -v sshpass | awk '{print $2}'`

    echo "process: $process_id"

    if [ ! -z "$process_id" ];then
        echo 'start run strace.....'
        strace -fp $process_id -t -o $2.log;
        break;
    fi

    sleep 5

done;
ssh 会话检测脚本

 

知识点:

strace 检测进程的IO调用,监控用户shell输入的命令字符;

 strace -fp 60864 -o /ssh.log 
 cat /ssh.log |grep 'write(8'
 rz -E #从xshell上传文件

 

 sshpass无需提示输入密码登录

[root@localhost sshpass-1.06]# sshpass -p wsnb ssh root@172.16.22.1  -o StrictHostKeyChecking=no 
Last login: Tue Jul 10 16:39:53 2018 from 192.168.113.84
[root@ecdb ~]# 

 

python生成唯一标识符

s=string.ascii_lowercase+string.digits
random_tag=''.join(random.sample(s,10))

 

 

解决普通用户,无法执行 strace命令;

方式1:执行文件  +s权限

chmod u+s `which strace`

 

方式2:修改sudo配置文件,使普通用户sudo时无需输入密码!

修改sudo配置文件,防止修改出错,一定要切换到root用户;


%普通用户  ALL=(ALL)       NOPASSWD: ALL

wq! #退出
vim /etc/sudoers

 

 

 

 

 

#!/usr/bin/python3
# -*- coding: utf-8 -*
from django.contrib.auth import authenticate
import subprocess,string,random
from audit import models
from django.conf import settings
class UserShell(object):
    '''用户登录堡垒机,启动自定制shell  '''
    def __init__(self,sys_argv):
        self.sys_argv=sys_argv
        self.user=None

    def auth(self):
        count=0
        while count < 3:
            username=input('username:').strip()
            password=input('password:').strip()
            user=authenticate(username=username,password=password)
            #none 代表认证失败,返回用户对象认证成功!
            if not user:
                count+=1
                print('无效的用户名或者,密码!')
            else:
                self.user=user
                return True
        else:
            print('输入次数超过3次!')


    def start(self):
        """启动交互程序"""

        if self.auth():
            # print(self.user.account.host_user_binds.all()) #select_related()
            while True:
                host_groups = self.user.account.host_groups.all()
                for index, group in enumerate(host_groups):
                    print("%s.\t%s[%s]" % (index, group, group.host_user_binds.count()))
                print("%s.\t未分组机器[%s]" % (len(host_groups), self.user.account.host_user_binds.count()))

                choice = input("select group>:").strip()
                if choice.isdigit():
                    choice = int(choice)
                    host_bind_list = None
                    if choice >= 0 and choice < len(host_groups):
                        selected_group = host_groups[choice]
                        host_bind_list = selected_group.host_user_binds.all()
                    elif choice == len(host_groups):  # 选择的未分组机器
                        # selected_group = self.user.account.host_user_binds.all()
                        host_bind_list = self.user.account.host_user_binds.all()

                    if host_bind_list:
                        while True:
                            for index, host in enumerate(host_bind_list):
                                print("%s.\t%s" % (index, host,))
                            choice2 = input("select host>:").strip()
                            if choice2.isdigit():
                                choice2 = int(choice2)
                                if choice2 >= 0 and choice2 < len(host_bind_list):
                                    selected_host = host_bind_list[choice2]
                                    s = string.ascii_lowercase + string.digits
                                    random_tag = ''.join(random.sample(s, 10))
                                    session_obj=models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)

                                    session_tracker_scipt='/bin/sh %s %s %s'%(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.pk)

                                    session_tracker_process=subprocess.Popen(session_tracker_scipt,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                    cmd='sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o stricthostkeychecking=no -Z %s' % (selected_host.host_user.password,
                                                                                                             selected_host.host_user.username,
                                                                                                             selected_host.host.ip_addr,
                                                                                                             selected_host.host.port,random_tag)
                                    subprocess.run(cmd,shell=True)#开启子进程交互
                                    print(session_tracker_process.stdout.readlines(),
                                          session_tracker_process.stderr.readlines())


                            elif choice2 == 'b':
                                break
汇总

 

2.shell远程登录程序检查日志文件,分析;

tab补全的命令,需要搜素write(5,该脚本实现思路,按键去尝试,循环多种条件判断;

import re


class AuditLogHandler(object):
    '''分析audit log日志'''

    def __init__(self,log_file):
        self.log_file_obj = self._get_file(log_file)


    def _get_file(self,log_file):

        return open(log_file)

    def parse(self):
        cmd_list = []
        cmd_str = ''
        catch_write5_flag = False #for tab complication
        for line in self.log_file_obj:
            #print(line.split())
            line = line.split()
            try:
                pid,time_clock,io_call,char = line[0:4]
                if io_call.startswith('read(4'):
                    if char == '"\\177",':#回退
                        char = '[1<-del]'
                    if char == '"\\33OB",': #vim中下箭头
                        char = '[down 1]'
                    if char == '"\\33OA",': #vim中下箭头
                        char = '[up 1]'
                    if char == '"\\33OC",': #vim中右移
                        char = '[->1]'
                    if char == '"\\33OD",': #vim中左移
                        char = '[1<-]'
                    if char == '"\33[2;2R",': #进入vim模式
                        continue
                    if char == '"\\33[>1;95;0c",':  # 进入vim模式
                        char = '[----enter vim mode-----]'


                    if char == '"\\33[A",': #命令行向上箭头
                        char = '[up 1]'
                        catch_write5_flag = True #取到向上按键拿到的历史命令
                    if char == '"\\33[B",':  # 命令行向上箭头
                        char = '[down 1]'
                        catch_write5_flag = True  # 取到向下按键拿到的历史命令
                    if char == '"\\33[C",':  # 命令行向右移动1位
                        char = '[->1]'
                    if char == '"\\33[D",':  # 命令行向左移动1位
                        char = '[1<-]'

                    cmd_str += char.strip('"",')
                    if char == '"\\t",':
                        catch_write5_flag = True
                        continue
                    if char == '"\\r",':
                        cmd_list.append([time_clock,cmd_str])
                        cmd_str = ''  # 重置
                    if char == '"':#space
                        cmd_str += ' '

                if catch_write5_flag: #to catch tab completion
                    if io_call.startswith('write(5'):
                        if io_call == '"\7",': #空键,不是空格,是回退不了就是这个键
                            pass
                        else:
                            cmd_str += char.strip('"",')
                        catch_write5_flag = False
            except ValueError as e:
                print("\033[031;1mSession log record err,please contact your IT admin,\033[0m",e)

        #print(cmd_list)
        for cmd in cmd_list:
            print(cmd)
        return cmd_list
if __name__ == "__main__":
    parser = AuditLogHandler(r'D:\zhanggen_audit\log\6.log')
    parser.parse()
日志分析

 

3.修改bashrc文件,限制用户登录行为;

alias rm='rm -i'
alias cp='cp -i'
alias mv='mv -i'

# Source global definitions
if [ -f /etc/bashrc ]; then
        . /etc/bashrc
fi



echo '-----------------------welcome  to  zhanggen  audit  --------------------------'

python3 /root/zhanggen_audit/audit_shell.py

echo 'bye'

logout
vim ~/.bashrc

 

缺陷:

虽然限制了用户shell登录,但无法阻止用户使用程序(paramiko)上传恶意文件!

 

 

 

方式2:提取paramiko源码demos文件,对其进行修改支持交互式操作;

from django.db import models
from django.contrib.auth.models import  User
# Create your models here.


class IDC(models.Model):
    name = models.CharField(max_length=64,unique=True)
    def __str__(self):
        return self.name

class Host(models.Model):
    """存储所有主机信息"""
    hostname = models.CharField(max_length=64,unique=True)
    ip_addr = models.GenericIPAddressField(unique=True)
    port = models.IntegerField(default=22)
    idc = models.ForeignKey("IDC")
    #host_groups = models.ManyToManyField("HostGroup")
    #host_users = models.ManyToManyField("HostUser")
    enabled = models.BooleanField(default=True)

    def __str__(self):
        return "%s-%s" %(self.hostname,self.ip_addr)

class HostGroup(models.Model):
    """主机组"""
    name = models.CharField(max_length=64,unique=True)
    host_user_binds  = models.ManyToManyField("HostUserBind")
    def __str__(self):
        return self.name


class HostUser(models.Model):
    """存储远程主机的用户信息
    root 123
    root abc
    root sfsfs
    """
    auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
    auth_type = models.SmallIntegerField(choices=auth_type_choices)
    username = models.CharField(max_length=32)
    password = models.CharField(blank=True,null=True,max_length=128)

    def __str__(self):
        return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)

    class Meta:
        unique_together = ('username','password')


class HostUserBind(models.Model):
    """绑定主机和用户"""
    host = models.ForeignKey("Host")
    host_user = models.ForeignKey("HostUser")

    def __str__(self):
        return "%s-%s" %(self.host,self.host_user)

    class Meta:
        unique_together = ('host','host_user')


class AuditLog(models.Model):
    """审计日志"""
    session = models.ForeignKey("SessionLog")
    cmd = models.TextField()
    date = models.DateTimeField(auto_now_add=True)
    def __str__(self):
        return "%s-%s" %(self.session,self.cmd)


class SessionLog(models.Model):
    account = models.ForeignKey("Account")
    host_user_bind = models.ForeignKey("HostUserBind")
    start_date = models.DateTimeField(auto_now_add=True)
    end_date = models.DateTimeField(blank=True,null=True)

    def __str__(self):
        return "%s-%s" %(self.account,self.host_user_bind)


class Account(models.Model):
    """堡垒机账户
    1. 扩展
    2. 继承
    user.account.host_user_bind
    """

    user = models.OneToOneField(User)
    name = models.CharField(max_length=64)

    host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
    host_groups = models.ManyToManyField("HostGroup",blank=True)
model.py

 

__author__ = 'Administrator'
import subprocess,random,string
from django.contrib.auth import authenticate
from django.conf import settings 
from audit import models
from audit.backend import ssh_interactive 

class UserShell(object):
    """用户登录堡垒机后的shell"""

    def __init__(self,sys_argv):
        self.sys_argv = sys_argv
        self.user = None

    def auth(self):

        count = 0
        while count < 3:
            username = input("username:").strip()
            password = input("password:").strip()
            user = authenticate(username=username,password=password)
            #None 代表认证不成功
            #user object ,认证对象 ,user.name
            if not user:
                count += 1
                print("Invalid username or password!")
            else:
                self.user = user
                return  True
        else:
            print("too many attempts.")

    def start(self):
        """启动交互程序"""

        if self.auth():
            #print(self.user.account.host_user_binds.all()) #select_related()
            while True:
                host_groups = self.user.account.host_groups.all()
                for index,group in enumerate(host_groups):
                    print("%s.\t%s[%s]"%(index,group,group.host_user_binds.count()))
                print("%s.\t未分组机器[%s]"%(len(host_groups),self.user.account.host_user_binds.count()))
                try:
                    choice = input("select group>:").strip()
                    if choice.isdigit():
                        choice = int(choice)
                        host_bind_list = None
                        if choice >=0 and choice < len(host_groups):
                            selected_group = host_groups[choice]
                            host_bind_list = selected_group.host_user_binds.all()
                        elif choice == len(host_groups): #选择的未分组机器
                            #selected_group = self.user.account.host_user_binds.all()
                            host_bind_list = self.user.account.host_user_binds.all()
                        if host_bind_list:
                            while True:
                                for index,host in enumerate(host_bind_list):
                                    print("%s.\t%s"%(index,host,))
                                choice2 = input("select host>:").strip()
                                if choice2.isdigit():
                                    choice2 = int(choice2)
                                    if choice2 >=0 and choice2 < len(host_bind_list):
                                        selected_host = host_bind_list[choice2]

                                        ssh_interactive.ssh_session(selected_host,self.user)


                                        # s = string.ascii_lowercase +string.digits
                                        # random_tag = ''.join(random.sample(s,10))
                                        # session_obj = models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)
                                        #
                                        # cmd = "sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o StrictHostKeyChecking=no -Z %s" %(selected_host.host_user.password,selected_host.host_user.username,selected_host.host.ip_addr,selected_host.host.port ,random_tag)
                                        # #start strace ,and sleep 1 random_tag, session_obj.id
                                        # session_tracker_script = "/bin/sh %s %s %s " %(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.id)
                                        #
                                        # session_tracker_obj =subprocess.Popen(session_tracker_script, shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                        #
                                        # ssh_channel = subprocess.run(cmd,shell=True)
                                        # print(session_tracker_obj.stdout.read(), session_tracker_obj.stderr.read())
                                        #
                                elif choice2 == 'b':
                                    break

                except KeyboardInterrupt as e :
                    pass
user_interactive.py
#!/usr/bin/env python

# Copyright (C) 2003-2007  Robey Pointer <robeypointer@gmail.com>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.


import base64
from binascii import hexlify
import getpass
import os
import select
import socket
import sys
import time
import traceback
from paramiko.py3compat import input
from audit import models
import paramiko

try:
    import interactive
except ImportError:
    from . import interactive


def manual_auth(t, username, password):
    # default_auth = 'p'
    # auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
    # if len(auth) == 0:
    #     auth = default_auth
    #
    # if auth == 'r':
    #     default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
    #     path = input('RSA key [%s]: ' % default_path)
    #     if len(path) == 0:
    #         path = default_path
    #     try:
    #         key = paramiko.RSAKey.from_private_key_file(path)
    #     except paramiko.PasswordRequiredException:
    #         password = getpass.getpass('RSA key password: ')
    #         key = paramiko.RSAKey.from_private_key_file(path, password)
    #     t.auth_publickey(username, key)
    # elif auth == 'd':
    #     default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
    #     path = input('DSS key [%s]: ' % default_path)
    #     if len(path) == 0:
    #         path = default_path
    #     try:
    #         key = paramiko.DSSKey.from_private_key_file(path)
    #     except paramiko.PasswordRequiredException:
    #         password = getpass.getpass('DSS key password: ')
    #         key = paramiko.DSSKey.from_private_key_file(path, password)
    #     t.auth_publickey(username, key)
    # else:
    # pw = getpass.getpass('Password for %s@%s: ' % (username, hostname))
    t.auth_password(username, password)


def ssh_session(bind_host_user, user_obj):
    # now connect
    hostname = bind_host_user.host.ip_addr #自动输入 主机名
    port = bind_host_user.host.port        #端口
    username = bind_host_user.host_user.username
    password = bind_host_user.host_user.password

    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #生成socket连接
        sock.connect((hostname, port))
    except Exception as e:
        print('*** Connect failed: ' + str(e))
        traceback.print_exc()
        sys.exit(1)

    try:
        t = paramiko.Transport(sock) #使用paramiko的方法去连接服务器执行命令!
        try:
            t.start_client()
        except paramiko.SSHException:
            print('*** SSH negotiation failed.')
            sys.exit(1)

        try:
            keys = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
        except IOError:
            try:
                keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
            except IOError:
                print('*** Unable to open host keys file')
                keys = {}

        # check server's host key -- this is important.
        key = t.get_remote_server_key()
        if hostname not in keys:
            print('*** WARNING: Unknown host key!')
        elif key.get_name() not in keys[hostname]:
            print('*** WARNING: Unknown host key!')
        elif keys[hostname][key.get_name()] != key:
            print('*** WARNING: Host key has changed!!!')
            sys.exit(1)
        else:
            print('*** Host key OK.')

        if not t.is_authenticated():
            manual_auth(t, username, password) #密码校验
        if not t.is_authenticated():
            print('*** Authentication failed. :(')
            t.close()
            sys.exit(1)

        chan = t.open_session()
        chan.get_pty()  # terminal
        chan.invoke_shell()
        print('*** Here we go!\n')

        session_obj = models.SessionLog.objects.create(account=user_obj.account,
                                                       host_user_bind=bind_host_user)
        interactive.interactive_shell(chan, session_obj)#开始进入交换模式·
        chan.close()
        t.close()

    except Exception as e:
        print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
        traceback.print_exc()
        try:
            t.close()
        except:
            pass
        sys.exit(1)
ssh_interactive.py
# Copyright (C) 2003-2007  Robey Pointer <robeypointer@gmail.com>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.


import socket
import sys
from paramiko.py3compat import u
from audit import models
# windows does not have termios...
try:
    import termios
    import tty
    has_termios = True
except ImportError:
    has_termios = False


def interactive_shell(chan,session_obj):
    if has_termios: #
        posix_shell(chan,session_obj) #unix 通用协议标准
    else:
        windows_shell(chan)


def posix_shell(chan,session_obj):
    import select
    
    oldtty = termios.tcgetattr(sys.stdin)
    try:
        tty.setraw(sys.stdin.fileno())
        tty.setcbreak(sys.stdin.fileno())
        chan.settimeout(0.0)
        flag = False
        cmd = ''
        while True: #开始输入命令
            r, w, e = select.select([chan, sys.stdin], [], []) #循环检测 输入、输出、错误,有反应就返回,没有就一直夯住!

            if chan in r:#远程 由返回 命令结果
                try:
                    x = u(chan.recv(1024))
                    if len(x) == 0:
                        sys.stdout.write('\r\n*** EOF\r\n')
                        break
                    if flag: #如果用户输入的Tab补全,服务器端返回
                        cmd += x
                        flag = False
                    sys.stdout.write(x)
                    sys.stdout.flush()
                except socket.timeout:
                    pass


            if sys.stdin in r: #本地输入
                x = sys.stdin.read(1) #输入1个字符就发送远程服务器
                if len(x) == 0:
                    break
                if x == '\r': #回车·
                    models.AuditLog.objects.create(session=session_obj,cmd=cmd)
                    cmd = ''
                elif x == '\t':#tab 本地1个字符+远程返回的
                    flag = True
                else:
                    cmd += x
                chan.send(x) #发送本地输入 到远程服务器

    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)

    
# thanks to Mike Looijmans for this code
def windows_shell(chan):
    import threading

    sys.stdout.write("Line-buffered terminal emulation. Press F6 or ^Z to send EOF.\r\n\r\n")
        
    def writeall(sock):
        while True:
            data = sock.recv(256)
            if not data:
                sys.stdout.write('\r\n*** EOF ***\r\n\r\n')
                sys.stdout.flush()
                break
            sys.stdout.write(data)
            sys.stdout.flush()
        
    writer = threading.Thread(target=writeall, args=(chan,))
    writer.start()
        
    try:
        while True:
            d = sys.stdin.read(1)
            if not d:
                break
            chan.send(d)
    except EOFError:
        # user hit ^Z or F6
        pass
interactive.py

 

 

 程序流程:用户界面---------->ssh自动输入用户&登录密码---------->进入shell命令交互模式

 

知识点:

1对1:      1个 对应  1个   (1个女人嫁给了1个男人,生活慢慢平淡下来,)

1对多:      1个 对应  N个   (这个女人隐瞒丈夫相继出轨了N个男人,这个男人发现老婆出轨了,很愤懑)

多对多:     双方都存在1对多关系 (也相继找了N个女情人,而这些女情人中就有他老婆出轨男人的老婆,故事结束。)

 

感悟:

这个故事很混乱! 怎么设计男、女表结构?  其实在做数据库表关系设计的时候,纠结2张表到底需要设计成什么关系?到不如加几张关系绑定表!

完全是出于  你的程序在允许的过程中到底 要向用户展示什么信息? 而决定的!

 

 

web页面使用堡垒机方式:

web开发模式

1.MTV/MVC 前后端杂交模式;(面向公司内部OA)

优势:简单,一人全栈;

缺陷:前后端耦合性高,性能低、单点压力

 

2.前后端分离(面向大众用户)

优势:前、后端开发人员商定好接口和数据格式,并行开发,效率高;解决了后端独自渲染模板的压力;

缺陷:招前端得花钱

 

3.hostlist 展示主机组和主机

  <div class="panel col-lg-3">
            <div class="panel-heading">
                <h3 class="panel-title">主机组</h3>
            </div>
            <div class="panel-body">
                <ul class="list-group">
                {% for group in  request.user.account.host_groups.all %}

                    <li class="list-group-item " onclick="GetHostlist({{ group.id }},this)"><span class="badge badge-success">{{ group.host_user_binds.count }}</span>{{ group.name }}</li>
                {% endfor %}
                    <li class="list-group-item " onclick="GetHostlist(-1,this)"> <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>未分组主机</li>

                </ul>
            </div>
        </div>
在标签上绑定事件
<script>

function GetHostlist(gid,self) {

    $.get("{% url 'get_host_list' %}",{'gid':gid},function(callback){

        var data  = JSON.parse(callback);
        console.log(data)
        var trs = ''
        $.each(data,function (index,i) {
            var tr = "<tr><td>" + i.host__hostname + "</td><td>" + i.host__ip_addr +"</td><td>" + i.host__idc__name
                    +"</td><td>" + i.host__port  + "</td><td>" + i.host_user__username+ "</td><td>Login</td></tr>";
            trs += tr

        })
        $("#hostlist").html(trs);



    });//end get
    $(self).addClass("active").siblings().removeClass('active');

}

</script>
通过ajax向后端请求数据

 

知识点:

如果给标签绑定事件,需要传参数,可以直接在标签直接绑定。

url(r'^get_tocken$', views.get_tocken, name="get_tocken"),
Django路由别名
function GetToken(self,bind_host_id) {
    $.post(
        '{% url "get_tocken" %}',     //通过url别名渲染url
        {'bind_host_id':bind_host_id,'csrfmiddlewaretoken':"{{ csrf_token }}"},//请求携带的参数
        function (callback) {          //回调函数
            console.log(callback)
        }

        )
}
Django模板语言
@login_required
def get_token(request):
    bind_host_id=request.POST.get('bind_host_id')
    time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
    exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                   host_user_bind_id=bind_host_id,
                                                   date__gt=time_obj)
    if exist_token_objs:  # has token already
        token_data = {'token': exist_token_objs[0].val}
    else:
        token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))

        token_obj=models.Token.objects.create(
            host_user_bind_id=bind_host_id,
            account=request.user.account,
            val=token_val)
        token_data={"token":token_val}

    return HttpResponse(json.dumps(token_data))
生成5分钟之内生效的token

 

4.点击主机登录,通过Shellinabox 插件以web页面的形式远程登录Linux主机;

 4.0 安装sehllinabox

yum install git openssl-devel pam-devel zlib-devel autoconf automake libtool

git clone https://github.com/shellinabox/shellinabox.git && cd shellinabox

autoreconf -i

./configure && make

make install

shellinaboxd -b -t  //-b选项代表在后台启动,-t选项表示不使用https方式启动,默认以nobody用户身份,监听TCP4200端口

netstat -ntpl |grep shell

 

5.django结合sehll inabox

 

5.1:用户在Django的hostlist页面点击生成tocken(绑定了account+host_bind_user),记录到数据库。

5.2: 用户在Django的hostlist页面 login跳转至 sehll inabox由于修改了bashrc跳转之后,就会执行python用户交互程序,python用户交互程序 提示用户输入 token;

5.3: 用户输入token之后,python 用户交互程序去数据库查询token,进而查询到host_bind_user的ip、用户、密码,调用paramiko的demo.py自动输入ip、用户、密码进入shell交互界面;

from django.db import models
from django.contrib.auth.models import  User
# Create your models here.


class IDC(models.Model):
    name = models.CharField(max_length=64,unique=True)
    def __str__(self):
        return self.name

class Host(models.Model):
    """存储所有主机信息"""
    hostname = models.CharField(max_length=64,unique=True)
    ip_addr = models.GenericIPAddressField(unique=True)
    port = models.IntegerField(default=22)
    idc = models.ForeignKey("IDC")
    #host_groups = models.ManyToManyField("HostGroup")
    #host_users = models.ManyToManyField("HostUser")
    enabled = models.BooleanField(default=True)

    def __str__(self):
        return "%s-%s" %(self.hostname,self.ip_addr)

class HostGroup(models.Model):
    """主机组"""
    name = models.CharField(max_length=64,unique=True)
    host_user_binds  = models.ManyToManyField("HostUserBind")
    def __str__(self):
        return self.name


class HostUser(models.Model):
    """存储远程主机的用户信息
    root 123
    root abc
    root sfsfs
    """
    auth_type_choices = ((0,'ssh-password'),(1,'ssh-key'))
    auth_type = models.SmallIntegerField(choices=auth_type_choices)
    username = models.CharField(max_length=32)
    password = models.CharField(blank=True,null=True,max_length=128)

    def __str__(self):
        return "%s-%s-%s" %(self.get_auth_type_display(),self.username,self.password)

    class Meta:
        unique_together = ('username','password')


class HostUserBind(models.Model):
    """绑定主机和用户"""
    host = models.ForeignKey("Host")
    host_user = models.ForeignKey("HostUser")

    def __str__(self):
        return "%s-%s" %(self.host,self.host_user)

    class Meta:
        unique_together = ('host','host_user')


class AuditLog(models.Model):
    """审计日志"""
    session = models.ForeignKey("SessionLog")
    cmd = models.TextField()
    date = models.DateTimeField(auto_now_add=True)
    def __str__(self):
        return "%s-%s" %(self.session,self.cmd)


class SessionLog(models.Model):
    account = models.ForeignKey("Account")
    host_user_bind = models.ForeignKey("HostUserBind")
    start_date = models.DateTimeField(auto_now_add=True)
    end_date = models.DateTimeField(blank=True,null=True)

    def __str__(self):
        return "%s-%s" %(self.account,self.host_user_bind)


class Account(models.Model):
    """堡垒机账户
    1. 扩展
    2. 继承
    user.account.host_user_bind
    """

    user = models.OneToOneField(User)
    name = models.CharField(max_length=64)

    host_user_binds = models.ManyToManyField("HostUserBind",blank=True)
    host_groups = models.ManyToManyField("HostGroup",blank=True)



class Token(models.Model):
    host_user_bind = models.ForeignKey("HostUserBind")
    val = models.CharField(max_length=128,unique=True)
    account = models.ForeignKey("Account")
    expire = models.IntegerField("超时时间(s)",default=300)
    date = models.DateTimeField(auto_now_add=True)
    def __str__(self):
        return "%s-%s" %(self.host_user_bind,self.val)
models.py
@login_required
def get_token(request):
    bind_host_id=request.POST.get('bind_host_id')
    time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
    exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                   host_user_bind_id=bind_host_id,
                                                   date__gt=time_obj)
    if exist_token_objs:  # has token already
        token_data = {'token': exist_token_objs[0].val}
    else:
        token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))

        token_obj=models.Token.objects.create(
            host_user_bind_id=bind_host_id,
            account=request.user.account,
            val=token_val)
        token_data={"token":token_val}

    return HttpResponse(json.dumps(token_data))
View.py
{% extends 'index.html' %}



{% block content-container %}
    <div id="page-title">
        <h1 class="page-header text-overflow">主机列表</h1>

        <!--Searchbox-->
        <div class="searchbox">
            <div class="input-group custom-search-form">
                <input type="text" class="form-control" placeholder="Search..">
                <span class="input-group-btn">
                    <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                </span>
            </div>
        </div>
    </div>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End page title-->
        <!--Breadcrumb-->
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <ol class="breadcrumb">
        <li><a href="#">Home</a></li>
        <li><a href="#">Library</a></li>
        <li class="active">主机列表</li>
    </ol>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End breadcrumb-->

    <div id="page-content">

        <div class="panel col-lg-3">
            <div class="panel-heading">
                <h3 class="panel-title">主机组</h3>
            </div>
            <div class="panel-body">
                <ul class="list-group">
                {% for group in  request.user.account.host_groups.all %}

                    <li class="list-group-item " onclick="GetHostlist({{ group.id }},this)"><span class="badge badge-success">{{ group.host_user_binds.count }}</span>{{ group.name }}</li>
                {% endfor %}
                    <li class="list-group-item " onclick="GetHostlist(-1,this)"> <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>未分组主机</li>

                </ul>
            </div>
        </div>
        <div class="panel col-lg-9">
            <div class="panel-heading">
                <h3 class="panel-title">主机列表</h3>
            </div>
            <div class="panel-body">

                <div class="table-responsive">
                    <table class="table table-striped">
                        <thead>
                            <tr>
                                <th>Hostname</th>
                                <th>IP</th>
                                <th>IDC</th>
                                <th>Port</th>
                                <th>Username</th>
                                <th>Login</th>
                                <th>Token</th>
                            </tr>
                        </thead>
                        <tbody id="hostlist">
{#                            <tr>#}
{#                                <td><a href="#fakelink" class="btn-link">Order #53451</a></td>#}
{#                                <td>Scott S. Calabrese</td>#}
{#                                <td>$24.98</td>#}
{#                            </tr>#}

                        </tbody>
                    </table>
                </div>

            </div>
        </div>

    </div>



<script>

function GetHostlist(gid,self) {

    $.get("{% url 'get_host_list' %}",{'gid':gid},function(callback){

        var data  = JSON.parse(callback);
        console.log(data);
        var trs = '';
        $.each(data,function (index,i) {
            var tr = "<tr><td>" + i.host__hostname + "</td><td>" + i.host__ip_addr +"</td><td>" + i.host__idc__name
                    +"</td><td>" + i.host__port  + "</td><td>" + i.host_user__username+ "</td><td><a class='btn btn-sm btn-info' onclick=GetToken(this,'"+i.id +"')>Token</a><a href='http://192.168.226.135:4200/' class='btn btn-sm btn-info'')>login</a></td><td ></td></tr>";
            trs += tr

        });
        $("#hostlist").html(trs);



    });//end get
    $(self).addClass("active").siblings().removeClass('active');

}

function GetToken(self,bind_host_id) {
    $.post(
        '{% url "get_token" %}',     //通过url别名渲染url
        {'bind_host_id':bind_host_id,'csrfmiddlewaretoken':"{{ csrf_token }}"},//请求携带的参数
        function (callback) {          //回调函数
            console.log(callback);
            var data = JSON.parse(callback); //django响应的数据
            $(self).parent().next().text(data.token);
        }

        )
}



</script>
{% endblock %}
hostlist.html
import subprocess,random,string,datetime
from django.contrib.auth import authenticate
from django.conf import settings
from audit import models
from audit.backend import ssh_interactive

class UserShell(object):
    """用户登录堡垒机后的shell"""

    def __init__(self,sys_argv):
        self.sys_argv = sys_argv
        self.user = None

    def auth(self):

        count = 0
        while count < 3:
            username = input("username:").strip()
            password = input("password:").strip()
            user = authenticate(username=username,password=password)
            #None 代表认证不成功
            #user object ,认证对象 ,user.name
            if not user:
                count += 1
                print("Invalid username or password!")
            else:
                self.user = user
                return  True
        else:
            print("too many attempts.")

    def token_auth(self):
        count = 0
        while count < 3:
            user_input = input("请输入token:").strip()
            if len(user_input) == 0:
                return
            if len(user_input) != 8:
                print("token length is 8")
            else:
                time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
                token_obj = models.Token.objects.filter(val=user_input, date__gt=time_obj).first()
                if token_obj:
                    if token_obj.val == user_input:  # 口令对上了
                        self.user = token_obj.account.user #进入交互式shll需要用户认证!
                        return token_obj
            count+=1
    def start(self):
        """启动交互程序"""
        token_obj = self.token_auth()
        if token_obj:
            ssh_interactive.ssh_session(token_obj.host_user_bind, self.user)
            exit()
        if self.auth():
            #print(self.user.account.host_user_binds.all()) #select_related()
            while True:
                host_groups = self.user.account.host_groups.all()
                for index,group in enumerate(host_groups):
                    print("%s.\t%s[%s]"%(index,group,group.host_user_binds.count()))
                print("%s.\t未分组机器[%s]"%(len(host_groups),self.user.account.host_user_binds.count()))
                try:
                    choice = input("select group>:").strip()
                    if choice.isdigit():
                        choice = int(choice)
                        host_bind_list = None
                        if choice >=0 and choice < len(host_groups):
                            selected_group = host_groups[choice]
                            host_bind_list = selected_group.host_user_binds.all()
                        elif choice == len(host_groups): #选择的未分组机器
                            #selected_group = self.user.account.host_user_binds.all()
                            host_bind_list = self.user.account.host_user_binds.all()
                        if host_bind_list:
                            while True:
                                for index,host in enumerate(host_bind_list):
                                    print("%s.\t%s"%(index,host,))
                                choice2 = input("select host>:").strip()
                                if choice2.isdigit():
                                    choice2 = int(choice2)
                                    if choice2 >=0 and choice2 < len(host_bind_list):
                                        selected_host = host_bind_list[choice2]

                                        ssh_interactive.ssh_session(selected_host,self.user)


                                        # s = string.ascii_lowercase +string.digits
                                        # random_tag = ''.join(random.sample(s,10))
                                        # session_obj = models.SessionLog.objects.create(account=self.user.account,host_user_bind=selected_host)
                                        #
                                        # cmd = "sshpass -p %s /usr/local/openssh/bin/ssh %s@%s -p %s -o StrictHostKeyChecking=no -Z %s" %(selected_host.host_user.password,selected_host.host_user.username,selected_host.host.ip_addr,selected_host.host.port ,random_tag)
                                        # #start strace ,and sleep 1 random_tag, session_obj.id
                                        # session_tracker_script = "/bin/sh %s %s %s " %(settings.SESSION_TRACKER_SCRIPT,random_tag,session_obj.id)
                                        #
                                        # session_tracker_obj =subprocess.Popen(session_tracker_script, shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
                                        #
                                        # ssh_channel = subprocess.run(cmd,shell=True)
                                        # print(session_tracker_obj.stdout.read(), session_tracker_obj.stderr.read())
                                        #
                                elif choice2 == 'b':
                                    break

                except KeyboardInterrupt as e :
                    pass
user_interactive.py

 

 

 

 

 

 三、通过堡垒机批量执行Linux命令

 

 

 

1.批量执行命令前端页面

{% extends 'index.html' %}



{% block content-container %}
{#    {% csrf_token %}#}
    <div id="page-title">
        <h1 class="page-header text-overflow">主机列表</h1>

        <!--Searchbox-->
        <div class="searchbox">
            <div class="input-group custom-search-form">
                <input type="text" class="form-control" placeholder="Search..">
                <span class="input-group-btn">
                    <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                </span>
            </div>
        </div>
    </div>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End page title-->
        <!--Breadcrumb-->
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <ol class="breadcrumb">
        <li><a href="#">Home</a></li>
        <li><a href="#">Library</a></li>
        <li class="active">主机列表</li>
    </ol>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End breadcrumb-->

    <div id="page-content">
        <div class="panel col-lg-3">
            <div class="panel-heading">
                <h3 class="panel-title">主机组 <span id="selected_hosts"></span></h3>
            </div>
            <div class="panel-body">

                <ul class="list-group" id="host_groups">
                {% for group in  request.user.account.host_groups.all %}

                    <li class="list-group-item " ><span class="badge badge-success">{{ group.host_user_binds.count }}</span>
                        <input type="checkbox" onclick="CheckAll(this)">
                        <a onclick="DisplayHostList(this)">{{ group.name }}</a>  <!--点击组名,组名下的 主机列表通过toggleclass 展示/隐藏 -->
                        <ul class="hide">
                            {% for bind_host in group.host_user_binds.all %}
                                <li><input onclick="ShowCheckedHostCount()" type="checkbox" value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                            {% endfor %}
                        </ul>
                    </li>

                {% endfor %}
                    <li class="list-group-item " > <span class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>
                       <input type="checkbox" onclick="CheckAll(this)">
                        <a onclick="DisplayHostList(this)">未分组主机</a>
                        <ul class="hide">
                            {% for bind_host in request.user.account.host_user_binds.all %}
                                <li><input onclick="ShowCheckedHostCount()" type="checkbox" value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                            {% endfor %}
                        </ul>
                    </li>

                </ul>



            </div>
        </div>

        <div class="col-lg-9">
            <div class="panel">
                <div class="panel-heading">
                    <h3 class="panel-title">命令</h3>
                </div>
                <div class="panel-body">
                    <textarea class="form-control" id="cmd"></textarea>
                    <button onclick="PostTask('cmd')" class="btn btn-info pull-right">执行</button>
                    <button  class="btn btn-danger ">终止</button>

                </div>
            </div>
            <div class="panel">
                <div class="panel-heading">
                    <h3 class="panel-title">任务结果</h3>
                </div>
                <div class="panel-body">

                    <div id="task_result">
                </div>
            </div>
        </div>

        </div>
    </div>


<script>
    function  DisplayHostList(self) {
        $(self).next().toggleClass("hide");
    }

    function CheckAll(self){
        console.log($(self).prop('checked'));
        $(self).parent().find("ul :checkbox").prop('checked',$(self).prop('checked'));

        ShowCheckedHostCount()
    }

    function ShowCheckedHostCount(){
        var selected_host_count = $("#host_groups ul").find(":checked").length;
        console.log(selected_host_count);
        $("#selected_hosts").text(selected_host_count);
        return selected_host_count
    }


{#    function GetTaskResult(task_id) {#}
{#        $.getJSON("{% url 'get_task_result' %}",{'task_id':task_id},function(callback){#}
{##}
{#            console.log(callback);#}
{##}
{#            var result_ele = '';#}
{#            $.each(callback,function (index,i) {#}
{#                var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" +i.host_user_bind__host__ip_addr +") ------" +#}
{#                    i.status + "</p>";#}
{#                var res_ele = "<pre>" + i.result +"</pre>";#}
{##}
{#                var single_result = p_ele + res_ele;#}
{#                result_ele += single_result#}
{#            });#}
{##}
{#            $("#task_result").html(result_ele)#}
{##}
{##}
{#        });//end getJSON#}
{##}
{#    }#}


    function  PostTask(task_type) {
        //1. 验证主机列表已选,命令已输入
        //2. 提交任务到后台
        var selected_host_ids = [];
        var selected_host_eles = $("#host_groups ul").find(":checked");
        $.each(selected_host_eles,function (index,ele) {
            selected_host_ids.push($(ele).val())
        });
        console.log(selected_host_ids);
        if ( selected_host_ids.length == 0){
            alert("主机未选择!");
            return false;
        }
        var cmd_text = $.trim($("#cmd").val());
        if ( cmd_text.length == 0){
            alert("未输入命令!");
            return false;

        }


        var task_data = {
            'task_type':task_type,
            'selected_host_ids': selected_host_ids,
            'cmd': cmd_text
        };

        $.post("{% url 'multitask' %}",{'csrfmiddlewaretoken':"{{ csrf_token }}",'task_data':JSON.stringify(task_data)},
            function(callback){
                    console.log(callback) ;// task id
                    var callback = JSON.parse(callback);

                    GetTaskResult(callback.task_id);
                    var result_timer = setInterval(function () {
                        GetTaskResult(callback.task_id)
                    },2000)


            } );//end post

    }
</script>
{% endblock %}
multi_cmd.html

 

2.前端收集批量执行的主机,通过ajax发送到后台

@login_required
def multitask(request):
    task_obj = task_handler.Task(request)
    respose=HttpResponse(json.dumps(task_obj.errors))
    if task_obj.is_valid():      # 如果验证成功
        result = task_obj.run()  #run()去选择要执行的任务类型,然后通过 getattr()去执行
        respose=HttpResponse(json.dumps({'task_id':result})) #返回数据库pk task_id

    return respose
views.py

 

3.后端通过is_valid方法验证数据的合法性

 

4.验证失败响应前端self.errors信息,验证成功执行run()选择任务类型;

 

5.选择任务类型(cmd/files_transfer)之后初始化数据库(更新Task、TaskLog表数据)

 

6.cmd/files_transfer方法开启新进程(multitask_execute.py)新进程开启进程池 去执行批量命令;

 

7.前端使用定时器不断去后台获取数据;

 

8.程序中断按钮

"""
Django settings for zhanggen_audit project.

Generated by 'django-admin startproject' using Django 1.11.4.

For more information on this file, see
https://docs.djangoproject.com/en/1.11/topics/settings/

For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.11/ref/settings/
"""

import os

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = '5ivlngau4a@_3y4vizrcxnnj(&vz2en#edpq%i&jr%99-xxv)&'

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = ['*']


# Application definition

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'audit.apps.AuditConfig',
]

MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'zhanggen_audit.urls'

TEMPLATES = [
    {
        'BACKEND': 'django.template.backends.django.DjangoTemplates',
        'DIRS': [os.path.join(BASE_DIR,  'templates'),],
        'APP_DIRS': True,
        'OPTIONS': {
            'context_processors': [
                'django.template.context_processors.debug',
                'django.template.context_processors.request',
                'django.contrib.auth.context_processors.auth',
                'django.contrib.messages.context_processors.messages',
            ],
        },
    },
]

WSGI_APPLICATION = 'zhanggen_audit.wsgi.application'


# Database
# https://docs.djangoproject.com/en/1.11/ref/settings/#databases

DATABASES = {
    'default': {
        'ENGINE': 'django.db.backends.sqlite3',
        'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
    }
}


# Password validation
# https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators

AUTH_PASSWORD_VALIDATORS = [
    {
        'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
    },
]


# Internationalization
# https://docs.djangoproject.com/en/1.11/topics/i18n/

LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'Asia/Shanghai'

USE_I18N = True

USE_L10N = True

USE_TZ = True


# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.11/howto/static-files/


STATIC_URL = '/static/'
STATICFILES_DIRS=(
os.path.join(BASE_DIR,'static'),
)


SESSION_TRACKER_SCRIPT=os.path.join(BASE_DIR,'audit%sbackend%ssession_check.sh')%(os.sep,os.sep) 
SESSION_TRACKER_SCRIPT_LOG_PATH=os.path.join(BASE_DIR,'log')#日志路径
MULTI_TASK_SCRIPT = os.path.join(BASE_DIR,'multitask_execute.py') #脚本路径
 
CURRENT_PGID=None #进程的 pgid
settings.py
"""zhanggen_audit URL Configuration

The `urlpatterns` list routes URLs to views. For more information please see:
    https://docs.djangoproject.com/en/1.11/topics/http/urls/
Examples:
Function views
    1. Add an import:  from my_app import views
    2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
Class-based views
    1. Add an import:  from other_app.views import Home
    2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
Including another URLconf
    1. Import the include() function: from django.conf.urls import url, include
    2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url
from django.contrib import admin
from audit import views

urlpatterns = [
    url(r'^admin/', admin.site.urls),
    url(r'^$', views.index ),
    url(r'^login/$', views.acc_login ),
    url(r'^logout/$', views.acc_logout ),
    url(r'^hostlist/$', views.host_list ,name="host_list"),
    url(r'^multitask/$', views.multitask ,name="multitask"),
    url(r'^multitask/result/$', views.multitask_result ,name="get_task_result"),
    url(r'^multitask/cmd/$', views.multi_cmd ,name="multi_cmd"),
    url(r'^multitask/file_transfer/$', views.multi_file_transfer ,name="multi_file_transfer"),
    url(r'^api/hostlist/$', views.get_host_list ,name="get_host_list"),
    url(r'^api/token/$', views.get_token ,name="get_token"),
    url(r'^api/task/file_upload/$', views.task_file_upload ,name="task_file_upload"),
    url(r'^api/task/file_download/$', views.task_file_download ,name="task_file_download"),
    url(r'^end_cmd/$', views.end_cmd,name="end_cmd"),

]
urls.py
from django.shortcuts import render,redirect,HttpResponse
from django.contrib.auth import authenticate,login,logout
from django.contrib.auth.decorators import login_required
from django.views.decorators.csrf import csrf_exempt
from django.conf import settings
import signal

import json,os
from audit import models
import random,string
import datetime
from audit import task_handler
from django import conf
import zipfile
from wsgiref.util import FileWrapper #from django.core.servers.basehttp import FileWrapper

@login_required
def index(request):
    return render(request,'index.html')



def acc_login(request):
    error = ''
    if request.method == "POST":
        username = request.POST.get('username')
        password = request.POST.get('password')
        user = authenticate(username=username,password=password)
        if user:
            login(request, user)
            return  redirect(request.GET.get('next') or  '/')
        else:
            error = "Wrong username or password!"
    return render(request,'login.html',{'error':error })


@login_required
def acc_logout(request):
    logout(request)

    return  redirect('/login/')

@login_required
def host_list(request):

    return render(request,'hostlist.html')


@login_required
def get_host_list(request):
    gid = request.GET.get('gid')
    if gid:
        if gid == '-1':#未分组
            host_list = request.user.account.host_user_binds.all()
        else:
            group_obj = request.user.account.host_groups.get(id=gid)
            host_list = group_obj.host_user_binds.all()

        data = json.dumps(list(host_list.values('id','host__hostname','host__ip_addr','host__idc__name','host__port',
                                'host_user__username')))
        return HttpResponse(data)

@login_required
def get_token(request):
    bind_host_id=request.POST.get('bind_host_id')
    time_obj = datetime.datetime.now() - datetime.timedelta(seconds=300)  # 5mins ago
    exist_token_objs = models.Token.objects.filter(account_id=request.user.account.id,
                                                   host_user_bind_id=bind_host_id,
                                                   date__gt=time_obj)
    if exist_token_objs:  # has token already
        token_data = {'token': exist_token_objs[0].val}
    else:
        token_val=''.join(random.sample(string.ascii_lowercase+string.digits,8))

        token_obj=models.Token.objects.create(
            host_user_bind_id=bind_host_id,
            account=request.user.account,
            val=token_val)
        token_data={"token":token_val}

    return HttpResponse(json.dumps(token_data))



@login_required
def multi_cmd(request):
    """多命令执行页面"""
    return render(request,'multi_cmd.html')


@login_required
def multitask(request):
    task_obj = task_handler.Task(request)
    respose=HttpResponse(json.dumps(task_obj.errors))
    if task_obj.is_valid():      # 如果验证成功
        task_obj = task_obj.run()  #run()去选择要执行的任务类型,然后通过 getattr()去执行
        respose=HttpResponse(json.dumps({'task_id':task_obj.id,'timeout':task_obj.timeout})) #返回数据库pk task_id

    return respose


@login_required
def multitask_result(request):
    """多任务结果"""
    task_id = request.GET.get('task_id')
    # [ {
    #     'task_log_id':23.
    #     'hostname':
    #     'ipaddr'
    #     'username'
    #     'status'
    # } ]


    task_obj = models.Task.objects.get(id=task_id)

    results = list(task_obj.tasklog_set.values('id','status',
                                'host_user_bind__host__hostname',
                                'host_user_bind__host__ip_addr',
                                'result'
                                ))

    return HttpResponse(json.dumps(results))





@login_required
def multi_file_transfer(request):
    random_str = ''.join(random.sample(string.ascii_lowercase + string.digits, 8))
    #return render(request,'multi_file_transfer.html',{'random_str':random_str})
    return render(request,'multi_file_transfer.html',locals())

@login_required
@csrf_exempt
def task_file_upload(request):
    random_str = request.GET.get('random_str')
    upload_to = "%s/%s/%s" %(conf.settings.FILE_UPLOADS,request.user.account.id,random_str)
    if not os.path.isdir(upload_to):
        os.makedirs(upload_to,exist_ok=True)

    file_obj = request.FILES.get('file')
    f = open("%s/%s"%(upload_to,file_obj.name),'wb')
    for chunk in file_obj.chunks():
        f.write(chunk)
    f.close()
    print(file_obj)

    return HttpResponse(json.dumps({'status':0}))




def send_zipfile(request,task_id,file_path):
    """
    Create a ZIP file on disk and transmit it in chunks of 8KB,
    without loading the whole file into memory. A similar approach can
    be used for large dynamic PDF files.
    """
    zip_file_name = 'task_id_%s_files' % task_id
    archive = zipfile.ZipFile(zip_file_name , 'w', zipfile.ZIP_DEFLATED)
    file_list = os.listdir(file_path)
    for filename in file_list:
        archive.write('%s/%s' %(file_path,filename),arcname=filename)
    archive.close()


    wrapper = FileWrapper(open(zip_file_name,'rb'))
    response = HttpResponse(wrapper, content_type='application/zip')
    response['Content-Disposition'] = 'attachment; filename=%s.zip' % zip_file_name
    response['Content-Length'] = os.path.getsize(zip_file_name)
    #temp.seek(0)
    return response

@login_required
def task_file_download(request):
    task_id = request.GET.get('task_id')
    print(task_id)
    task_file_path = "%s/%s"%( conf.settings.FILE_DOWNLOADS,task_id)
    return send_zipfile(request,task_id,task_file_path)


def end_cmd(request):
    current_task_pgid=settings.CURRENT_PGID
    os.killpg(current_task_pgid,signal.SIGKILL)
    return HttpResponse(current_task_pgid)
views.py
import json,subprocess,os,signal
from audit import models
from django.conf import settings
from django.db.transaction import atomic
class Task(object):
    '''  '''
    def __init__(self,request):
        self.request=request
        self.errors=[]
        self.task_data=None

    def is_valid(self):
        task_data=self.request.POST.get('task_data')#{"task_type":"cmd","selected_host_ids":["1","2"],"cmd":"DF"}
        if task_data:
            self.task_data=json.loads(task_data)
            self.task_type=self.task_data.get('task_type')
            if self.task_type == 'cmd':
                selected_host_ids=self.task_data.get('selected_host_ids')
                if selected_host_ids:
                    return True
                self.errors.append({'invalid_argument': '命令/主机不存在'})

            elif self.task_type == 'files_transfer':
                selected_host_ids =self.task_data.get('selected_host_ids')
                pass
                #验证文件路径


            else:
                self.errors.append({'invalid_argument': '不支持的任务类型!'})
        self.errors.append({'invalid_data': 'task_data不存在!'})

    def run(self):
        task_func = getattr(self, self.task_data.get('task_type'))  #
        task_obj = task_func() #调用执行命令
        print(task_obj.pk)  # 100 #这里是任务id是自增的
        return task_obj


    @atomic #事物操作 任务信息和 子任务都要同时创建完成!
    def cmd(self):
        task_obj=models.Task.objects.create(
            task_type=0,
            account=self.request.user.account,
            content=self.task_data.get('cmd'),
        ) #1.增加批量任务信息,并返回批量任务信息的 pk


        tasklog_objs=[] #2.增加子任务信息(初始化数据库)
        host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
        for host_id in host_ids:
            tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                               host_user_bind_id=host_id,
                               status = 3))
        models.TaskLog.objects.bulk_create(tasklog_objs,100)  # 没100条记录 commit 1次!

        task_id=task_obj.pk
        cmd_str = "python3 %s %s" % (settings.MULTI_TASK_SCRIPT,task_id)  # 执行multitask.py脚本路径
        print('------------------>',cmd_str)
        multitask_obj = subprocess.Popen(cmd_str,stdout=subprocess.PIPE,shell=True,stderr=subprocess.PIPE) #新打开1个新进程
        settings.CURRENT_PGID=os.getpgid(multitask_obj.pid) #os.getpgid(multitask_obj.pid)

        # os.killpg(pgid=pgid,sig=signal.SIGKILL)

        # print(multitask_obj.stderr.read().decode('utf-8') or multitask_obj.stdout.read().decode('utf-8'))
        #print("task result :",multitask_obj.stdout.read().decode('utf-8'),multitask_obj.stderr.read().decode('utf-8'))
        # print(multitask_obj.stdout.read())

        # for host_id in self.task_data.get('selected_host_ids'):
        #     t=Thread(target=self.run_cmd,args=(host_id,self.task_data.get('cmd')))
        #     t.start()

        return task_obj

    def run_cmd(self,host_id,cmd):
        pass

    def files_transfer(self):
        pass
task_handler.py
{% extends 'index.html' %}



{% block content-container %}
    {#    {% csrf_token %}#}
    <div id="page-title">
        <h1 class="page-header text-overflow">主机列表</h1>

        <!--Searchbox-->
        <div class="searchbox">
            <div class="input-group custom-search-form">
                <input type="text" class="form-control" placeholder="Search..">
                <span class="input-group-btn">
                    <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                </span>
            </div>
        </div>
    </div>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End page title-->
    <!--Breadcrumb-->
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <ol class="breadcrumb">
        <li><a href="#">Home</a></li>
        <li><a href="#">Library</a></li>
        <li class="active">主机列表</li>
    </ol>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End breadcrumb-->

    <div id="page-content">
        <div class="panel col-lg-3">
            <div class="panel-heading">
                <h3 class="panel-title">主机组 <span id="selected_hosts"></span></h3>
            </div>
            <div class="panel-body">

                <ul class="list-group" id="host_groups">
                    {% for group in  request.user.account.host_groups.all %}

                        <li class="list-group-item "><span
                                class="badge badge-success">{{ group.host_user_binds.count }}</span>
                            <input type="checkbox" onclick="CheckAll(this)">
                            <a onclick="DisplayHostList(this)">{{ group.name }}</a>
                            <!--点击组名,组名下的 主机列表通过toggleclass 展示/隐藏 -->
                            <ul class="hide">
                                {% for bind_host in group.host_user_binds.all %}
                                    <li><input onclick="ShowCheckedHostCount()" type="checkbox"
                                               value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                                {% endfor %}
                            </ul>
                        </li>

                    {% endfor %}
                    <li class="list-group-item "><span
                            class="badge badge-success">{{ request.user.account.host_user_binds.count }}</span>
                        <input type="checkbox" onclick="CheckAll(this)">
                        <a onclick="DisplayHostList(this)">未分组主机</a>
                        <ul class="hide">
                            {% for bind_host in request.user.account.host_user_binds.all %}
                                <li><input onclick="ShowCheckedHostCount()" type="checkbox"
                                           value="{{ bind_host.id }}">{{ bind_host.host.ip_addr }}</li>
                            {% endfor %}
                        </ul>
                    </li>

                </ul>


            </div>
        </div>

        <div class="col-lg-9">
            <div class="panel">
                <div class="panel-heading">
                    <h3 class="panel-title">命令</h3>
                </div>
                <div class="panel-body">
                    <textarea class="form-control" id="cmd"></textarea>
                    <button onclick="PostTask('cmd')" class="btn btn-info pull-right">执行</button>
                    <button class="btn btn-danger" onclick="End()">终止</button>

                </div>

            </div>

            <div id="task_result_panel" class="panel">
                <div class="panel-heading">
                    <h3 class="panel-title">任务结果</h3>
                </div>
                <div class="panel-body">
                    <div class="progress">
                        <div id='task_progress' style="width: 0%;" class="progress-bar progress-bar-info"></div>
                    </div>
                    <div id="task_result"></div>

                </div>
            </div>

        </div>


        <script>
            function DisplayHostList(self) {
                $(self).next().toggleClass("hide");
            }

            function CheckAll(self) {
                console.log($(self).prop('checked'));
                $(self).parent().find("ul :checkbox").prop('checked', $(self).prop('checked'));

                ShowCheckedHostCount()
            }

            function ShowCheckedHostCount() {
                var selected_host_count = $("#host_groups ul").find(":checked").length;
                console.log(selected_host_count);
                $("#selected_hosts").text(selected_host_count);
                return selected_host_count
            }


            function GetTaskResult(task_id, task_timeout) {
                $.getJSON("{% url 'get_task_result' %}", {'task_id': task_id}, function (callback) {
                        console.log(callback);
                        var result_ele = '';
                        var all_task_finished = true;   //全部完成flag
                        var finished_task_count = 0;   //已完成的任务数量
                        $.each(callback, function (index, i) {
                            var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" + i.host_user_bind__host__ip_addr + ") ------" +
                                i.status + "</p>";
                            var res_ele = "<pre>" + i.result + "</pre>"; //<pre> 标签按后端格式显示数据

                            var single_result = p_ele + res_ele;
                            result_ele += single_result;

                            if (i.status == 3) {
                                all_task_finished = false;
                            } else {
                                //task not finished yet
                                finished_task_count += 1;

                            }

                        });

                        if (task_timeout_counter < task_timeout) {
                            task_timeout_counter += 2;
                        }
                        else {
                            all_task_finished = true
                        }
                        if (all_task_finished) {   //完成!

                            clearInterval(result_timer);
                             var unexecuted =callback.length-finished_task_count;
                            $.niftyNoty({   //提示超时
                                type: 'danger',
                                container: '#task_result_panel',
                                html: '<h4 id="Prompt">'+'执行:'+callback.length +'  '+ '完成:'+finished_task_count+'  '+'失败:'+ unexecuted +'</h4>',
                                closeBtn: false
                            });
                            console.log("timmer canceled....")
                        }
                        $("#task_result").html(result_ele);

                        var total_finished_percent = parseInt(finished_task_count / callback.length * 100);
                        $("#task_progress").text(total_finished_percent + "%");
                        $("#task_progress").css("width", total_finished_percent + "%");


                    }
                )
                ;//end getJSON

            }


            function PostTask(task_type) {
                //1. 验证主机列表已选,命令已输入
                //2. 提交任务到后台
                $('.alert').remove();
                var selected_host_ids = [];
                var selected_host_eles = $("#host_groups ul").find(":checked");
                $.each(selected_host_eles, function (index, ele) {
                    selected_host_ids.push($(ele).val())
                });
                console.log(selected_host_ids);
                if (selected_host_ids.length == 0) {
                    alert("主机未选择!");
                    return false;
                }
                var cmd_text = $.trim($("#cmd").val());
                if (cmd_text.length == 0) {
                    alert("未输入命令!");
                    return false;

                }


                var task_data = {
                    'task_type': task_type,
                    'selected_host_ids': selected_host_ids,
                    'cmd': cmd_text
                };

                $.post("{% url 'multitask' %}", {
                        'csrfmiddlewaretoken': "{{ csrf_token }}",
                        'task_data': JSON.stringify(task_data)
                    },
                    function (callback) {
                        console.log(callback);// task id
                        var callback = JSON.parse(callback);

                        task_timeout_counter = 0;// add 2 during each call of GetTaskResult

                        GetTaskResult(callback.task_id, callback.timeout); //那批量任务ID 去获取子任务的进展!那超时时间做对比

                        result_timer = setInterval(function () {
                            GetTaskResult(callback.task_id, callback.timeout)
                        }, 2000);

                        //diplay download file btn
                        $("#file-download-btn").removeClass("hide").attr('href', "{% url 'task_file_download' %}?task_id=" + callback.task_id);


                    });//end post


            }

            function End(){
                 $.getJSON("{% url 'end_cmd' %}", function (callback) {
                     console.log(callback)
                 })
            }
        </script>
{% endblock %}
multi_cmd.html
import time
import sys,os
import multiprocessing
import paramiko

def cmd_run(tasklog_id,cmd_str):
    try:
        import django
        django.setup()
        from audit import models
        tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
        print(tasklog_obj, cmd_str)
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(tasklog_obj.host_user_bind.host.ip_addr,
                    tasklog_obj.host_user_bind.host.port,
                    tasklog_obj.host_user_bind.host_user.username,
                    tasklog_obj.host_user_bind.host_user.password,
                    timeout=15) #配置超时时间15秒!
        stdin, stdout, stderr = ssh.exec_command(cmd_str)
        result = stdout.read() + stderr.read()
        print('---------%s--------' % tasklog_obj.host_user_bind)
        print(result)
        ssh.close()
        tasklog_obj.result = result or 'cmd has no result output .'#如果没有 返回结果 /出现错误
        tasklog_obj.status = 0
        tasklog_obj.save()
    except Exception as e:
        print(e)

def file_transfer(bind_host_obj):
    pass


if __name__ == '__main__':
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    sys.path.append(BASE_DIR)
    os.environ.setdefault("DJANGO_SETTINGS_MODULE", "zhanggen_audit.settings")
    import django
    django.setup()

    from audit import models
    task_id = sys.argv[1]
    from audit import models
    task_id=int(sys.argv[1])
    # 1. 根据Taskid拿到任务对象,
    # 2. 拿到任务关联的所有主机
    # 3.  根据任务类型调用多进程 执行不同的方法
    # 4 . 每个子任务执行完毕后,自己把 子任务执行结果 写入数据库 TaskLog表
    task_obj = models.Task.objects.get(id=task_id)
    pool=multiprocessing.Pool(processes=10) #开启 1个拥有10个进程的进程池


    if task_obj.task_type == 0:
        task_func=cmd_run
    else:
        task_func =file_transfer

    for task_log in task_obj.tasklog_set.all(): #查询子任务信息,并更新子任务,进入执行阶段!
        pool.apply_async(task_func,args=(task_log.pk,task_obj.content)) #开启子进程,把子任务信息的pk、和 批量任务的命令传进去!

    pool.close()
    pool.join()
multitask_execute.py

 

 四、通过堡垒机批量上传和下载文件

 1.上传本地文件至多台服务器(批量上传)

 

每次访问批量上传页面上传唯一字符串

使用filedropzone组件做批量上传ul,并限制文件大小、个数,文件提交后端时携带 唯一字符串

后端生成   /固定上传路径/用户ID/唯一字符串/文件的路径,并写入文件;(filedropzone组件把文件拖拽进去之后,自动上传)

前端点击执行 验证堡垒机上的用户上传路径是否合法,然后开启多进程 分别通过paramiko去发送至远程服务的路径

 

"""
Django settings for zhanggen_audit project.

Generated by 'django-admin startproject' using Django 1.11.4.

For more information on this file, see
https://docs.djangoproject.com/en/1.11/topics/settings/

For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.11/ref/settings/
"""

import os

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = '5ivlngau4a@_3y4vizrcxnnj(&vz2en#edpq%i&jr%99-xxv)&'

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = ['*']


# Application definition

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    'audit.apps.AuditConfig',
]

MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'zhanggen_audit.urls'

TEMPLATES = [
    {
        'BACKEND': 'django.template.backends.django.DjangoTemplates',
        'DIRS': [os.path.join(BASE_DIR,  'templates'),],
        'APP_DIRS': True,
        'OPTIONS': {
            'context_processors': [
                'django.template.context_processors.debug',
                'django.template.context_processors.request',
                'django.contrib.auth.context_processors.auth',
                'django.contrib.messages.context_processors.messages',
            ],
        },
    },
]

WSGI_APPLICATION = 'zhanggen_audit.wsgi.application'


# Database
# https://docs.djangoproject.com/en/1.11/ref/settings/#databases

DATABASES = {
    'default': {
        'ENGINE': 'django.db.backends.sqlite3',
        'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
    }
}


# Password validation
# https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators

AUTH_PASSWORD_VALIDATORS = [
    {
        'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
    },
    {
        'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
    },
]


# Internationalization
# https://docs.djangoproject.com/en/1.11/topics/i18n/

LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'Asia/Shanghai'

USE_I18N = True

USE_L10N = True

USE_TZ = True


# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.11/howto/static-files/


STATIC_URL = '/static/'
STATICFILES_DIRS=(
os.path.join(BASE_DIR,'static'),
)


SESSION_TRACKER_SCRIPT=os.path.join(BASE_DIR,'audit%sbackend%ssession_check.sh')%(os.sep,os.sep)
SESSION_TRACKER_SCRIPT_LOG_PATH=os.path.join(BASE_DIR,'log')#日志路径
MULTI_TASK_SCRIPT = os.path.join(BASE_DIR,'multitask_execute.py') #脚本路径

CURRENT_PGID=None #进程的 pgid
FILE_UPLOADS = os.path.join(BASE_DIR,'uploads')     #上传文件的堡垒机路径
FILE_DOWNLOADS = os.path.join(BASE_DIR,'downloads') #下载文件的堡垒机路径
配置堡垒机上传和下载文件的路径
<script>
    function  DisplayHostList(self) {
        $(self).next().toggleClass("hide");
    }

    function CheckAll(self){
        console.log($(self).prop('checked'));
        $(self).parent().find("ul :checkbox").prop('checked',$(self).prop('checked'));

        ShowCheckedHostCount()
    }

    function ShowCheckedHostCount(){
        var selected_host_count = $("#host_groups ul").find(":checked").length
        console.log(selected_host_count);
        $("#selected_hosts").text(selected_host_count);
        return selected_host_count
    }


    function GetTaskResult(task_id,task_timeout) {
        $.getJSON("{% url 'get_task_result' %}",{'task_id':task_id},function(callback){

            console.log(callback)

            var result_ele = ''
            var all_task_finished = true
            var finished_task_count = 0 ;
            $.each(callback,function (index,i) {
                var p_ele = "<p>" + i.host_user_bind__host__hostname + "(" +i.host_user_bind__host__ip_addr +") ------" +
                    i.status + "</p>";
                var res_ele = "<pre>" + i.result +"</pre>";

                var single_result = p_ele + res_ele;
                result_ele += single_result;

                //check if ths sub task is finished.
                if ( i.status == 3){
                    all_task_finished = false;
                }else {
                    //task not finished yet
                    finished_task_count += 1;
                }

            });//end each
            //check if the task_timer_count < task_timeout, otherwise it means the task is timedout, setInterval function need to be cancelled
            if (task_timeout_counter < task_timeout){
                // not timed out yet
                task_timeout_counter += 2;

            }else {
                all_task_finished = true; // set all task to be finished ,because it 's already reached the global timeout

                $.niftyNoty({
                    type: 'danger',
                    container : '#task_result_panel',
                    html : '<h4 class="alert-title">Task timed out!</h4><p class="alert-message">The task has timed out!</p><div class="mar-top"><button type="button" class="btn btn-info" data-dismiss="noty">Close this notification</button></div>',
                    closeBtn : false
                });
            }

            if ( all_task_finished){
                clearInterval(result_timer);
                console.log("timmer canceled....")
            }


            $("#task_result").html(result_ele);
            // set progress bar
            var total_finished_percent = parseInt(finished_task_count / callback.length * 100 );
            $("#task_progress").text(total_finished_percent+"%");
            $("#task_progress").css("width",total_finished_percent +"%");
        });//end getJSON

    }


    function  PostTask(task_type) {
        //1. 验证主机列表已选,命令已输入
        //2. 提交任务到后台
        var selected_host_ids = [];
        var selected_host_eles = $("#host_groups ul").find(":checked")
        $.each(selected_host_eles,function (index,ele) {
            selected_host_ids.push($(ele).val())
        });
        console.log(selected_host_ids)
        if ( selected_host_ids.length == 0){
            alert("主机未选择!")
            return false
        }

        if ( task_type == 'cmd'){
            var cmd_text = $.trim($("#cmd").val())
            if ( cmd_text.length == 0){
                alert("未输入命令!")
                return false

            }
        }else {
            //file_transfer
            var remote_path = $("#remote_path").val();
            if ($.trim(remote_path).length == 0){
                alert("必须输入1个远程路径")
                return false
            }
        }



        var task_data = {
            'task_type':task_type,
            'selected_host_ids': selected_host_ids,
            //'cmd': cmd_text
        };
        if ( task_type == 'cmd'){
            task_data['cmd'] =  cmd_text

        }else {

            var file_transfer_type = $("select[name='transfer-type']").val();
            task_data['file_transfer_type'] = file_transfer_type;
            task_data['random_str'] = "{{ random_str }}";
            task_data['remote_path'] = $("#remote_path").val();


        }


        $.post("{% url 'multitask' %}",{'csrfmiddlewaretoken':"{{ csrf_token }}",'task_data':JSON.stringify(task_data)},
            function(callback){
                    console.log(callback) ;// task id
                    var callback = JSON.parse(callback);

                    GetTaskResult(callback.task_id,callback.timeout);
                    task_timeout_counter = 0; // add 2 during each call of GetTaskResult
                    result_timer = setInterval(function () {
                        GetTaskResult(callback.task_id,callback.timeout)
                    },2000);

                    //diplay download file btn
                    $("#file-download-btn").removeClass("hide").attr('href', "{% url 'task_file_download' %}?task_id="+callback.task_id);


            } );//end post

    }
</script>
multi_file_transfer.html
{% extends 'index.html' %}
{% block extra-css %}
    <link href="/static/plugins/dropzone/dropzone.css" rel="stylesheet">
    <script src="/static/plugins/dropzone/dropzone.js"></script>
{% endblock %}


{% block content-container %}
    <div id="page-title">
        <h1 class="page-header text-overflow">主机列表</h1>

        <!--Searchbox-->
        <div class="searchbox">
            <div class="input-group custom-search-form">
                <input type="text" class="form-control" placeholder="Search..">
                <span class="input-group-btn">
                    <button class="text-muted" type="button"><i class="pli-magnifi-glass"></i></button>
                </span>
            </div>
        </div>
    </div>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End page title-->
    <!--Breadcrumb-->
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <ol class="breadcrumb">
        <li><a href="#">Home</a></li>
        <li><a href="#">Library</a></li>
        <li class="active">主机列表</li>
    </ol>
    <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
    <!--End breadcrumb-->
    <div id="page-content">
        {% include 'components/hostgroups.html' %}
        <div class="col-lg-9">
            <div class="panel">
                <div class="panel-heading">
                    <h3 class="panel-title">文件传输</h3>
                </div>
                <div class="panel-body">
                    <select name="transfer-type" onchange="ToggleUploadEle(this)">
                        <option value="send">发送文件到远程主机</option>
                        <option value="get">from远程主机下载文件</option>
                    </select>


                    <form id="filedropzone" class="dropzone">

                    </form>
                    {#                    <input type="hidden" value="{{ random_str }}" name="random_str">#}
                    <input id="remote_path" class="form-control" type="text" placeholder="远程路径">

                    <button id="file_count" onclick="PostTask('file_transfer')" class="btn btn-info pull-right">执行</button>
                    <button class="btn btn-danger ">终止</button>
                    <a id="file-download-btn" class="btn btn-info hide" href="">下载任务文件到本地</a>


                </div>
            </div>
            {% include 'components/taskresult.html' %}
        </div>

    </div>
    </div>

    {% include 'components/multitask_js.html' %}
    <script>

        $('#filedropzone').dropzone({
            url: "{% url 'task_file_upload' %}?random_str={{ random_str }}", //必须填写
            method: "post",  //也可用put
            maxFiles: 10,//一次性上传的文件数量上限
            maxFilesize: 2, //MB
            //acceptedFiles: ".jpg,.gif,.png"//限制上传的类型
            dictMaxFilesExceeded: "您最多只能上传10个文件!",
            dictFileTooBig: "文件过大上传文件最大支持."
            /*
            init: function () {
                this.on("success", function (file) { //文件上传成功触发事件
                    $('#file_count').attr('file_count')
                });
            }
            */

        });
        Dropzone.autoDiscover = false;


        function ToggleUploadEle(self) {

            console.log($(self).val());
            if ($(self).val() == 'get') {
                $(self).next().addClass("hide")
            } else {
                $(self).next().removeClass('hide')
            }

        }

    </script>

{% endblock %}
multi_file_transfer.html
from django.conf.urls import url
from django.contrib import admin
from audit import views

urlpatterns = [
    url(r'^admin/', admin.site.urls),
    url(r'^$', views.index ),
    url(r'^login/$', views.acc_login ),
    url(r'^logout/$', views.acc_logout ),
    url(r'^hostlist/$', views.host_list ,name="host_list"),
    url(r'^multitask/$', views.multitask ,name="multitask"),
    url(r'^multitask/result/$', views.multitask_result ,name="get_task_result"),
    url(r'^multitask/cmd/$', views.multi_cmd ,name="multi_cmd"),
    url(r'^api/hostlist/$', views.get_host_list ,name="get_host_list"),
    url(r'^api/token/$', views.get_token ,name="get_token"),
    url(r'^multitask/file_transfer/$', views.multi_file_transfer, name="multi_file_transfer"),
    url(r'^api/task/file_upload/$', views.task_file_upload ,name="task_file_upload"),
    url(r'^api/task/file_download/$', views.task_file_download ,name="task_file_download"),
    url(r'^end_cmd/$', views.end_cmd,name="end_cmd"),

]
urls.py
import json,subprocess,os,signal
from audit import models
from django.conf import settings
from django.db.transaction import atomic
class Task(object):
    '''  '''
    def __init__(self,request):
        self.request=request
        self.errors=[]
        self.task_data=None

    def is_valid(self):
        task_data=self.request.POST.get('task_data')#{"task_type":"cmd","selected_host_ids":["1","2"],"cmd":"DF"}
        if task_data:
            self.task_data=json.loads(task_data)
            self.task_type=self.task_data.get('task_type')
            if self.task_type == 'cmd':
                selected_host_ids=self.task_data.get('selected_host_ids')
                if selected_host_ids:
                    return True
                self.errors.append({'invalid_argument': '命令/主机不存在'})

            elif self.task_type == 'file_transfer': #
                selected_host_ids =self.task_data.get('selected_host_ids')
                self.task_type = self.task_data.get('task_type')
                #验证文件路径
                user_id=models.Account.objects.filter(user=self.request.user).first().pk
                random_str=self.task_data.get('random_str')
                file_path=settings.FILE_UPLOADS+os.sep+str(user_id)+os.sep+random_str
                if os.path.isdir(file_path):
                    return True
                if not os.path.isdir(file_path):
                    self.errors.append({'invalid_argument': '上传路径失败,请重新上传'})
                if not selected_host_ids:
                    self.errors.append({'invalid_argument': '远程主机不存在'})



            else:
                self.errors.append({'invalid_argument': '不支持的任务类型!'})
        self.errors.append({'invalid_data': 'task_data不存在!'})

    def run(self):
        task_func = getattr(self, self.task_data.get('task_type'))  #
        task_obj = task_func() #调用执行命令
        #print(task_obj.pk)  # 100 #这里是任务id是自增的
        return task_obj


    @atomic #事物操作 任务信息和 子任务都要同时创建完成!
    def cmd(self):
        task_obj=models.Task.objects.create(
            task_type=0,
            account=self.request.user.account,
            content=self.task_data.get('cmd'),
        ) #1.增加批量任务信息,并返回批量任务信息的 pk


        tasklog_objs=[] #2.增加子任务信息(初始化数据库)
        host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
        for host_id in host_ids:
            tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                               host_user_bind_id=host_id,
                               status = 3))
        models.TaskLog.objects.bulk_create(tasklog_objs,100)  # 没100条记录 commit 1次!

        task_id=task_obj.pk
        cmd_str = "python %s %s" % (settings.MULTI_TASK_SCRIPT,task_id)  # 执行multitask.py脚本路径
        print('------------------>',cmd_str)
        multitask_obj = subprocess.Popen(cmd_str,shell=True,stdout=subprocess.PIPE,stderr=subprocess.PIPE) #新打开1个新进程
        #settings.CURRENT_PGID=os.getpgid(multitask_obj.pid) #os.getpgid(multitask_obj.pid)

        # os.killpg(pgid=pgid,sig=signal.SIGKILL)

        # print(multitask_obj.stderr.read().decode('utf-8') or multitask_obj.stdout.read().decode('utf-8'))
        #print("task result :",multitask_obj.stdout.read().decode('utf-8'),multitask_obj.stderr.read().decode('utf-8'))
        # print(multitask_obj.stdout.read())

        # for host_id in self.task_data.get('selected_host_ids'):
        #     t=Thread(target=self.run_cmd,args=(host_id,self.task_data.get('cmd')))
        #     t.start()

        return task_obj

    @atomic  # 事物操作 任务信息和 子任务都要同时创建完成!
    def file_transfer(self):
        print(self.task_data) #{'task_type': 'file_transfer', 'selected_host_ids': ['3'], 'file_transfer_type': 'send', 'random_str': 'iuon9bhm', 'remote_path': '/'}
        task_obj = models.Task.objects.create(
            task_type=1,
            account=self.request.user.account,
            content=json.dumps(self.task_data),
        )  # 1.增加批量任务信息,并返回批量任务信息的 pk

        tasklog_objs = []  # 2.增加子任务信息(初始化数据库)
        host_ids = set(self.task_data.get("selected_host_ids"))  # 获取选中的主机id,并用集合去重
        for host_id in host_ids:
            tasklog_objs.append(models.TaskLog(task_id=task_obj.id,
                                               host_user_bind_id=host_id,
                                               status=3))
        models.TaskLog.objects.bulk_create(tasklog_objs, 100)  # 没100条记录 commit 1次!

        task_id = task_obj.pk
        cmd_str = "python %s %s" % (settings.MULTI_TASK_SCRIPT, task_id)  # 执行multitask.py脚本路径
        print('------------------>', cmd_str)
        multitask_obj = subprocess.Popen(cmd_str, shell=True, stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE)  # 新打开1个新进程

        return task_obj
task_handler.py
import time,json
import sys,os
import multiprocessing
import paramiko

def cmd_run(tasklog_id,task_obj_id,cmd_str,):
    try:
        import django
        django.setup()
        from audit import models
        tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
        print(tasklog_obj, cmd_str)
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(tasklog_obj.host_user_bind.host.ip_addr,
                    tasklog_obj.host_user_bind.host.port,
                    tasklog_obj.host_user_bind.host_user.username,
                    tasklog_obj.host_user_bind.host_user.password,
                    timeout=15) #配置超时时间15秒!
        stdin, stdout, stderr = ssh.exec_command(cmd_str)
        result = stdout.read() + stderr.read()
        print('---------%s--------' % tasklog_obj.host_user_bind)
        print(result)
        ssh.close()
        #修改子任务数据库结果
        tasklog_obj.result = result or 'cmd has no result output .'#如果没有 返回结果 /出现错误
        tasklog_obj.status = 0
        tasklog_obj.save()
    except Exception as e:
        print(e)

def file_transfer(tasklog_id,task_id,task_content):
    import django
    django.setup()
    from django.conf import settings
    from audit import models
    tasklog_obj = models.TaskLog.objects.get(id=tasklog_id)
    try:
        print('task contnt:', tasklog_obj)
        task_data = json.loads(tasklog_obj.task.content)
        t = paramiko.Transport((tasklog_obj.host_user_bind.host.ip_addr, tasklog_obj.host_user_bind.host.port))
        t.connect(username=tasklog_obj.host_user_bind.host_user.username, password=tasklog_obj.host_user_bind.host_user.password,)
        sftp = paramiko.SFTPClient.from_transport(t)

        if task_data.get('file_transfer_type') =='send':
            local_path = "%s/%s/%s" %( settings.FILE_UPLOADS,
                                       tasklog_obj.task.account.id,
                                       task_data.get('random_str'))
            print("local path",local_path)
            for file_name in os.listdir(local_path):
                sftp.put('%s/%s' %(local_path,file_name), '%s/%s'%(task_data.get('remote_path'), file_name))
            tasklog_obj.result = "send all files done..."

        else:
            # 循环到所有的机器上的指定目录下载文件
            download_dir = "{download_base_dir}/{task_id}".format(download_base_dir=settings.FILE_DOWNLOADS,
                                                                  task_id=task_id)
            if not os.path.exists(download_dir):
                os.makedirs(download_dir,exist_ok=True)

            remote_filename = os.path.basename(task_data.get('remote_path'))
            local_path = "%s/%s.%s" %(download_dir,tasklog_obj.host_user_bind.host.ip_addr,remote_filename)
            sftp.get(task_data.get('remote_path'),local_path )
            #remote path  /tmp/test.py
            tasklog_obj.result = 'get remote file [%s] to local done' %(task_data.get('remote_path'))
        t.close()

        tasklog_obj.status = 0
        tasklog_obj.save()
        # ssh = paramiko.SSHClient()
        # ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())

    except Exception as e:
        print("error :",e )
        tasklog_obj.result = str(e)
        tasklog_obj.save()




if __name__ == '__main__':
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    sys.path.append(BASE_DIR)
    os.environ.setdefault("DJANGO_SETTINGS_MODULE", "zhanggen_audit.settings")
    import django
    django.setup()

    from audit import models
    task_id = sys.argv[1]
    from audit import models
    task_id=int(sys.argv[1])
    # 1. 根据Taskid拿到任务对象,
    # 2. 拿到任务关联的所有主机
    # 3.  根据任务类型调用多进程 执行不同的方法
    # 4 . 每个子任务执行完毕后,自己把 子任务执行结果 写入数据库 TaskLog表
    task_obj = models.Task.objects.get(id=task_id)

    pool=multiprocessing.Pool(processes=10) #开启 1个拥有10个进程的进程池

    if task_obj.task_type == 0:
        task_func=cmd_run
    else:
        task_func =file_transfer

    for task_log in task_obj.tasklog_set.all(): #查询子任务信息,并更新子任务,进入执行阶段!
        pool.apply_async(task_func,args=(task_log.id,task_obj.id,task_obj.content)) #开启子进程,把子任务信息的pk、和 批量任务的命令传进去!

    pool.close()
    pool.join()
multitask_execute.py

 

 2.从多台服务器上get文件至本地(批量下载)

 

 

用户输入远程服务器文件路径,堡垒机生成本地下载路径( /下载文件路径/task_id/ip.远程文件名)

开启多进程 通过paramiko下载远程主机的文件 到堡垒机下载路径;

任务执行完毕前端弹出 下载文件到本地按钮 (携带?批量任务ID)

用户点击下载文件到本地 a标签,后端获取当前批量任务的ID,把当前批量任务下载的files,打包返回给用户浏览器!

 

def send_zipfile(request,task_id,file_path):

    zip_file_name = 'task_id_%s_files' % task_id
    archive = zipfile.ZipFile(zip_file_name , 'w', zipfile.ZIP_DEFLATED) #创建1个zip 包

    file_list = os.listdir(file_path) #找到堡垒机目录下 所有文件

    for filename in file_list:      #把所有文件写入 zip包中!
        archive.write('%s/%s' %(file_path,filename),arcname=filename)
    archive.close()
    #-------------------------------------------------------------- #文件打包完毕!

    wrapper = FileWrapper(open(zip_file_name,'rb')) #在内存中打开 打包好的压缩包

    response = HttpResponse(wrapper, content_type='application/zip') #修改Django的response的content_type
    response['Content-Disposition'] = 'attachment; filename=%s.zip' % zip_file_name #告诉流量器以 附件形式下载
    response['Content-Length'] = os.path.getsize(zip_file_name)               #文件大小
    #temp.seek(0)
    return response






@login_required
def task_file_download(request): #下载文件到本地
    task_id = request.GET.get('task_id')
    print(task_id)
    task_file_path = "%s/%s"%( conf.settings.FILE_DOWNLOADS,task_id)
    download_files=os.listdir(task_file_path)
    print(download_files)
    return send_zipfile(request,task_id,task_file_path) #调用打包函数
Django响应压缩文件

 

Python调用Salt的API

from public import RecordLoggre
import requests

# 使用requests请求https出现警告,做的设置
from requests.packages.urllib3.exceptions import InsecureRequestWarning

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)


class SaltApi:
    """
    定义salt api接口的类
    初始化获得token
    """

    def __init__(self):
        self.url = 'https://10.65.0.46:8000'
        self.username = 'salt'
        self.password = 'bFmtHdxSjr2Pm6nd'

        self.headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
                          "Chrome/50.0.2661.102 Safari/537.36",
            "Content-type": "application/json"
        }
        self.record_log = RecordLoggre()
        self.get_token()

    def get_token(self):
        login_url = f'{self.url}/login'
        login_params = {'username': self.username, 'password': self.password, 'eauth': 'pam'}
        respose = self.get_data(login_url, login_params)
        token = respose.get('return', [dict()])[0].get('token')
        if token:
            self.headers['X-Auth-Token'] = token
        else:
            raise Exception('获取salt token失败')

    def get_data(self, url, params):
        """获取请求数据"""
        try:
            request = requests.post(url, json=params, headers=self.headers, verify=False)
            response = request.json()
        except Exception as Error:
            self.record_log.error(str(Error), exc_info=True)
            response = dict()
        return response

    def salt_command(self, tgt, method, **kwargs):
        """远程执行命令,相当于salt 'client1' cmd.run 'free -m'"""
        params = {'client': 'local', 'fun': method, 'tgt': tgt}
        params.update(kwargs)
        result = self.get_data(self.url, params)
        return result

if __name__ == '__main__':
    # arg = [f'runas={self.user_name} ', command]
    # result = self.salt_api.salt_command(self.host_name, 'cmd.run', arg=arg)
    obj=SaltApi()
    args=['runas=work ', 'ls /']
    res=obj.salt_command(tgt="vm-qa-chatgpt-service001.tx-ap-singapore.apus.com", method='cmd.run', arg=args)
    print(res)
SaltApi.py

 

3.架构描述

 

 

 

当前架构缺陷:multitask在堡垒机上开多进程,随着用户量的增长,开启的进程数量也会越多;

未来设想:在Django 和 multitask之间增加队列,实现用户大并发!

 

 

 

GitHub:https://github.com/zhanggen3714/zhanggen_audit

GateOne安装

                                         

 

posted on 2018-07-08 15:23  Martin8866  阅读(4410)  评论(0编辑  收藏  举报