Skip to content

Commit b438515

Browse files
author
vshepard
committed
Add ability to skip ssl when connect to PostgresNode
1 parent 569923a commit b438515

File tree

8 files changed

+106
-99
lines changed

8 files changed

+106
-99
lines changed

testgres/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_new_node(name=None, base_dir=None, **kwargs):
4242
return PostgresNode(name=name, base_dir=base_dir, **kwargs)
4343

4444

45-
def get_remote_node(name=None, conn_params=None):
45+
def get_remote_node(name=None):
4646
"""
4747
Simply a wrapper around :class:`.PostgresNode` constructor for remote node.
4848
See :meth:`.PostgresNode.__init__` for details.
@@ -51,4 +51,4 @@ def get_remote_node(name=None, conn_params=None):
5151
ssh_key=None,
5252
username=default_username())
5353
"""
54-
return get_new_node(name=name, conn_params=conn_params)
54+
return get_new_node(name=name)

testgres/node.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def __repr__(self):
126126

127127

128128
class PostgresNode(object):
129-
def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(), bin_dir=None, prefix=None):
129+
def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionParams = ConnectionParams(),
130+
bin_dir=None, prefix=None):
130131
"""
131132
PostgresNode constructor.
132133
@@ -150,13 +151,9 @@ def __init__(self, name=None, base_dir=None, port=None, conn_params: ConnectionP
150151
self.name = name or generate_app_name()
151152
if testgres_config.os_ops:
152153
self.os_ops = testgres_config.os_ops
153-
elif conn_params.ssh_key:
154-
self.os_ops = RemoteOperations(conn_params)
155-
else:
156-
self.os_ops = LocalOperations(conn_params)
157154

158155
self.host = self.os_ops.host
159-
self.port = port or reserve_port()
156+
self.port = port or self.os_ops.port or reserve_port()
160157

161158
self.ssh_key = self.os_ops.ssh_key
162159

@@ -1005,7 +1002,7 @@ def psql(self,
10051002

10061003
# select query source
10071004
if query:
1008-
if self.os_ops.remote:
1005+
if self.os_ops.conn_params.remote:
10091006
psql_params.extend(("-c", '"{}"'.format(query)))
10101007
else:
10111008
psql_params.extend(("-c", query))
@@ -1016,7 +1013,7 @@ def psql(self,
10161013

10171014
# should be the last one
10181015
psql_params.append(dbname)
1019-
if not self.os_ops.remote:
1016+
if not self.os_ops.conn_params.remote:
10201017
# start psql process
10211018
process = subprocess.Popen(psql_params,
10221019
stdin=subprocess.PIPE,

testgres/operations/local_ops.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ class LocalOperations(OsOperations):
4040
def __init__(self, conn_params=None):
4141
if conn_params is None:
4242
conn_params = ConnectionParams()
43-
super(LocalOperations, self).__init__(conn_params.username)
44-
self.conn_params = conn_params
45-
self.host = conn_params.host
46-
self.ssh_key = None
47-
self.remote = False
48-
self.username = conn_params.username or getpass.getuser()
43+
super(LocalOperations, self).__init__(conn_params)
4944

5045
@staticmethod
5146
def _raise_exec_exception(message, command, exit_code, output):
@@ -305,14 +300,3 @@ def get_pid(self):
305300

306301
def get_process_children(self, pid):
307302
return psutil.Process(pid).children()
308-
309-
# Database control
310-
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
311-
conn = pglib.connect(
312-
host=host,
313-
port=port,
314-
database=dbname,
315-
user=user,
316-
password=password,
317-
)
318-
return conn

testgres/operations/os_ops.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212

1313

1414
class ConnectionParams:
15-
def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None):
15+
def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None, remote=False, skip_ssl=False):
16+
"""
17+
skip_ssl: if is True, the connection is established without SSL.
18+
"""
19+
self.remote = remote
1620
self.host = host
1721
self.port = port
1822
self.ssh_key = ssh_key
1923
self.username = username
24+
self.skip_ssl = skip_ssl
2025

2126

2227
def get_default_encoding():
@@ -26,9 +31,12 @@ def get_default_encoding():
2631

2732

2833
class OsOperations:
29-
def __init__(self, username=None):
30-
self.ssh_key = None
31-
self.username = username or getpass.getuser()
34+
def __init__(self, conn_params=ConnectionParams()):
35+
self.ssh_key = conn_params.ssh_key
36+
self.username = conn_params.username or getpass.getuser()
37+
self.host = conn_params.host
38+
self.port = conn_params.port
39+
self.conn_params = conn_params
3240

3341
# Command execution
3442
def exec_command(self, cmd, **kwargs):
@@ -115,4 +123,14 @@ def get_process_children(self, pid):
115123

116124
# Database control
117125
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
118-
raise NotImplementedError()
126+
ssl_options = {"sslmode": "disable"} if self.conn_params.skip_ssl and 'psycopg2' in globals() else {}
127+
conn = pglib.connect(
128+
host=host,
129+
port=port,
130+
database=dbname,
131+
user=user,
132+
password=password,
133+
**({"ssl_context": None} if self.conn_params.skip_ssl and 'pg8000' in globals() else ssl_options)
134+
)
135+
136+
return conn

testgres/operations/remote_ops.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,17 @@ def cmdline(self):
3737

3838
class RemoteOperations(OsOperations):
3939
def __init__(self, conn_params: ConnectionParams):
40-
4140
if not platform.system().lower() == "linux":
4241
raise EnvironmentError("Remote operations are supported only on Linux!")
42+
super().__init__(conn_params)
4343

44-
super().__init__(conn_params.username)
45-
self.conn_params = conn_params
46-
self.host = conn_params.host
47-
self.port = conn_params.port
48-
self.ssh_key = conn_params.ssh_key
4944
self.ssh_args = []
5045
if self.ssh_key:
5146
self.ssh_args += ["-i", self.ssh_key]
5247
if self.port:
5348
self.ssh_args += ["-p", self.port]
54-
self.remote = True
5549
self.username = conn_params.username or getpass.getuser()
56-
self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host
50+
self.ssh_dest = f"{self.username}@{self.host}" if self.username else self.host
5751

5852
def __enter__(self):
5953
return self
@@ -361,17 +355,6 @@ def get_process_children(self, pid):
361355
else:
362356
raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")
363357

364-
# Database control
365-
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
366-
conn = pglib.connect(
367-
host=host,
368-
port=port,
369-
database=dbname,
370-
user=user,
371-
password=password,
372-
)
373-
return conn
374-
375358

376359
def normalize_error(error):
377360
if isinstance(error, bytes):

testgres/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_bin_path(filename):
9797
# check if it's already absolute
9898
if os.path.isabs(filename):
9999
return filename
100-
if tconf.os_ops.remote:
100+
if tconf.os_ops.conn_params.remote:
101101
pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG")
102102
else:
103103
# try PG_CONFIG - get from local machine
@@ -154,7 +154,7 @@ def cache_pg_config_data(cmd):
154154
return _pg_config_data
155155

156156
# try specified pg_config path or PG_CONFIG
157-
if tconf.os_ops.remote:
157+
if tconf.os_ops.conn_params.remote:
158158
pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG")
159159
else:
160160
# try PG_CONFIG - get from local machine

tests/test_remote.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from testgres import ExecUtilException
5+
from testgres import ExecUtilException, get_remote_node, testgres_config
66
from testgres import RemoteOperations
77
from testgres import ConnectionParams
88

@@ -34,7 +34,7 @@ def test_exec_command_failure(self):
3434
exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True)
3535
except ExecUtilException as e:
3636
error = e.message
37-
assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n'
37+
assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n'
3838

3939
def test_is_executable_true(self):
4040
"""
@@ -87,7 +87,7 @@ def test_makedirs_and_rmdirs_failure(self):
8787
exit_status, result, error = self.operations.rmdirs(path, verbose=True)
8888
except ExecUtilException as e:
8989
error = e.message
90-
assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n"
90+
assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n"
9191

9292
def test_listdir(self):
9393
"""
@@ -192,3 +192,25 @@ def test_isfile_false(self):
192192
response = self.operations.isfile(filename)
193193

194194
assert response is False
195+
196+
def test_skip_ssl(self):
197+
conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1',
198+
username=os.getenv('USER'),
199+
remote=True,
200+
skip_ssl=True)
201+
os_ops = RemoteOperations(conn_params)
202+
testgres_config.set_os_ops(os_ops=os_ops)
203+
with get_remote_node().init().start() as node:
204+
with node.connect() as con:
205+
con.begin()
206+
con.execute('create table test(val int)')
207+
con.execute('insert into test values (1)')
208+
con.commit()
209+
210+
con.begin()
211+
con.execute('insert into test values (2)')
212+
res = con.execute('select * from test order by val asc')
213+
if isinstance(res, list):
214+
res.sort()
215+
assert res == [(1,), (2,)]
216+

0 commit comments

Comments
 (0)