Skip to content

Commit d8b8993

Browse files
committed
DNSServer code refactoring
get rid of intermediate mem buffers and extra data copies, most of the data could be referenced or copied from the source packet - removed _buffer member - replaced DNSQuestion.QName from uint8_t[] to char* added sanity checks for mem bounds optimize label/packet length calculations other code cleanup
1 parent 9f08964 commit d8b8993

File tree

2 files changed

+59
-62
lines changed

2 files changed

+59
-62
lines changed

libraries/DNSServer/src/DNSServer.cpp

Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,28 @@
99
#define DEBUG_OUTPUT Serial
1010
#endif
1111

12-
DNSServer::DNSServer()
12+
DNSServer::DNSServer() : _port(0), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain)
1313
{
14-
_ttl = htonl(DNS_DEFAULT_TTL);
15-
_errorReplyCode = DNSReplyCode::NonExistentDomain;
16-
_dnsHeader = (DNSHeader*) malloc( sizeof(DNSHeader) ) ;
17-
_dnsQuestion = (DNSQuestion*) malloc( sizeof(DNSQuestion) ) ;
18-
_buffer = NULL;
19-
_currentPacketSize = 0;
20-
_port = 0;
14+
_dnsHeader = new DNSHeader();
15+
_dnsQuestion = new DNSQuestion();
2116
}
2217

2318
DNSServer::~DNSServer()
2419
{
2520
if (_dnsHeader) {
26-
free(_dnsHeader);
27-
_dnsHeader = NULL;
21+
delete _dnsHeader;
22+
_dnsHeader = nullptr;
2823
}
2924
if (_dnsQuestion) {
30-
free(_dnsQuestion);
31-
_dnsQuestion = NULL;
32-
}
33-
if (_buffer) {
34-
free(_buffer);
35-
_buffer = NULL;
25+
delete _dnsQuestion;
26+
_dnsQuestion = nullptr;
3627
}
3728
}
3829

3930
bool DNSServer::start(const uint16_t &port, const String &domainName,
4031
const IPAddress &resolvedIP)
4132
{
4233
_port = port;
43-
_buffer = NULL;
4434
_domainName = domainName;
4535
_resolvedIP[0] = resolvedIP[0];
4636
_resolvedIP[1] = resolvedIP[1];
@@ -64,8 +54,6 @@ void DNSServer::setTTL(const uint32_t &ttl)
6454
void DNSServer::stop()
6555
{
6656
_udp.close();
67-
free(_buffer);
68-
_buffer = NULL;
6957
}
7058

7159
void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
@@ -76,22 +64,30 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
7664

7765
void DNSServer::_handleUDP(AsyncUDPPacket& pkt)
7866
{
79-
_currentPacketSize = pkt.length();
80-
if (!_currentPacketSize) return;
81-
82-
// Allocate buffer for the DNS query
83-
if (_buffer != NULL)
84-
free(_buffer);
85-
_buffer = (unsigned char*)malloc(_currentPacketSize * sizeof(char));
86-
if (_buffer == NULL)
87-
return;
67+
size_t _currentPacketSize = pkt.length();
68+
if (_currentPacketSize < DNS_HEADER_SIZE) return;
69+
70+
// get DNS header (beginning of message)
71+
memcpy( _dnsHeader, pkt.data(), DNS_HEADER_SIZE );
72+
if (_dnsHeader->QR != DNS_QR_QUERY) return; // ignore non-query mesages
8873

89-
// Put the packet received in the buffer and get DNS header (beginning of message)
90-
// and the question
91-
pkt.read(_buffer, _currentPacketSize);
92-
memcpy( _dnsHeader, _buffer, DNS_HEADER_SIZE ) ;
9374
if ( requestIncludesOnlyOneQuestion() )
9475
{
76+
char * enoflbls = strchr((const char*)pkt.data() + DNS_HEADER_SIZE, 0); // find end_of_label marker
77+
++enoflbls; // include null terminator
78+
_dnsQuestion->QName = pkt.data() + DNS_HEADER_SIZE; // we can reference labels from the request
79+
_dnsQuestion->QNameLength = enoflbls - (char*)pkt.data() - DNS_HEADER_SIZE;
80+
/*
81+
check if we aint going out of pkt bounds
82+
proper dns req should have label terminator at least 4 bytes before end of packet
83+
*/
84+
if (_dnsQuestion->QNameLength > _currentPacketSize - sizeof(_dnsQuestion->QType) - sizeof(_dnsQuestion->QClass)) return; // malformed packet
85+
86+
// Copy the QType and QClass
87+
memcpy( &_dnsQuestion->QType, enoflbls, sizeof(_dnsQuestion->QType) );
88+
memcpy( &_dnsQuestion->QClass, enoflbls + sizeof(_dnsQuestion->QType), sizeof(_dnsQuestion->QClass) );
89+
90+
/*
9591
// The QName has a variable length, maximum 255 bytes and is comprised of multiple labels.
9692
// Each label contains a byte to describe its length and the label itself. The list of
9793
// labels terminates with a zero-valued byte. In "github.com", we have two labels "github" & "com"
@@ -108,25 +104,22 @@ void DNSServer::_handleUDP(AsyncUDPPacket& pkt)
108104
// Copy the QType and QClass
109105
memcpy( &_dnsQuestion->QType, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], sizeof(_dnsQuestion->QType) ) ;
110106
memcpy( &_dnsQuestion->QClass, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength + sizeof(_dnsQuestion->QType)], sizeof(_dnsQuestion->QClass) ) ;
107+
*/
111108
}
112109

113-
114-
if (_dnsHeader->QR == DNS_QR_QUERY &&
115-
_dnsHeader->OPCode == DNS_OPCODE_QUERY &&
110+
// will reply with IP only to "*" or if doman matches without www. subdomain
111+
if (_dnsHeader->OPCode == DNS_OPCODE_QUERY &&
116112
requestIncludesOnlyOneQuestion() &&
117-
(_domainName == "*" || getDomainNameWithoutWwwPrefix() == _domainName)
113+
(_domainName == "*" ||
114+
getDomainNameWithoutWwwPrefix((const char*)pkt.data() + DNS_HEADER_SIZE, _dnsQuestion->QNameLength) == _domainName)
118115
)
119116
{
120117
replyWithIP(pkt);
118+
return;
121119
}
122-
else if (_dnsHeader->QR == DNS_QR_QUERY)
123-
{
124-
replyWithCustomCode(pkt);
125-
}
126-
127-
free(_buffer);
128-
_buffer = NULL;
129120

121+
// otherwise reply with custom code
122+
replyWithCustomCode(pkt);
130123
}
131124

132125
bool DNSServer::requestIncludesOnlyOneQuestion()
@@ -138,25 +131,22 @@ bool DNSServer::requestIncludesOnlyOneQuestion()
138131
}
139132

140133

141-
String DNSServer::getDomainNameWithoutWwwPrefix()
134+
String DNSServer::getDomainNameWithoutWwwPrefix(const char* start, size_t len)
142135
{
143-
// Error checking : if the buffer containing the DNS request is a null pointer, return an empty domain
144136
String parsedDomainName("");
145-
if (_buffer == NULL)
146-
return parsedDomainName;
147137

148-
// Set the start of the domain just after the header (12 bytes). If equal to null character, return an empty domain
149-
unsigned char *start = _buffer + DNS_OFFSET_DOMAIN_NAME;
150138
if (*start == 0)
151139
{
152140
return parsedDomainName;
153141
}
154142

143+
parsedDomainName.reserve(len);
155144
int pos = 0;
156145
while(true)
157146
{
158-
unsigned char labelLength = *(start + pos);
159-
for(int i = 0; i < labelLength; i++)
147+
uint8_t labelLength = *(start + pos);
148+
149+
for(uint8_t i = 0; i < labelLength; i++)
160150
{
161151
pos++;
162152
parsedDomainName += (char)*(start + pos);
@@ -186,8 +176,8 @@ void DNSServer::replyWithIP(AsyncUDPPacket& req)
186176

187177
// Write the question
188178
rpl.write(_dnsQuestion->QName, _dnsQuestion->QNameLength) ;
189-
rpl.write( (unsigned char*) &_dnsQuestion->QType, 2 ) ;
190-
rpl.write( (unsigned char*) &_dnsQuestion->QClass, 2 ) ;
179+
rpl.write( (uint8_t*) &_dnsQuestion->QType, 2 ) ;
180+
rpl.write( (uint8_t*) &_dnsQuestion->QClass, 2 ) ;
191181

192182
// Write the answer
193183
// Use DNS name compression : instead of repeating the name in this RNAME occurence,
@@ -209,14 +199,14 @@ void DNSServer::replyWithIP(AsyncUDPPacket& req)
209199

210200
#ifdef DEBUG_ESP_DNS
211201
DEBUG_OUTPUT.printf("DNS responds: %s for %s\n",
212-
IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix().c_str() );
202+
IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix((const char*)rpl.data() + DNS_HEADER_SIZE, _dnsQuestion->QNameLength).c_str() );
213203
#endif
214204
}
215205

216206
void DNSServer::replyWithCustomCode(AsyncUDPPacket& req)
217207
{
218208
_dnsHeader->QR = DNS_QR_RESPONSE;
219-
_dnsHeader->RCode = (unsigned char)_errorReplyCode;
209+
_dnsHeader->RCode = (uint16_t)_errorReplyCode;
220210
_dnsHeader->QDCount = 0;
221211

222212
AsyncUDPMessage rpl(sizeof(DNSHeader));

libraries/DNSServer/src/DNSServer.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#define DNS_OFFSET_DOMAIN_NAME 12 // Offset in bytes to reach the domain name in the DNS message
1010
#define DNS_HEADER_SIZE 12
1111

12-
enum class DNSReplyCode
12+
enum class DNSReplyCode:uint16_t
1313
{
1414
NoError = 0,
1515
FormError = 1,
@@ -66,7 +66,7 @@ struct DNSHeader
6666

6767
struct DNSQuestion
6868
{
69-
uint8_t QName[256] ; //need 1 Byte for zero termination!
69+
const uint8_t* QName;
7070
uint16_t QNameLength ;
7171
uint16_t QType ;
7272
uint16_t QClass ;
@@ -91,18 +91,25 @@ class DNSServer
9191
private:
9292
AsyncUDP _udp;
9393
uint16_t _port;
94+
uint32_t _ttl;
95+
DNSReplyCode _errorReplyCode;
9496
String _domainName;
9597
unsigned char _resolvedIP[4];
96-
int _currentPacketSize;
97-
unsigned char* _buffer;
9898
DNSHeader* _dnsHeader;
99-
uint32_t _ttl;
100-
DNSReplyCode _errorReplyCode;
10199
DNSQuestion* _dnsQuestion ;
102100

103101

104102
void downcaseAndRemoveWwwPrefix(String &domainName);
105-
String getDomainNameWithoutWwwPrefix();
103+
104+
/**
105+
* @brief Get the Domain Name Without Www Prefix object
106+
* scan labels in DNS packet and build a string of a domain name
107+
* truncate any www. label if found
108+
* @param start a pointer to the start of labels records in DNS packet
109+
* @param len labels length
110+
* @return String
111+
*/
112+
String getDomainNameWithoutWwwPrefix(const char* start, size_t len);
106113
bool requestIncludesOnlyOneQuestion();
107114
void replyWithIP(AsyncUDPPacket& req);
108115
void replyWithCustomCode(AsyncUDPPacket& req);

0 commit comments

Comments
 (0)