diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index f60310d544536d..050d66e5d54e67 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -24,6 +24,7 @@ jobs: make: strategy: matrix: + os: [ubuntu-22.04, ubuntu-20.04] test_task: [check] arch: [''] configure: ['cppflags=-DVM_CHECK_MODE'] @@ -31,19 +32,24 @@ jobs: include: - test_task: check arch: i686 + os: ubuntu-22.04 - test_task: check configure: '--disable-yjit' + os: ubuntu-22.04 - test_task: check configure: '--enable-shared --enable-load-relative' + os: ubuntu-22.04 - test_task: test-bundler-parallel + os: ubuntu-22.04 - test_task: test-bundled-gems + os: ubuntu-22.04 fail-fast: false env: GITPULLOPTIONS: --no-tags origin ${{ github.ref }} RUBY_DEBUG: ci - runs-on: ubuntu-20.04 + runs-on: ${{ matrix.os }} if: >- ${{!(false diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb index 2aecd4df863f48..0a064a5c941850 100644 --- a/lib/prism/ffi.rb +++ b/lib/prism/ffi.rb @@ -23,15 +23,21 @@ module LibRubyParser # :nodoc: # size_t -> :size_t # void -> :void # - def self.resolve_type(type) + def self.resolve_type(type, callbacks) type = type.strip - type.end_with?("*") ? :pointer : type.delete_prefix("const ").to_sym + + if !type.end_with?("*") + type.delete_prefix("const ").to_sym + else + type = type.delete_suffix("*").rstrip + callbacks.include?(type.to_sym) ? type.to_sym : :pointer + end end # Read through the given header file and find the declaration of each of the # given functions. For each one, define a function with the same name and # signature as the C function. - def self.load_exported_functions_from(header, *functions) + def self.load_exported_functions_from(header, *functions, callbacks) File.foreach(File.expand_path("../../include/#{header}", __dir__)) do |line| # We only want to attempt to load exported functions. next unless line.start_with?("PRISM_EXPORTED_FUNCTION ") @@ -55,24 +61,28 @@ def self.load_exported_functions_from(header, *functions) # Resolve the type of the argument by dropping the name of the argument # first if it is present. - arg_types.map! { |type| resolve_type(type.sub(/\w+$/, "")) } + arg_types.map! { |type| resolve_type(type.sub(/\w+$/, ""), callbacks) } # Attach the function using the FFI library. - attach_function name, arg_types, resolve_type(return_type) + attach_function name, arg_types, resolve_type(return_type, []) end # If we didn't find all of the functions, raise an error. raise "Could not find functions #{functions.inspect}" unless functions.empty? end + callback :pm_parse_stream_fgets_t, [:pointer, :int, :pointer], :pointer + load_exported_functions_from( "prism.h", "pm_version", "pm_serialize_parse", + "pm_serialize_parse_stream", "pm_serialize_parse_comments", "pm_serialize_lex", "pm_serialize_parse_lex", - "pm_parse_success_p" + "pm_parse_success_p", + [:pm_parse_stream_fgets_t] ) load_exported_functions_from( @@ -81,7 +91,8 @@ def self.load_exported_functions_from(header, *functions) "pm_buffer_init", "pm_buffer_value", "pm_buffer_length", - "pm_buffer_free" + "pm_buffer_free", + [] ) load_exported_functions_from( @@ -90,7 +101,8 @@ def self.load_exported_functions_from(header, *functions) "pm_string_free", "pm_string_source", "pm_string_length", - "pm_string_sizeof" + "pm_string_sizeof", + [] ) # This object represents a pm_buffer_t. We only use it as an opaque pointer, @@ -215,13 +227,36 @@ def parse(code, **options) end # Mirror the Prism.parse_file API by using the serialization API. This uses - # native strings instead of Ruby strings because it allows us to use mmap when - # it is available. + # native strings instead of Ruby strings because it allows us to use mmap + # when it is available. def parse_file(filepath, **options) options[:filepath] = filepath LibRubyParser::PrismString.with_file(filepath) { |string| parse_common(string, string.read, options) } end + # Mirror the Prism.parse_stream API by using the serialization API. + def parse_stream(stream, **options) + LibRubyParser::PrismBuffer.with do |buffer| + source = +"" + callback = -> (string, size, _) { + raise "Expected size to be >= 0, got: #{size}" if size <= 0 + + if !(line = stream.gets(size - 1)).nil? + source << line + string.write_string("#{line}\x00", line.bytesize + 1) + end + } + + # In the pm_serialize_parse_stream function it accepts a pointer to the + # IO object as a void* and then passes it through to the callback as the + # third argument, but it never touches it itself. As such, since we have + # access to the IO object already through the closure of the lambda, we + # can pass a null pointer here and not worry. + LibRubyParser.pm_serialize_parse_stream(buffer.pointer, nil, callback, dump_options(options)) + Prism.load(source, buffer.read) + end + end + # Mirror the Prism.parse_comments API by using the serialization API. def parse_comments(code, **options) LibRubyParser::PrismString.with_string(code) { |string| parse_comments_common(string, code, options) } diff --git a/prism/extension.c b/prism/extension.c index 09ce6a1c0c29f6..292e67891f386b 100644 --- a/prism/extension.c +++ b/prism/extension.c @@ -504,6 +504,24 @@ parser_warnings(pm_parser_t *parser, rb_encoding *encoding, VALUE source) { return warnings; } +/** + * Create a new parse result from the given parser, value, encoding, and source. + */ +static VALUE +parse_result_create(pm_parser_t *parser, VALUE value, rb_encoding *encoding, VALUE source) { + VALUE result_argv[] = { + value, + parser_comments(parser, source), + parser_magic_comments(parser, source), + parser_data_loc(parser, source), + parser_errors(parser, encoding, source), + parser_warnings(parser, encoding, source), + source + }; + + return rb_class_new_instance(7, result_argv, rb_cPrismParseResult); +} + /******************************************************************************/ /* Lexing Ruby code */ /******************************************************************************/ @@ -610,19 +628,11 @@ parse_lex_input(pm_string_t *input, const pm_options_t *options, bool return_nod value = parse_lex_data.tokens; } - VALUE result_argv[] = { - value, - parser_comments(&parser, source), - parser_magic_comments(&parser, source), - parser_data_loc(&parser, source), - parser_errors(&parser, parse_lex_data.encoding, source), - parser_warnings(&parser, parse_lex_data.encoding, source), - source - }; - + VALUE result = parse_result_create(&parser, value, parse_lex_data.encoding, source); pm_node_destroy(&parser, node); pm_parser_free(&parser); - return rb_class_new_instance(7, result_argv, rb_cPrismParseResult); + + return result; } /** @@ -682,17 +692,8 @@ parse_input(pm_string_t *input, const pm_options_t *options) { rb_encoding *encoding = rb_enc_find(parser.encoding->name); VALUE source = pm_source_new(&parser, encoding); - VALUE result_argv[] = { - pm_ast_new(&parser, node, encoding, source), - parser_comments(&parser, source), - parser_magic_comments(&parser, source), - parser_data_loc(&parser, source), - parser_errors(&parser, encoding, source), - parser_warnings(&parser, encoding, source), - source - }; - - VALUE result = rb_class_new_instance(7, result_argv, rb_cPrismParseResult); + VALUE value = pm_ast_new(&parser, node, encoding, source); + VALUE result = parse_result_create(&parser, value, encoding, source) ; pm_node_destroy(&parser, node); pm_parser_free(&parser); @@ -751,6 +752,60 @@ parse(int argc, VALUE *argv, VALUE self) { return value; } +/** + * An implementation of fgets that is suitable for use with Ruby IO objects. + */ +static char * +parse_stream_fgets(char *string, int size, void *stream) { + RUBY_ASSERT(size > 0); + + VALUE line = rb_funcall((VALUE) stream, rb_intern("gets"), 1, INT2FIX(size - 1)); + if (NIL_P(line)) { + return NULL; + } + + const char *cstr = StringValueCStr(line); + size_t length = strlen(cstr); + + memcpy(string, cstr, length); + string[length] = '\0'; + + return string; +} + +/** + * call-seq: + * Prism::parse_stream(stream, **options) -> ParseResult + * + * Parse the given object that responds to `gets` and return a ParseResult + * instance. The options that are supported are the same as Prism::parse. + */ +static VALUE +parse_stream(int argc, VALUE *argv, VALUE self) { + VALUE stream; + VALUE keywords; + rb_scan_args(argc, argv, "1:", &stream, &keywords); + + pm_options_t options = { 0 }; + extract_options(&options, Qnil, keywords); + + pm_parser_t parser; + pm_buffer_t buffer; + + pm_node_t *node = pm_parse_stream(&parser, &buffer, (void *) stream, parse_stream_fgets, &options); + rb_encoding *encoding = rb_enc_find(parser.encoding->name); + + VALUE source = pm_source_new(&parser, encoding); + VALUE value = pm_ast_new(&parser, node, encoding, source); + VALUE result = parse_result_create(&parser, value, encoding, source); + + pm_node_destroy(&parser, node); + pm_buffer_free(&buffer); + pm_parser_free(&parser); + + return result; +} + /** * call-seq: * Prism::parse_file(filepath, **options) -> ParseResult @@ -992,26 +1047,16 @@ integer_parse(VALUE self, VALUE source) { pm_integer_t integer = { 0 }; pm_integer_parse(&integer, PM_INTEGER_BASE_UNKNOWN, start, start + length); - VALUE number = UINT2NUM(integer.head.value); - size_t shift = 0; - - for (pm_integer_word_t *node = integer.head.next; node != NULL; node = node->next) { - VALUE receiver = rb_funcall(UINT2NUM(node->value), rb_intern("<<"), 1, ULONG2NUM(++shift * 32)); - number = rb_funcall(receiver, rb_intern("|"), 1, number); - } - - if (integer.negative) number = rb_funcall(number, rb_intern("-@"), 0); - pm_buffer_t buffer = { 0 }; pm_integer_string(&buffer, &integer); VALUE string = rb_str_new(pm_buffer_value(&buffer), pm_buffer_length(&buffer)); pm_buffer_free(&buffer); - pm_integer_free(&integer); VALUE result = rb_ary_new_capa(2); - rb_ary_push(result, number); + rb_ary_push(result, pm_integer_new(&integer)); rb_ary_push(result, string); + pm_integer_free(&integer); return result; } @@ -1271,6 +1316,7 @@ Init_prism(void) { rb_define_singleton_method(rb_cPrism, "lex", lex, -1); rb_define_singleton_method(rb_cPrism, "lex_file", lex_file, -1); rb_define_singleton_method(rb_cPrism, "parse", parse, -1); + rb_define_singleton_method(rb_cPrism, "parse_stream", parse_stream, -1); rb_define_singleton_method(rb_cPrism, "parse_file", parse_file, -1); rb_define_singleton_method(rb_cPrism, "parse_comments", parse_comments, -1); rb_define_singleton_method(rb_cPrism, "parse_file_comments", parse_file_comments, -1); diff --git a/prism/extension.h b/prism/extension.h index 6e5a3450122a93..13a9aabde3e5c3 100644 --- a/prism/extension.h +++ b/prism/extension.h @@ -10,6 +10,7 @@ VALUE pm_source_new(const pm_parser_t *parser, rb_encoding *encoding); VALUE pm_token_new(const pm_parser_t *parser, const pm_token_t *token, rb_encoding *encoding, VALUE source); VALUE pm_ast_new(const pm_parser_t *parser, const pm_node_t *node, rb_encoding *encoding, VALUE source); +VALUE pm_integer_new(const pm_integer_t *integer); void Init_prism_api_node(void); void Init_prism_pack(void); diff --git a/prism/parser.h b/prism/parser.h index 80521e4ad943af..02f60192d559c6 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -234,6 +234,9 @@ typedef struct pm_lex_mode { * a tilde heredoc. */ size_t common_whitespace; + + /** True if the previous token ended with a line continuation. */ + bool line_continuation; } heredoc; } as; diff --git a/prism/prism.c b/prism/prism.c index 6717488882edec..6921feac48fffe 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -8831,6 +8831,8 @@ pm_token_buffer_escape(pm_parser_t *parser, pm_token_buffer_t *token_buffer) { const uint8_t *end = parser->current.end - 1; pm_buffer_append_bytes(&token_buffer->buffer, start, (size_t) (end - start)); + + token_buffer->cursor = end; } /** @@ -9450,7 +9452,8 @@ parser_lex(pm_parser_t *parser) { .next_start = parser->current.end, .quote = quote, .indent = indent, - .common_whitespace = (size_t) -1 + .common_whitespace = (size_t) -1, + .line_continuation = false } }); @@ -10719,6 +10722,9 @@ parser_lex(pm_parser_t *parser) { // current lex mode. pm_lex_mode_t *lex_mode = parser->lex_modes.current; + bool line_continuation = lex_mode->as.heredoc.line_continuation; + lex_mode->as.heredoc.line_continuation = false; + // We'll check if we're at the end of the file. If we are, then we // will add an error (because we weren't able to find the // terminator) but still continue parsing so that content after the @@ -10736,7 +10742,7 @@ parser_lex(pm_parser_t *parser) { // If we are immediately following a newline and we have hit the // terminator, then we need to return the ending of the heredoc. - if (current_token_starts_line(parser)) { + if (!line_continuation && current_token_starts_line(parser)) { const uint8_t *start = parser->current.start; if (start + ident_length <= parser->end) { const uint8_t *newline = next_newline(start, parser->end - start); @@ -10808,7 +10814,7 @@ parser_lex(pm_parser_t *parser) { const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); pm_token_buffer_t token_buffer = { { 0 }, 0 }; - bool was_escaped_newline = false; + bool was_line_continuation = false; while (breakpoint != NULL) { switch (*breakpoint) { @@ -10831,7 +10837,7 @@ parser_lex(pm_parser_t *parser) { // some leading whitespace. const uint8_t *start = breakpoint + 1; - if (!was_escaped_newline && (start + ident_length <= parser->end)) { + if (!was_line_continuation && (start + ident_length <= parser->end)) { // We want to match the terminator starting from the end of the line in case // there is whitespace in the ident such as <<-' DOC' or <<~' DOC'. const uint8_t *newline = next_newline(start, parser->end - start); @@ -10873,7 +10879,6 @@ parser_lex(pm_parser_t *parser) { // heredoc here as string content. Then, the next time a // token is lexed, it will match again and return the // end of the heredoc. - if (lex_mode->as.heredoc.indent == PM_HEREDOC_INDENT_TILDE) { if ((lex_mode->as.heredoc.common_whitespace > whitespace) && peek_at(parser, start) != '\n') { lex_mode->as.heredoc.common_whitespace = whitespace; @@ -10881,7 +10886,7 @@ parser_lex(pm_parser_t *parser) { parser->current.end = breakpoint + 1; - if (!was_escaped_newline) { + if (!was_line_continuation) { pm_token_buffer_flush(parser, &token_buffer); LEX(PM_TOKEN_STRING_CONTENT); } @@ -10943,7 +10948,26 @@ parser_lex(pm_parser_t *parser) { } /* fallthrough */ case '\n': - was_escaped_newline = true; + // If we are in a tilde here, we should + // break out of the loop and return the + // string content. + if (lex_mode->as.heredoc.indent == PM_HEREDOC_INDENT_TILDE) { + const uint8_t *end = parser->current.end; + pm_newline_list_append(&parser->newline_list, end); + + // Here we want the buffer to only + // include up to the backslash. + parser->current.end = breakpoint; + pm_token_buffer_flush(parser, &token_buffer); + + // Now we can advance the end of the + // token past the newline. + parser->current.end = end + 1; + lex_mode->as.heredoc.line_continuation = true; + LEX(PM_TOKEN_STRING_CONTENT); + } + + was_line_continuation = true; token_buffer.cursor = parser->current.end + 1; breakpoint = parser->current.end; continue; @@ -10980,7 +11004,7 @@ parser_lex(pm_parser_t *parser) { assert(false && "unreachable"); } - was_escaped_newline = false; + was_line_continuation = false; } if (parser->current.end > parser->current.start) { @@ -16626,7 +16650,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b pm_interpolated_symbol_node_append(interpolated, first_string); pm_interpolated_symbol_node_append(interpolated, second_string); - free(current); + xfree(current); current = (pm_node_t *) interpolated; } else { assert(false && "unreachable"); @@ -18681,6 +18705,99 @@ pm_parse(pm_parser_t *parser) { return parse_program(parser); } +/** + * Read into the stream until the gets callback returns false. If the last read + * line from the stream matches an __END__ marker, then halt and return false, + * otherwise return true. + */ +static bool +pm_parse_stream_read(pm_buffer_t *buffer, void *stream, pm_parse_stream_fgets_t *fgets) { +#define LINE_SIZE 4096 + char line[LINE_SIZE]; + + while (fgets(line, LINE_SIZE, stream) != NULL) { + size_t length = strlen(line); + + if (length == LINE_SIZE && line[length - 1] != '\n') { + // If we read a line that is the maximum size and it doesn't end + // with a newline, then we'll just append it to the buffer and + // continue reading. + pm_buffer_append_string(buffer, line, length); + continue; + } + + // Append the line to the buffer. + pm_buffer_append_string(buffer, line, length); + + // Check if the line matches the __END__ marker. If it does, then stop + // reading and return false. In most circumstances, this means we should + // stop reading from the stream so that the DATA constant can pick it + // up. + switch (length) { + case 7: + if (strncmp(line, "__END__", 7) == 0) return false; + break; + case 8: + if (strncmp(line, "__END__\n", 8) == 0) return false; + break; + case 9: + if (strncmp(line, "__END__\r\n", 9) == 0) return false; + break; + } + } + + return true; +#undef LINE_SIZE +} + +/** + * Determine if there was an unterminated heredoc at the end of the input, which + * would mean the stream isn't finished and we should keep reading. + * + * For the other lex modes we can check if the lex mode has been closed, but for + * heredocs when we hit EOF we close the lex mode and then go back to parse the + * rest of the line after the heredoc declaration so that we get more of the + * syntax tree. + */ +static bool +pm_parse_stream_unterminated_heredoc_p(pm_parser_t *parser) { + pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) parser->error_list.head; + + for (; diagnostic != NULL; diagnostic = (pm_diagnostic_t *) diagnostic->node.next) { + if (diagnostic->diag_id == PM_ERR_HEREDOC_TERM) { + return true; + } + } + + return false; +} + +/** + * Parse a stream of Ruby source and return the tree. + * + * Prism is designed around having the entire source in memory at once, but you + * can stream stdin in to Ruby so we need to support a streaming API. + */ +PRISM_EXPORTED_FUNCTION pm_node_t * +pm_parse_stream(pm_parser_t *parser, pm_buffer_t *buffer, void *stream, pm_parse_stream_fgets_t *fgets, const pm_options_t *options) { + pm_buffer_init(buffer); + + bool eof = pm_parse_stream_read(buffer, stream, fgets); + pm_parser_init(parser, (const uint8_t *) pm_buffer_value(buffer), pm_buffer_length(buffer), options); + pm_node_t *node = pm_parse(parser); + + while (!eof && parser->error_list.size > 0 && (parser->lex_modes.index > 0 || pm_parse_stream_unterminated_heredoc_p(parser))) { + pm_node_destroy(parser, node); + eof = pm_parse_stream_read(buffer, stream, fgets); + + pm_parser_free(parser); + pm_parser_init(parser, (const uint8_t *) pm_buffer_value(buffer), pm_buffer_length(buffer), options); + node = pm_parse(parser); + } + + return node; +} + static inline void pm_serialize_header(pm_buffer_t *buffer) { pm_buffer_append_string(buffer, "PRISM", 5); @@ -18723,6 +18840,28 @@ pm_serialize_parse(pm_buffer_t *buffer, const uint8_t *source, size_t size, cons pm_options_free(&options); } +/** + * Parse and serialize the AST represented by the source that is read out of the + * given stream into to the given buffer. + */ +PRISM_EXPORTED_FUNCTION void +pm_serialize_parse_stream(pm_buffer_t *buffer, void *stream, pm_parse_stream_fgets_t *fgets, const char *data) { + pm_parser_t parser; + pm_options_t options = { 0 }; + pm_options_read(&options, data); + + pm_buffer_t parser_buffer; + pm_node_t *node = pm_parse_stream(&parser, &parser_buffer, stream, fgets, &options); + pm_serialize_header(buffer); + pm_serialize_content(&parser, node, buffer); + pm_buffer_append_byte(buffer, '\0'); + + pm_node_destroy(&parser, node); + pm_buffer_free(&parser_buffer); + pm_parser_free(&parser); + pm_options_free(&options); +} + /** * Parse and serialize the comments in the given source to the given buffer. */ diff --git a/prism/prism.h b/prism/prism.h index 7d9b96fa829e99..5e3919f40b8c94 100644 --- a/prism/prism.h +++ b/prism/prism.h @@ -79,6 +79,36 @@ PRISM_EXPORTED_FUNCTION void pm_parser_free(pm_parser_t *parser); */ PRISM_EXPORTED_FUNCTION pm_node_t * pm_parse(pm_parser_t *parser); +/** + * This function is used in pm_parse_stream to retrieve a line of input from a + * stream. It closely mirrors that of fgets so that fgets can be used as the + * default implementation. + */ +typedef char * (pm_parse_stream_fgets_t)(char *string, int size, void *stream); + +/** + * Parse a stream of Ruby source and return the tree. + * + * @param parser The parser to use. + * @param buffer The buffer to use. + * @param stream The stream to parse. + * @param fgets The function to use to read from the stream. + * @param options The optional options to use when parsing. + * @return The AST representing the source. + */ +PRISM_EXPORTED_FUNCTION pm_node_t * pm_parse_stream(pm_parser_t *parser, pm_buffer_t *buffer, void *stream, pm_parse_stream_fgets_t *fgets, const pm_options_t *options); + +/** + * Parse and serialize the AST represented by the source that is read out of the + * given stream into to the given buffer. + * + * @param buffer The buffer to serialize to. + * @param stream The stream to parse. + * @param fgets The function to use to read from the stream. + * @param data The optional data to pass to the parser. + */ +PRISM_EXPORTED_FUNCTION void pm_serialize_parse_stream(pm_buffer_t *buffer, void *stream, pm_parse_stream_fgets_t *fgets, const char *data); + /** * Serialize the given list of comments to the given buffer. * diff --git a/prism/static_literals.c b/prism/static_literals.c index 713721bb73acd5..81231692f6688a 100644 --- a/prism/static_literals.c +++ b/prism/static_literals.c @@ -53,12 +53,11 @@ node_hash(const pm_parser_t *parser, const pm_node_t *node) { case PM_INTEGER_NODE: { // Integers hash their value. const pm_integer_t *integer = &((const pm_integer_node_t *) node)->value; - const uint32_t *value = &integer->head.value; - - uint32_t hash = murmur_hash((const uint8_t *) value, sizeof(uint32_t)); - for (const pm_integer_word_t *word = integer->head.next; word != NULL; word = word->next) { - value = &word->value; - hash ^= murmur_hash((const uint8_t *) value, sizeof(uint32_t)); + uint32_t hash; + if (integer->values) { + hash = murmur_hash((const uint8_t *) integer->values, sizeof(uint32_t) * integer->length); + } else { + hash = murmur_hash((const uint8_t *) &integer->value, sizeof(uint32_t)); } if (integer->negative) { @@ -204,9 +203,9 @@ pm_int64_value(const pm_parser_t *parser, const pm_node_t *node) { switch (PM_NODE_TYPE(node)) { case PM_INTEGER_NODE: { const pm_integer_t *integer = &((const pm_integer_node_t *) node)->value; - if (integer->length > 0) return integer->negative ? INT64_MIN : INT64_MAX; + if (integer->values) return integer->negative ? INT64_MIN : INT64_MAX; - int64_t value = (int64_t) integer->head.value; + int64_t value = (int64_t) integer->value; return integer->negative ? -value : value; } case PM_SOURCE_LINE_NODE: diff --git a/prism/templates/ext/prism/api_node.c.erb b/prism/templates/ext/prism/api_node.c.erb index 301479b3c5dff3..0e8aaae3226093 100644 --- a/prism/templates/ext/prism/api_node.c.erb +++ b/prism/templates/ext/prism/api_node.c.erb @@ -37,14 +37,26 @@ pm_string_new(const pm_string_t *string, rb_encoding *encoding) { return rb_enc_str_new((const char *) pm_string_source(string), pm_string_length(string), encoding); } -static VALUE +VALUE pm_integer_new(const pm_integer_t *integer) { - VALUE result = UINT2NUM(integer->head.value); - size_t shift = 0; + VALUE result; + if (integer->values == NULL) { + result = UINT2NUM(integer->value); + } else { + VALUE string = rb_str_new(NULL, integer->length * 8); + unsigned char *bytes = (unsigned char *) RSTRING_PTR(string); + + size_t offset = integer->length * 8; + for (size_t value_index = 0; value_index < integer->length; value_index++) { + uint32_t value = integer->values[value_index]; + + for (int index = 0; index < 8; index++) { + int byte = (value >> (4 * index)) & 0xf; + bytes[--offset] = byte < 10 ? byte + '0' : byte - 10 + 'a'; + } + } - for (const pm_integer_word_t *node = integer->head.next; node != NULL; node = node->next) { - VALUE receiver = rb_funcall(UINT2NUM(node->value), rb_intern("<<"), 1, ULONG2NUM(++shift * 32)); - result = rb_funcall(receiver, rb_intern("|"), 1, result); + result = rb_funcall(string, rb_intern("to_i"), 1, UINT2NUM(16)); } if (integer->negative) { diff --git a/prism/templates/src/serialize.c.erb b/prism/templates/src/serialize.c.erb index 0313f43d78a351..27fde37f698601 100644 --- a/prism/templates/src/serialize.c.erb +++ b/prism/templates/src/serialize.c.erb @@ -52,10 +52,14 @@ pm_serialize_string(const pm_parser_t *parser, const pm_string_t *string, pm_buf static void pm_serialize_integer(const pm_integer_t *integer, pm_buffer_t *buffer) { pm_buffer_append_byte(buffer, integer->negative ? 1 : 0); - pm_buffer_append_varuint(buffer, pm_sizet_to_u32(integer->length + 1)); - - for (const pm_integer_word_t *node = &integer->head; node != NULL; node = node->next) { - pm_buffer_append_varuint(buffer, node->value); + if (integer->values == NULL) { + pm_buffer_append_varuint(buffer, pm_sizet_to_u32(1)); + pm_buffer_append_varuint(buffer, integer->value); + } else { + pm_buffer_append_varuint(buffer, pm_sizet_to_u32(integer->length)); + for (size_t i = 0; i < integer->length; i++) { + pm_buffer_append_varuint(buffer, integer->values[i]); + } } } diff --git a/prism/util/pm_integer.c b/prism/util/pm_integer.c index c03b930ad3f3d7..160a78920c0086 100644 --- a/prism/util/pm_integer.c +++ b/prism/util/pm_integer.c @@ -1,143 +1,452 @@ #include "prism/util/pm_integer.h" /** - * Create a new node for an integer in the linked list. + * Pull out the length and values from the integer, regardless of the form in + * which the length/values are stored. */ -static pm_integer_word_t * -pm_integer_node_create(pm_integer_t *integer, uint32_t value) { - integer->length++; +#define INTEGER_EXTRACT(integer, length_variable, values_variable) \ + if ((integer)->values == NULL) { \ + length_variable = 1; \ + values_variable = &(integer)->value; \ + } else { \ + length_variable = (integer)->length; \ + values_variable = (integer)->values; \ + } - pm_integer_word_t *node = xmalloc(sizeof(pm_integer_word_t)); - if (node == NULL) return NULL; +/** + * Adds two positive pm_integer_t with the given base. + * Return pm_integer_t with values allocated. Not normalized. + */ +static void +big_add(pm_integer_t *destination, pm_integer_t *left, pm_integer_t *right, uint64_t base) { + size_t left_length; + uint32_t *left_values; + INTEGER_EXTRACT(left, left_length, left_values) + + size_t right_length; + uint32_t *right_values; + INTEGER_EXTRACT(right, right_length, right_values) + + size_t length = left_length < right_length ? right_length : left_length; + uint32_t *values = (uint32_t *) xmalloc(sizeof(uint32_t) * (length + 1)); + if (values == NULL) return; + + uint64_t carry = 0; + for (size_t index = 0; index < length; index++) { + uint64_t sum = carry + (index < left_length ? left_values[index] : 0) + (index < right_length ? right_values[index] : 0); + values[index] = (uint32_t) (sum % base); + carry = sum / base; + } - *node = (pm_integer_word_t) { .next = NULL, .value = value }; - return node; + if (carry > 0) { + values[length] = (uint32_t) carry; + length++; + } + + *destination = (pm_integer_t) { 0, length, values, false }; } /** - * Copy one integer onto another. + * Internal use for karatsuba_multiply. Calculates `a - b - c` with the given + * base. Assume a, b, c, a - b - c all to be poitive. + * Return pm_integer_t with values allocated. Not normalized. */ static void -pm_integer_copy(pm_integer_t *dest, const pm_integer_t *src) { - dest->negative = src->negative; - dest->length = 0; +big_sub2(pm_integer_t *destination, pm_integer_t *a, pm_integer_t *b, pm_integer_t *c, uint64_t base) { + size_t a_length; + uint32_t *a_values; + INTEGER_EXTRACT(a, a_length, a_values) + + size_t b_length; + uint32_t *b_values; + INTEGER_EXTRACT(b, b_length, b_values) + + size_t c_length; + uint32_t *c_values; + INTEGER_EXTRACT(c, c_length, c_values) + + uint32_t *values = (uint32_t*) xmalloc(sizeof(uint32_t) * a_length); + int64_t carry = 0; + + for (size_t index = 0; index < a_length; index++) { + int64_t sub = ( + carry + + a_values[index] - + (index < b_length ? b_values[index] : 0) - + (index < c_length ? c_values[index] : 0) + ); + + if (sub >= 0) { + values[index] = (uint32_t) sub; + carry = 0; + } else { + sub += 2 * (int64_t) base; + values[index] = (uint32_t) ((uint64_t) sub % base); + carry = sub / (int64_t) base - 2; + } + } - dest->head.value = src->head.value; - dest->head.next = NULL; + while (a_length > 1 && values[a_length - 1] == 0) a_length--; + *destination = (pm_integer_t) { 0, a_length, values, false }; +} - pm_integer_word_t *dest_current = &dest->head; - const pm_integer_word_t *src_current = src->head.next; +/** + * Multiply two positive integers with the given base using karatsuba algorithm. + * Return pm_integer_t with values allocated. Not normalized. + */ +static void +karatsuba_multiply(pm_integer_t *destination, pm_integer_t *left, pm_integer_t *right, uint64_t base) { + size_t left_length; + uint32_t *left_values; + INTEGER_EXTRACT(left, left_length, left_values) + + size_t right_length; + uint32_t *right_values; + INTEGER_EXTRACT(right, right_length, right_values) + + if (left_length > right_length) { + size_t temporary_length = left_length; + left_length = right_length; + right_length = temporary_length; + + uint32_t *temporary_values = left_values; + left_values = right_values; + right_values = temporary_values; + } - while (src_current != NULL) { - dest_current->next = pm_integer_node_create(dest, src_current->value); - if (dest_current->next == NULL) return; + if (left_length <= 10) { + size_t length = left_length + right_length; + uint32_t *values = (uint32_t *) xcalloc(length, sizeof(uint32_t)); + if (values == NULL) return; + + for (size_t left_index = 0; left_index < left_length; left_index++) { + uint32_t carry = 0; + for (size_t right_index = 0; right_index < right_length; right_index++) { + uint64_t product = (uint64_t) left_values[left_index] * right_values[right_index] + values[left_index + right_index] + carry; + values[left_index + right_index] = (uint32_t) (product % base); + carry = (uint32_t) (product / base); + } + values[left_index + right_length] = carry; + } - dest_current = dest_current->next; - src_current = src_current->next; + while (length > 1 && values[length - 1] == 0) length--; + *destination = (pm_integer_t) { 0, length, values, false }; + return; } - dest_current->next = NULL; + if (left_length * 2 <= right_length) { + uint32_t *values = (uint32_t*) xcalloc(left_length + right_length, sizeof(uint32_t)); + + for (size_t start_offset = 0; start_offset < right_length; start_offset += left_length) { + size_t end_offset = start_offset + left_length; + if (end_offset > right_length) end_offset = right_length; + + pm_integer_t sliced_right = { + .value = 0, + .length = end_offset - start_offset, + .values = right_values + start_offset, + .negative = false + }; + + pm_integer_t product; + karatsuba_multiply(&product, left, &sliced_right, base); + + uint32_t carry = 0; + for (size_t index = 0; index < product.length; index++) { + uint64_t sum = (uint64_t) values[start_offset + index] + product.values[index] + carry; + values[start_offset + index] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + + if (carry > 0) values[start_offset + product.length] += carry; + pm_integer_free(&product); + } + + *destination = (pm_integer_t) { 0, left_length + right_length, values, false }; + return; + } + + size_t half = left_length / 2; + pm_integer_t x0 = { 0, half, left_values, false }; + pm_integer_t x1 = { 0, left_length - half, left_values + half, false }; + pm_integer_t y0 = { 0, half, right_values, false }; + pm_integer_t y1 = { 0, right_length - half, right_values + half, false }; + + pm_integer_t z0; + karatsuba_multiply(&z0, &x0, &y0, base); + + pm_integer_t z2; + karatsuba_multiply(&z2, &x1, &y1, base); + + // For simplicity to avoid considering negative values, + // use `z1 = (x0 + x1) * (y0 + y1) - z0 - z2` instead of original karatsuba algorithm. + pm_integer_t x01; + big_add(&x01, &x0, &x1, base); + + pm_integer_t y01; + big_add(&y01, &y0, &y1, base); + + pm_integer_t xy; + karatsuba_multiply(&xy, &x01, &y01, base); + + pm_integer_t z1; + big_sub2(&z1, &xy, &z0, &z2, base); + + size_t length = left_length + right_length; + uint32_t *values = (uint32_t*) xcalloc(length, sizeof(uint32_t)); + memcpy(values, z0.values, sizeof(uint32_t) * z0.length); + memcpy(values + 2 * half, z2.values, sizeof(uint32_t) * z2.length); + + uint32_t carry = 0; + for(size_t index = 0; index < z1.length; index++) { + uint64_t sum = (uint64_t) carry + values[index + half] + z1.values[index]; + values[index + half] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + + for(size_t index = half + z1.length; carry > 0; index++) { + uint64_t sum = (uint64_t) carry + values[index]; + values[index] = (uint32_t) (sum % base); + carry = (uint32_t) (sum / base); + } + + while (length > 1 && values[length - 1] == 0) length--; + pm_integer_free(&z0); + pm_integer_free(&z1); + pm_integer_free(&z2); + pm_integer_free(&x01); + pm_integer_free(&y01); + pm_integer_free(&xy); + + *destination = (pm_integer_t) { 0, length, values, false }; +} + +/** + * The values of a hexadecimal digit, where the index is the ASCII character. + */ +static const int8_t pm_integer_parse_digit_values[256] = { +// 0 1 2 3 4 5 6 7 8 9 A B C D E F + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 1x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 2x + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, // 3x + -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 4x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 5x + -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 6x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 7x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 8x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 9x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ax + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Cx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Dx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ex + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Fx +}; + +/** + * Return the value of a hexadecimal digit in a uint8_t. + */ +static uint8_t +pm_integer_parse_digit(const uint8_t character) { + int8_t value = pm_integer_parse_digit_values[character]; + assert(value != -1 && "invalid digit"); + + return (uint8_t) value; } /** - * Add a 32-bit integer to an integer. + * Create a pm_integer_t from uint64_t with the given base. It is assumed that + * the memory for the pm_integer_t pointer has been zeroed. */ static void -pm_integer_add(pm_integer_t *integer, uint32_t addend) { - uint32_t carry = addend; - pm_integer_word_t *current = &integer->head; - - while (carry > 0) { - uint64_t result = (uint64_t) current->value + carry; - carry = (uint32_t) (result >> 32); - current->value = (uint32_t) result; - - if (carry > 0) { - if (current->next == NULL) { - current->next = pm_integer_node_create(integer, carry); - break; - } +pm_integer_from_uint64(pm_integer_t *integer, uint64_t value, uint64_t base) { + if (value < base) { + integer->value = (uint32_t) value; + return; + } - current = current->next; - } + size_t length = 0; + uint64_t length_value = value; + while (length_value > 0) { + length++; + length_value /= base; + } + + uint32_t *values = (uint32_t *) xmalloc(sizeof(uint32_t) * length); + if (values == NULL) return; + + for (size_t value_index = 0; value_index < length; value_index++) { + values[value_index] = (uint32_t) (value % base); + value /= base; } + + integer->length = length; + integer->values = values; } /** - * Multiple an integer by a 32-bit integer. In practice, the multiplier is the - * base of the integer, so this is 2, 8, 10, or 16. + * Normalize pm_integer_t. + * Heading zero values will be removed. If the integer fits into uint32_t, + * values is set to NULL, length is set to 0, and value field will be used. */ static void -pm_integer_multiply(pm_integer_t *integer, uint32_t multiplier) { - uint32_t carry = 0; +pm_integer_normalize(pm_integer_t *integer) { + if (integer->values == NULL) { + return; + } - for (pm_integer_word_t *current = &integer->head; current != NULL; current = current->next) { - uint64_t result = (uint64_t) current->value * multiplier + carry; - carry = (uint32_t) (result >> 32); - current->value = (uint32_t) result; + while (integer->length > 1 && integer->values[integer->length - 1] == 0) { + integer->length--; + } - if (carry > 0 && current->next == NULL) { - current->next = pm_integer_node_create(integer, carry); - break; - } + if (integer->length > 1) { + return; } + + uint32_t value = integer->values[0]; + bool negative = integer->negative && value != 0; + + pm_integer_free(integer); + *integer = (pm_integer_t) { .value = value, .length = 0, .values = NULL, .negative = negative }; } /** - * Divide an individual word by a 32-bit integer. This will recursively divide - * any subsequent nodes in the linked list. + * Convert base of the integer. + * In practice, it converts 10**9 to 1<<32 or 1<<32 to 10**9. */ -static uint32_t -pm_integer_divide_word(pm_integer_t *integer, pm_integer_word_t *word, uint32_t dividend) { - uint32_t remainder = 0; - if (word->next != NULL) { - remainder = pm_integer_divide_word(integer, word->next, dividend); - - if (integer->length > 0 && word->next->value == 0) { - xfree(word->next); - word->next = NULL; - integer->length--; +static void +pm_integer_convert_base(pm_integer_t *destination, const pm_integer_t *source, uint64_t base_from, uint64_t base_to) { + size_t source_length; + const uint32_t *source_values; + INTEGER_EXTRACT(source, source_length, source_values) + + size_t bigints_length = (source_length + 1) / 2; + pm_integer_t *bigints = (pm_integer_t *) xcalloc(bigints_length, sizeof(pm_integer_t)); + if (bigints == NULL) return; + + for (size_t index = 0; index < source_length; index += 2) { + uint64_t value = source_values[index] + base_from * (index + 1 < source_length ? source_values[index + 1] : 0); + pm_integer_from_uint64(&bigints[index / 2], value, base_to); + } + + pm_integer_t base = { 0 }; + pm_integer_from_uint64(&base, base_from, base_to); + + while (bigints_length > 1) { + pm_integer_t next_base; + karatsuba_multiply(&next_base, &base, &base, base_to); + + pm_integer_free(&base); + base = next_base; + + size_t next_length = (bigints_length + 1) / 2; + pm_integer_t *next_bigints = (pm_integer_t *) xmalloc(sizeof(pm_integer_t) * next_length); + + for (size_t bigints_index = 0; bigints_index < bigints_length; bigints_index += 2) { + if (bigints_index + 1 == bigints_length) { + next_bigints[bigints_index / 2] = bigints[bigints_index]; + } else { + pm_integer_t multiplied; + karatsuba_multiply(&multiplied, &base, &bigints[bigints_index + 1], base_to); + + big_add(&next_bigints[bigints_index / 2], &bigints[bigints_index], &multiplied, base_to); + pm_integer_free(&bigints[bigints_index]); + pm_integer_free(&bigints[bigints_index + 1]); + pm_integer_free(&multiplied); + } } + + xfree(bigints); + bigints = next_bigints; + bigints_length = next_length; + } + + *destination = bigints[0]; + destination->negative = source->negative; + pm_integer_normalize(destination); + + xfree(bigints); + pm_integer_free(&base); +} + +#undef INTEGER_EXTRACT + +/** + * Convert digits to integer with the given power-of-two base. + */ +static void +pm_integer_parse_powof2(pm_integer_t *integer, uint32_t base, const uint8_t *digits, size_t digits_length) { + size_t bit = 1; + while (base > (uint32_t) (1 << bit)) bit++; + + size_t length = (digits_length * bit + 31) / 32; + uint32_t *values = (uint32_t *) xcalloc(length, sizeof(uint32_t)); + + for (size_t digit_index = 0; digit_index < digits_length; digit_index++) { + size_t bit_position = bit * (digits_length - digit_index - 1); + uint32_t value = digits[digit_index]; + + size_t index = bit_position / 32; + size_t shift = bit_position % 32; + + values[index] |= value << shift; + if (32 - shift < bit) values[index + 1] |= value >> (32 - shift); } - uint64_t value = ((uint64_t) remainder << 32) | word->value; - word->value = (uint32_t) (value / dividend); - return (uint32_t) (value % dividend); + while (length > 1 && values[length - 1] == 0) length--; + *integer = (pm_integer_t) { .value = 0, .length = length, .values = values, .negative = false }; + pm_integer_normalize(integer); } /** - * Divide an integer by a 32-bit integer. In practice, this is only 10 so that - * we can format it as a string. It returns the remainder of the division. + * Convert decimal digits to pm_integer_t. */ -static uint32_t -pm_integer_divide(pm_integer_t *integer, uint32_t dividend) { - return pm_integer_divide_word(integer, &integer->head, dividend); +static void +pm_integer_parse_decimal(pm_integer_t *integer, const uint8_t *digits, size_t digits_length) { + const size_t batch = 9; + size_t length = (digits_length + batch - 1) / batch; + + uint32_t *values = (uint32_t *) xcalloc(length, sizeof(uint32_t)); + uint32_t value = 0; + + for (size_t digits_index = 0; digits_index < digits_length; digits_index++) { + value = value * 10 + digits[digits_index]; + + size_t reverse_index = digits_length - digits_index - 1; + if (reverse_index % batch == 0) { + values[reverse_index / batch] = value; + value = 0; + } + } + + // Convert base from 10**9 to 1<<32. + pm_integer_convert_base(integer, &((pm_integer_t) { .value = 0, .length = length, .values = values, .negative = false }), 1000000000, ((uint64_t) 1 << 32)); + xfree(values); } /** - * Return the value of a digit in a uint32_t. + * Parse a large integer from a string that does not fit into uint32_t. */ -static uint32_t -pm_integer_parse_digit(const uint8_t character) { - switch (character) { - case '0': return 0; - case '1': return 1; - case '2': return 2; - case '3': return 3; - case '4': return 4; - case '5': return 5; - case '6': return 6; - case '7': return 7; - case '8': return 8; - case '9': return 9; - case 'a': case 'A': return 10; - case 'b': case 'B': return 11; - case 'c': case 'C': return 12; - case 'd': case 'D': return 13; - case 'e': case 'E': return 14; - case 'f': case 'F': return 15; - default: assert(false && "unreachable"); return 0; +static void +pm_integer_parse_big(pm_integer_t *integer, uint32_t multiplier, const uint8_t *start, const uint8_t *end) { + // Allocate an array to store digits. + uint8_t *digits = xmalloc(sizeof(uint8_t) * (size_t) (end - start)); + size_t digits_length = 0; + + for (; start < end; start++) { + if (*start == '_') continue; + digits[digits_length++] = pm_integer_parse_digit(*start); + } + + // Construct pm_integer_t from the digits. + if (multiplier == 10) { + pm_integer_parse_decimal(integer, digits, digits_length); + } else { + pm_integer_parse_powof2(integer, multiplier, digits, digits_length); } + + xfree(digits); } /** @@ -189,15 +498,22 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s // invalid integer. If this is the case, we'll just return 0. if (start >= end) return; - // Add the first digit to the integer. - pm_integer_add(integer, pm_integer_parse_digit(*start++)); + const uint8_t *cursor = start; + uint64_t value = (uint64_t) pm_integer_parse_digit(*cursor++); - // Add the subsequent digits to the integer. - for (; start < end; start++) { - if (*start == '_') continue; - pm_integer_multiply(integer, multiplier); - pm_integer_add(integer, pm_integer_parse_digit(*start)); + for (; cursor < end; cursor++) { + if (*cursor == '_') continue; + value = value * multiplier + (uint64_t) pm_integer_parse_digit(*cursor); + + if (value > UINT32_MAX) { + // If the integer is too large to fit into a single uint32_t, then + // we'll parse it as a big integer. + pm_integer_parse_big(integer, multiplier, start, end); + return; + } } + + integer->value = (uint32_t) value; } /** @@ -205,7 +521,7 @@ pm_integer_parse(pm_integer_t *integer, pm_integer_base_t base, const uint8_t *s */ size_t pm_integer_memsize(const pm_integer_t *integer) { - return sizeof(pm_integer_t) + integer->length * sizeof(pm_integer_word_t); + return sizeof(pm_integer_t) + integer->length * sizeof(uint32_t); } /** @@ -218,16 +534,22 @@ pm_integer_compare(const pm_integer_t *left, const pm_integer_t *right) { if (left->negative != right->negative) return left->negative ? -1 : 1; int negative = left->negative ? -1 : 1; - if (left->length < right->length) return -1 * negative; - if (left->length > right->length) return 1 * negative; + if (left->values == NULL && right->values == NULL) { + if (left->value < right->value) return -1 * negative; + if (left->value > right->value) return 1 * negative; + return 0; + } + + if (left->values == NULL || left->length < right->length) return -1 * negative; + if (right->values == NULL || left->length > right->length) return 1 * negative; - for ( - const pm_integer_word_t *left_word = &left->head, *right_word = &right->head; - left_word != NULL && right_word != NULL; - left_word = left_word->next, right_word = right_word->next - ) { - if (left_word->value < right_word->value) return -1 * negative; - if (left_word->value > right_word->value) return 1 * negative; + for (size_t index = 0; index < left->length; index++) { + size_t value_index = left->length - index - 1; + uint32_t left_value = left->values[value_index]; + uint32_t right_value = right->values[value_index]; + + if (left_value < right_value) return -1 * negative; + if (left_value > right_value) return 1 * negative; } return 0; @@ -242,65 +564,62 @@ pm_integer_string(pm_buffer_t *buffer, const pm_integer_t *integer) { pm_buffer_append_byte(buffer, '-'); } - switch (integer->length) { - case 0: { - const uint32_t value = integer->head.value; - pm_buffer_append_format(buffer, "%" PRIu32, value); - return; - } - case 1: { - const uint64_t value = ((uint64_t) integer->head.value) | (((uint64_t) integer->head.next->value) << 32); - pm_buffer_append_format(buffer, "%" PRIu64, value); - return; - } - default: { - // First, allocate a buffer that we'll copy the decimal digits into. - size_t length = (integer->length + 1) * 10; - char *digits = xcalloc(length, sizeof(char)); - if (digits == NULL) return; - - // Next, create a new integer that we'll use to store the result of - // the division and modulo operations. - pm_integer_t copy; - pm_integer_copy(©, integer); - - // Then, iterate through the integer, dividing by 10 and storing the - // result in the buffer. - char *ending = digits + length - 1; - char *current = ending; - - while (copy.length > 0 || copy.head.value > 0) { - uint32_t remainder = pm_integer_divide(©, 10); - *current-- = (char) ('0' + remainder); - } + // If the integer fits into a single uint32_t, then we can just append the + // value directly to the buffer. + if (integer->values == NULL) { + pm_buffer_append_format(buffer, "%" PRIu32, integer->value); + return; + } - // Finally, append the string to the buffer and free the digits. - pm_buffer_append_string(buffer, current + 1, (size_t) (ending - current)); - xfree(digits); - return; - } + // If the integer is two uint32_t values, then we can | them together and + // append the result to the buffer. + if (integer->length == 2) { + const uint64_t value = ((uint64_t) integer->values[0]) | ((uint64_t) integer->values[1] << 32); + pm_buffer_append_format(buffer, "%" PRIu64, value); + return; } -} -/** - * Recursively destroy the linked list of an integer. - */ -static void -pm_integer_word_destroy(pm_integer_word_t *integer) { - if (integer->next != NULL) { - pm_integer_word_destroy(integer->next); + // Otherwise, first we'll convert the base from 1<<32 to 10**9. + pm_integer_t converted; + pm_integer_convert_base(&converted, integer, (uint64_t) 1 << 32, 1000000000); + + if (converted.values == NULL) { + pm_buffer_append_format(buffer, "%" PRIu32, converted.value); + pm_integer_free(&converted); + return; + } + + // Allocate a buffer that we'll copy the decimal digits into. + size_t digits_length = converted.length * 9; + char *digits = xcalloc(digits_length, sizeof(char)); + if (digits == NULL) return; + + // Pack bigdecimal to digits. + for (size_t value_index = 0; value_index < converted.length; value_index++) { + uint32_t value = converted.values[value_index]; + + for (size_t digit_index = 0; digit_index < 9; digit_index++) { + digits[digits_length - 9 * value_index - digit_index - 1] = (char) ('0' + value % 10); + value /= 10; + } } - xfree(integer); + size_t start_offset = 0; + while (start_offset < digits_length - 1 && digits[start_offset] == '0') start_offset++; + + // Finally, append the string to the buffer and free the digits. + pm_buffer_append_string(buffer, digits + start_offset, digits_length - start_offset); + xfree(digits); + pm_integer_free(&converted); } /** * Free the internal memory of an integer. This memory will only be allocated if - * the integer exceeds the size of a single node in the linked list. + * the integer exceeds the size of a single uint32_t. */ PRISM_EXPORTED_FUNCTION void pm_integer_free(pm_integer_t *integer) { - if (integer->head.next) { - pm_integer_word_destroy(integer->head.next); + if (integer->values) { + xfree(integer->values); } } diff --git a/prism/util/pm_integer.h b/prism/util/pm_integer.h index 63f560275d47e7..7f172988b3e182 100644 --- a/prism/util/pm_integer.h +++ b/prism/util/pm_integer.h @@ -15,30 +15,25 @@ #include /** - * A node in the linked list of a pm_integer_t. + * A structure represents an arbitrary-sized integer. */ -typedef struct pm_integer_word { - /** A pointer to the next node in the list. */ - struct pm_integer_word *next; - - /** The value of the node. */ +typedef struct { + /** + * Embedded value for small integer. This value is set to 0 if the value + * does not fit into uint32_t. + */ uint32_t value; -} pm_integer_word_t; -/** - * This structure represents an arbitrary-sized integer. It is implemented as a - * linked list of 32-bit integers, with the least significant digit at the head - * of the list. - */ -typedef struct { - /** The number of nodes in the linked list that have been allocated. */ + /** + * The number of allocated values. length is set to 0 if the integer fits + * into uint32_t. + */ size_t length; /** - * The head of the linked list, embedded directly so that allocations do not - * need to be performed for small integers. + * List of 32-bit integers. Set to NULL if the integer fits into uint32_t. */ - pm_integer_word_t head; + uint32_t *values; /** * Whether or not the integer is negative. It is stored this way so that a diff --git a/prism/util/pm_string.c b/prism/util/pm_string.c index 753429a2336e16..8342edc34ef973 100644 --- a/prism/util/pm_string.c +++ b/prism/util/pm_string.c @@ -175,7 +175,7 @@ pm_string_file_init(pm_string_t *string, const char *filepath) { } // Create a buffer to read the file into. - uint8_t *source = malloc(file_size); + uint8_t *source = xmalloc(file_size); if (source == NULL) { CloseHandle(file); return false; @@ -190,7 +190,7 @@ pm_string_file_init(pm_string_t *string, const char *filepath) { // Check the number of bytes read if (bytes_read != file_size) { - free(source); + xfree(source); CloseHandle(file); return false; } @@ -220,7 +220,7 @@ pm_string_file_init(pm_string_t *string, const char *filepath) { } size_t length = (size_t) file_size; - uint8_t *source = malloc(length); + uint8_t *source = xmalloc(length); if (source == NULL) { fclose(file); return false; @@ -231,7 +231,7 @@ pm_string_file_init(pm_string_t *string, const char *filepath) { fclose(file); if (bytes_read != 1) { - free(source); + xfree(source); return false; } diff --git a/prism_compile.c b/prism_compile.c index fff43d6f13e633..8b65ae7a7af6db 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -125,16 +125,31 @@ static VALUE parse_integer(const pm_integer_node_t *node) { const pm_integer_t *integer = &node->value; + VALUE result; + + if (integer->values == NULL) { + result = UINT2NUM(integer->value); + } else { + VALUE string = rb_str_new(NULL, integer->length * 8); + unsigned char *bytes = (unsigned char *) RSTRING_PTR(string); + + size_t offset = integer->length * 8; + for (size_t value_index = 0; value_index < integer->length; value_index++) { + uint32_t value = integer->values[value_index]; - VALUE result = UINT2NUM(integer->head.value); - size_t shift = 0; + for (int index = 0; index < 8; index++) { + int byte = (value >> (4 * index)) & 0xf; + bytes[--offset] = byte < 10 ? byte + '0' : byte - 10 + 'a'; + } + } + + result = rb_funcall(string, rb_intern("to_i"), 1, UINT2NUM(16)); + } - for (pm_integer_word_t *node = integer->head.next; node != NULL; node = node->next) { - VALUE receiver = rb_funcall(UINT2NUM(node->value), rb_intern("<<"), 1, ULONG2NUM(++shift * 32)); - result = rb_funcall(receiver, rb_intern("|"), 1, result); + if (integer->negative) { + result = rb_funcall(result, rb_intern("-@"), 0); } - if (integer->negative) result = rb_funcall(result, rb_intern("-@"), 0); return result; } diff --git a/template/GNUmakefile.in b/template/GNUmakefile.in index ce7d931e0ab8b9..22ff1078dcbbb6 100644 --- a/template/GNUmakefile.in +++ b/template/GNUmakefile.in @@ -6,9 +6,9 @@ endif ifeq ($(filter Makefile,$(MAKEFILE_LIST)),) include Makefile +endif GNUmakefile: $(srcdir)/template/GNUmakefile.in -endif override silence := $(if $(findstring s,$(firstword $(MFLAGS))),yes,no) @@ -29,5 +29,3 @@ override UNICODE_TABLES_DEPENDENTS = \ -include uncommon.mk include $(srcdir)/defs/gmake.mk - -GNUmakefile: $(srcdir)/template/GNUmakefile.in diff --git a/test/prism/integer_parse_test.rb b/test/prism/integer_parse_test.rb index afc3806fe63046..f42e817e7930fb 100644 --- a/test/prism/integer_parse_test.rb +++ b/test/prism/integer_parse_test.rb @@ -26,6 +26,12 @@ def test_integer_parse assert_integer_parse(2**32) assert_integer_parse(2**64 + 2**32) assert_integer_parse(2**128 + 2**64 + 2**32) + + num = 99 ** 99 + assert_integer_parse(num, "0b#{num.to_s(2)}") + assert_integer_parse(num, "0o#{num.to_s(8)}") + assert_integer_parse(num, "0d#{num.to_s(10)}") + assert_integer_parse(num, "0x#{num.to_s(16)}") end private diff --git a/test/prism/parse_stream_test.rb b/test/prism/parse_stream_test.rb new file mode 100644 index 00000000000000..9e6347b92b01eb --- /dev/null +++ b/test/prism/parse_stream_test.rb @@ -0,0 +1,74 @@ +# frozen_string_literal: true + +require_relative "test_helper" +require "stringio" + +module Prism + class ParseStreamTest < TestCase + def test_single_line + io = StringIO.new("1 + 2") + result = Prism.parse_stream(io) + + assert result.success? + assert_kind_of Prism::CallNode, result.value.statements.body.first + end + + def test_multi_line + io = StringIO.new("1 + 2\n3 + 4") + result = Prism.parse_stream(io) + + assert result.success? + assert_kind_of Prism::CallNode, result.value.statements.body.first + assert_kind_of Prism::CallNode, result.value.statements.body.last + end + + def test_multi_read + io = StringIO.new("a" * 4096 * 4) + result = Prism.parse_stream(io) + + assert result.success? + assert_kind_of Prism::CallNode, result.value.statements.body.first + end + + def test___END__ + io = StringIO.new("1 + 2\n3 + 4\n__END__\n5 + 6") + result = Prism.parse_stream(io) + + assert result.success? + assert_equal 2, result.value.statements.body.length + assert_equal "5 + 6", io.read + end + + def test_false___END___in_string + io = StringIO.new("1 + 2\n3 + 4\n\"\n__END__\n\"\n5 + 6") + result = Prism.parse_stream(io) + + assert result.success? + assert_equal 4, result.value.statements.body.length + end + + def test_false___END___in_regexp + io = StringIO.new("1 + 2\n3 + 4\n/\n__END__\n/\n5 + 6") + result = Prism.parse_stream(io) + + assert result.success? + assert_equal 4, result.value.statements.body.length + end + + def test_false___END___in_list + io = StringIO.new("1 + 2\n3 + 4\n%w[\n__END__\n]\n5 + 6") + result = Prism.parse_stream(io) + + assert result.success? + assert_equal 4, result.value.statements.body.length + end + + def test_false___END___in_heredoc + io = StringIO.new("1 + 2\n3 + 4\n<<-EOF\n__END__\nEOF\n5 + 6") + result = Prism.parse_stream(io) + + assert result.success? + assert_equal 4, result.value.statements.body.length + end + end +end diff --git a/test/prism/parser_test.rb b/test/prism/parser_test.rb index d3bf52d96ce28c..118b7322fee050 100644 --- a/test/prism/parser_test.rb +++ b/test/prism/parser_test.rb @@ -45,22 +45,22 @@ class ParserTest < TestCase base = File.join(__dir__, "fixtures") # These files are erroring because of the parser gem being wrong. - skip_incorrect = %w[ - embdoc_no_newline_at_end.txt + skip_incorrect = [ + "embdoc_no_newline_at_end.txt" ] # These files are either failing to parse or failing to translate, so we'll # skip them for now. - skip_all = skip_incorrect | %w[ - dash_heredocs.txt - dos_endings.txt - heredocs_with_ignored_newlines.txt - regex.txt - regex_char_width.txt - spanning_heredoc.txt - spanning_heredoc_newlines.txt - tilde_heredocs.txt - unescaping.txt + skip_all = skip_incorrect | [ + "dash_heredocs.txt", + "dos_endings.txt", + "heredocs_with_ignored_newlines.txt", + "regex.txt", + "regex_char_width.txt", + "spanning_heredoc.txt", + "spanning_heredoc_newlines.txt", + "tilde_heredocs.txt", + "unescaping.txt" ] # Not sure why these files are failing on JRuby, but skipping them for now. @@ -70,21 +70,21 @@ class ParserTest < TestCase # These files are failing to translate their lexer output into the lexer # output expected by the parser gem, so we'll skip them for now. - skip_tokens = %w[ - comments.txt - constants.txt - endless_range_in_conditional.txt - heredoc_with_comment.txt - heredoc_with_escaped_newline_at_start.txt - heredocs_leading_whitespace.txt - heredocs_nested.txt - heredocs_with_ignored_newlines_and_non_empty.txt - indented_file_end.txt - non_alphanumeric_methods.txt - range_begin_open_inclusive.txt - single_quote_heredocs.txt - strings.txt - xstring.txt + skip_tokens = [ + "comments.txt", + "constants.txt", + "endless_range_in_conditional.txt", + "heredoc_with_comment.txt", + "heredoc_with_escaped_newline_at_start.txt", + "heredocs_leading_whitespace.txt", + "heredocs_nested.txt", + "heredocs_with_ignored_newlines_and_non_empty.txt", + "indented_file_end.txt", + "non_alphanumeric_methods.txt", + "range_begin_open_inclusive.txt", + "single_quote_heredocs.txt", + "strings.txt", + "xstring.txt" ] Dir["*.txt", base: base].each do |name| diff --git a/test/prism/ripper_test.rb b/test/prism/ripper_test.rb index f19e8ddbb6d3ff..07238fc3d544e3 100644 --- a/test/prism/ripper_test.rb +++ b/test/prism/ripper_test.rb @@ -25,30 +25,29 @@ class RipperTest < TestCase # Ripper cannot handle named capture groups in regular expressions. "regex.txt", "regex_char_width.txt", - "whitequark/lvar_injecting_match.txt" + "whitequark/lvar_injecting_match.txt", + + # Ripper fails to understand some structures that span across heredocs. + "spanning_heredoc.txt" ] omitted = [ "dos_endings.txt", "heredocs_with_ignored_newlines.txt", + "seattlerb/block_call_dot_op2_brace_block.txt", + "seattlerb/block_command_operation_colon.txt", + "seattlerb/block_command_operation_dot.txt", "seattlerb/heredoc__backslash_dos_format.txt", "seattlerb/heredoc_backslash_nl.txt", "seattlerb/heredoc_nested.txt", "seattlerb/heredoc_squiggly_blank_line_plus_interpolation.txt", - "seattlerb/heredoc_squiggly_no_indent.txt", - "spanning_heredoc.txt", "tilde_heredocs.txt", "unparser/corpus/semantic/dstr.txt", "whitequark/dedenting_heredoc.txt", - "whitequark/parser_bug_640.txt", "whitequark/parser_drops_truncated_parts_of_squiggly_heredoc.txt", "whitequark/parser_slash_slash_n_escaping_in_literals.txt", - "whitequark/slash_newline_in_heredocs.txt", - - "seattlerb/block_call_dot_op2_brace_block.txt", - "seattlerb/block_command_operation_colon.txt", - "seattlerb/block_command_operation_dot.txt", - "whitequark/send_block_chain_cmd.txt" + "whitequark/send_block_chain_cmd.txt", + "whitequark/slash_newline_in_heredocs.txt" ] relatives.each do |relative| diff --git a/test/prism/ruby_parser_test.rb b/test/prism/ruby_parser_test.rb index 89150b2faac33e..1d22f0e7b8729f 100644 --- a/test/prism/ruby_parser_test.rb +++ b/test/prism/ruby_parser_test.rb @@ -71,6 +71,7 @@ class RubyParserTest < TestCase # https://github.com/seattlerb/ruby_parser/issues/344 failures = crlf | %w[ alias.txt + heredocs_with_ignored_newlines.txt method_calls.txt methods.txt multi_write.txt @@ -94,6 +95,7 @@ class RubyParserTest < TestCase whitequark/lvar_injecting_match.txt whitequark/not.txt whitequark/op_asgn_cmd.txt + whitequark/parser_bug_640.txt whitequark/parser_slash_slash_n_escaping_in_literals.txt whitequark/pattern_matching_single_line_allowed_omission_of_parentheses.txt whitequark/pattern_matching_single_line.txt diff --git a/test/prism/snapshots/heredocs_with_ignored_newlines.txt b/test/prism/snapshots/heredocs_with_ignored_newlines.txt index 00111b1ca54625..cdc0b4faab9279 100644 --- a/test/prism/snapshots/heredocs_with_ignored_newlines.txt +++ b/test/prism/snapshots/heredocs_with_ignored_newlines.txt @@ -11,7 +11,7 @@ │ └── unescaped: "" └── @ InterpolatedStringNode (location: (4,0)-(4,8)) ├── opening_loc: (4,0)-(4,8) = "<<~THERE" - ├── parts: (length: 8) + ├── parts: (length: 9) │ ├── @ StringNode (location: (5,0)-(6,0)) │ │ ├── flags: ∅ │ │ ├── opening_loc: ∅ @@ -42,12 +42,18 @@ │ │ ├── content_loc: (9,0)-(10,0) = "\n" │ │ ├── closing_loc: ∅ │ │ └── unescaped: "\n" - │ ├── @ StringNode (location: (10,0)-(12,0)) + │ ├── @ StringNode (location: (10,0)-(11,0)) │ │ ├── flags: ∅ │ │ ├── opening_loc: ∅ - │ │ ├── content_loc: (10,0)-(12,0) = " <<~BUT\\\n but\n" + │ │ ├── content_loc: (10,0)-(11,0) = " <<~BUT\\\n" │ │ ├── closing_loc: ∅ - │ │ └── unescaped: "<<~BUT but\n" + │ │ └── unescaped: "<<~BUT" + │ ├── @ StringNode (location: (11,0)-(12,0)) + │ │ ├── flags: ∅ + │ │ ├── opening_loc: ∅ + │ │ ├── content_loc: (11,0)-(12,0) = " but\n" + │ │ ├── closing_loc: ∅ + │ │ └── unescaped: " but\n" │ ├── @ StringNode (location: (12,0)-(13,0)) │ │ ├── flags: ∅ │ │ ├── opening_loc: ∅ diff --git a/test/prism/snapshots/whitequark/parser_bug_640.txt b/test/prism/snapshots/whitequark/parser_bug_640.txt index 0320011e2e295d..a9d3f957e83910 100644 --- a/test/prism/snapshots/whitequark/parser_bug_640.txt +++ b/test/prism/snapshots/whitequark/parser_bug_640.txt @@ -3,9 +3,19 @@ └── statements: @ StatementsNode (location: (1,0)-(1,6)) └── body: (length: 1) - └── @ StringNode (location: (1,0)-(1,6)) - ├── flags: ∅ + └── @ InterpolatedStringNode (location: (1,0)-(1,6)) ├── opening_loc: (1,0)-(1,6) = "<<~FOO" - ├── content_loc: (2,0)-(4,0) = " baz\\\n qux\n" - ├── closing_loc: (4,0)-(5,0) = "FOO\n" - └── unescaped: "baz qux\n" + ├── parts: (length: 2) + │ ├── @ StringNode (location: (2,0)-(3,0)) + │ │ ├── flags: ∅ + │ │ ├── opening_loc: ∅ + │ │ ├── content_loc: (2,0)-(3,0) = " baz\\\n" + │ │ ├── closing_loc: ∅ + │ │ └── unescaped: "baz" + │ └── @ StringNode (location: (3,0)-(4,0)) + │ ├── flags: ∅ + │ ├── opening_loc: ∅ + │ ├── content_loc: (3,0)-(4,0) = " qux\n" + │ ├── closing_loc: ∅ + │ └── unescaped: "qux\n" + └── closing_loc: (4,0)-(5,0) = "FOO\n" diff --git a/test/prism/snapshots/whitequark/slash_newline_in_heredocs.txt b/test/prism/snapshots/whitequark/slash_newline_in_heredocs.txt index 58a134dd62b2ac..8d6fce2ba9676b 100644 --- a/test/prism/snapshots/whitequark/slash_newline_in_heredocs.txt +++ b/test/prism/snapshots/whitequark/slash_newline_in_heredocs.txt @@ -11,13 +11,19 @@ │ └── unescaped: " 1 2\n 3\n" └── @ InterpolatedStringNode (location: (8,0)-(8,4)) ├── opening_loc: (8,0)-(8,4) = "<<~E" - ├── parts: (length: 2) - │ ├── @ StringNode (location: (9,0)-(11,0)) + ├── parts: (length: 3) + │ ├── @ StringNode (location: (9,0)-(10,0)) │ │ ├── flags: ∅ │ │ ├── opening_loc: ∅ - │ │ ├── content_loc: (9,0)-(11,0) = " 1 \\\n 2\n" + │ │ ├── content_loc: (9,0)-(10,0) = " 1 \\\n" │ │ ├── closing_loc: ∅ - │ │ └── unescaped: "1 2\n" + │ │ └── unescaped: "1 " + │ ├── @ StringNode (location: (10,0)-(11,0)) + │ │ ├── flags: ∅ + │ │ ├── opening_loc: ∅ + │ │ ├── content_loc: (10,0)-(11,0) = " 2\n" + │ │ ├── closing_loc: ∅ + │ │ └── unescaped: "2\n" │ └── @ StringNode (location: (11,0)-(12,0)) │ ├── flags: ∅ │ ├── opening_loc: ∅ diff --git a/test/prism/unescape_test.rb b/test/prism/unescape_test.rb index 2a352c52347e84..72ad780d8bbd4a 100644 --- a/test/prism/unescape_test.rb +++ b/test/prism/unescape_test.rb @@ -230,6 +230,8 @@ def assert_unescape(context, escape) else assert_equal expected.bytes, actual.bytes, message end + rescue Exception + binding.irb end end end diff --git a/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.lock b/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.lock index ed0ce757331a50..bc76ee824e577f 100644 --- a/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.lock +++ b/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.lock @@ -152,18 +152,18 @@ dependencies = [ [[package]] name = "rb-sys" -version = "0.9.89" +version = "0.9.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d197f2c03751ef006f29d593d22aa9068c9c358e04ca503afea0329c366147c" +checksum = "55d933382388cc7a6fdfd54e222eca7994791ac4b9ce5c9e8df280c739d86bbe" dependencies = [ "rb-sys-build", ] [[package]] name = "rb-sys-build" -version = "0.9.89" +version = "0.9.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b50caf8fd028f12abe00d6debe2ae2adf6202c9ca3caa59487eda710d90fa28" +checksum = "ebc5a7e3a875419baaa0d8cc606cdfb9361b444cb7e5abcf0de4693025887374" dependencies = [ "bindgen", "lazy_static", diff --git a/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.toml b/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.toml index e2c72480c1b604..424c61e45226cc 100644 --- a/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.toml +++ b/test/rubygems/test_gem_ext_cargo_builder/custom_name/ext/custom_name_lib/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -rb-sys = "0.9.89" +rb-sys = "0.9.90"