Skip to content

Commit

Permalink
Improve the pool
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Aug 31, 2024
1 parent ea39f26 commit 66a84d8
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions llamafile/pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,46 +48,48 @@ struct llamafile_thread {
};

static atomic_int g_active;
static _Atomic(llamafile_thread *) g_idle;
static atomic_uintptr_t g_idle;

static void unlock_mutex(void *arg) {
pthread_mutex_t *mu = (pthread_mutex_t *)arg;
pthread_mutex_unlock(mu);
}
#define MASQUE 0x00fffffffffffff0
#define PTR(x) ((uintptr_t)(x) & MASQUE)
#define TAG(x) ROL((uintptr_t)(x) & ~MASQUE, 8)
#define ABA(p, t) ((uintptr_t)(p) | (ROR((uintptr_t)(t), 8) & ~MASQUE))
#define ROL(x, n) (((x) << (n)) | ((x) >> (64 - (n))))
#define ROR(x, n) (((x) >> (n)) | ((x) << (64 - (n))))

static void idle_push(llamafile_thread *thread) {
int backoff = 0;
thread->next = atomic_load_explicit(&g_idle, memory_order_relaxed);
while (!atomic_compare_exchange_weak_explicit(&g_idle, &thread->next, thread,
memory_order_acq_rel, memory_order_relaxed))
backoff = pthread_delay_np(&g_idle, backoff);
uintptr_t tip;
unassert(!TAG(thread));
tip = atomic_load_explicit(&g_idle, memory_order_relaxed);
for (;;) {
thread->next = (llamafile_thread *)PTR(tip);
if (atomic_compare_exchange_weak_explicit(&g_idle, &tip, ABA(thread, TAG(tip) + 1),
memory_order_release, memory_order_relaxed))
break;
}
}

static llamafile_thread *idle_pop(void) {
int backoff = 0;
uintptr_t tip;
llamafile_thread *thread;
for (;;) {
if ((thread = atomic_load_explicit(&g_idle, memory_order_relaxed))) {
if (atomic_compare_exchange_weak_explicit(&g_idle, &thread, thread->next,
memory_order_acq_rel, memory_order_relaxed))
return thread;
backoff = pthread_delay_np(g_idle, backoff);
} else {
return nullptr;
}
}
tip = atomic_load_explicit(&g_idle, memory_order_relaxed);
while ((thread = (llamafile_thread *)PTR(tip)))
if (atomic_compare_exchange_weak_explicit(&g_idle, &tip, ABA(thread->next, TAG(tip) + 1),
memory_order_acquire, memory_order_relaxed))
break;
return thread;
}

static void cancel_task(llamafile_task *task) {
pthread_mutex_lock(&task->mu);
task->res = PTHREAD_CANCELED;
task->th = 0;
atomic_store_explicit(&task->th, 0, memory_order_release);
pthread_cond_signal(&task->cv);
pthread_mutex_unlock(&task->mu);
}

static void llamafile_thread_canceled(llamafile_thread *thread) {
thread->th = 0;
atomic_store_explicit(&thread->th, 0, memory_order_release);
cancel_task(thread->task);
delete thread;
--g_active;
Expand All @@ -103,16 +105,14 @@ static void *llamafile_thread_worker(void *arg) {
void *res = thread->task->func(thread->task->arg);
pthread_setcancelstate(PTHREAD_CANCEL_MASKED, 0);

for (;;) {
if (thread->th != -1)
if (thread->task->th != -1)
for (;;)
if (atomic_load_explicit(&thread->th, memory_order_acquire) != -1)
if (atomic_load_explicit(&thread->task->th, memory_order_acquire) != -1)
break;
pthread_pause_np();
}

pthread_mutex_lock(&thread->task->mu);
thread->task->res = res;
thread->task->th = 0;
atomic_store_explicit(&thread->task->th, 0, memory_order_release);
pthread_cond_signal(&thread->task->cv);
pthread_mutex_unlock(&thread->task->mu);

Expand All @@ -131,7 +131,7 @@ static void *llamafile_thread_worker(void *arg) {
if (thread->task)
cancel_task(thread->task);

thread->th = 0;
atomic_store_explicit(&thread->th, 0, memory_order_release);
g_key.set(nullptr);
delete thread;
--g_active;
Expand All @@ -151,7 +151,8 @@ static errno_t llamafile_thread_create(llamafile_task *task) {
errno_t err = pthread_create((pthread_t *)&thread->th, &attr, llamafile_thread_worker, thread);
pthread_attr_destroy(&attr);
if (!err) {
task->th = thread->th.load();
atomic_store_explicit(&task->th, atomic_load_explicit(&thread->th, memory_order_relaxed),
memory_order_release);
} else {
delete thread;
}
Expand All @@ -166,8 +167,9 @@ errno_t llamafile_task_create(llamafile_task **out_task, void *(*func)(void *),
llamafile_thread *thread;
if ((thread = idle_pop())) {
pthread_mutex_lock(&thread->mu);
atomic_store_explicit(&task->th, atomic_load_explicit(&thread->th, memory_order_relaxed),
memory_order_release);
thread->task = task;
task->th = thread->th.load();
pthread_cond_signal(&thread->cv);
pthread_mutex_unlock(&thread->mu);
err = 0;
Expand All @@ -182,10 +184,15 @@ errno_t llamafile_task_create(llamafile_task **out_task, void *(*func)(void *),
return err;
}

static void unlock_mutex(void *arg) {
pthread_mutex_t *mu = (pthread_mutex_t *)arg;
pthread_mutex_unlock(mu);
}

errno_t llamafile_task_join(llamafile_task *task, void **out_res) {
pthread_cleanup_push(unlock_mutex, &task->mu);
pthread_mutex_lock(&task->mu);
while (task->th)
while (atomic_load_explicit(&task->th, memory_order_acquire))
pthread_cond_wait(&task->cv, &task->mu);
pthread_cleanup_pop(true);
if (out_res)
Expand All @@ -195,9 +202,13 @@ errno_t llamafile_task_join(llamafile_task *task, void **out_res) {
}

errno_t llamafile_task_cancel(llamafile_task *task) {
pthread_t th;
errno_t err = 0;
if (task->th)
err = pthread_cancel(task->th);
if ((th = atomic_load_explicit(&task->th, memory_order_acquire))) {
err = pthread_cancel(th);
} else {
err = ESRCH;
}
return err;
}

Expand All @@ -207,7 +218,7 @@ void llamafile_task_shutdown(void) {
llamafile_thread *thread;
for (;;) {
while ((thread = idle_pop()))
if ((th = thread->th))
if ((th = atomic_load_explicit(&thread->th, memory_order_acquire)))
pthread_cancel(th);
if (!g_active)
break;
Expand Down

0 comments on commit 66a84d8

Please sign in to comment.