# -*- coding:utf-8 -*-
import sys
from multiprocessing import Process, Value as PValue, current_process
from os.path import getsize
from threading import Thread
from paramiko import SSHClient, AutoAddPolicy
from hypernets.utils import logging
from hypernets.utils.common import Counter
logger = logging.get_logger(__name__)
[docs]class DumpFileThread(Thread):
counter = Counter()
def __init__(self, in_file_handle, out_file_handle, buf_size=16):
super(DumpFileThread, self).__init__()
assert in_file_handle and out_file_handle
# self.name = f'{self.__class__.__name__}-{self.counter()}'
self.name = f'{self.__class__.__name__}-{current_process().pid}-{self.counter()}'
self.in_file_handle = in_file_handle
self.out_file_handle = out_file_handle
self.buf_size = buf_size
[docs] def run(self):
data = self.in_file_handle.read(self.buf_size)
while data and len(data) > 0:
self.out_file_handle.write(data)
data = self.in_file_handle.read(self.buf_size)
[docs]class SshProcess(Process):
def __init__(self, ssh_host, ssh_port, cmd, in_file, out_file, err_file, environment=None):
super(SshProcess, self).__init__()
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.cmd = cmd
self.in_file = in_file
self.out_file = out_file
self.err_file = err_file
self.environment = environment
self._exit_code = PValue('i', -1)
[docs] def run(self, verbose=False):
if verbose and logger.is_info_enabled():
logger.info(f'[{self.name}] [SSH {self.ssh_host}]: {self.cmd}')
try:
code = self.ssh_run(self.ssh_host, self.ssh_port,
self.cmd,
self.in_file,
self.out_file,
self.err_file,
self.environment)
except KeyboardInterrupt:
code = 137
if verbose and logger.is_info_enabled():
logger.info(f'[{self.name}] [SSH {self.ssh_host}] {self.cmd} done with {code}')
self._exit_code.value = code
[docs] @staticmethod
def ssh_run(ssh_host, ssh_port, cmd, in_file, out_file, err_file, environment):
with SSHClient() as ssh:
ssh.set_missing_host_key_policy(AutoAddPolicy())
ssh.connect(ssh_host, ssh_port)
stdin, stdout, stderr = ssh.exec_command(cmd, bufsize=10, environment=environment)
if in_file and getsize(in_file) > 0:
with open(in_file, 'rb') as f:
data = f.read()
stdin.write(data)
stdin.flush()
channel = stdout.channel
# channel.settimeout(0.1)
if out_file and err_file:
with open(out_file, 'wb', buffering=0)as o, open(err_file, 'wb', buffering=0) as e:
threads = [DumpFileThread(stdout, o), DumpFileThread(stderr, e)]
for p in threads: p.start()
for p in threads: p.join()
else:
threads = [DumpFileThread(stdout, sys.stdout), DumpFileThread(stderr, sys.stderr)]
for p in threads: p.start()
for p in threads: p.join()
assert channel.exit_status_ready()
code = channel.recv_exit_status()
return code
@property
def exitcode(self):
code = self._exit_code.value
return code if code >= 0 else None