Skip to content

Commit 70ed315

Browse files
projectgusdpgeorge
authored andcommitted
py/malloc: Add mutex for tracked allocations.
Fixes thread safety issue that could cause memory corruption on ports with (MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL) - currently only rp2 and unix have this configuration. Adds unit test for TLS sockets that exercises this code path. I wasn't able to make this fail on rp2, the race condition window is pretty narrow and may not have a direct impact on a quiet system. This work was funded through GitHub Sponsors. Signed-off-by: Angus Gratton <[email protected]>
1 parent bee1fd5 commit 70ed315

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

py/malloc.c

+32-1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,31 @@ void m_free(void *ptr)
209209

210210
#if MICROPY_TRACKED_ALLOC
211211

212+
#if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL
213+
// If there's no GIL, use the GC recursive mutex to protect the tracked node linked list
214+
// under m_tracked_head.
215+
//
216+
// (For ports with GIL, the expectation is to only call tracked alloc functions
217+
// while holding the GIL.)
218+
219+
static inline void m_tracked_node_lock(void) {
220+
mp_thread_recursive_mutex_lock(&MP_STATE_MEM(gc_mutex), 1);
221+
}
222+
223+
static inline void m_tracked_node_unlock(void) {
224+
mp_thread_recursive_mutex_unlock(&MP_STATE_MEM(gc_mutex));
225+
}
226+
227+
#else
228+
229+
static inline void m_tracked_node_lock(void) {
230+
}
231+
232+
static inline void m_tracked_node_unlock(void) {
233+
}
234+
235+
#endif
236+
212237
#define MICROPY_TRACKED_ALLOC_STORE_SIZE (!MICROPY_ENABLE_GC)
213238

214239
typedef struct _m_tracked_node_t {
@@ -222,6 +247,7 @@ typedef struct _m_tracked_node_t {
222247

223248
#if MICROPY_DEBUG_VERBOSE
224249
static size_t m_tracked_count_links(size_t *nb) {
250+
m_tracked_node_lock();
225251
m_tracked_node_t *node = MP_STATE_VM(m_tracked_head);
226252
size_t n = 0;
227253
*nb = 0;
@@ -234,6 +260,7 @@ static size_t m_tracked_count_links(size_t *nb) {
234260
#endif
235261
node = node->next;
236262
}
263+
m_tracked_node_unlock();
237264
return n;
238265
}
239266
#endif
@@ -248,12 +275,14 @@ void *m_tracked_calloc(size_t nmemb, size_t size) {
248275
size_t n = m_tracked_count_links(&nb);
249276
DEBUG_printf("m_tracked_calloc(%u, %u) -> (%u;%u) %p\n", (int)nmemb, (int)size, (int)n, (int)nb, node);
250277
#endif
278+
m_tracked_node_lock();
251279
if (MP_STATE_VM(m_tracked_head) != NULL) {
252280
MP_STATE_VM(m_tracked_head)->prev = node;
253281
}
254282
node->prev = NULL;
255283
node->next = MP_STATE_VM(m_tracked_head);
256284
MP_STATE_VM(m_tracked_head) = node;
285+
m_tracked_node_unlock();
257286
#if MICROPY_TRACKED_ALLOC_STORE_SIZE
258287
node->size = nmemb * size;
259288
#endif
@@ -278,7 +307,8 @@ void m_tracked_free(void *ptr_in) {
278307
size_t nb;
279308
size_t n = m_tracked_count_links(&nb);
280309
DEBUG_printf("m_tracked_free(%p, [%p, %p], nbytes=%u, links=%u;%u)\n", node, node->prev, node->next, (int)data_bytes, (int)n, (int)nb);
281-
#endif
310+
#endif // MICROPY_DEBUG_VERBOSE
311+
m_tracked_node_lock();
282312
if (node->next != NULL) {
283313
node->next->prev = node->prev;
284314
}
@@ -287,6 +317,7 @@ void m_tracked_free(void *ptr_in) {
287317
} else {
288318
MP_STATE_VM(m_tracked_head) = node->next;
289319
}
320+
m_tracked_node_unlock();
290321
m_free(node
291322
#if MICROPY_MALLOC_USES_ALLOCATED_SIZE
292323
#if MICROPY_TRACKED_ALLOC_STORE_SIZE

tests/extmod/ssl_threads.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Ensure that SSL sockets can be allocated from multiple
2+
# threads without thread safety issues
3+
import unittest
4+
5+
try:
6+
import _thread
7+
import io
8+
import tls
9+
import time
10+
except ImportError:
11+
print("SKIP")
12+
raise SystemExit
13+
14+
15+
class TestSocket(io.IOBase):
16+
def write(self, buf):
17+
return len(buf)
18+
19+
def readinto(self, buf):
20+
return 0
21+
22+
def ioctl(self, cmd, arg):
23+
return 0
24+
25+
def setblocking(self, value):
26+
pass
27+
28+
29+
ITERS = 256
30+
31+
32+
class TLSThreads(unittest.TestCase):
33+
def test_sslsocket_threaded(self):
34+
self.done = False
35+
# only run in two threads: too much RAM demand otherwise, and rp2 only
36+
# supports two anyhow
37+
_thread.start_new_thread(self._alloc_many_sockets, (True,))
38+
self._alloc_many_sockets(False)
39+
while not self.done:
40+
time.sleep(0.1)
41+
print("done")
42+
43+
def _alloc_many_sockets(self, set_done_flag):
44+
print("start", _thread.get_ident())
45+
ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
46+
ctx.verify_mode = tls.CERT_NONE
47+
for n in range(ITERS):
48+
s = TestSocket()
49+
s = ctx.wrap_socket(s, do_handshake_on_connect=False)
50+
s.close() # Free associated resources now from thread, not in a GC pass
51+
print("done", _thread.get_ident())
52+
if set_done_flag:
53+
self.done = True
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main()

0 commit comments

Comments
 (0)