37
37
#include "py/stream.h"
38
38
#include "py/objstr.h"
39
39
#include "py/reader.h"
40
+ #include "py/mphal.h"
40
41
#include "py/gc.h"
41
42
#include "extmod/vfs.h"
42
43
47
48
#include "mbedtls/pk.h"
48
49
#include "mbedtls/entropy.h"
49
50
#include "mbedtls/ctr_drbg.h"
51
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
52
+ #include "mbedtls/timing.h"
53
+ #endif
50
54
#include "mbedtls/debug.h"
51
55
#include "mbedtls/error.h"
52
56
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
65
69
66
70
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
67
71
72
+ #define MP_ENDPOINT_IS_SERVER (1 << 0)
73
+ #define MP_TRANSPORT_IS_DTLS (1 << 1)
74
+
75
+ #define MP_PROTOCOL_TLS_CLIENT 0
76
+ #define MP_PROTOCOL_TLS_SERVER MP_ENDPOINT_IS_SERVER
77
+ #define MP_PROTOCOL_DTLS_CLIENT MP_TRANSPORT_IS_DTLS
78
+ #define MP_PROTOCOL_DTLS_SERVER MP_ENDPOINT_IS_SERVER | MP_TRANSPORT_IS_DTLS
79
+
68
80
// This corresponds to an SSLContext object.
69
81
typedef struct _mp_obj_ssl_context_t {
70
82
mp_obj_base_t base ;
@@ -91,6 +103,12 @@ typedef struct _mp_obj_ssl_socket_t {
91
103
92
104
uintptr_t poll_mask ; // Indicates which read or write operations the protocol needs next
93
105
int last_error ; // The last error code, if any
106
+
107
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
108
+ mp_uint_t timer_start_ms ;
109
+ mp_uint_t timer_fin_ms ;
110
+ mp_uint_t timer_int_ms ;
111
+ #endif
94
112
} mp_obj_ssl_socket_t ;
95
113
96
114
static const mp_obj_type_t ssl_context_type ;
@@ -242,7 +260,10 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
242
260
mp_arg_check_num (n_args , n_kw , 1 , 1 , false);
243
261
244
262
// This is the "protocol" argument.
245
- mp_int_t endpoint = mp_obj_get_int (args [0 ]);
263
+ mp_int_t protocol = mp_obj_get_int (args [0 ]);
264
+
265
+ int endpoint = (protocol & MP_ENDPOINT_IS_SERVER ) ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ;
266
+ int transport = (protocol & MP_TRANSPORT_IS_DTLS ) ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM ;
246
267
247
268
// Create SSLContext object.
248
269
#if MICROPY_PY_SSL_FINALISER
@@ -282,7 +303,7 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
282
303
}
283
304
284
305
ret = mbedtls_ssl_config_defaults (& self -> conf , endpoint ,
285
- MBEDTLS_SSL_TRANSPORT_STREAM , MBEDTLS_SSL_PRESET_DEFAULT );
306
+ transport , MBEDTLS_SSL_PRESET_DEFAULT );
286
307
if (ret != 0 ) {
287
308
mbedtls_raise_error (ret );
288
309
}
@@ -525,6 +546,39 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
525
546
}
526
547
}
527
548
549
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
550
+ static void _mbedtls_timing_set_delay (void * ctx , uint32_t int_ms , uint32_t fin_ms ) {
551
+ mp_obj_ssl_socket_t * o = (mp_obj_ssl_socket_t * )ctx ;
552
+
553
+ o -> timer_int_ms = int_ms ;
554
+ o -> timer_fin_ms = fin_ms ;
555
+
556
+ if (fin_ms != 0 ) {
557
+ o -> timer_start_ms = mp_hal_ticks_ms ();
558
+ }
559
+ }
560
+
561
+ static int _mbedtls_timing_get_delay (void * ctx ) {
562
+ mp_obj_ssl_socket_t * o = (mp_obj_ssl_socket_t * )ctx ;
563
+
564
+ if (o -> timer_fin_ms == 0 ) {
565
+ return -1 ;
566
+ }
567
+
568
+ mp_uint_t elapsed_ms = mp_hal_ticks_ms () - o -> timer_start_ms ;
569
+
570
+ if (elapsed_ms >= o -> timer_fin_ms ) {
571
+ return 2 ;
572
+ }
573
+
574
+ if (elapsed_ms >= o -> timer_int_ms ) {
575
+ return 1 ;
576
+ }
577
+
578
+ return 0 ;
579
+ }
580
+ #endif
581
+
528
582
static mp_obj_t ssl_socket_make_new (mp_obj_ssl_context_t * ssl_context , mp_obj_t sock ,
529
583
bool server_side , bool do_handshake_on_connect , mp_obj_t server_hostname ) {
530
584
@@ -577,6 +631,10 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
577
631
mp_raise_ValueError (MP_ERROR_TEXT ("CERT_REQUIRED requires server_hostname" ));
578
632
}
579
633
634
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
635
+ mbedtls_ssl_set_timer_cb (& o -> ssl , o , _mbedtls_timing_set_delay , _mbedtls_timing_get_delay );
636
+ #endif
637
+
580
638
mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
581
639
582
640
if (do_handshake_on_connect ) {
@@ -788,6 +846,12 @@ static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
788
846
{ MP_ROM_QSTR (MP_QSTR_readinto ), MP_ROM_PTR (& mp_stream_readinto_obj ) },
789
847
{ MP_ROM_QSTR (MP_QSTR_readline ), MP_ROM_PTR (& mp_stream_unbuffered_readline_obj ) },
790
848
{ MP_ROM_QSTR (MP_QSTR_write ), MP_ROM_PTR (& mp_stream_write_obj ) },
849
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
850
+ { MP_ROM_QSTR (MP_QSTR_recv ), MP_ROM_PTR (& mp_stream_read1_obj ) },
851
+ { MP_ROM_QSTR (MP_QSTR_recv_into ), MP_ROM_PTR (& mp_stream_readinto_obj ) },
852
+ { MP_ROM_QSTR (MP_QSTR_send ), MP_ROM_PTR (& mp_stream_write1_obj ) },
853
+ { MP_ROM_QSTR (MP_QSTR_sendall ), MP_ROM_PTR (& mp_stream_write_obj ) },
854
+ #endif
791
855
{ MP_ROM_QSTR (MP_QSTR_setblocking ), MP_ROM_PTR (& socket_setblocking_obj ) },
792
856
{ MP_ROM_QSTR (MP_QSTR_close ), MP_ROM_PTR (& mp_stream_close_obj ) },
793
857
#if MICROPY_PY_SSL_FINALISER
@@ -879,8 +943,12 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
879
943
880
944
// Constants.
881
945
{ MP_ROM_QSTR (MP_QSTR_MBEDTLS_VERSION ), MP_ROM_PTR (& mbedtls_version_obj )},
882
- { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (MBEDTLS_SSL_IS_CLIENT ) },
883
- { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (MBEDTLS_SSL_IS_SERVER ) },
946
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_CLIENT ), MP_ROM_INT (MP_PROTOCOL_TLS_CLIENT ) },
947
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_TLS_SERVER ), MP_ROM_INT (MP_PROTOCOL_TLS_SERVER ) },
948
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
949
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_DTLS_CLIENT ), MP_ROM_INT (MP_PROTOCOL_DTLS_CLIENT ) },
950
+ { MP_ROM_QSTR (MP_QSTR_PROTOCOL_DTLS_SERVER ), MP_ROM_INT (MP_PROTOCOL_DTLS_SERVER ) },
951
+ #endif
884
952
{ MP_ROM_QSTR (MP_QSTR_CERT_NONE ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_NONE ) },
885
953
{ MP_ROM_QSTR (MP_QSTR_CERT_OPTIONAL ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_OPTIONAL ) },
886
954
{ MP_ROM_QSTR (MP_QSTR_CERT_REQUIRED ), MP_ROM_INT (MBEDTLS_SSL_VERIFY_REQUIRED ) },
0 commit comments