diff --git a/README.md b/README.md index f74852a..06e4c79 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ memcr -p ``` For the list of available options, check memcr help: ``` -memcr [-h] [-p PID] [-d DIR] [-S DIR] [-l PORT|PATH] [-n] [-m] [-f] [-z] [-c] [-e] +memcr [-h] [-p PID] [-d DIR] [-S DIR] [-l PORT|PATH] [-n] [-m] [-f] [-z] [-c] [-e] [-a] options: -h --help help -p --pid target processs pid @@ -67,6 +67,7 @@ options: -z --compress compress memory dump -c --checksum enable md5 checksum for memory dump -e --encrypt enable encryption of memory dump + -a --abort-checkpoint allow checkpoint to be aborted ``` memcr also supports client / server scenario where memcr runs as a deamon and listens for commands from a client process. The main reason for supporting this is that memcr needs rather high privileges to hijack target process and it's a good idea to keep it separate from memcr-client that can run in a container with low privileges. diff --git a/memcr.c b/memcr.c index c292ebd..18ca9c8 100644 --- a/memcr.c +++ b/memcr.c @@ -117,6 +117,8 @@ static int proc_mem; static int rss_file; static int compress; static int checksum; +static int abort_checkpoint; + #define BIT(x) (1ULL << x) @@ -177,8 +179,94 @@ int __attribute__((weak)) lib__fini(void); #define CHECKPOINTED_PIDS_LIMIT 16 #define PID_INVALID 0 -static pid_t checkpointed_pids[CHECKPOINTED_PIDS_LIMIT]; -static pid_t checkpoint_workers[CHECKPOINTED_PIDS_LIMIT]; +#define STATE_RESTORED 0 +#define STATE_CHECKPOINTING 1 +#define STATE_CHECKPOINTED 2 + +pthread_mutex_t checkpoint_service_data_lock = PTHREAD_MUTEX_INITIALIZER; +static struct { + pid_t pid; + pid_t worker; + int state; + int checkpoint_interrupt; + int checkpoint_cmd_sd; +} checkpoint_service_data[CHECKPOINTED_PIDS_LIMIT]; + +#define SOCKET_INVALID (-1) +static int checkpoint_service_socket = SOCKET_INVALID; + +#define TRUE 1 +#define FALSE 0 + +#define MAX_CLIENT_CONNECTIONS 8 + +struct service_command_ctx { + struct service_command svc_cmd; + int cd; +}; + +static struct { + pthread_mutex_t lock; + pthread_cond_t cond; + struct service_command_ctx svc_ctxs[MAX_CLIENT_CONNECTIONS]; + int front_idx; + int back_idx; + size_t size; + int interrupt; +} service_cmds_ctx = { .lock = PTHREAD_MUTEX_INITIALIZER, .cond = PTHREAD_COND_INITIALIZER }; + +static int service_cmds_push_back(struct service_command_ctx *ctx) +{ + int ret = 0; + pthread_mutex_lock(&service_cmds_ctx.lock); + if (service_cmds_ctx.size >= MAX_CLIENT_CONNECTIONS) { + fprintf(stderr, "[-] %s: Commands queue full\n", __func__); + ret = 1; + goto err; + } + + service_cmds_ctx.svc_ctxs[service_cmds_ctx.back_idx] = *ctx; + service_cmds_ctx.back_idx++; + service_cmds_ctx.size++; + if(service_cmds_ctx.back_idx >= MAX_CLIENT_CONNECTIONS) + service_cmds_ctx.back_idx = 0; + + pthread_cond_signal(&service_cmds_ctx.cond); + +err: + pthread_mutex_unlock(&service_cmds_ctx.lock); + return ret; +} + +static int service_cmds_wait_and_pop_front(struct service_command_ctx *ctx) +{ + int ret = 0; + pthread_mutex_lock(&service_cmds_ctx.lock); + while (service_cmds_ctx.size == 0 && service_cmds_ctx.interrupt == FALSE && ret == 0) + ret = pthread_cond_wait(&service_cmds_ctx.cond, &service_cmds_ctx.lock); + + if (!ret && service_cmds_ctx.size > 0 && service_cmds_ctx.interrupt == FALSE) { + *ctx = service_cmds_ctx.svc_ctxs[service_cmds_ctx.front_idx]; + service_cmds_ctx.front_idx++; + service_cmds_ctx.size--; + if(service_cmds_ctx.front_idx >= MAX_CLIENT_CONNECTIONS) + service_cmds_ctx.front_idx = 0; + } else if (service_cmds_ctx.interrupt == TRUE) + ret = 1; + else + fprintf(stderr, "[-] %s: pthread_cond_wait(): %m\n", __func__); + + pthread_mutex_unlock(&service_cmds_ctx.lock); + return ret; +} + +static void service_cmds_interrupt() +{ + pthread_mutex_lock(&service_cmds_ctx.lock); + service_cmds_ctx.interrupt = TRUE; + pthread_cond_signal(&service_cmds_ctx.cond); + pthread_mutex_unlock(&service_cmds_ctx.lock); +} #ifdef CHECKSUM_MD5 #if OPENSSL_VERSION_NUMBER >= 0x30000000L @@ -338,66 +426,6 @@ static void cleanup_pid(pid_t pid) unlink(path); } -static void cleanup_checkpointed_pids(void) -{ - fprintf(stdout, "[i] Terminating checkpointed processes\n"); - for (int i=0; i 0) @@ -721,6 +756,174 @@ static int dump_write(int fd, const void *buf, size_t count) return ret; } +static void init_pid_checkpoint_data(pid_t pid) +{ + pthread_mutex_lock(&checkpoint_service_data_lock); + for (int i=0; iaddr & (PAGE_SIZE - 1)) { @@ -838,7 +1041,7 @@ static int setup_listen_socket(struct sockaddr *addr, socklen_t addrlen) goto err; } - ret = listen(sd, 8); + ret = listen(sd, MAX_CLIENT_CONNECTIONS); if (ret) { fprintf(stderr, "listen() failed: %m\n"); goto err; @@ -1265,7 +1468,7 @@ static int get_vma_pages(int pd, int md, int cd, struct vm_area *vma, int fd) return 0; } -static int get_target_pages(int pid, struct vm_area vmas[], int nr_vmas) +static int get_target_pages(int pid, struct vm_area vmas[], int nr_vmas, int (*is_aborted)(void)) { int ret = -1; char path[PATH_MAX]; @@ -1308,8 +1511,12 @@ static int get_target_pages(int pid, struct vm_area vmas[], int nr_vmas) ret = 0; for (idx = 0; idx < nr_vmas; idx++) { - struct vm_area *vma = &vmas[idx]; + if (is_aborted && is_aborted()) { + fprintf(stdout, "[i] get target pages aborted\n"); + break; + } + struct vm_area *vma = &vmas[idx]; ret = get_vma_pages(pd, md, cd, vma, fd); if (ret) break; @@ -1471,7 +1678,7 @@ static long diff_ms(struct timespec *ts) return (tsn.tv_sec*1000 + tsn.tv_nsec/1000000) - (ts->tv_sec*1000 + ts->tv_nsec/1000000); } -static int cmd_checkpoint(pid_t pid) +static int cmd_checkpoint(pid_t pid, int (*is_aborted)(void)) { int ret; struct vm_stats vms_a, vms_b; @@ -1502,7 +1709,7 @@ static int cmd_checkpoint(pid_t pid) fprintf(stdout, "[+] downloading pages\n"); clock_gettime(CLOCK_MONOTONIC, &ts); - ret = get_target_pages(pid, vmas, nr_vmas); + ret = get_target_pages(pid, vmas, nr_vmas, is_aborted); #ifdef CHECKSUM_MD5 if (checksum) @@ -1887,7 +2094,7 @@ static int ctx_restore(pid_t pid) return 0; } -static int execute_parasite_checkpoint(pid_t pid) +static int execute_parasite_checkpoint(pid_t pid, int (*is_aborted)(void)) { unsigned long ret; @@ -1920,7 +2127,7 @@ static int execute_parasite_checkpoint(pid_t pid) parasite_watch(parasite_pid); - ret = cmd_checkpoint(pid); + ret = cmd_checkpoint(pid, is_aborted); return ret; } @@ -1962,7 +2169,7 @@ static void sigchld_handler_service (int sig, siginfo_t *sip, void *notused) int status; if (sip->si_pid == waitpid(sip->si_pid, &status, WNOHANG)) { fprintf(stdout, "[+] Worker %d exit.\n", sip->si_pid); - clear_pid_on_worker_exit(sip->si_pid); + clear_pid_on_worker_exit_non_blocking(sip->si_pid); } } @@ -2084,7 +2291,7 @@ static int setup_restore_socket_service(pid_t pid) return rd; } -static int checkpoint_worker(pid_t pid) +static int checkpoint_worker(pid_t pid, int (*is_aborted)(void)) { int ret; @@ -2092,7 +2299,7 @@ static int checkpoint_worker(pid_t pid) if (ret) return ret; - ret = execute_parasite_checkpoint(pid); + ret = execute_parasite_checkpoint(pid, is_aborted); if (ret) { fprintf(stderr, "[%d] Parasite checkpoint failed! Killing the target app...\n", getpid()); kill(pid, SIGKILL); @@ -2136,10 +2343,13 @@ static int application_worker(pid_t pid, int checkpoint_resp_socket) if (rsd < 0) ret |= rsd; + register_socket_for_checkpoint_service_cmds(checkpoint_resp_socket); + if (0 == ret) { - ret |= checkpoint_worker(pid); + ret |= checkpoint_worker(pid, is_checkpoint_aborted_srv_mode); } ret |= send_response_to_service(checkpoint_resp_socket, ret); // send resp to service + clear_socket_for_checkpoint_service_cmds(); close(checkpoint_resp_socket); if (ret) { @@ -2166,6 +2376,33 @@ static int application_worker(pid_t pid, int checkpoint_resp_socket) return ret; } +static void try_to_interrupt_checkpoint(pid_t pid) +{ + pthread_mutex_lock(&checkpoint_service_data_lock); + for (int i=0; i 0) { - set_pid_checkpointed(svc_cmd.pid, forkpid); close(checkpoint_resp_sockets[1]); - - checkpoint_procedure_service(checkpoint_resp_sockets[0], cd); + set_pid_checkpointing(svc_ctx.svc_cmd.pid, checkpoint_resp_sockets[0]); + checkpoint_procedure_service(checkpoint_resp_sockets[0], svc_ctx.cd); + set_pid_checkpointed(svc_ctx.svc_cmd.pid, forkpid); + close(checkpoint_resp_sockets[0]); } else { fprintf(stderr, "%s(): Fork error!\n", __func__); } @@ -2284,25 +2516,81 @@ static int handle_connection(int cd) break; } case MEMCR_RESTORE: { - fprintf(stdout, "[+] got MEMCR_RESTORE for %d.\n", svc_cmd.pid); + fprintf(stdout, "[+] handling MEMCR_RESTORE for %d.\n", svc_ctx.svc_cmd.pid); + restore_procedure_service(svc_ctx.cd, svc_ctx.svc_cmd); + clear_pid_checkpoint_data(svc_ctx.svc_cmd.pid); + break; + } + default: + fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_ctx.svc_cmd.cmd); + break; + } - if (!is_pid_checkpointed(svc_cmd.pid)) { - fprintf(stdout, "[i] Process %d is not checkpointed!\n", svc_cmd.pid); - send_response_to_client(cd, MEMCR_INVALID_PID); - break; - } + close(svc_ctx.cd); + fprintf(stdout, "[+] cmd handled for %d. \n", svc_ctx.svc_cmd.pid); - restore_procedure_service(cd, svc_cmd); + goto retry; +} - clear_pid_checkpointed(svc_cmd.pid); +static void service_command(struct service_command_ctx *svc_ctx) +{ + int ret = MEMCR_OK; + switch (svc_ctx->svc_cmd.cmd) + { + case MEMCR_CHECKPOINT: + { + fprintf(stdout, "[+] got MEMCR_CHECKPOINT for %d.\n", svc_ctx->svc_cmd.pid); + + if (!can_checkpoint_pid(svc_ctx->svc_cmd.pid)) + { + fprintf(stdout, "[i] Process %d is already checkpointed or checkpoint is ongoing!\n", svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_INVALID_PID); + close(svc_ctx->cd); break; } - default: - fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_cmd.cmd); - break; + + init_pid_checkpoint_data(svc_ctx->svc_cmd.pid); + ret = service_cmds_push_back(svc_ctx); + if (!ret) + fprintf(stdout, "[+] Checkpoint request scheduled...\n"); + else { + fprintf(stdout, "[+] Checkpoint request schedule error.\n"); + clear_pid_checkpoint_data(svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + } + break; } + case MEMCR_RESTORE: + { + fprintf(stdout, "[+] got MEMCR_RESTORE for %d.\n", svc_ctx->svc_cmd.pid); - return ret; + if (!can_restore_pid(svc_ctx->svc_cmd.pid)) + { + fprintf(stdout, "[i] Process %d is not checkpointed or restore is ongoing!\n", svc_ctx->svc_cmd.pid); + send_response_to_client(svc_ctx->cd, MEMCR_INVALID_PID); + close(svc_ctx->cd); + break; + } + + if (abort_checkpoint) + try_to_interrupt_checkpoint(svc_ctx->svc_cmd.pid); + int ret = service_cmds_push_back(svc_ctx); + if (!ret) + fprintf(stdout, "[+] Restore request scheduled...\n"); + else { + fprintf(stdout, "[+] Restore request schedule error.\n"); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + } + break; + } + default: + fprintf(stderr, "%s() unexpected command %d\n", __func__, svc_ctx->svc_cmd.cmd); + send_response_to_client(svc_ctx->cd, MEMCR_ERROR_GENERAL); + close(svc_ctx->cd); + break; + } } static int service_mode(const char *listen_location) @@ -2314,6 +2602,7 @@ static int service_mode(const char *listen_location) fd_set readfds; struct timeval tv; int errsv; + pthread_t svc_cmd_thread_id; if (listen_port > 0) csd = setup_listen_tcp_socket(listen_port); @@ -2326,6 +2615,12 @@ static int service_mode(const char *listen_location) flags = fcntl(csd, F_GETFL); fcntl(csd, F_SETFL, flags | O_NONBLOCK); + ret = pthread_create(&svc_cmd_thread_id, NULL, service_command_thread, NULL); + if (ret) { + printf("[-] pthread_create() failed: %s\n", strerror(ret)); + goto err; + } + fprintf(stdout, "[x] Waiting for a checkpoint command on a socket\n"); while (!interrupted) { @@ -2348,9 +2643,15 @@ static int service_mode(const char *listen_location) cd = accept(csd, NULL, NULL); if (cd >= 0) { - ret = handle_connection(cd); - close(cd); - fprintf(stdout, "[+] Request handled...\n"); + struct service_command_ctx svc_ctx = { .cd = cd }; + ret = read_command(cd, &svc_ctx.svc_cmd); + if (ret < 0) { + fprintf(stderr, "%s(): Error reading a command!\n", __func__); + close(cd); + continue; + } + + service_command(&svc_ctx); continue; } @@ -2361,6 +2662,11 @@ static int service_mode(const char *listen_location) } } + service_cmds_interrupt(); + pthread_join(svc_cmd_thread_id, NULL); + +err: + close(csd); if (!listen_port) unlink(listen_location); @@ -2370,16 +2676,38 @@ static int service_mode(const char *listen_location) return ret; } +static int is_checkpoint_aborted_ui_mode() +{ + return interrupted; +} + +static void* user_abort_thread(void *ptr) +{ + fprintf(stdout, "[x] --> press enter to abort checkpoint <--\n"); + fgetc(stdin); + interrupted = TRUE; + return NULL; +} + static int user_interactive_mode(pid_t pid) { int ret; + pthread_t user_abort_thread_id; + + if (abort_checkpoint) { + ret = pthread_create(&user_abort_thread_id, NULL, user_abort_thread, NULL); + if (ret) { + printf("[-] pthread_create() failed: %s\n", strerror(ret)); + return ret; + } + } ret = seize_target(pid); if (ret) return ret; - ret = execute_parasite_checkpoint(pid); - if (ret) + ret = execute_parasite_checkpoint(pid, is_checkpoint_aborted_ui_mode); + if(ret) goto out; if (!no_wait) { @@ -2387,9 +2715,15 @@ static int user_interactive_mode(pid_t pid) long h, m, s, ms; struct timespec ts; - fprintf(stdout, "[x] --> press enter to restore process memory and unfreeze <--\n"); + if (!interrupted) + fprintf(stdout, "[x] --> press enter to restore process memory and unfreeze <--\n"); clock_gettime(CLOCK_MONOTONIC, &ts); - fgetc(stdin); + + if (abort_checkpoint) + pthread_join(user_abort_thread_id, NULL); + else + fgetc(stdin); + dms = diff_ms(&ts); h = dms/1000/60/60; m = (dms/1000/60) % 60; @@ -2425,7 +2759,8 @@ static void usage(const char *name, int status) " -f --rss-file include file mapped memory\n" \ " -z --compress compress memory dump\n" \ " -c --checksum enable md5 checksum for memory dump\n" \ - " -e --encrypt enable encryption of memory dump\n", + " -e --encrypt enable encryption of memory dump\n" \ + " -a --abort-checkpoint allow checkpoint to be aborted\n", name); exit(status); @@ -2463,6 +2798,7 @@ int main(int argc, char *argv[]) { "rss-file", 0, NULL, 'f'}, { "compress", 0, NULL, 'z'}, { "checksum", 0, NULL, 'c'}, + { "abort-checkpoint", 0, NULL, 'a'}, { "encrypt", 2, 0, 'e'}, { NULL, 0, NULL, 0} }; @@ -2470,7 +2806,7 @@ int main(int argc, char *argv[]) dump_dir = "/tmp"; parasite_socket_dir = NULL; - while ((opt = getopt_long(argc, argv, "hp:d:S:l:nmfzce::", long_options, &option_index)) != -1) { + while ((opt = getopt_long(argc, argv, "hp:d:S:l:nmfzcae::", long_options, &option_index)) != -1) { switch (opt) { case 'h': usage(argv[0], 0); @@ -2515,6 +2851,9 @@ int main(int argc, char *argv[]) else if (optind < argc && argv[optind][0] != '-') encrypt_arg = argv[optind++]; break; + case 'a': + abort_checkpoint = 1; + break; default: /* '?' */ usage(argv[0], 1); }