Skip to content

Commit da13d0a

Browse files
fibonacci-matrixJonasVautherinjulianoes
authored
core: Correctly close sockets (#2357)
* core: Correctly close sockets * Update socket_holder.cpp - Fix newline style * Update socket_holder.cpp * Update src/mavsdk/core/tcp_client_connection.cpp Co-authored-by: Jonas Vautherin <[email protected]> * Update socket_holder.cpp - fix code style * Update tcp_client_connection.cpp - fix code style * Update tcp_server_connection.cpp - fix code style * Update socket_holder.h - fix code style * Update src/mavsdk/core/socket_holder.h Co-authored-by: Julian Oes <[email protected]> * Update socket_holder.h - use 64 bit descriptor type on Win64 * Update socket_holder.cpp - use 64 bit descriptor type on Win64 * Update socket_holder.cpp - minor improving of if logic * Remove default move constructor It is currently not in use, and the default implementation is not suitable because it does not change _fd to INVALID for the object being copied from --------- Co-authored-by: Jonas Vautherin <[email protected]> Co-authored-by: Julian Oes <[email protected]>
1 parent da070c5 commit da13d0a

9 files changed

+140
-79
lines changed

src/mavsdk/core/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ target_sources(mavsdk
5252
server_component.cpp
5353
server_component_impl.cpp
5454
server_plugin_impl_base.cpp
55+
socket_holder.cpp
5556
tcp_client_connection.cpp
5657
tcp_server_connection.cpp
5758
timeout_handler.cpp

src/mavsdk/core/socket_holder.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "socket_holder.h"
2+
3+
#ifndef WINDOWS
4+
#include <sys/socket.h>
5+
#include <unistd.h>
6+
#endif
7+
8+
namespace mavsdk {
9+
10+
SocketHolder::SocketHolder(DescriptorType fd) noexcept : _fd{fd} {}
11+
12+
SocketHolder::~SocketHolder() noexcept
13+
{
14+
close();
15+
}
16+
17+
void SocketHolder::reset(DescriptorType fd) noexcept
18+
{
19+
if (_fd != fd) {
20+
close();
21+
_fd = fd;
22+
}
23+
}
24+
25+
void SocketHolder::close() noexcept
26+
{
27+
if (!empty()) {
28+
#if defined(WINDOWS)
29+
shutdown(_fd, SD_BOTH);
30+
closesocket(_fd);
31+
WSACleanup();
32+
#else
33+
// This should interrupt a recv/recvfrom call.
34+
shutdown(_fd, SHUT_RDWR);
35+
36+
// But on Mac, closing is also needed to stop blocking recv/recvfrom.
37+
::close(_fd);
38+
#endif
39+
_fd = invalid_socket_fd;
40+
}
41+
}
42+
43+
bool SocketHolder::empty() const noexcept
44+
{
45+
return _fd == invalid_socket_fd;
46+
}
47+
48+
SocketHolder::DescriptorType SocketHolder::get() const noexcept
49+
{
50+
return _fd;
51+
}
52+
53+
} // namespace mavsdk

src/mavsdk/core/socket_holder.h

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#if defined(WINDOWS)
4+
#include <winsock2.h>
5+
#endif
6+
7+
namespace mavsdk {
8+
9+
class SocketHolder {
10+
public:
11+
#if defined(WINDOWS)
12+
using DescriptorType = SOCKET;
13+
static constexpr DescriptorType invalid_socket_fd = INVALID_SOCKET;
14+
#else
15+
using DescriptorType = int;
16+
static constexpr DescriptorType invalid_socket_fd = -1;
17+
#endif
18+
19+
SocketHolder() noexcept = default;
20+
explicit SocketHolder(DescriptorType socket_fd) noexcept;
21+
22+
~SocketHolder() noexcept;
23+
24+
void reset(DescriptorType fd) noexcept;
25+
void close() noexcept;
26+
27+
bool empty() const noexcept;
28+
DescriptorType get() const noexcept;
29+
30+
private:
31+
SocketHolder(const SocketHolder&) = delete;
32+
SocketHolder& operator=(const SocketHolder&) = delete;
33+
34+
DescriptorType _fd = invalid_socket_fd;
35+
};
36+
37+
} // namespace mavsdk

src/mavsdk/core/tcp_client_connection.cpp

+10-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <arpa/inet.h>
1616
#include <errno.h>
1717
#include <netdb.h>
18-
#include <unistd.h> // for close()
1918
#endif
2019

2120
#ifndef WINDOWS
@@ -71,9 +70,9 @@ ConnectionResult TcpClientConnection::setup_port()
7170
}
7271
#endif
7372

74-
_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
73+
_socket_fd.reset(socket(AF_INET, SOCK_STREAM, 0));
7574

76-
if (_socket_fd < 0) {
75+
if (_socket_fd.empty()) {
7776
LogErr() << "socket error" << GET_ERROR(errno);
7877
_is_ok = false;
7978
return ConnectionResult::SocketError;
@@ -93,8 +92,10 @@ ConnectionResult TcpClientConnection::setup_port()
9392

9493
memcpy(&remote_addr.sin_addr, hp->h_addr, hp->h_length);
9594

96-
if (connect(_socket_fd, reinterpret_cast<sockaddr*>(&remote_addr), sizeof(struct sockaddr_in)) <
97-
0) {
95+
if (connect(
96+
_socket_fd.get(),
97+
reinterpret_cast<sockaddr*>(&remote_addr),
98+
sizeof(struct sockaddr_in)) < 0) {
9899
LogErr() << "connect error: " << GET_ERROR(errno);
99100
_is_ok = false;
100101
return ConnectionResult::SocketConnectionError;
@@ -113,19 +114,7 @@ ConnectionResult TcpClientConnection::stop()
113114
{
114115
_should_exit = true;
115116

116-
#ifndef WINDOWS
117-
// This should interrupt a recv/recvfrom call.
118-
shutdown(_socket_fd, SHUT_RDWR);
119-
120-
// But on Mac, closing is also needed to stop blocking recv/recvfrom.
121-
close(_socket_fd);
122-
#else
123-
shutdown(_socket_fd, SD_BOTH);
124-
125-
closesocket(_socket_fd);
126-
127-
WSACleanup();
128-
#endif
117+
_socket_fd.close();
129118

130119
if (_recv_thread) {
131120
_recv_thread->join();
@@ -175,7 +164,7 @@ bool TcpClientConnection::send_message(const mavlink_message_t& message)
175164
#endif
176165

177166
const auto send_len = sendto(
178-
_socket_fd,
167+
_socket_fd.get(),
179168
reinterpret_cast<char*>(buffer),
180169
buffer_len,
181170
flags,
@@ -202,7 +191,7 @@ void TcpClientConnection::receive()
202191
setup_port();
203192
}
204193

205-
const auto recv_len = recv(_socket_fd, buffer, sizeof(buffer), 0);
194+
const auto recv_len = recv(_socket_fd.get(), buffer, sizeof(buffer), 0);
206195

207196
if (recv_len == 0) {
208197
// This can happen when shutdown is called on the socket,
@@ -212,7 +201,7 @@ void TcpClientConnection::receive()
212201
}
213202

214203
if (recv_len < 0) {
215-
// This happens on desctruction when close(_socket_fd) is called,
204+
// This happens on destruction when close(_socket_fd.get()) is called,
216205
// therefore be quiet.
217206
// LogErr() << "recvfrom error: " << GET_ERROR(errno);
218207
// Something went wrong, we should try to re-connect in next iteration.

src/mavsdk/core/tcp_client_connection.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include "socket_holder.h"
4+
35
#include <atomic>
46
#include <mutex>
57
#include <memory>
@@ -43,7 +45,7 @@ class TcpClientConnection : public Connection {
4345
int _remote_port_number;
4446

4547
std::mutex _mutex = {};
46-
int _socket_fd = -1;
48+
SocketHolder _socket_fd;
4749

4850
std::unique_ptr<std::thread> _recv_thread{};
4951
std::atomic_bool _should_exit;

src/mavsdk/core/tcp_server_connection.cpp

+23-27
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <arpa/inet.h>
1919
#include <errno.h>
2020
#include <netdb.h>
21-
#include <unistd.h> // for close()
2221
#endif
2322

2423
#ifndef WINDOWS
@@ -57,8 +56,8 @@ ConnectionResult TcpServerConnection::start()
5756
}
5857
#endif
5958

60-
_server_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
61-
if (_server_socket_fd < 0) {
59+
_server_socket_fd.reset(socket(AF_INET, SOCK_STREAM, 0));
60+
if (_server_socket_fd.empty()) {
6261
LogErr() << "socket error: " << GET_ERROR(errno);
6362
return ConnectionResult::SocketError;
6463
}
@@ -68,13 +67,15 @@ ConnectionResult TcpServerConnection::start()
6867
server_addr.sin_addr.s_addr = INADDR_ANY;
6968
server_addr.sin_port = htons(_local_port);
7069

71-
if (bind(_server_socket_fd, reinterpret_cast<sockaddr*>(&server_addr), sizeof(server_addr)) <
72-
0) {
70+
if (bind(
71+
_server_socket_fd.get(),
72+
reinterpret_cast<sockaddr*>(&server_addr),
73+
sizeof(server_addr)) < 0) {
7374
LogErr() << "bind error: " << GET_ERROR(errno);
7475
return ConnectionResult::SocketError;
7576
}
7677

77-
if (listen(_server_socket_fd, 3) < 0) {
78+
if (listen(_server_socket_fd.get(), 3) < 0) {
7879
LogErr() << "listen error: " << GET_ERROR(errno);
7980
return ConnectionResult::SocketError;
8081
}
@@ -89,16 +90,8 @@ ConnectionResult TcpServerConnection::stop()
8990
{
9091
_should_exit = true;
9192

92-
#ifndef WINDOWS
93-
shutdown(_client_socket_fd, SHUT_RDWR);
94-
close(_client_socket_fd);
95-
close(_server_socket_fd);
96-
#else
97-
shutdown(_client_socket_fd, SD_BOTH);
98-
closesocket(_client_socket_fd);
99-
closesocket(_server_socket_fd);
100-
WSACleanup();
101-
#endif
93+
_client_socket_fd.close();
94+
_server_socket_fd.close();
10295

10396
if (_accept_receive_thread && _accept_receive_thread->joinable()) {
10497
_accept_receive_thread->join();
@@ -126,7 +119,7 @@ bool TcpServerConnection::send_message(const mavlink_message_t& message)
126119
#endif
127120

128121
const auto send_len =
129-
send(_client_socket_fd, reinterpret_cast<const char*>(buffer), buffer_len, flags);
122+
send(_client_socket_fd.get(), reinterpret_cast<const char*>(buffer), buffer_len, flags);
130123

131124
if (send_len != buffer_len) {
132125
LogErr() << "send failure: " << GET_ERROR(errno);
@@ -140,27 +133,28 @@ void TcpServerConnection::accept_client()
140133
#ifdef WINDOWS
141134
// Set server socket to non-blocking
142135
u_long iMode = 1;
143-
int iResult = ioctlsocket(_server_socket_fd, FIONBIO, &iMode);
136+
int iResult = ioctlsocket(_server_socket_fd.get(), FIONBIO, &iMode);
144137
if (iResult != 0) {
145138
LogErr() << "ioctlsocket failed with error: " << WSAGetLastError();
146139
}
147140
#else
148141
// Set server socket to non-blocking
149-
int flags = fcntl(_server_socket_fd, F_GETFL, 0);
150-
fcntl(_server_socket_fd, F_SETFL, flags | O_NONBLOCK);
142+
int flags = fcntl(_server_socket_fd.get(), F_GETFL, 0);
143+
fcntl(_server_socket_fd.get(), F_SETFL, flags | O_NONBLOCK);
151144
#endif
152145

153146
while (!_should_exit) {
154147
fd_set readfds;
155148
FD_ZERO(&readfds);
156-
FD_SET(_server_socket_fd, &readfds);
149+
FD_SET(_server_socket_fd.get(), &readfds);
157150

158151
// Set timeout to 1 second
159152
timeval timeout;
160153
timeout.tv_sec = 1;
161154
timeout.tv_usec = 0;
162155

163-
const int activity = select(_server_socket_fd + 1, &readfds, nullptr, nullptr, &timeout);
156+
const int activity =
157+
select(_server_socket_fd.get() + 1, &readfds, nullptr, nullptr, &timeout);
164158

165159
if (activity < 0 && errno != EINTR) {
166160
LogErr() << "select error: " << GET_ERROR(errno);
@@ -172,13 +166,15 @@ void TcpServerConnection::accept_client()
172166
continue;
173167
}
174168

175-
if (FD_ISSET(_server_socket_fd, &readfds)) {
169+
if (FD_ISSET(_server_socket_fd.get(), &readfds)) {
176170
sockaddr_in client_addr{};
177171
socklen_t client_addr_len = sizeof(client_addr);
178172

179-
_client_socket_fd = accept(
180-
_server_socket_fd, reinterpret_cast<sockaddr*>(&client_addr), &client_addr_len);
181-
if (_client_socket_fd < 0) {
173+
_client_socket_fd.reset(accept(
174+
_server_socket_fd.get(),
175+
reinterpret_cast<sockaddr*>(&client_addr),
176+
&client_addr_len));
177+
if (_client_socket_fd.empty()) {
182178
if (_should_exit) {
183179
return;
184180
}
@@ -197,7 +193,7 @@ void TcpServerConnection::receive()
197193

198194
bool dataReceived = false;
199195
while (!dataReceived && !_should_exit) {
200-
const auto recv_len = recv(_client_socket_fd, buffer.data(), buffer.size(), 0);
196+
const auto recv_len = recv(_client_socket_fd.get(), buffer.data(), buffer.size(), 0);
201197

202198
#ifdef WINDOWS
203199
if (recv_len == SOCKET_ERROR) {

src/mavsdk/core/tcp_server_connection.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "connection.h"
4+
#include "socket_holder.h"
45

56
#include <atomic>
67
#include <string>
@@ -28,8 +29,8 @@ class TcpServerConnection : public Connection {
2829
Connection::ReceiverCallback _receiver_callback;
2930
std::string _local_ip;
3031
int _local_port;
31-
int _server_socket_fd{-1};
32-
int _client_socket_fd{-1};
32+
SocketHolder _server_socket_fd;
33+
SocketHolder _client_socket_fd;
3334
std::unique_ptr<std::thread> _accept_receive_thread;
3435
std::atomic<bool> _should_exit{false};
3536
};

0 commit comments

Comments
 (0)