Skip to content

Commit

Permalink
Merge pull request #1659 from Shigoto-dev19/shigoto-base64
Browse files Browse the repository at this point in the history
Add Base64 Encoding/Decoding and Enhanced String Handling for Bytes
  • Loading branch information
mitschabaude authored May 22, 2024
2 parents e7fa35d + 6dafe1f commit f0860e8
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

## [Unreleased](https://github.com/o1-labs/o1js/compare/6a1012162...HEAD)

### Added

- Added `base64Encode()` and `base64Decode(byteLength)` methods to the `Bytes` class.

### Fixes

- Fix type inference for `method.returns(Type)`, to require a matching return signature https://github.com/o1-labs/o1js/pull/1653
Expand Down
208 changes: 208 additions & 0 deletions src/lib/provable/bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { chunkString } from '../util/arrays.js';
import { Provable } from './provable.js';
import { UInt8 } from './int.js';
import { randomBytes } from '../../bindings/crypto/random.js';
import { Field } from './field.js';
import { Bool } from './bool.js';

// external API
export { Bytes };
Expand Down Expand Up @@ -96,6 +98,102 @@ class Bytes {
.join('');
}

/**
* Base64 encode bytes.
*/
base64Encode(): Bytes {
const uint8Bytes = this.bytes;

// Convert each byte to its 8-bit binary representation and reverse endianness
let plainBits: Bool[] = uint8Bytes
.map((b) => b.value.toBits(8).reverse())
.flat();

// Calculate the bit padding required to make the total bits length a multiple of 6
const bitPadding =
plainBits.length % 6 !== 0 ? 6 - (plainBits.length % 6) : 0;

// Add the required bit padding with 0 bits
plainBits.push(...Array(bitPadding).fill(new Bool(false)));

let encodedBytes: UInt8[] = [];

// Process the bits 6 at a time and encode to Base64
for (let i = 0; i < plainBits.length; i += 6) {
// Slice the next 6 bits and reverse endianness
let byteBits = plainBits.slice(i, i + 6).reverse();

// Convert the 6-bit chunk to a UInt8 value for indexing the Base64 table
const indexTableByte = UInt8.Unsafe.fromField(Field.fromBits(byteBits));

// Use the index to get the corresponding Base64 character and add to the result
encodedBytes.push(base64EncodeLookup(indexTableByte));
}

// Add '=' padding to the encoded output if required
const paddingLength =
uint8Bytes.length % 3 !== 0 ? 3 - (uint8Bytes.length % 3) : 0;
encodedBytes.push(...Array(paddingLength).fill(UInt8.from(61)));

return Bytes.from(encodedBytes);
}

/**
* Decode Base64-encoded bytes.
*
* @param byteLength The length of the output decoded bytes.
* @returns Decoded bytes as {@link Bytes}.
*
* @warning
* Ensure the input Base64 string does not contain '=' characters in the middle,
* as it can cause unexpected decoding results.
*/
base64Decode(byteLength: number): Bytes {
const encodedB64Bytes = this.bytes;

const charLength = encodedB64Bytes.length;
assert(
charLength % 4 === 0,
'Input base64 byte length should be a multiple of 4!'
);

let decodedB64Bytes: UInt8[] = new Array(byteLength).fill(UInt8.from(0));

let bitsIn: Bool[][][] = Array.from({ length: charLength / 4 }, () => []);
let bitsOut: Bool[][][] = Array.from({ length: charLength / 4 }, () =>
Array.from({ length: 4 }, () => [])
);

let idx = 0;
for (let i = 0; i < charLength / 4; i++) {
for (let j = 0; j < 4; j++) {
const translated = base64DecodeLookup(encodedB64Bytes[4 * i + j]);
bitsIn[i][j] = translated.toBits(6);
}

// Convert from four 6-bit words to three 8-bit words, unpacking the base64 encoding
bitsOut[i][0] = [bitsIn[i][1][4], bitsIn[i][1][5], ...bitsIn[i][0]];

for (let j = 0; j < 4; j++) {
bitsOut[i][1][j] = bitsIn[i][2][j + 2];
bitsOut[i][1][j + 4] = bitsIn[i][1][j];
}

bitsOut[i][2] = [...bitsIn[i][3], bitsIn[i][2][0], bitsIn[i][2][1]];

for (let j = 0; j < 3; j++) {
if (idx + j < byteLength) {
decodedB64Bytes[idx + j] = UInt8.Unsafe.fromField(
Field.fromBits(bitsOut[i][j])
);
}
}
idx += 3;
}

return Bytes.from(decodedB64Bytes);
}

// dynamic subclassing infra
static _size?: number;
static _provable?: ProvablePureExtended<
Expand Down Expand Up @@ -132,3 +230,113 @@ function createBytes(size: number): typeof Bytes {
});
};
}

/**
* Decodes a Base64 character to its original value.
* Adapted from the algorithm described in: http://0x80.pl/notesen/2016-01-17-sse-base64-decoding.html#vector-lookup-base
*
* @param input - The Base64 encoded byte to be decoded.
* @returns - The corresponding decoded value as a Field.
*/
function base64DecodeLookup(input: UInt8): Field {
// Initialize a Field to validate if the input byte is a valid Base64 character
let isValidBase64Chars = new Field(0);

// ['A' - 'Z'] range
const le_Z = input.lessThan(91);
const ge_A = input.greaterThan(64);
const range_AZ = le_Z.and(ge_A);
const sum_AZ = range_AZ.toField().mul(input.value.sub(65));
isValidBase64Chars = isValidBase64Chars.add(range_AZ.toField());

// ['a' - 'z'] range
const le_z = input.lessThan(123);
const ge_a = input.greaterThan(96);
const range_az = le_z.and(ge_a);
const sum_az = range_az.toField().mul(input.value.sub(71)).add(sum_AZ);
isValidBase64Chars = isValidBase64Chars.add(range_az.toField());

// ['0' - '9'] range
const le_9 = input.lessThan(58);
const ge_0 = input.greaterThan(47);
const range_09 = le_9.and(ge_0);
const sum_09 = range_09.toField().mul(input.value.add(4)).add(sum_az);
isValidBase64Chars = isValidBase64Chars.add(range_09.toField());

// '+' character
const equal_plus = input.value.equals(43);
const sum_plus = equal_plus.toField().mul(input.value.add(19)).add(sum_09);
isValidBase64Chars = isValidBase64Chars.add(equal_plus.toField());

// '/' character
const equal_slash = input.value.equals(47);
const sum_slash = equal_slash
.toField()
.mul(input.value.add(16))
.add(sum_plus);
isValidBase64Chars = isValidBase64Chars.add(equal_slash.toField());

// '=' character
const equal_eqsign = input.value.equals(61);
isValidBase64Chars = isValidBase64Chars.add(equal_eqsign.toField());

// Validate if input contains only valid Base64 characters
isValidBase64Chars.assertEquals(
1,
'Please provide Base64-encoded bytes containing only alphanumeric characters and +/='
);

return sum_slash;
}

/**
* Encodes a byte into its Base64 character representation.
*
* @param input - The byte to be encoded to Base64.
* @returns - The corresponding Base64 encoded character as a UInt8.
*/
function base64EncodeLookup(input: UInt8): UInt8 {
// Initialize a Field to validate if the input byte is included in the Base64 index table
let isValidBase64Chars = new Field(0);

// ['A', 'Z'] - Note: Remove greater than zero check because a UInt8 byte is always positive
const le_Z = input.lessThanOrEqual(25);
const range_AZ = le_Z;
const sum_AZ = range_AZ.toField().mul(input.value.add(65));
isValidBase64Chars = isValidBase64Chars.add(range_AZ.toField());

// ['a', 'z']
const le_z = input.lessThanOrEqual(51);
const ge_a = input.greaterThanOrEqual(26);
const range_az = le_z.and(ge_a);
const sum_az = range_az.toField().mul(input.value.add(71)).add(sum_AZ);
isValidBase64Chars = isValidBase64Chars.add(range_az.toField());

// ['0', '9']
const le_9 = input.lessThanOrEqual(61);
const ge_0 = input.greaterThanOrEqual(52);
const range_09 = le_9.and(ge_0);
const sum_09 = range_09.toField().mul(input.value.sub(4)).add(sum_az);
isValidBase64Chars = isValidBase64Chars.add(range_09.toField());

// '+'
const equal_plus = input.value.equals(62);
const sum_plus = equal_plus.toField().mul(input.value.sub(19)).add(sum_09);
isValidBase64Chars = isValidBase64Chars.add(equal_plus.toField());

// '/'
const equal_slash = input.value.equals(63);
const sum_slash = equal_slash
.toField()
.mul(input.value.sub(16))
.add(sum_plus);
isValidBase64Chars = isValidBase64Chars.add(equal_slash.toField());

// Validate if input contains only valid base64 characters
isValidBase64Chars.assertEquals(
1,
'Invalid character detected: The input contains a byte that is not present in the BASE64 index table!'
);

return UInt8.Unsafe.fromField(sum_slash);
}
107 changes: 107 additions & 0 deletions src/lib/provable/test/base64.unit-test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import { Bytes } from '../wrapped-classes.js';
import { describe, test } from 'node:test';
import { expect } from 'expect';

function calculateB64DecodedBytesLength(base64String: string): number {
// Calculate the length of the base64-encoded string
const base64Length = base64String.length;

// Count the number of padding characters '=' in the base64 string
const padding = (base64String.match(/=/g) || []).length;

// Calculate the length of the decoded bytes
const byteLength = (base64Length * 3) / 4 - padding;

return byteLength;
}

function generateRandomString(
maxLength: number,
encoding?: BufferEncoding
): string {
// Generate a random length between 1 and maxLength
const randomLength = Math.floor(Math.random() * maxLength) + 1;

// Generate random bytes
const randomBytes = Bytes(randomLength).random().toBytes();

// Convert to string given the chosen encoding
const randomString = Buffer.from(randomBytes).toString(encoding);

return randomString;
}

describe('Base64 Decode Tests', () => {
function testBase64Decode(base64String: string) {
// Calculate the expected length of the decoded bytes
const decodedByteLength = calculateB64DecodedBytesLength(base64String);

// Decode the base64 string
const decodedBytes = Bytes.fromString(base64String)
.base64Decode(decodedByteLength)
.toBytes();

// Calculate the expected decoded bytes using JS implementation
const decodedString = atob(base64String);
let expectedDecodedBytes = new Uint8Array(decodedString.length);

// Populate the expected decoded bytes array with character codes
for (let i = 0; i < decodedString.length; i++) {
expectedDecodedBytes[i] = decodedString.charCodeAt(i);
}

expect(decodedBytes).toEqual(expectedDecodedBytes);
}

test('should decode a base64-encoded input', async () => {
const input = '7xQMDuoVVU4m0W0WRVSrVXMeGSIASsnucK9dJsrc+vU=';
testBase64Decode(input);
});

test('should decode a base64-encoded input (1000 iterations)', async () => {
for (let i = 0; i < 1000; i++) {
const randomBase64String = generateRandomString(100, 'base64');
testBase64Decode(randomBase64String);
}
});

test('should reject a base64-encoded input of length not a multiple of 4', async () => {
const input = 'ad/';
const errorMessage = 'Input base64 byte length should be a multiple of 4!';
expect(() => testBase64Decode(input)).toThrowError(errorMessage);
});

test('should reject input containing non-base64 characters', async () => {
const input = 'ad$=';
const errorMessage =
'Please provide Base64-encoded bytes containing only alphanumeric characters and +/=';
expect(() => testBase64Decode(input)).toThrowError(errorMessage);
});
});

describe('Base64 Encode Tests', () => {
function testBase64Encode(input: string) {
const inputBytes = Bytes.fromString(input);

// Base64 Encode the input bytes
const encodedBytes = inputBytes.base64Encode();

// Calculate the expected encoded bytes using JS implementation
const expectedEncodedBytes = Bytes.from(Buffer.from(btoa(input)));

expect(encodedBytes).toEqual(expectedEncodedBytes);
}

test('should Base64 encode an input', async () => {
const input =
'ef140c0eea15554e26d16d164554ab55731e1922004ac9ee70af5d26cadcfaf5';
testBase64Encode(input);
});

test('should Base64 encode different inputs (1000 iterations)', async () => {
for (let i = 0; i < 1000; i++) {
const input = generateRandomString(100, 'base64');
testBase64Encode(input);
}
});
});

0 comments on commit f0860e8

Please sign in to comment.