diff --git a/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb b/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb new file mode 100644 index 0000000..b5cb007 --- /dev/null +++ b/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb @@ -0,0 +1,82 @@ +# frozen_string_literal: true + +module ActiveRecordSlottedCounters + module Adapters + class MysqlUpsert + attr_reader :klass, :current_adapter_name + + def initialize(klass, current_adapter_name) + @klass = klass + @current_adapter_name = current_adapter_name + end + + def apply? + return false unless defined?(ActiveRecord::ConnectionAdapters::Mysql2Adapter) + + current_adapter_name == ActiveRecord::ConnectionAdapters::Mysql2Adapter::ADAPTER_NAME + end + + def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) + raise ArgumentError, "Values must not be empty" if attributes.empty? + + keys = attributes.first.keys + klass.all_timestamp_attributes_in_model + + current_time = klass.current_time_from_proper_timezone + data = attributes.map { |attr| attr.values + [current_time, current_time] } + + columns = columns_for_attributes(keys) + + fields_str = quote_column_names(columns) + values_str = quote_many_records(columns, data) + + sql = <<~SQL + INSERT INTO #{klass.quoted_table_name} + (#{fields_str}) + VALUES #{values_str} + SQL + + if on_duplicate.present? + sql += " ON DUPLICATE KEY UPDATE #{on_duplicate};" + end + + # insert/update and return amount of updated rows + klass.connection.update(sql) + end + + def wrap_column_name(value) + "VALUES(#{value})" + end + + private + + def columns_for_attributes(attributes) + attributes.map do |attribute| + klass.column_for_attribute(attribute) + end + end + + def quote_column_names(columns, table_name: false) + columns.map do |column| + column_name = klass.connection.quote_column_name(column.name) + if table_name + "#{klass.quoted_table_name}.#{column_name}" + else + column_name + end + end.join(",") + end + + def quote_record(columns, record_values) + values_str = record_values.each_with_index.map do |value, i| + type = klass.connection.lookup_cast_type_from_column(columns[i]) + klass.connection.quote(type.serialize(value)) + end.join(",") + "(#{values_str})" + end + + def quote_many_records(columns, data) + data.map { |values| quote_record(columns, values) }.join(",") + end + end + end +end diff --git a/lib/activerecord_slotted_counters/adapters/pg_upsert.rb b/lib/activerecord_slotted_counters/adapters/pg_upsert.rb index b16ad8f..dce9cd3 100644 --- a/lib/activerecord_slotted_counters/adapters/pg_upsert.rb +++ b/lib/activerecord_slotted_counters/adapters/pg_upsert.rb @@ -3,14 +3,18 @@ module ActiveRecordSlottedCounters module Adapters class PgUpsert - attr_reader :klass + attr_reader :klass, :current_adapter_name - def initialize(klass) + def initialize(klass, current_adapter_name) @klass = klass + @current_adapter_name = current_adapter_name end def apply? - ActiveRecord::VERSION::MAJOR < 7 && klass.connection.adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME + return false if ActiveRecord::VERSION::MAJOR >= 7 + return false unless defined?(ActiveRecord::ConnectionAdapters::PostgreSQLAdapter) + + current_adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME end def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) diff --git a/lib/activerecord_slotted_counters/adapters/rails_upsert.rb b/lib/activerecord_slotted_counters/adapters/rails_upsert.rb index 1dd103e..93be24f 100644 --- a/lib/activerecord_slotted_counters/adapters/rails_upsert.rb +++ b/lib/activerecord_slotted_counters/adapters/rails_upsert.rb @@ -3,10 +3,11 @@ module ActiveRecordSlottedCounters module Adapters class RailsUpsert - attr_reader :klass + attr_reader :klass, :current_adapter_name - def initialize(klass) + def initialize(klass, current_adapter_name) @klass = klass + @current_adapter_name = current_adapter_name end def apply? @@ -14,7 +15,7 @@ def apply? end def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) - klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by) + klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by).rows.count end end end diff --git a/lib/activerecord_slotted_counters/has_slotted_counter.rb b/lib/activerecord_slotted_counters/has_slotted_counter.rb index 259db8c..b52c423 100644 --- a/lib/activerecord_slotted_counters/has_slotted_counter.rb +++ b/lib/activerecord_slotted_counters/has_slotted_counter.rb @@ -5,6 +5,7 @@ require "activerecord_slotted_counters/adapters/rails_upsert" require "activerecord_slotted_counters/adapters/pg_upsert" +require "activerecord_slotted_counters/adapters/mysql_upsert" module ActiveRecordSlottedCounters class SlottedCounter < ::ActiveRecord::Base @@ -24,7 +25,8 @@ def message class << self def bulk_insert(attributes) - on_duplicate_clause = "count = slotted_counters.count + excluded.count" + on_duplicate_clause = + "count = slotted_counters.count + #{slotted_counter_db_adapter.wrap_column_name("count")}" slotted_counter_db_adapter.bulk_insert( attributes, @@ -41,15 +43,18 @@ def slotted_counter_db_adapter def set_slotted_counter_db_adapter available_adapters = [ - ActiveRecordSlottedCounters::Adapters::RailsUpsert, - ActiveRecordSlottedCounters::Adapters::PgUpsert + ActiveRecordSlottedCounters::Adapters::MysqlUpsert, + ActiveRecordSlottedCounters::Adapters::PgUpsert, + ActiveRecordSlottedCounters::Adapters::RailsUpsert ] + current_adapter_name = connection.adapter_name + adapter = available_adapters - .map { |adapter| adapter.new(self) } + .map { |adapter| adapter.new(self, current_adapter_name) } .detect { |adapter| adapter.apply? } - raise NotSupportedAdapter.new(connection.adapter_name) if adapter.nil? + raise NotSupportedAdapter.new(current_adapter_name) if adapter.nil? adapter end @@ -205,9 +210,7 @@ def reset_slotted_counters(id, *counters, touch: nil) def insert_counters_records(ids, counters) counters_params = prepare_slotted_counters_params(ids, counters) - result = ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params) - - result.rows.count + ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params) end def remove_counters_records(ids, counter_name)