Skip to content

Remove fabric and sshtunnel #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
"port-for>=0.4",
"six>=1.9.0",
"psutil",
"packaging",
"fabric",
"sshtunnel"
"packaging"
]

# Add compatibility enum class
Expand All @@ -29,7 +27,7 @@
readme = f.read()

setup(
version='1.9.1',
version='1.9.2',
name='testgres',
packages=['testgres', 'testgres.operations'],
description='Testing utility for PostgreSQL and its extensions',
Expand Down
11 changes: 10 additions & 1 deletion testgres/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ def __str__(self):
if self.out:
msg.append(u'----\n{}'.format(self.out))

return six.text_type('\n').join(msg)
return self.convert_and_join(msg)

@staticmethod
def convert_and_join(msg_list):
# Convert each byte element in the list to str
str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in
msg_list]

# Join the list into a single string with the specified delimiter
return six.text_type('\n').join(str_list)


@six.python_2_unicode_compatible
Expand Down
2 changes: 1 addition & 1 deletion testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def pgbench(self,
# should be the last one
_params.append(dbname)

proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True)
proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True)

return proc

Expand Down
13 changes: 9 additions & 4 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from distutils.spawn import find_executable
from distutils import rmtree


CMD_TIMEOUT_SEC = 60
error_markers = [b'error', b'Permission denied', b'fatal']

Expand All @@ -37,7 +36,8 @@ def __init__(self, conn_params=None):
# Command execution
def exec_command(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None):
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
"""
Execute a command in a subprocess.

Expand Down Expand Up @@ -69,9 +69,14 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
stdout=stdout,
stderr=stderr,
)
if proc:
if get_process:
return process
result, error = process.communicate(input)

try:
result, error = process.communicate(input, timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
exit_status = process.returncode

error_found = exit_status != 0 or any(marker in error for marker in error_markers)
Expand Down
91 changes: 43 additions & 48 deletions testgres/operations/remote_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import os
import subprocess
import tempfile
import time

import sshtunnel
# we support both pg8000 and psycopg2
try:
import psycopg2 as pglib
except ImportError:
try:
import pg8000 as pglib
except ImportError:
raise ImportError("You must have psycopg2 or pg8000 modules installed")

from ..exceptions import ExecUtilException

from .os_ops import OsOperations, ConnectionParams
from .os_ops import pglib

sshtunnel.SSH_TIMEOUT = 5.0
sshtunnel.TUNNEL_TIMEOUT = 5.0

ConsoleEncoding = locale.getdefaultlocale()[1]
if not ConsoleEncoding:
Expand Down Expand Up @@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams):
self.remote = True
self.username = conn_params.username or self.get_user()
self.add_known_host(self.host)
self.tunnel_process = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close_tunnel()
self.close_ssh_tunnel()

def close_tunnel(self):
if getattr(self, 'tunnel', None):
self.tunnel.stop(force=True)
start_time = time.time()
while self.tunnel.is_active:
if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT:
break
time.sleep(0.5)
def establish_ssh_tunnel(self, local_port, remote_port):
"""
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
"""
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)

def close_ssh_tunnel(self):
if hasattr(self, 'tunnel_process'):
self.tunnel_process.terminate()
self.tunnel_process.wait()
del self.tunnel_process
else:
print("No active tunnel to close.")

def add_known_host(self, host):
cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin())
Expand All @@ -78,21 +87,29 @@ def add_known_host(self, host):
raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd,
exit_code=e.returncode, out=e.stderr)

def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False,
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
stderr=None, proc=None):
stderr=None, get_process=None, timeout=None):
"""
Execute a command in the SSH session.
Args:
- cmd (str): The command to be executed.
"""
ssh_cmd = []
if isinstance(cmd, str):
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
elif isinstance(cmd, list):
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process

try:
result, error = process.communicate(input, timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))

result, error = process.communicate(input)
exit_status = process.returncode

if encoding:
Expand Down Expand Up @@ -372,41 +389,19 @@ def get_process_children(self, pid):
raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")

# Database control
def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None):
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
"""
Connects to a PostgreSQL database on the remote system.
Args:
- dbname (str): The name of the database to connect to.
- user (str): The username for the database connection.
- password (str, optional): The password for the database connection. Defaults to None.
- host (str, optional): The IP address of the remote system. Defaults to "localhost".
- port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.

This function establishes a connection to a PostgreSQL database on the remote system using the specified
parameters. It returns a connection object that can be used to interact with the database.
Established SSH tunnel and Connects to a PostgreSQL
"""
self.close_tunnel()
self.tunnel = sshtunnel.open_tunnel(
(self.host, 22), # Remote server IP and SSH port
ssh_username=self.username,
ssh_pkey=self.ssh_key,
remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port
local_bind_address=('localhost', 0)
# Local machine IP and available port (0 means it will pick any available port)
)
self.tunnel.start()

self.establish_ssh_tunnel(local_port=port, remote_port=5432)
try:
# Use localhost and self.tunnel.local_bind_port to connect
conn = pglib.connect(
host='localhost', # Connect to localhost
port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel
host=host,
port=port,
database=dbname,
user=user or self.username,
password=password
user=user,
password=password,
)

return conn
except Exception as e:
self.tunnel.stop()
raise ExecUtilException("Could not create db tunnel. {}".format(e))
raise Exception(f"Could not connect to the database. Error: {e}")
7 changes: 4 additions & 3 deletions tests/test_simple_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,10 @@ def test_pgbench(self):
options=['-q']).pgbench_run(time=2)

# run TPC-B benchmark
out = node.pgbench(stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
options=['-T3'])
proc = node.pgbench(stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
options=['-T3'])
out = proc.communicate()[0]
self.assertTrue(b'tps = ' in out)

def test_pg_config(self):
Expand Down