socket-proxy: use hash_ops with destructor for managing Connection

This also renames context_clear() -> context_done(), to follow our
recent coding style.
This commit is contained in:
Yu Watanabe 2025-04-13 01:16:00 +09:00
parent 234b86a444
commit f4a717aa4d
1 changed files with 51 additions and 52 deletions
src/socket-proxy

View File

@ -80,6 +80,24 @@ static Connection* connection_free(Connection *c) {
return mfree(c);
}
DEFINE_TRIVIAL_CLEANUP_FUNC(Connection*, connection_free);
DEFINE_PRIVATE_HASH_OPS_WITH_VALUE_DESTRUCTOR(
connection_hash_ops,
void, trivial_hash_func, trivial_compare_func,
Connection, connection_free);
static void context_done(Context *context) {
assert(context);
set_free_with_destructor(context->listen, sd_event_source_unref);
set_free(context->connections);
sd_event_unref(context->event);
sd_resolve_unref(context->resolve);
sd_event_source_unref(context->idle_time);
}
static int idle_time_cb(sd_event_source *s, uint64_t usec, void *userdata) {
Context *c = userdata;
int r;
@ -119,17 +137,6 @@ static void connection_release(Connection *c) {
context_reset_timer(c->context);
}
static void context_clear(Context *context) {
assert(context);
set_free_with_destructor(context->listen, sd_event_source_unref);
set_free_with_destructor(context->connections, connection_free);
sd_event_unref(context->event);
sd_resolve_unref(context->resolve);
sd_event_source_unref(context->idle_time);
}
static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
int r;
@ -456,70 +463,62 @@ fail:
return 0; /* ignore errors, continue serving */
}
static int add_connection_socket(Context *context, int fd) {
Connection *c;
static int context_add_connection(Context *context, int fd) {
int r;
assert(context);
assert(fd >= 0);
if (set_size(context->connections) > arg_connections_max) {
log_warning("Hit connection limit, refusing connection.");
safe_close(fd);
return 0;
_cleanup_close_ int nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
if (nfd < 0) {
if (!ERRNO_IS_ACCEPT_AGAIN(errno))
log_warning_errno(errno, "Failed to accept() socket, ignoring: %m");
return -errno;
}
if (DEBUG_LOGGING) {
_cleanup_free_ char *peer = NULL;
(void) getpeername_pretty(nfd, true, &peer);
log_debug("New connection from %s", strna(peer));
}
if (set_size(context->connections) > arg_connections_max)
return log_warning_errno(SYNTHETIC_ERRNO(EBUSY), "Hit connection limit, refusing connection.");
r = sd_event_source_set_enabled(context->idle_time, SD_EVENT_OFF);
if (r < 0)
log_warning_errno(r, "Unable to disable idle timer, continuing: %m");
c = new(Connection, 1);
if (!c) {
log_oom();
return 0;
}
_cleanup_(connection_freep) Connection *c = new(Connection, 1);
if (!c)
return log_oom();
*c = (Connection) {
.context = context,
.server_fd = fd,
.client_fd = -EBADF,
.server_to_client_buffer = EBADF_PAIR,
.client_to_server_buffer = EBADF_PAIR,
.server_fd = TAKE_FD(nfd),
.client_fd = -EBADF,
.server_to_client_buffer = EBADF_PAIR,
.client_to_server_buffer = EBADF_PAIR,
};
r = set_ensure_put(&context->connections, NULL, c);
if (r < 0) {
free(c);
log_oom();
return 0;
}
r = set_ensure_put(&context->connections, &connection_hash_ops, c);
if (r < 0)
return log_oom();
return resolve_remote(c);
c->context = context;
return resolve_remote(TAKE_PTR(c));
}
static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
_cleanup_free_ char *peer = NULL;
Context *context = ASSERT_PTR(userdata);
int nfd = -EBADF, r;
int r;
assert(s);
assert(fd >= 0);
assert(revents & EPOLLIN);
nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
if (nfd < 0) {
if (!ERRNO_IS_ACCEPT_AGAIN(errno))
log_warning_errno(errno, "Failed to accept() socket: %m");
} else {
(void) getpeername_pretty(nfd, true, &peer);
log_debug("New connection from %s", strna(peer));
r = add_connection_socket(context, nfd);
if (r < 0) {
log_warning_errno(r, "Failed to accept connection, ignoring: %m");
safe_close(nfd);
}
}
if (context_add_connection(context, fd) < 0)
context_reset_timer(context);
r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
if (r < 0)
@ -669,7 +668,7 @@ static int parse_argv(int argc, char *argv[]) {
}
static int run(int argc, char *argv[]) {
_cleanup_(context_clear) Context context = {};
_cleanup_(context_done) Context context = {};
_unused_ _cleanup_(notify_on_cleanup) const char *notify_stop = NULL;
int r, n, fd;