20
20
error_markers = [b'error' , b'Permission denied' ]
21
21
22
22
23
+ class PsUtilProcessProxy :
24
+ def __init__ (self , ssh , pid ):
25
+ self .ssh = ssh
26
+ self .pid = pid
27
+
28
+ def kill (self ):
29
+ command = f"kill { self .pid } "
30
+ self .ssh .exec_command (command )
31
+
32
+ def cmdline (self ):
33
+ command = f"ps -p { self .pid } -o cmd --no-headers"
34
+ stdin , stdout , stderr = self .ssh .exec_command (command )
35
+ cmdline = stdout .read ().decode ('utf-8' ).strip ()
36
+ return cmdline .split ()
37
+
38
+
23
39
class RemoteOperations (OsOperations ):
24
40
def __init__ (self , host = "127.0.0.1" , hostname = 'localhost' , port = None , ssh_key = None , username = None ):
25
41
super ().__init__ (username )
@@ -71,7 +87,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa
71
87
self .ssh = self .ssh_connect ()
72
88
73
89
if isinstance (cmd , list ):
74
- cmd = " " .join (cmd )
90
+ cmd = ' ' .join (item . decode ( 'utf-8' ) if isinstance ( item , bytes ) else item for item in cmd )
75
91
if input :
76
92
stdin , stdout , stderr = self .ssh .exec_command (cmd )
77
93
stdin .write (input )
@@ -140,17 +156,6 @@ def is_executable(self, file):
140
156
is_exec = self .exec_command (f"test -x { file } && echo OK" )
141
157
return is_exec == b"OK\n "
142
158
143
- def add_to_path (self , new_path ):
144
- pathsep = self .pathsep
145
- # Check if the directory is already in PATH
146
- path = self .environ ("PATH" )
147
- if new_path not in path .split (pathsep ):
148
- if self .remote :
149
- self .exec_command (f"export PATH={ new_path } { pathsep } { path } " )
150
- else :
151
- os .environ ["PATH" ] = f"{ new_path } { pathsep } { path } "
152
- return pathsep
153
-
154
159
def set_env (self , var_name : str , var_val : str ):
155
160
"""
156
161
Set the value of an environment variable.
@@ -243,9 +248,17 @@ def mkdtemp(self, prefix=None):
243
248
raise ExecUtilException ("Could not create temporary directory." )
244
249
245
250
def mkstemp (self , prefix = None ):
246
- cmd = f"mktemp { prefix } XXXXXX"
247
- filename = self .exec_command (cmd ).strip ()
248
- return filename
251
+ if prefix :
252
+ temp_dir = self .exec_command (f"mktemp { prefix } XXXXX" , encoding = 'utf-8' )
253
+ else :
254
+ temp_dir = self .exec_command ("mktemp" , encoding = 'utf-8' )
255
+
256
+ if temp_dir :
257
+ if not os .path .isabs (temp_dir ):
258
+ temp_dir = os .path .join ('/home' , self .username , temp_dir .strip ())
259
+ return temp_dir
260
+ else :
261
+ raise ExecUtilException ("Could not create temporary directory." )
249
262
250
263
def copytree (self , src , dst ):
251
264
if not os .path .isabs (dst ):
@@ -291,7 +304,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
291
304
data = data .encode (encoding )
292
305
if isinstance (data , list ):
293
306
# ensure each line ends with a newline
294
- data = [s if s . endswith ( ' \n ' ) else s + '\n ' for s in data ]
307
+ data = [( s if isinstance ( s , str ) else s . decode ( 'utf-8' )). rstrip ( ' \n ' ) + '\n ' for s in data ]
295
308
tmp_file .writelines (data )
296
309
else :
297
310
tmp_file .write (data )
@@ -351,8 +364,8 @@ def isfile(self, remote_file):
351
364
352
365
def isdir (self , dirname ):
353
366
cmd = f"if [ -d { dirname } ]; then echo True; else echo False; fi"
354
- response = self .exec_command (cmd , encoding = 'utf-8' )
355
- return response .strip () == "True"
367
+ response = self .exec_command (cmd )
368
+ return response .strip () == b "True"
356
369
357
370
def remove_file (self , filename ):
358
371
cmd = f"rm { filename } "
@@ -366,16 +379,16 @@ def kill(self, pid, signal):
366
379
367
380
def get_pid (self ):
368
381
# Get current process id
369
- return self .exec_command ("echo $$" )
382
+ return int ( self .exec_command ("echo $$" , encoding = 'utf-8' ) )
370
383
371
384
def get_remote_children (self , pid ):
372
385
command = f"pgrep -P { pid } "
373
386
stdin , stdout , stderr = self .ssh .exec_command (command )
374
387
children = stdout .readlines ()
375
- return [int (child_pid .strip ()) for child_pid in children ]
388
+ return [PsUtilProcessProxy ( self . ssh , int (child_pid .strip () )) for child_pid in children ]
376
389
377
390
# Database control
378
- def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , port = 5432 ):
391
+ def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , port = 5432 , ssh_key = None ):
379
392
"""
380
393
Connects to a PostgreSQL database on the remote system.
381
394
Args:
@@ -389,19 +402,26 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432):
389
402
This function establishes a connection to a PostgreSQL database on the remote system using the specified
390
403
parameters. It returns a connection object that can be used to interact with the database.
391
404
"""
392
- with sshtunnel .open_tunnel (
393
- (host , 22 ), # Remote server IP and SSH port
394
- ssh_username = self .username ,
395
- ssh_pkey = self .ssh_key ,
396
- remote_bind_address = (host , port ), # PostgreSQL server IP and PostgreSQL port
397
- local_bind_address = ('localhost' , port ), # Local machine IP and available port
398
- ):
405
+ tunnel = sshtunnel .open_tunnel (
406
+ (host , 22 ), # Remote server IP and SSH port
407
+ ssh_username = user or self .username ,
408
+ ssh_pkey = ssh_key or self .ssh_key ,
409
+ remote_bind_address = (host , port ), # PostgreSQL server IP and PostgreSQL port
410
+ local_bind_address = ('localhost' , port ) # Local machine IP and available port
411
+ )
412
+
413
+ tunnel .start ()
414
+
415
+ try :
399
416
conn = pglib .connect (
400
- host = host ,
401
- port = port ,
417
+ host = host , # change to 'localhost' because we're connecting through a local ssh tunnel
418
+ port = tunnel . local_bind_port , # use the local bind port set up by the tunnel
402
419
dbname = dbname ,
403
- user = user ,
420
+ user = user or self . username ,
404
421
password = password
405
422
)
406
423
407
- return conn
424
+ return conn
425
+ except Exception as e :
426
+ tunnel .stop ()
427
+ raise e
0 commit comments