Skip to content

Commit

Permalink
feat: generatePublicKey from base64 string
Browse files Browse the repository at this point in the history
`Wireguard:generatePublicKey` now accepts the privateKey as a base64 encoded string, which gets decoded internally to produce the raw bytes.
  • Loading branch information
CompeyDev committed Mar 31, 2024
1 parent a50adb2 commit 666a5f7
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 116 deletions.
111 changes: 52 additions & 59 deletions src/base64.ts
Original file line number Diff line number Diff line change
@@ -1,59 +1,13 @@
const { slice } = require<{
slice: <T extends defined>(arr: T[], start: number, stop?: number) => T[];
const { toBinary, getCharAt } = require<{
toBinary: (int: number) => string;
getCharAt: (str: string, pos: number) => string;
}>("./util.lua");

function stringToBytes(str: string) {
const result = [];

for (let i = 0; i < str.size(); i++) {
result.push(string.byte(str, i + 1)[0]);
}

return result;
}

// Adapted from https://github.com/un-ts/ab64/blob/main/src/ponyfill.ts#L24
const _atob = (asc: string) => {
const b64CharList = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";

const b64Chars = string.split(b64CharList, "");

const b64Table = b64Chars.reduce<Record<string, number>>((acc, char, index) => {
acc[char] = index;
return acc;
}, {});

const fromCharCode = string.char;

asc = string.gsub(asc, "%s+", "")[0];
asc += string.char(...slice(stringToBytes("=="), 2 - (asc.size() & 3)));

let u24: number;
let binary = "";
let r1: number;
let r2: number;

for (let i = 0; i < asc.size(); i++) {
u24 =
(b64Table[string.byte(asc, i++)[0]] << 18) |
(b64Table[string.byte(asc, i++)[0]] << 12) |
((r1 = b64Table[string.byte(asc, i++)[0]]) << 6) |
(r2 = b64Table[string.byte(asc, i++)[0]]);
binary +=
r1 === 64
? fromCharCode((u24 >> 16) & 255)
: r2 === 64
? fromCharCode((u24 >> 16) & 255, (u24 >> 8) & 255)
: fromCharCode((u24 >> 16) & 255, (u24 >> 8) & 255, u24 & 255);
}

return binary;
};
const BASE64_CHAR = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

// Adapted from https://gist.github.com/jonleighton/958841
export function atob(buf: number[]): string {
export function encode(buf: number[]): string {
let base64 = "";
const encodings = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

const byteLength = buf.size();
const byteRemainder = byteLength % 3;
Expand All @@ -75,10 +29,10 @@ export function atob(buf: number[]): string {

// Convert the raw binary segments to the appropriate ASCII encoding
base64 +=
string.char(string.byte(encodings, a + 1)[0]) +
string.char(string.byte(encodings, b + 1)[0]) +
string.char(string.byte(encodings, c + 1)[0]) +
string.char(string.byte(encodings, d + 1)[0]);
string.char(string.byte(BASE64_CHAR, a + 1)[0]) +
string.char(string.byte(BASE64_CHAR, b + 1)[0]) +
string.char(string.byte(BASE64_CHAR, c + 1)[0]) +
string.char(string.byte(BASE64_CHAR, d + 1)[0]);
}

// Deal with the remaining bytes and padding
Expand All @@ -90,7 +44,7 @@ export function atob(buf: number[]): string {
// Set the 4 least significant bits to zero
b = (chunk & 3) << 4;

base64 += string.byte(encodings, a)[0] + string.byte(encodings, b)[0] + "==";
base64 += string.byte(BASE64_CHAR, a)[0] + string.byte(BASE64_CHAR, b)[0] + "==";
} else if (byteRemainder === 2) {
chunk = (buf[mainLength] << 8) | buf[mainLength + 1];

Expand All @@ -101,11 +55,50 @@ export function atob(buf: number[]): string {
c = (chunk & 15) << 2;

base64 +=
string.char(string.byte(encodings, a + 1)[0]) +
string.char(string.byte(encodings, b + 1)[0]) +
string.char(string.byte(encodings, c + 1)[0]) +
string.char(string.byte(BASE64_CHAR, a + 1)[0]) +
string.char(string.byte(BASE64_CHAR, b + 1)[0]) +
string.char(string.byte(BASE64_CHAR, c + 1)[0]) +
"=";
}

return base64;
}

// FIXME: Ideally, you'd want to use bit math and mask off bytes and stuff,
// but I'm lazy, so this logic uses string manipulation instead
export function decode(base64: string): number[] {
// Strip padding from base64
base64 = base64.split("=")[0].gsub("%s", "")[0];

// Convert base64 chars to lookup table offsets
const chars = [];
for (let i = 1; i <= base64.size(); i++) {
const char = getCharAt(base64, i);
const [pos] = string.find(BASE64_CHAR, char);

pos !== undefined ? chars.push(pos - 1) : error("invalid base64 data");
}

// Convert offsets to 6 bit binary numbers
const bin = chars.map(toBinary);

// Combine all binary numbers into one
let combinedBin = "";
bin.forEach((b) => (combinedBin += b));

// Split the combined binary number into smaller ones of 8 bits each
const intermediaryBin = [];
while (combinedBin.size() > 0) {
intermediaryBin.push(string.sub(combinedBin, 1, 8));
combinedBin = string.sub(combinedBin, 9, combinedBin.size());
}

// Convert each individual 8 bit binary number to a base 10 integer
const decoded = [];
for (let i = 0; i < intermediaryBin.size() - 1; i++) {
const byte = tonumber(intermediaryBin[i], 2);
decoded.push(byte !== undefined ? byte : error("got invalid byte while decoding base64"));
}

return decoded;
}
21 changes: 15 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
const { generatePrivateKey, generatePublicKey } = require<{
generatePrivateKey: () => number[];
generatePublicKey: (privateKey: number[]) => number[];
generatePublicKey: (privateKey: string | number[]) => number[];
}>("./wg.lua");
const { atob } = require<{ atob: (buf: number[]) => string }>("./base64.lua");
const base64 = require<{
encode: (buf: number[]) => string;
decode: (base64: string) => number[];
}>("./base64.lua");

export interface Keypair {
publicKey: string;
Expand All @@ -11,7 +14,7 @@ export interface Keypair {

export interface Wireguard {
generateKeypair(): Keypair;
generatePublicKey(privateKey: number[]): string;
generatePublicKey(privateKey: number[] | string): string;
}

export const wireguard: Wireguard = {
Expand All @@ -21,14 +24,20 @@ export const wireguard: Wireguard = {
? pcall<[], number[]>(() => generatePublicKey(privateKey))
: error("failed to generate private key");
return {
publicKey: atob(publicKeyOk ? publicKey : error("failed to generate public key")),
privateKey: atob(privateKey as number[]),
publicKey: base64.encode(publicKeyOk ? publicKey : error("failed to generate public key")),
privateKey: base64.encode(privateKey as number[]),
};
},

generatePublicKey: function (privateKey) {
if (typeIs(privateKey, "string")) {
privateKey = base64.decode(privateKey);
}

const [publicKeyOk, publicKey] = pcall<[], number[]>(() => generatePublicKey(privateKey));

return atob(publicKeyOk ? publicKey : error("failed to generate public key"));
return base64.encode(
publicKeyOk ? publicKey : error("failed to generate public key %s".format(publicKey as string)),
);
},
};
2 changes: 1 addition & 1 deletion src/init.luau
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export type Keypair = {

export type Wireguard = {
generateKeypair: (self: {}) -> Keypair,
generatePublicKey: (self: {}, privateKey: { number }) -> string,
generatePublicKey: (self: {}, privateKey: { number } | string) -> string,
}

return {
Expand Down
42 changes: 27 additions & 15 deletions src/util.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
export function slice<T extends defined>(arr: T[], start: number, stop?: number): T[] {
const length = arr.size();
const OCTAL_LOOKUP = ["000", "001", "010", "011", "100", "101", "110", "111"];

if (start < 0) {
start = math.max(length + start, 0);
}
export function toBinary(int: number): string {
let bin = string.format("%o", int);
bin = bin.gsub(
".",
(b: string) =>
OCTAL_LOOKUP[
(() => {
const [ok, val] = pcall<[], number>(() => {
const res = tonumber(b);

if (stop === undefined) {
stop = length;
} else if (stop < 0) {
stop = math.max(length + stop, 0);
}
if (typeIs(res, "nil")) {
error("failed to convert to binary");
}

const result: T[] = [];
return res;
});

for (let i = start; i < stop; i++) {
result.push(arr[i]);
}
return ok ? val : error(val);
})()
],
)[0];

return result;
// Pad to ensure the binary number is 6 bits
bin = "0".rep(6 - bin.size()) + bin;

return bin;
}

export function getCharAt(str: string, pos: number): string {
return string.char(str.byte(pos)[0]);
}
37 changes: 2 additions & 35 deletions tests/generatePublicKey.luau
Original file line number Diff line number Diff line change
@@ -1,40 +1,7 @@
local wg = require("../out/").wireguard

local PRIVATE_KEY = {
[1] = 208,
[2] = 109,
[3] = 43,
[4] = 223,
[5] = 41,
[6] = 233,
[7] = 180,
[8] = 88,
[9] = 228,
[10] = 1,
[11] = 132,
[12] = 145,
[13] = 79,
[14] = 164,
[15] = 143,
[16] = 199,
[17] = 134,
[18] = 67,
[19] = 153,
[20] = 226,
[21] = 151,
[22] = 39,
[23] = 198,
[24] = 16,
[25] = 30,
[26] = 109,
[27] = 90,
[28] = 11,
[29] = 22,
[30] = 4,
[31] = 217,
[32] = 105,
}
local PUBLIC_KEY = "mYqWwJuiVXsXqfqXOKOKVTTZRovUXqzPkRtz1DwX1Wc="
local PRIVATE_KEY = "iIWrphmeEnCLZFjdN17RQfEq8ND1MX+qAdIpRJdRhEA="
local PUBLIC_KEY = "lYnVoKy9rzIapS0zPoLHskf4B+L3FouFXWwddKhRa3s="

local publicKey = wg:generatePublicKey(PRIVATE_KEY)

Expand Down

0 comments on commit 666a5f7

Please sign in to comment.