diff --git a/src/base64.ts b/src/base64.ts index 8c976c1..b952dd2 100644 --- a/src/base64.ts +++ b/src/base64.ts @@ -1,59 +1,13 @@ -const { slice } = require<{ - slice: (arr: T[], start: number, stop?: number) => T[]; +const { toBinary, getCharAt } = require<{ + toBinary: (int: number) => string; + getCharAt: (str: string, pos: number) => string; }>("./util.lua"); -function stringToBytes(str: string) { - const result = []; - - for (let i = 0; i < str.size(); i++) { - result.push(string.byte(str, i + 1)[0]); - } - - return result; -} - -// Adapted from https://github.com/un-ts/ab64/blob/main/src/ponyfill.ts#L24 -const _atob = (asc: string) => { - const b64CharList = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; - - const b64Chars = string.split(b64CharList, ""); - - const b64Table = b64Chars.reduce>((acc, char, index) => { - acc[char] = index; - return acc; - }, {}); - - const fromCharCode = string.char; - - asc = string.gsub(asc, "%s+", "")[0]; - asc += string.char(...slice(stringToBytes("=="), 2 - (asc.size() & 3))); - - let u24: number; - let binary = ""; - let r1: number; - let r2: number; - - for (let i = 0; i < asc.size(); i++) { - u24 = - (b64Table[string.byte(asc, i++)[0]] << 18) | - (b64Table[string.byte(asc, i++)[0]] << 12) | - ((r1 = b64Table[string.byte(asc, i++)[0]]) << 6) | - (r2 = b64Table[string.byte(asc, i++)[0]]); - binary += - r1 === 64 - ? fromCharCode((u24 >> 16) & 255) - : r2 === 64 - ? fromCharCode((u24 >> 16) & 255, (u24 >> 8) & 255) - : fromCharCode((u24 >> 16) & 255, (u24 >> 8) & 255, u24 & 255); - } - - return binary; -}; +const BASE64_CHAR = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; // Adapted from https://gist.github.com/jonleighton/958841 -export function atob(buf: number[]): string { +export function encode(buf: number[]): string { let base64 = ""; - const encodings = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; const byteLength = buf.size(); const byteRemainder = byteLength % 3; @@ -75,10 +29,10 @@ export function atob(buf: number[]): string { // Convert the raw binary segments to the appropriate ASCII encoding base64 += - string.char(string.byte(encodings, a + 1)[0]) + - string.char(string.byte(encodings, b + 1)[0]) + - string.char(string.byte(encodings, c + 1)[0]) + - string.char(string.byte(encodings, d + 1)[0]); + string.char(string.byte(BASE64_CHAR, a + 1)[0]) + + string.char(string.byte(BASE64_CHAR, b + 1)[0]) + + string.char(string.byte(BASE64_CHAR, c + 1)[0]) + + string.char(string.byte(BASE64_CHAR, d + 1)[0]); } // Deal with the remaining bytes and padding @@ -90,7 +44,7 @@ export function atob(buf: number[]): string { // Set the 4 least significant bits to zero b = (chunk & 3) << 4; - base64 += string.byte(encodings, a)[0] + string.byte(encodings, b)[0] + "=="; + base64 += string.byte(BASE64_CHAR, a)[0] + string.byte(BASE64_CHAR, b)[0] + "=="; } else if (byteRemainder === 2) { chunk = (buf[mainLength] << 8) | buf[mainLength + 1]; @@ -101,11 +55,50 @@ export function atob(buf: number[]): string { c = (chunk & 15) << 2; base64 += - string.char(string.byte(encodings, a + 1)[0]) + - string.char(string.byte(encodings, b + 1)[0]) + - string.char(string.byte(encodings, c + 1)[0]) + + string.char(string.byte(BASE64_CHAR, a + 1)[0]) + + string.char(string.byte(BASE64_CHAR, b + 1)[0]) + + string.char(string.byte(BASE64_CHAR, c + 1)[0]) + "="; } return base64; } + +// FIXME: Ideally, you'd want to use bit math and mask off bytes and stuff, +// but I'm lazy, so this logic uses string manipulation instead +export function decode(base64: string): number[] { + // Strip padding from base64 + base64 = base64.split("=")[0].gsub("%s", "")[0]; + + // Convert base64 chars to lookup table offsets + const chars = []; + for (let i = 1; i <= base64.size(); i++) { + const char = getCharAt(base64, i); + const [pos] = string.find(BASE64_CHAR, char); + + pos !== undefined ? chars.push(pos - 1) : error("invalid base64 data"); + } + + // Convert offsets to 6 bit binary numbers + const bin = chars.map(toBinary); + + // Combine all binary numbers into one + let combinedBin = ""; + bin.forEach((b) => (combinedBin += b)); + + // Split the combined binary number into smaller ones of 8 bits each + const intermediaryBin = []; + while (combinedBin.size() > 0) { + intermediaryBin.push(string.sub(combinedBin, 1, 8)); + combinedBin = string.sub(combinedBin, 9, combinedBin.size()); + } + + // Convert each individual 8 bit binary number to a base 10 integer + const decoded = []; + for (let i = 0; i < intermediaryBin.size() - 1; i++) { + const byte = tonumber(intermediaryBin[i], 2); + decoded.push(byte !== undefined ? byte : error("got invalid byte while decoding base64")); + } + + return decoded; +} diff --git a/src/index.ts b/src/index.ts index 158edc0..2a0457e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,8 +1,11 @@ const { generatePrivateKey, generatePublicKey } = require<{ generatePrivateKey: () => number[]; - generatePublicKey: (privateKey: number[]) => number[]; + generatePublicKey: (privateKey: string | number[]) => number[]; }>("./wg.lua"); -const { atob } = require<{ atob: (buf: number[]) => string }>("./base64.lua"); +const base64 = require<{ + encode: (buf: number[]) => string; + decode: (base64: string) => number[]; +}>("./base64.lua"); export interface Keypair { publicKey: string; @@ -11,7 +14,7 @@ export interface Keypair { export interface Wireguard { generateKeypair(): Keypair; - generatePublicKey(privateKey: number[]): string; + generatePublicKey(privateKey: number[] | string): string; } export const wireguard: Wireguard = { @@ -21,14 +24,20 @@ export const wireguard: Wireguard = { ? pcall<[], number[]>(() => generatePublicKey(privateKey)) : error("failed to generate private key"); return { - publicKey: atob(publicKeyOk ? publicKey : error("failed to generate public key")), - privateKey: atob(privateKey as number[]), + publicKey: base64.encode(publicKeyOk ? publicKey : error("failed to generate public key")), + privateKey: base64.encode(privateKey as number[]), }; }, generatePublicKey: function (privateKey) { + if (typeIs(privateKey, "string")) { + privateKey = base64.decode(privateKey); + } + const [publicKeyOk, publicKey] = pcall<[], number[]>(() => generatePublicKey(privateKey)); - return atob(publicKeyOk ? publicKey : error("failed to generate public key")); + return base64.encode( + publicKeyOk ? publicKey : error("failed to generate public key %s".format(publicKey as string)), + ); }, }; diff --git a/src/init.luau b/src/init.luau index 6f109b7..c0c00da 100644 --- a/src/init.luau +++ b/src/init.luau @@ -5,7 +5,7 @@ export type Keypair = { export type Wireguard = { generateKeypair: (self: {}) -> Keypair, - generatePublicKey: (self: {}, privateKey: { number }) -> string, + generatePublicKey: (self: {}, privateKey: { number } | string) -> string, } return { diff --git a/src/util.ts b/src/util.ts index 88ef4e8..ebccc68 100644 --- a/src/util.ts +++ b/src/util.ts @@ -1,21 +1,33 @@ -export function slice(arr: T[], start: number, stop?: number): T[] { - const length = arr.size(); +const OCTAL_LOOKUP = ["000", "001", "010", "011", "100", "101", "110", "111"]; - if (start < 0) { - start = math.max(length + start, 0); - } +export function toBinary(int: number): string { + let bin = string.format("%o", int); + bin = bin.gsub( + ".", + (b: string) => + OCTAL_LOOKUP[ + (() => { + const [ok, val] = pcall<[], number>(() => { + const res = tonumber(b); - if (stop === undefined) { - stop = length; - } else if (stop < 0) { - stop = math.max(length + stop, 0); - } + if (typeIs(res, "nil")) { + error("failed to convert to binary"); + } - const result: T[] = []; + return res; + }); - for (let i = start; i < stop; i++) { - result.push(arr[i]); - } + return ok ? val : error(val); + })() + ], + )[0]; - return result; + // Pad to ensure the binary number is 6 bits + bin = "0".rep(6 - bin.size()) + bin; + + return bin; +} + +export function getCharAt(str: string, pos: number): string { + return string.char(str.byte(pos)[0]); } diff --git a/tests/generatePublicKey.luau b/tests/generatePublicKey.luau index 1043017..fbaffce 100644 --- a/tests/generatePublicKey.luau +++ b/tests/generatePublicKey.luau @@ -1,40 +1,7 @@ local wg = require("../out/").wireguard -local PRIVATE_KEY = { - [1] = 208, - [2] = 109, - [3] = 43, - [4] = 223, - [5] = 41, - [6] = 233, - [7] = 180, - [8] = 88, - [9] = 228, - [10] = 1, - [11] = 132, - [12] = 145, - [13] = 79, - [14] = 164, - [15] = 143, - [16] = 199, - [17] = 134, - [18] = 67, - [19] = 153, - [20] = 226, - [21] = 151, - [22] = 39, - [23] = 198, - [24] = 16, - [25] = 30, - [26] = 109, - [27] = 90, - [28] = 11, - [29] = 22, - [30] = 4, - [31] = 217, - [32] = 105, -} -local PUBLIC_KEY = "mYqWwJuiVXsXqfqXOKOKVTTZRovUXqzPkRtz1DwX1Wc=" +local PRIVATE_KEY = "iIWrphmeEnCLZFjdN17RQfEq8ND1MX+qAdIpRJdRhEA=" +local PUBLIC_KEY = "lYnVoKy9rzIapS0zPoLHskf4B+L3FouFXWwddKhRa3s=" local publicKey = wg:generatePublicKey(PRIVATE_KEY)