Systemd/src/socket-proxy/socket-proxyd.c
Zbigniew Jędrzejewski-Szmek 11a1589223 tree-wide: drop license boilerplate
Files which are installed as-is (any .service and other unit files, .conf
files, .policy files, etc), are left as is. My assumption is that SPDX
identifiers are not yet that well known, so it's better to retain the
extended header to avoid any doubt.

I also kept any copyright lines. We can probably remove them, but it'd nice to
obtain explicit acks from all involved authors before doing that.
2018-04-06 18:58:55 +02:00

676 lines
20 KiB
C

/* SPDX-License-Identifier: LGPL-2.1+ */
/***
This file is part of systemd.
Copyright 2013 David Strauss
***/
#include <errno.h>
#include <fcntl.h>
#include <getopt.h>
#include <netdb.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include "sd-daemon.h"
#include "sd-event.h"
#include "sd-resolve.h"
#include "alloc-util.h"
#include "fd-util.h"
#include "log.h"
#include "path-util.h"
#include "set.h"
#include "socket-util.h"
#include "string-util.h"
#include "parse-util.h"
#include "util.h"
#define BUFFER_SIZE (256 * 1024)
static unsigned arg_connections_max = 256;
static const char *arg_remote_host = NULL;
typedef struct Context {
sd_event *event;
sd_resolve *resolve;
Set *listen;
Set *connections;
} Context;
typedef struct Connection {
Context *context;
int server_fd, client_fd;
int server_to_client_buffer[2]; /* a pipe */
int client_to_server_buffer[2]; /* a pipe */
size_t server_to_client_buffer_full, client_to_server_buffer_full;
size_t server_to_client_buffer_size, client_to_server_buffer_size;
sd_event_source *server_event_source, *client_event_source;
sd_resolve_query *resolve_query;
} Connection;
static void connection_free(Connection *c) {
assert(c);
if (c->context)
set_remove(c->context->connections, c);
sd_event_source_unref(c->server_event_source);
sd_event_source_unref(c->client_event_source);
safe_close(c->server_fd);
safe_close(c->client_fd);
safe_close_pair(c->server_to_client_buffer);
safe_close_pair(c->client_to_server_buffer);
sd_resolve_query_unref(c->resolve_query);
free(c);
}
static void context_free(Context *context) {
assert(context);
set_free_with_destructor(context->listen, sd_event_source_unref);
set_free_with_destructor(context->connections, connection_free);
sd_event_unref(context->event);
sd_resolve_unref(context->resolve);
}
static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
int r;
assert(c);
assert(buffer);
assert(sz);
if (buffer[0] >= 0)
return 0;
r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
if (r < 0)
return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
(void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
r = fcntl(buffer[0], F_GETPIPE_SZ);
if (r < 0)
return log_error_errno(errno, "Failed to get pipe buffer size: %m");
assert(r > 0);
*sz = r;
return 0;
}
static int connection_shovel(
Connection *c,
int *from, int buffer[2], int *to,
size_t *full, size_t *sz,
sd_event_source **from_source, sd_event_source **to_source) {
bool shoveled;
assert(c);
assert(from);
assert(buffer);
assert(buffer[0] >= 0);
assert(buffer[1] >= 0);
assert(to);
assert(full);
assert(sz);
assert(from_source);
assert(to_source);
do {
ssize_t z;
shoveled = false;
if (*full < *sz && *from >= 0 && *to >= 0) {
z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
if (z > 0) {
*full += z;
shoveled = true;
} else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
*from_source = sd_event_source_unref(*from_source);
*from = safe_close(*from);
} else if (!IN_SET(errno, EAGAIN, EINTR))
return log_error_errno(errno, "Failed to splice: %m");
}
if (*full > 0 && *to >= 0) {
z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
if (z > 0) {
*full -= z;
shoveled = true;
} else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
*to_source = sd_event_source_unref(*to_source);
*to = safe_close(*to);
} else if (!IN_SET(errno, EAGAIN, EINTR))
return log_error_errno(errno, "Failed to splice: %m");
}
} while (shoveled);
return 0;
}
static int connection_enable_event_sources(Connection *c);
static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
Connection *c = userdata;
int r;
assert(s);
assert(fd >= 0);
assert(c);
r = connection_shovel(c,
&c->server_fd, c->server_to_client_buffer, &c->client_fd,
&c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
&c->server_event_source, &c->client_event_source);
if (r < 0)
goto quit;
r = connection_shovel(c,
&c->client_fd, c->client_to_server_buffer, &c->server_fd,
&c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
&c->client_event_source, &c->server_event_source);
if (r < 0)
goto quit;
/* EOF on both sides? */
if (c->server_fd == -1 && c->client_fd == -1)
goto quit;
/* Server closed, and all data written to client? */
if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
goto quit;
/* Client closed, and all data written to server? */
if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
goto quit;
r = connection_enable_event_sources(c);
if (r < 0)
goto quit;
return 1;
quit:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int connection_enable_event_sources(Connection *c) {
uint32_t a = 0, b = 0;
int r;
assert(c);
if (c->server_to_client_buffer_full > 0)
b |= EPOLLOUT;
if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
a |= EPOLLIN;
if (c->client_to_server_buffer_full > 0)
a |= EPOLLOUT;
if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
b |= EPOLLIN;
if (c->server_event_source)
r = sd_event_source_set_io_events(c->server_event_source, a);
else if (c->server_fd >= 0)
r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
else
r = 0;
if (r < 0)
return log_error_errno(r, "Failed to set up server event source: %m");
if (c->client_event_source)
r = sd_event_source_set_io_events(c->client_event_source, b);
else if (c->client_fd >= 0)
r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
else
r = 0;
if (r < 0)
return log_error_errno(r, "Failed to set up client event source: %m");
return 0;
}
static int connection_complete(Connection *c) {
int r;
assert(c);
r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
if (r < 0)
goto fail;
r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
if (r < 0)
goto fail;
r = connection_enable_event_sources(c);
if (r < 0)
goto fail;
return 0;
fail:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
Connection *c = userdata;
socklen_t solen;
int error, r;
assert(s);
assert(fd >= 0);
assert(c);
solen = sizeof(error);
r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
if (r < 0) {
log_error_errno(errno, "Failed to issue SO_ERROR: %m");
goto fail;
}
if (error != 0) {
log_error_errno(error, "Failed to connect to remote host: %m");
goto fail;
}
c->client_event_source = sd_event_source_unref(c->client_event_source);
return connection_complete(c);
fail:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
int r;
assert(c);
assert(sa);
assert(salen);
c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
if (c->client_fd < 0) {
log_error_errno(errno, "Failed to get remote socket: %m");
goto fail;
}
r = connect(c->client_fd, sa, salen);
if (r < 0) {
if (errno == EINPROGRESS) {
r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
if (r < 0) {
log_error_errno(r, "Failed to add connection socket: %m");
goto fail;
}
r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
if (r < 0) {
log_error_errno(r, "Failed to enable oneshot event source: %m");
goto fail;
}
} else {
log_error_errno(errno, "Failed to connect to remote host: %m");
goto fail;
}
} else {
r = connection_complete(c);
if (r < 0)
goto fail;
}
return 0;
fail:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
Connection *c = userdata;
assert(q);
assert(c);
if (ret != 0) {
log_error("Failed to resolve host: %s", gai_strerror(ret));
goto fail;
}
c->resolve_query = sd_resolve_query_unref(c->resolve_query);
return connection_start(c, ai->ai_addr, ai->ai_addrlen);
fail:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int resolve_remote(Connection *c) {
static const struct addrinfo hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
.ai_flags = AI_ADDRCONFIG
};
union sockaddr_union sa = {};
const char *node, *service;
int r;
if (path_is_absolute(arg_remote_host)) {
sa.un.sun_family = AF_UNIX;
strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path));
return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
}
if (arg_remote_host[0] == '@') {
sa.un.sun_family = AF_UNIX;
sa.un.sun_path[0] = 0;
strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1);
return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
}
service = strrchr(arg_remote_host, ':');
if (service) {
node = strndupa(arg_remote_host, service - arg_remote_host);
service++;
} else {
node = arg_remote_host;
service = "80";
}
log_debug("Looking up address info for %s:%s", node, service);
r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
if (r < 0) {
log_error_errno(r, "Failed to resolve remote host: %m");
goto fail;
}
return 0;
fail:
connection_free(c);
return 0; /* ignore errors, continue serving */
}
static int add_connection_socket(Context *context, int fd) {
Connection *c;
int r;
assert(context);
assert(fd >= 0);
if (set_size(context->connections) > arg_connections_max) {
log_warning("Hit connection limit, refusing connection.");
safe_close(fd);
return 0;
}
r = set_ensure_allocated(&context->connections, NULL);
if (r < 0) {
log_oom();
return 0;
}
c = new0(Connection, 1);
if (!c) {
log_oom();
return 0;
}
c->context = context;
c->server_fd = fd;
c->client_fd = -1;
c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
r = set_put(context->connections, c);
if (r < 0) {
free(c);
log_oom();
return 0;
}
return resolve_remote(c);
}
static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
_cleanup_free_ char *peer = NULL;
Context *context = userdata;
int nfd = -1, r;
assert(s);
assert(fd >= 0);
assert(revents & EPOLLIN);
assert(context);
nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
if (nfd < 0) {
if (errno != -EAGAIN)
log_warning_errno(errno, "Failed to accept() socket: %m");
} else {
getpeername_pretty(nfd, true, &peer);
log_debug("New connection from %s", strna(peer));
r = add_connection_socket(context, nfd);
if (r < 0) {
log_error_errno(r, "Failed to accept connection, ignoring: %m");
safe_close(fd);
}
}
r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
if (r < 0) {
log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
sd_event_exit(context->event, r);
return r;
}
return 1;
}
static int add_listen_socket(Context *context, int fd) {
sd_event_source *source;
int r;
assert(context);
assert(fd >= 0);
r = set_ensure_allocated(&context->listen, NULL);
if (r < 0) {
log_oom();
return r;
}
r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
if (r < 0)
return log_error_errno(r, "Failed to determine socket type: %m");
if (r == 0) {
log_error("Passed in socket is not a stream socket.");
return -EINVAL;
}
r = fd_nonblock(fd, true);
if (r < 0)
return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
if (r < 0)
return log_error_errno(r, "Failed to add event source: %m");
r = set_put(context->listen, source);
if (r < 0) {
log_error_errno(r, "Failed to add source to set: %m");
sd_event_source_unref(source);
return r;
}
/* Set the watcher to oneshot in case other processes are also
* watching to accept(). */
r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
if (r < 0)
return log_error_errno(r, "Failed to enable oneshot mode: %m");
return 0;
}
static void help(void) {
printf("%1$s [HOST:PORT]\n"
"%1$s [SOCKET]\n\n"
"Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
" -c --connections-max= Set the maximum number of connections to be accepted\n"
" -h --help Show this help\n"
" --version Show package version\n",
program_invocation_short_name);
}
static int parse_argv(int argc, char *argv[]) {
enum {
ARG_VERSION = 0x100,
ARG_IGNORE_ENV
};
static const struct option options[] = {
{ "connections-max", required_argument, NULL, 'c' },
{ "help", no_argument, NULL, 'h' },
{ "version", no_argument, NULL, ARG_VERSION },
{}
};
int c, r;
assert(argc >= 0);
assert(argv);
while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
switch (c) {
case 'h':
help();
return 0;
case 'c':
r = safe_atou(optarg, &arg_connections_max);
if (r < 0) {
log_error("Failed to parse --connections-max= argument: %s", optarg);
return r;
}
if (arg_connections_max < 1) {
log_error("Connection limit is too low.");
return -EINVAL;
}
break;
case ARG_VERSION:
return version();
case '?':
return -EINVAL;
default:
assert_not_reached("Unhandled option");
}
if (optind >= argc) {
log_error("Not enough parameters.");
return -EINVAL;
}
if (argc != optind+1) {
log_error("Too many parameters.");
return -EINVAL;
}
arg_remote_host = argv[optind];
return 1;
}
int main(int argc, char *argv[]) {
Context context = {};
int r, n, fd;
log_parse_environment();
log_open();
r = parse_argv(argc, argv);
if (r <= 0)
goto finish;
r = sd_event_default(&context.event);
if (r < 0) {
log_error_errno(r, "Failed to allocate event loop: %m");
goto finish;
}
r = sd_resolve_default(&context.resolve);
if (r < 0) {
log_error_errno(r, "Failed to allocate resolver: %m");
goto finish;
}
r = sd_resolve_attach_event(context.resolve, context.event, 0);
if (r < 0) {
log_error_errno(r, "Failed to attach resolver: %m");
goto finish;
}
sd_event_set_watchdog(context.event, true);
n = sd_listen_fds(1);
if (n < 0) {
log_error("Failed to receive sockets from parent.");
r = n;
goto finish;
} else if (n == 0) {
log_error("Didn't get any sockets passed in.");
r = -EINVAL;
goto finish;
}
for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
r = add_listen_socket(&context, fd);
if (r < 0)
goto finish;
}
r = sd_event_loop(context.event);
if (r < 0) {
log_error_errno(r, "Failed to run event loop: %m");
goto finish;
}
finish:
context_free(&context);
return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
}