From 63c436fe4aa982bfabd4acd01d66c9f99da1911b Mon Sep 17 00:00:00 2001 From: Adrian Muzyka Date: Wed, 15 Nov 2023 06:37:44 -0500 Subject: [PATCH] Add checkpoint abort --- memcr.c | 559 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 450 insertions(+), 109 deletions(-) diff --git a/memcr.c b/memcr.c index c292ebd..2d7ffff 100644 --- a/memcr.c +++ b/memcr.c @@ -117,6 +117,8 @@ static int proc_mem; static int rss_file; static int compress; static int checksum; +static int service; + #define BIT(x) (1ULL << x) @@ -177,8 +179,97 @@ int __attribute__((weak)) lib__fini(void); #define CHECKPOINTED_PIDS_LIMIT 16 #define PID_INVALID 0 -static pid_t checkpointed_pids[CHECKPOINTED_PIDS_LIMIT]; -static pid_t checkpoint_workers[CHECKPOINTED_PIDS_LIMIT]; +#define STATE_RESTORED 0 +#define STATE_CHECKPOINTING 1 +#define STATE_CHECKPOINTED 2 + +static pthread_mutex_t checkpoint_service_data_lock = PTHREAD_MUTEX_INITIALIZER; +static struct { + pid_t pid; + pid_t worker; + int state; + int checkpoint_abort; + int checkpoint_cmd_sd; +} checkpoint_service_data[CHECKPOINTED_PIDS_LIMIT]; + +static pthread_mutex_t checkpoint_user_lock = PTHREAD_MUTEX_INITIALIZER; +static int checkpoint_user_abort; + +#define SOCKET_INVALID (-1) +static int checkpoint_service_socket = SOCKET_INVALID; + +#define TRUE 1 +#define FALSE 0 + +#define MAX_CLIENT_CONNECTIONS 8 + +struct service_command_ctx { + struct service_command svc_cmd; + int cd; +}; + +static struct { + pthread_mutex_t lock; + pthread_cond_t cond; + struct service_command_ctx svc_ctxs[MAX_CLIENT_CONNECTIONS]; + int front_idx; + int back_idx; + size_t size; + int interrupt; +} service_cmds_ctx = { .lock = PTHREAD_MUTEX_INITIALIZER, .cond = PTHREAD_COND_INITIALIZER }; + +static int service_cmds_push_back(struct service_command_ctx *ctx) +{ + int ret = 0; + pthread_mutex_lock(&service_cmds_ctx.lock); + if (service_cmds_ctx.size >= MAX_CLIENT_CONNECTIONS) { + fprintf(stderr, "[-] %s: Commands queue full\n", __func__); + ret = 1; + goto err; + } + + service_cmds_ctx.svc_ctxs[service_cmds_ctx.back_idx] = *ctx; + service_cmds_ctx.back_idx++; + service_cmds_ctx.size++; + if(service_cmds_ctx.back_idx >= MAX_CLIENT_CONNECTIONS) + service_cmds_ctx.back_idx = 0; + + pthread_cond_signal(&service_cmds_ctx.cond); + +err: + pthread_mutex_unlock(&service_cmds_ctx.lock); + return ret; +} + +static int service_cmds_wait_and_pop_front(struct service_command_ctx *ctx) +{ + int ret = 0; + pthread_mutex_lock(&service_cmds_ctx.lock); + while (service_cmds_ctx.size == 0 && service_cmds_ctx.interrupt == FALSE && ret == 0) + ret = pthread_cond_wait(&service_cmds_ctx.cond, &service_cmds_ctx.lock); + + if (!ret && service_cmds_ctx.size > 0 && service_cmds_ctx.interrupt == FALSE) { + *ctx = service_cmds_ctx.svc_ctxs[service_cmds_ctx.front_idx]; + service_cmds_ctx.front_idx++; + service_cmds_ctx.size--; + if(service_cmds_ctx.front_idx >= MAX_CLIENT_CONNECTIONS) + service_cmds_ctx.front_idx = 0; + } else if (service_cmds_ctx.interrupt == TRUE) + ret = 1; + else + fprintf(stderr, "[-] %s: pthread_cond_wait(): %m\n", __func__); + + pthread_mutex_unlock(&service_cmds_ctx.lock); + return ret; +} + +static void service_cmds_interrupt(void) +{ + pthread_mutex_lock(&service_cmds_ctx.lock); + service_cmds_ctx.interrupt = TRUE; + pthread_cond_signal(&service_cmds_ctx.cond); + pthread_mutex_unlock(&service_cmds_ctx.lock); +} #ifdef CHECKSUM_MD5 #if OPENSSL_VERSION_NUMBER >= 0x30000000L @@ -338,66 +429,6 @@ static void cleanup_pid(pid_t pid) unlink(path); } -static void cleanup_checkpointed_pids(void) -{ - fprintf(stdout, "[i] Terminating checkpointed processes\n"); - for (int i=0; i 0) @@ -721,6 +759,184 @@ static int dump_write(int fd, const void *buf, size_t count) return ret; } +static void init_pid_checkpoint_data(pid_t pid) +{ + pthread_mutex_lock(&checkpoint_service_data_lock); + for (int i=0; iaddr & (PAGE_SIZE - 1)) { @@ -838,7 +1054,7 @@ static int setup_listen_socket(struct sockaddr *addr, socklen_t addrlen) goto err; } - ret = listen(sd, 8); + ret = listen(sd, MAX_CLIENT_CONNECTIONS); if (ret) { fprintf(stderr, "listen() failed: %m\n"); goto err; @@ -1308,8 +1524,12 @@ static int get_target_pages(int pid, struct vm_area vmas[], int nr_vmas) ret = 0; for (idx = 0; idx < nr_vmas; idx++) { - struct vm_area *vma = &vmas[idx]; + if (is_checkpoint_aborted()) { + fprintf(stdout, "[i] get target pages aborted\n"); + break; + } + struct vm_area *vma = &vmas[idx]; ret = get_vma_pages(pd, md, cd, vma, fd); if (ret) break; @@ -1962,7 +2182,7 @@ static void sigchld_handler_service (int sig, siginfo_t *sip, void *notused) int status; if (sip->si_pid == waitpid(sip->si_pid, &status, WNOHANG)) { fprintf(stdout, "[+] Worker %d exit.\n", sip->si_pid); - clear_pid_on_worker_exit(sip->si_pid); + clear_pid_on_worker_exit_non_blocking(sip->si_pid); } } @@ -2136,10 +2356,13 @@ static int application_worker(pid_t pid, int checkpoint_resp_socket) if (rsd < 0) ret |= rsd; + register_socket_for_checkpoint_service_cmds(checkpoint_resp_socket); + if (0 == ret) { ret |= checkpoint_worker(pid); } ret |= send_response_to_service(checkpoint_resp_socket, ret); // send resp to service + clear_socket_for_checkpoint_service_cmds(); close(checkpoint_resp_socket); if (ret) { @@ -2166,6 +2389,33 @@ static int application_worker(pid_t pid, int checkpoint_resp_socket) return ret; } +static void try_to_abort_checkpoint(pid_t pid) +{ + pthread_mutex_lock(&checkpoint_service_data_lock); + for (int i=0; i 0) { - set_pid_checkpointed(svc_cmd.pid, forkpid); close(checkpoint_resp_sockets[1]); - - checkpoint_procedure_service(checkpoint_resp_sockets[0], cd); + set_pid_checkpointing(svc_ctx.svc_cmd.pid, checkpoint_resp_sockets[0]); + checkpoint_procedure_service(checkpoint_resp_sockets[0], svc_ctx.cd); + set_pid_checkpointed(svc_ctx.svc_cmd.pid, forkpid); + close(checkpoint_resp_sockets[0]); } else { fprintf(stderr, "%s(): Fork error!\n", __func__); } @@ -2284,25 +2529,80 @@ static int handle_connection(int cd) break; } case MEMCR_RESTORE: { - fprintf(stdout, "[+] got MEMCR_RESTORE for %d.\n", svc_cmd.pid); + fprintf(stdout, "[+] handling MEMCR_RESTORE for %d.\n", svc_ctx.svc_cmd.pid); + restore_procedure_service(svc_ctx.cd, svc_ctx.svc_cmd); + clear_pid_checkpoint_data(svc_ctx.svc_cmd.pid); + break; + } + default: + fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_ctx.svc_cmd.cmd); + break; + } - if (!is_pid_checkpointed(svc_cmd.pid)) { - fprintf(stdout, "[i] Process %d is not checkpointed!\n", svc_cmd.pid); - send_response_to_client(cd, MEMCR_INVALID_PID); - break; - } + close(svc_ctx.cd); + fprintf(stdout, "[+] cmd handled for %d. \n", svc_ctx.svc_cmd.pid); + + goto retry; +} - restore_procedure_service(cd, svc_cmd); +static void service_command(struct service_command_ctx *svc_ctx) +{ + int ret = MEMCR_OK; + switch (svc_ctx->svc_cmd.cmd) + { + case MEMCR_CHECKPOINT: + { + fprintf(stdout, "[+] got MEMCR_CHECKPOINT for %d.\n", svc_ctx->svc_cmd.pid); - clear_pid_checkpointed(svc_cmd.pid); + if (!can_checkpoint_pid(svc_ctx->svc_cmd.pid)) + { + fprintf(stdout, "[i] Process %d is already checkpointed or checkpoint is ongoing!\n", svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_INVALID_PID); + close(svc_ctx->cd); break; } - default: - fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_cmd.cmd); - break; + + init_pid_checkpoint_data(svc_ctx->svc_cmd.pid); + ret = service_cmds_push_back(svc_ctx); + if (!ret) + fprintf(stdout, "[+] Checkpoint request scheduled...\n"); + else { + fprintf(stdout, "[+] Checkpoint request schedule error.\n"); + clear_pid_checkpoint_data(svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + } + break; } + case MEMCR_RESTORE: + { + fprintf(stdout, "[+] got MEMCR_RESTORE for %d.\n", svc_ctx->svc_cmd.pid); - return ret; + if (!can_restore_pid(svc_ctx->svc_cmd.pid)) + { + fprintf(stdout, "[i] Process %d is not checkpointed!\n", svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_INVALID_PID); + close(svc_ctx->cd); + break; + } + + try_to_abort_checkpoint(svc_ctx->svc_cmd.pid); + int ret = service_cmds_push_back(svc_ctx); + if (!ret) + fprintf(stdout, "[+] Restore request scheduled...\n"); + else { + fprintf(stdout, "[+] Restore request schedule error.\n"); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + } + break; + } + default: + fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_ctx->svc_cmd.cmd); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + break; + } } static int service_mode(const char *listen_location) @@ -2314,6 +2614,7 @@ static int service_mode(const char *listen_location) fd_set readfds; struct timeval tv; int errsv; + pthread_t svc_cmd_thread_id; if (listen_port > 0) csd = setup_listen_tcp_socket(listen_port); @@ -2326,6 +2627,12 @@ static int service_mode(const char *listen_location) flags = fcntl(csd, F_GETFL); fcntl(csd, F_SETFL, flags | O_NONBLOCK); + ret = pthread_create(&svc_cmd_thread_id, NULL, service_command_thread, NULL); + if (ret) { + printf("[-] pthread_create() failed: %s\n", strerror(ret)); + goto err; + } + fprintf(stdout, "[x] Waiting for a checkpoint command on a socket\n"); while (!interrupted) { @@ -2348,9 +2655,15 @@ static int service_mode(const char *listen_location) cd = accept(csd, NULL, NULL); if (cd >= 0) { - ret = handle_connection(cd); - close(cd); - fprintf(stdout, "[+] Request handled...\n"); + struct service_command_ctx svc_ctx = { .cd = cd }; + ret = read_command(cd, &svc_ctx.svc_cmd); + if (ret < 0) { + fprintf(stderr, "%s(): Error reading a command!\n", __func__); + close(cd); + continue; + } + + service_command(&svc_ctx); continue; } @@ -2361,6 +2674,11 @@ static int service_mode(const char *listen_location) } } + service_cmds_interrupt(); + pthread_join(svc_cmd_thread_id, NULL); + +err: + close(csd); if (!listen_port) unlink(listen_location); @@ -2370,16 +2688,33 @@ static int service_mode(const char *listen_location) return ret; } +static void* user_abort_thread(void *ptr) +{ + fprintf(stdout, "[x] --> press enter to abort checkpoint <--\n"); + fgetc(stdin); + pthread_mutex_lock(&checkpoint_user_lock); + checkpoint_user_abort = TRUE; + pthread_mutex_unlock(&checkpoint_user_lock); + return NULL; +} + static int user_interactive_mode(pid_t pid) { int ret; + pthread_t user_abort_thread_id; + + ret = pthread_create(&user_abort_thread_id, NULL, user_abort_thread, NULL); + if (ret) { + printf("[-] pthread_create() failed: %s\n", strerror(ret)); + return ret; + } ret = seize_target(pid); if (ret) return ret; ret = execute_parasite_checkpoint(pid); - if (ret) + if(ret) goto out; if (!no_wait) { @@ -2387,9 +2722,14 @@ static int user_interactive_mode(pid_t pid) long h, m, s, ms; struct timespec ts; - fprintf(stdout, "[x] --> press enter to restore process memory and unfreeze <--\n"); + pthread_mutex_lock(&checkpoint_user_lock); + if (!checkpoint_user_abort && !interrupted) + fprintf(stdout, "[x] --> press enter to restore process memory and unfreeze <--\n"); + pthread_mutex_unlock(&checkpoint_user_lock); clock_gettime(CLOCK_MONOTONIC, &ts); - fgetc(stdin); + + pthread_join(user_abort_thread_id, NULL); + dms = diff_ms(&ts); h = dms/1000/60/60; m = (dms/1000/60) % 60; @@ -2486,6 +2826,7 @@ int main(int argc, char *argv[]) break; case 'l': listen_location = optarg; + service = 1; break; case 'n': no_wait = 1;