socket: add support to control no. of connections from one source (#3607)

Introduce MaxConnectionsPerSource= that is number of concurrent
connections allowed per IP.

RFE: 1939
This commit is contained in:
Susant Sahani 2016-08-02 23:18:23 +05:30 committed by Zbigniew Jędrzejewski-Szmek
parent 87edd2b116
commit 9d56542764
7 changed files with 213 additions and 0 deletions

View file

@ -443,6 +443,14 @@
</varlistentry>
<varlistentry>
<term><varname>MaxConnectionsPerSource=</varname></term>
<listitem><para>The maximum number of connections for a service per source IP address.
This is is very similar to the <varname>MaxConnections=</varname> directive
above. Disabled by default.</para>
</listitem>
</varlistentry>
<varlistentry>
<term><varname>KeepAlive=</varname></term>
<listitem><para>Takes a boolean argument. If true, the TCP/IP
stack will send a keep alive message after 2h (depending on

View file

@ -137,6 +137,7 @@ const sd_bus_vtable bus_socket_vtable[] = {
SD_BUS_PROPERTY("Symlinks", "as", NULL, offsetof(Socket, symlinks), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("Mark", "i", bus_property_get_int, offsetof(Socket, mark), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("MaxConnections", "u", bus_property_get_unsigned, offsetof(Socket, max_connections), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("MaxConnectionsPerSource", "u", bus_property_get_unsigned, offsetof(Socket, max_connections_per_source), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("MessageQueueMaxMessages", "x", bus_property_get_long, offsetof(Socket, mq_maxmsg), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("MessageQueueMessageSize", "x", bus_property_get_long, offsetof(Socket, mq_msgsize), SD_BUS_VTABLE_PROPERTY_CONST),
SD_BUS_PROPERTY("ReusePort", "b", bus_property_get_bool, offsetof(Socket, reuse_port), SD_BUS_VTABLE_PROPERTY_CONST),

View file

@ -293,6 +293,7 @@ Socket.DirectoryMode, config_parse_mode, 0,
Socket.Accept, config_parse_bool, 0, offsetof(Socket, accept)
Socket.Writable, config_parse_bool, 0, offsetof(Socket, writable)
Socket.MaxConnections, config_parse_unsigned, 0, offsetof(Socket, max_connections)
Socket.MaxConnectionsPerSource, config_parse_unsigned, 0, offsetof(Socket, max_connections_per_source)
Socket.KeepAlive, config_parse_bool, 0, offsetof(Socket, keep_alive)
Socket.KeepAliveTimeSec, config_parse_sec, 0, offsetof(Socket, keep_alive_time)
Socket.KeepAliveIntervalSec, config_parse_sec, 0, offsetof(Socket, keep_alive_interval)

View file

@ -342,6 +342,7 @@ static void service_done(Unit *u) {
s->bus_name_owner = mfree(s->bus_name_owner);
service_close_socket_fd(s);
s->peer = socket_peer_unref(s->peer);
unit_ref_unset(&s->accept_socket);

View file

@ -152,6 +152,7 @@ struct Service {
pid_t main_pid, control_pid;
int socket_fd;
SocketPeer *peer;
bool socket_fd_selinux_context_net;
bool permissions_start_only;

View file

@ -57,6 +57,7 @@
#include "unit-printf.h"
#include "unit.h"
#include "user-util.h"
#include "in-addr-util.h"
static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
[SOCKET_DEAD] = UNIT_INACTIVE,
@ -77,6 +78,9 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
SocketPeer *socket_peer_new(void);
int socket_find_peer(Socket *s, int fd, SocketPeer **p);
static void socket_init(Unit *u) {
Socket *s = SOCKET(u);
@ -141,11 +145,17 @@ void socket_free_ports(Socket *s) {
static void socket_done(Unit *u) {
Socket *s = SOCKET(u);
SocketPeer *p;
assert(s);
socket_free_ports(s);
while ((p = hashmap_steal_first(s->peers_by_address)))
p->socket = NULL;
s->peers_by_address = hashmap_free(s->peers_by_address);
s->exec_runtime = exec_runtime_unref(s->exec_runtime);
exec_command_free_array(s->exec_command, _SOCKET_EXEC_COMMAND_MAX);
s->control_command = NULL;
@ -468,6 +478,40 @@ static int socket_verify(Socket *s) {
return 0;
}
static void peer_address_hash_func(const void *p, struct siphash *state) {
const SocketPeer *s = p;
assert(s);
if (s->peer.sa.sa_family == AF_INET)
siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
else if (s->peer.sa.sa_family == AF_INET6)
siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
}
static int peer_address_compare_func(const void *a, const void *b) {
const SocketPeer *x = a, *y = b;
if (x->peer.sa.sa_family < y->peer.sa.sa_family)
return -1;
if (x->peer.sa.sa_family > y->peer.sa.sa_family)
return 1;
switch(x->peer.sa.sa_family) {
case AF_INET:
return memcmp(&x->peer.in.sin_addr, &y->peer.in.sin_addr, sizeof(x->peer.in.sin_addr));
case AF_INET6:
return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
}
return -1;
}
const struct hash_ops peer_address_hash_ops = {
.hash = peer_address_hash_func,
.compare = peer_address_compare_func
};
static int socket_load(Unit *u) {
Socket *s = SOCKET(u);
int r;
@ -475,6 +519,10 @@ static int socket_load(Unit *u) {
assert(u);
assert(u->load_state == UNIT_STUB);
r = hashmap_ensure_allocated(&s->peers_by_address, &peer_address_hash_ops);
if (r < 0)
return r;
r = unit_load_fragment_and_dropin(u);
if (r < 0)
return r;
@ -2050,6 +2098,7 @@ static void socket_enter_running(Socket *s, int cfd) {
socket_set_state(s, SOCKET_RUNNING);
} else {
_cleanup_free_ char *prefix = NULL, *instance = NULL, *name = NULL;
_cleanup_(socket_peer_unrefp) SocketPeer *p = NULL;
Service *service;
if (s->n_connections >= s->max_connections) {
@ -2058,6 +2107,21 @@ static void socket_enter_running(Socket *s, int cfd) {
return;
}
if (s->max_connections_per_source > 0) {
r = socket_find_peer(s, cfd, &p);
if (r < 0) {
safe_close(cfd);
return;
}
if (p->n_ref > s->max_connections_per_source) {
log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
safe_close(cfd);
p = NULL;
return;
}
}
r = socket_instantiate_service(s);
if (r < 0)
goto fail;
@ -2099,6 +2163,11 @@ static void socket_enter_running(Socket *s, int cfd) {
cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
s->n_connections++;
if (s->max_connections_per_source > 0) {
service->peer = socket_peer_ref(p);
p = NULL;
}
r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
if (r < 0) {
/* We failed to activate the new service, but it still exists. Let's make sure the service
@ -2244,7 +2313,9 @@ static int socket_stop(Unit *u) {
static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
Socket *s = SOCKET(u);
SocketPeer *k;
SocketPort *p;
Iterator i;
int r;
assert(u);
@ -2295,6 +2366,16 @@ static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
}
}
HASHMAP_FOREACH(k, s->peers_by_address, i) {
_cleanup_free_ char *t = NULL;
r = sockaddr_pretty(&k->peer.sa, FAMILY_ADDRESS_SIZE(k->peer.sa.sa_family), true, true, &t);
if (r < 0)
return r;
unit_serialize_item_format(u, f, "peer", "%u %s", k->n_ref, t);
}
return 0;
}
@ -2458,6 +2539,33 @@ static int socket_deserialize_item(Unit *u, const char *key, const char *value,
}
}
} else if (streq(key, "peer")) {
_cleanup_(socket_peer_unrefp) SocketPeer *p;
int n_ref, skip = 0;
SocketAddress a;
int r;
if (sscanf(value, "%u %n", &n_ref, &skip) < 1 || n_ref < 1)
log_unit_debug(u, "Failed to parse socket peer value: %s", value);
else {
r = socket_address_parse(&a, value+skip);
if (r < 0)
return r;
p = socket_peer_new();
if (!p)
return log_oom();
p->n_ref = n_ref;
memcpy(&p->peer, &a.sockaddr, sizeof(a.sockaddr));
p->socket = s;
r = hashmap_put(s->peers_by_address, p, p);
if (r < 0)
return r;
p = NULL;
}
} else
log_unit_debug(UNIT(s), "Unknown serialization key: %s", key);
@ -2554,6 +2662,83 @@ _pure_ static bool socket_check_gc(Unit *u) {
return s->n_connections > 0;
}
SocketPeer *socket_peer_new(void) {
SocketPeer *p;
p = new0(SocketPeer, 1);
if (!p)
return NULL;
p->n_ref = 1;
return p;
}
SocketPeer *socket_peer_ref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref++;
return p;
}
SocketPeer *socket_peer_unref(SocketPeer *p) {
if (!p)
return NULL;
assert(p->n_ref > 0);
p->n_ref--;
if (p->n_ref > 0)
return NULL;
if (p->socket)
(void) hashmap_remove(p->socket->peers_by_address, p);
free(p);
return NULL;
}
int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
_cleanup_free_ SocketPeer *remote = NULL;
SocketPeer sa, *i;
socklen_t salen = sizeof(sa.peer);
int r;
assert(fd >= 0);
assert(s);
r = getpeername(fd, &sa.peer.sa, &salen);
if (r < 0)
return log_error_errno(errno, "getpeername failed: %m");
i = hashmap_get(s->peers_by_address, &sa);
if (i) {
*p = i;
return 1;
}
remote = socket_peer_new();
if (!remote)
return log_oom();
memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
remote->socket = s;
r = hashmap_put(s->peers_by_address, remote, remote);
if (r < 0)
return r;
*p = remote;
remote = NULL;
return 0;
}
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
SocketPort *p = userdata;
int cfd = -1;

View file

@ -20,6 +20,7 @@
***/
typedef struct Socket Socket;
typedef struct SocketPeer SocketPeer;
#include "mount.h"
#include "service.h"
@ -79,9 +80,12 @@ struct Socket {
LIST_HEAD(SocketPort, ports);
Hashmap *peers_by_address;
unsigned n_accepted;
unsigned n_connections;
unsigned max_connections;
unsigned max_connections_per_source;
unsigned backlog;
unsigned keep_alive_cnt;
@ -164,6 +168,18 @@ struct Socket {
RateLimit trigger_limit;
};
struct SocketPeer {
unsigned n_ref;
Socket *socket;
union sockaddr_union peer;
};
SocketPeer *socket_peer_ref(SocketPeer *p);
SocketPeer *socket_peer_unref(SocketPeer *p);
DEFINE_TRIVIAL_CLEANUP_FUNC(SocketPeer*, socket_peer_unref);
/* Called from the service code when collecting fds */
int socket_collect_fds(Socket *s, int **fds);