# -*- coding: utf-8 -*-
# !/usr/bin/env python
# Software: PyCharm
# __author__ == "YU HAIPENG"
# fileName: sshproxy.py
# Month: 九月
# time: 2020/9/26 20:43
# noqa
"""
# 对于更多限制命令,需要在系统中设置
/etc/sudoers

Defaults    requiretty
Defaults:cmdb    !requiretty

"""
import os
import platform
import paramiko
from re import compile

__all__ = ["SSHProxy", "get_sys", "get_home"]

split_reg = compile(r'[\\|/]')


class SSHProxy(object):

    def __init__(
            self,
            hostname,
            username="root",
            port=22,
            private_key_path=None,
            password=None):
        self.hostname = hostname
        self.port = port
        self.username = username
        self.private_key_path = main_path(private_key_path)
        self.password = password
        self.transport = None
        self.__ssh = None

    def open(self):
        """初始化连接"""
        self.transport = paramiko.Transport((self.hostname, self.port))
        if self.private_key_path:
            private_key = paramiko.RSAKey.from_private_key_file(
                self.private_key_path)
            try:
                self.transport.connect(username=self.username, pkey=private_key)
            except paramiko.ssh_exception.AuthenticationException as e:
                if self.password:
                    self.transport.auth_password(username=self.username, password=self.password)
                else:
                    raise e
        else:
            if self.password:
                self.transport.connect(username=self.username, password=self.password)
            else:
                id_rsa = os.path.join(get_home(), f".ssh", "id_rsa")
                if os.path.isfile(id_rsa):
                    private_key = paramiko.RSAKey.from_private_key_file(id_rsa)
                    self.transport.connect(username=self.username, pkey=private_key)
                    return
                raise paramiko.ssh_exception.AuthenticationException("authentication failed")

    def close(self):
        self.transport.close()
        self.transport = None
        self.__ssh = None

    def command(self,
                cmd,
                buf_size=-1,
                timeout=None,
                get_pty=False,
                environment=None,
                exec_source=False,
                ):
        if self.__ssh is None:
            ssh = paramiko.SSHClient()
            """
            AutoAddPolicy 自动添加主机名及主机密钥到本地HostKeys对象,不依赖load_system_host_key的配置.
                即新建立ssh连接时不需要再输入yes或no进行确认
                1. AutoAddPolicy 自动添加主机名及主机密钥到本地HostKeys对象,不依赖load_system_host_key的配置。
                    即新建立ssh连接时不需要再输入yes或no进行确认
                2.WarningPolicy 用于记录一个未知的主机密钥的python警告。并接受,功能上和AutoAddPolicy类似,但是会提示是新连接
                3.RejectPolicy 自动拒绝未知的主机名和密钥,依赖load_system_host_key的配置。此为默认选项
            """
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh._transport = self.transport
            self.__ssh = ssh
        if exec_source:
            cmd = f"source /etc/profile;{cmd}"
        stdin, stdout, stderr = self.__ssh.exec_command(
            cmd, bufsize=buf_size,
            timeout=timeout,
            get_pty=get_pty,
            environment=environment)
        stdout_str = self.bytes_to_string(stdout.read())
        stderr_str = self.bytes_to_string(stderr.read())
        # ssh.close()
        return {"stdout": stdout_str, "stderr": stderr_str}

    @staticmethod
    def bytes_to_string(bytes_info):
        if isinstance(bytes_info, bytes):
            return str(bytes_info, 'utf8')
        return bytes_info

    def upload(self, local_path, remote_path, callback=None, confirm=True):
        """
        上传文件
        @param local_path:
        @param remote_path:
        @param callback: 可选回调函数(形式:``func(int,int)```),接受
                        到目前为止传输的字节数和要传输的总字节数
        @param confirm:以后是否对文件执行stat()来确认文件
        @return:
        """
        sftp = paramiko.SFTPClient.from_transport(self.transport)
        try:
            sftp.put(main_path(local_path), remote_path, callback, confirm)
        finally:
            sftp.close()

    def download(self, remote_path, local_path, callback=None):
        """

        @param remote_path:
        @param local_path:
        @param callback: 可选回调函数(形式:``func(int,int)```),接受
                        到目前为止传输的字节数和要传输的总字节数
        @return:
        """
        sftp = paramiko.SFTPClient.from_transport(self.transport)
        try:
            sftp.get(remote_path, main_path(local_path), callback)
        finally:
            sftp.close()

    def __enter__(self):
        self.open()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def main_path(path: str):
    """
    路径总函数
    @param path:
    @return:
    """
    if path is None:
        return path
    current_path = os.getcwd()
    if path.startswith('..'):
        path = _wne_path(_parse(path, current_path))
    elif path.startswith('.'):
        path = _wne_path(path[1:], current_path)
    elif path:
        path = _wne_path(path)
    else:
        raise ValueError('文件路径错误')
    return path


def _parse(path: str, current_path):
    """
    解析 .. 路径
    :param path:
    :param current_path:
    :return:
    """
    new_path_args = list(filter(lambda x: x != '', _get_path_params(path)))
    row = 0
    while row < len(new_path_args):
        if new_path_args[row] == '..':
            current_path = os.path.dirname(current_path)
            new_path_args.remove(new_path_args[row])
            row -= 1
        else:
            break
        row += 1
    return os.path.join(current_path, *new_path_args)


def _wne_path(new_path: str, current_path=None):
    new_path_args = _get_path_params(new_path)
    if current_path:
        path = os.path.join(current_path, *new_path_args)
    else:
        sys_str = get_sys()
        if sys_str == "Windows":
            if new_path_args[0].find(':') != -1:
                new_path_args[0] += os.sep
            path = os.path.join(*new_path_args)
        elif sys_str in ["Linux", "Mac", "Darwin"]:
            if new_path_args[0] == '':
                new_path_args[0] = os.sep
            path = os.path.join(*new_path_args)
        else:
            path = new_path
    return path


def get_sys():
    """
    平台
    @return:
    """
    sys_str = platform.system()
    return sys_str


def get_home(path='~'):
    """
    家目录
    @return:
    """
    return os.path.expanduser(path)


def _get_path_params(path):
    """
    路径参数
    @param path:
    @return:
    """

    return split_reg.split(path)


if __name__ == '__main__':
    with SSHProxy(
            hostname="192.168.1.242",
            # password="39ca04fbf62d"
            # private_key_path=r"C:\Users\yuhaipeng\.ssh\id_rsa"
    ) as ssh:
        res = ssh.command("bash cat_nginx_acclog.sh", )
        print(res['stdout'])