Skip to content

Commit

Permalink
added help/usage to apps and new makefile targets. (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkeegan authored Dec 9, 2024
1 parent 2b480e8 commit 0975af8
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 42 deletions.
13 changes: 10 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
CXX = g++
CXXFLAGS = -std=c++11 -Werror -O3 -march=native -mtune=native
CXXFLAGS = -std=c++11 -Werror -O3 -march=native -mtune=native -Wformat -Werror=format-security

# Conditional settings for Windows
ifeq ($(OS),Windows_NT)
LIBS = -lws2_32 # or -lpthreadGC2 if needed
DELETECMD = del /f
else
LIBS = -lpthread
DELETECMD = rm -fv
endif

.PHONY: all apps clean tests

apps: dllama dllama-api socket-benchmark
tests: funcs-test quants-test tokenizer-test commands-test llama2-tasks-test
all: apps tests
clean:
$(DELETECMD) *.o dllama dllama-* socket-benchmark mmap-buffer-* *-test *.exe
utils: src/utils.cpp
$(CXX) $(CXXFLAGS) -c src/utils.cpp -o utils.o
quants: src/quants.cpp
$(CXX) $(CXXFLAGS) -c src/quants.cpp -o quants.o
funcs: src/funcs.cpp
$(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o
funcs-test: src/funcs-test.cpp funcs
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o
commands: src/commands.cpp
$(CXX) $(CXXFLAGS) -c src/commands.cpp -o commands.o
socket: src/socket.cpp
Expand Down
30 changes: 20 additions & 10 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@ FloatType parseFloatType(char* val) {
if (strcmp(val, "f16") == 0) return F16;
if (strcmp(val, "q40") == 0) return Q40;
if (strcmp(val, "q80") == 0) return Q80;
printf("Invalid float type %s\n", val);
exit(EXIT_FAILURE);
std::string errMsg = "Invalid float type '" + std::string(val) + "'";
throw BadArgumentException(errMsg);
}

ChatTemplateType parseChatTemplateType(char* val) {
if (strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
if (strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
if (strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR;
if (strcmp(val, "chatml") == 0) return TEMPLATE_CHATML;
throw std::runtime_error("Invalid chat template type");

std::string errMsg = "Invalid chat template type '" + std::string(val) + "'";
throw BadArgumentException(errMsg);
}

AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
AppArgs args;
args.help = false;
args.mode = NULL;
args.nThreads = 4;
args.modelPath = NULL;
Expand All @@ -48,6 +49,15 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.mode = argv[1];
i++;
}
// First see if any of the args are asking for help/usage and fail fast
for (int x = 0; x < argc; x++) {
if ((strcmp(argv[x], "--usage") == 0) ||
(strcmp(argv[x], "--help") == 0) ||
(strcmp(argv[x], "-h") == 0)) {
args.help = true;
return args;
}
}
for (; i + 1 < argc; i += 2) {
char* name = argv[i];
char* value = argv[i + 1];
Expand All @@ -74,8 +84,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
char* v = argv[i + 1 + s];
char* sep = strstr(v, ":");
if (sep == NULL) {
printf("Invalid address %s\n", v);
exit(EXIT_FAILURE);
std::string errMsg = "Invalid worker address '" + std::string(v) + "'";
throw BadArgumentException(errMsg);
}
int hostLen = sep - v;
args.workerHosts[s] = new char[hostLen + 1];
Expand Down Expand Up @@ -104,8 +114,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
} else if (strcmp(name, "--packet-alignment") == 0) {
args.packetAlignment = (size_t)atoi(value);
} else {
printf("Unknown option %s\n", name);
exit(EXIT_FAILURE);
std::string errMsg = "Unknown option '" + std::string(name) + "'";
throw BadArgumentException(errMsg);
}
}
return args;
Expand All @@ -119,10 +129,10 @@ TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {

void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)) {
if (args->modelPath == NULL) {
throw std::runtime_error("Model is required");
throw BadArgumentException("Model is required");
}
if (args->tokenizerPath == NULL) {
throw std::runtime_error("Tokenizer is required");
throw BadArgumentException("Tokenizer is required");
}

SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts, args->packetAlignment);
Expand Down
8 changes: 8 additions & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include "llama2-tasks.hpp"
#include "mixtral-tasks.hpp"
#include "tokenizer.hpp"
#include <stdexcept>
#include <string>

class AppArgs {
public:
char* mode;
int nThreads;
size_t packetAlignment;
bool help;

// inference
char* modelPath;
Expand Down Expand Up @@ -51,4 +54,9 @@ class App {
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec));
};

class BadArgumentException : public std::runtime_error {
public:
explicit BadArgumentException(const std::string& message) : std::runtime_error(message) {}
};

#endif
41 changes: 39 additions & 2 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,49 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
closeServerSocket(serverSocket);
}

#ifdef _WIN32
#define EXECUTABLE_NAME "dllama-api.exe"
#else
#define EXECUTABLE_NAME "dllama-api"
#endif

void usage() {
fprintf(stderr, "Usage: %s {--model <path>} {--tokenizer <path>} [--port <p>]\n", EXECUTABLE_NAME);
fprintf(stderr, " [--buffer-float-type {f32|f16|q40|q80}]\n");
fprintf(stderr, " [--weights-float-type {f32|f16|q40|q80}]\n");
fprintf(stderr, " [--max-seq-len <max>]\n");
fprintf(stderr, " [--nthreads <n>]\n");
fprintf(stderr, " [--workers <ip:port> ...]\n");
fprintf(stderr, " [--packet-alignment <pa>]\n");
fprintf(stderr, " [--temperature <temp>]\n");
fprintf(stderr, " [--topp <t>]\n");
fprintf(stderr, " [--seed <s>]\n");
fprintf(stderr, "Example:\n");
fprintf(stderr, " sudo nice -n -20 ./dllama-api --port 9990 --nthreads 4 \\\n");
fprintf(stderr, " --model dllama_model_llama3_2_3b_instruct_q40.m \\\n");
fprintf(stderr, " --tokenizer dllama_tokenizer_llama3_2_3b_instruct_q40.t \\\n");
fprintf(stderr, " --buffer-float-type q80 --max-seq-len 8192 \\\n");
fprintf(stderr, " --workers 10.0.0.2:9998 10.0.0.3:9998 10.0.0.4:9998\n");
fflush(stderr);
}

int main(int argc, char *argv[]) {
initQuants();
initSockets();

AppArgs args = AppArgs::parse(argc, argv, false);
App::run(&args, server);
try {
AppArgs args = AppArgs::parse(argc, argv, false);
if (args.help) {
usage();
return EXIT_SUCCESS;
}
App::run(&args, server);
} catch (const BadArgumentException& e) {
fprintf(stderr, "%s\n\n", e.what());
usage();
cleanupSockets();
return EXIT_FAILURE;
}

cleanupSockets();
return EXIT_SUCCESS;
Expand Down
169 changes: 151 additions & 18 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stdexcept>
#include <sstream>
#include <string>
#include <unordered_map>

#include "../../utils.hpp"
#include "../../socket.hpp"
Expand All @@ -16,7 +17,7 @@

void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
if (args->prompt == NULL)
throw std::runtime_error("Prompt is required");
throw BadArgumentException("Prompt is required");

// encode the (string) prompt into tokens sequence
int numPromptTokens = 0;
Expand Down Expand Up @@ -220,35 +221,167 @@ void worker(AppArgs* args) {
delete socketPool;
}

#ifdef _WIN32
#define EXECUTABLE_NAME "dllama.exe"
#else
#define EXECUTABLE_NAME "dllama"
#endif

bool isValidMode(const char *mode) {
if (mode == NULL) {
return false;
} else {
return (strcmp(mode, "generate") == 0) ||
(strcmp(mode, "inference") == 0) ||
(strcmp(mode, "chat") == 0) ||
(strcmp(mode, "worker") == 0);
}
}

std::unordered_map<std::string, std::string> examples = {
{"generate", ""},
{"inference", R"( sudo nice -n -20 ./dllama inference \
--prompt "Super briefly describe the 80s - ten words."\
--steps 32 --seed 12345 \
--model dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 --nthreads 4 --max-seq-len 8192 \
--workers 10.0.0.2:9998 10.0.0.3:9998 10.0.0.4:9998
)"},
{"chat", R"( sudo nice -n -20 ./dllama chat \
--model dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 --nthreads 4 --max-seq-len 8192 \
--workers 10.0.0.2:9998 10.0.0.3:9998 10.0.0.4:9998
)"},
{"worker", R"( sudo nice -n -20 ./dllama worker --port 9998 --nthreads 4
)"}};

std::string inference_and_generate_usage_string =
R"( {inference|generate} {--model <path>} {--tokenizer <path>}
{--prompt <p>}
[--steps <s>]
[--buffer-float-type {f32|f16|q40|q80}]
[--weights-float-type {f32|f16|q40|q80}]
[--max-seq-len <max>]
[--nthreads <n>]
[--workers <ip:port> ...]
[--packet-alignment <pa>]
[--temperature <temp>]
[--topp <t>]
[--seed <s>]
)";

std::unordered_map<std::string, std::string> usageText = {
{"generate", inference_and_generate_usage_string},
{"inference", inference_and_generate_usage_string},
{"chat", R"( chat {--model <path>} {--tokenizer <path>}
[--buffer-float-type {f32|f16|q40|q80}]
[--weights-float-type {f32|f16|q40|q80}]
[--max-seq-len <max>]
[--nthreads <n>]
[--workers <ip:port> ...]
[--packet-alignment <pa>]
[--temperature <temp>]
[--topp <t>]
[--seed <s>]
[--chat-template {llama2|llama3|zephyr|chatml}]
)"},
{"worker", R"( worker [--nthreads <n>] [--port <p>]
)"}};

#define MULTIPLE_USAGES_PREFIX " "
#define SOLO_USAGE_PREFIX "Usage: "

void usage(const char *mode, bool solo=true) {
if (!isValidMode(mode)) {
fprintf(stderr, "Usage: %s {inference | generate | chat | worker} {ARGS}\n", EXECUTABLE_NAME);
usage("inference", false);
usage("chat", false);
usage("worker", false);
fprintf(stderr, "Examples:\n");
fprintf(stderr, "%s", examples["worker"].c_str());
fprintf(stderr, "%s", examples["chat"].c_str());
fprintf(stderr, "%s", examples["inference"].c_str());
} else {
fprintf(stderr, "%s%s%s",
solo ? SOLO_USAGE_PREFIX : MULTIPLE_USAGES_PREFIX,
EXECUTABLE_NAME,
usageText[mode].c_str());
if (solo && (!examples[mode].empty())) {
fprintf(stderr, "Example:\n");
fprintf(stderr, "%s", examples[mode].c_str());
}
}
fflush(stderr);
}

void usage() {
usage(NULL);
}

int main(int argc, char *argv[]) {
initQuants();
initSockets();

AppArgs args = AppArgs::parse(argc, argv, true);
bool success = false;

if (args.mode != NULL) {
if (strcmp(args.mode, "inference") == 0) {
args.benchmark = true;
App::run(&args, generate);
success = true;
} else if (strcmp(args.mode, "generate") == 0) {
args.benchmark = false;
App::run(&args, generate);
success = true;
} else if (strcmp(args.mode, "chat") == 0) {
App::run(&args, chat);
success = true;
} else if (strcmp(args.mode, "worker") == 0) {
worker(&args);
success = true;
try {
AppArgs args = AppArgs::parse(argc, argv, true);
if (args.help) {
if ((args.mode == NULL) ||
(strcmp(args.mode, "--usage") == 0) ||
(strcmp(args.mode, "--help") == 0) ||
(strcmp(args.mode, "-h") == 0)) {
usage();
} else if (isValidMode(args.mode)) {
usage(args.mode);
} else {
usage();
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

if (args.mode != NULL) {
if (strcmp(args.mode, "inference") == 0) {
if (args.prompt == NULL) {
throw BadArgumentException("Prompt is required");
}
args.benchmark = true;
App::run(&args, generate);
success = true;
} else if (strcmp(args.mode, "generate") == 0) {
if (args.prompt == NULL) {
throw BadArgumentException("Prompt is required");
}
args.benchmark = false;
App::run(&args, generate);
success = true;
} else if (strcmp(args.mode, "chat") == 0) {
App::run(&args, chat);
success = true;
} else if (strcmp(args.mode, "worker") == 0) {
worker(&args);
success = true;
}
}
} catch (const BadArgumentException& e) {
fprintf(stderr, "%s\n\n", e.what());
if ((argc > 1) && isValidMode(argv[1])) {
usage(argv[1]);
} else {
usage();
}
cleanupSockets();
return EXIT_FAILURE;
}

cleanupSockets();

if (success)
return EXIT_SUCCESS;
fprintf(stderr, "Invalid usage\n");
fprintf(stderr, "Invalid usage\n\n");
usage();
return EXIT_FAILURE;
}
Loading

0 comments on commit 0975af8

Please sign in to comment.