diff --git a/lib/activerecord_slotted_counters/adapters/base_adapter.rb b/lib/activerecord_slotted_counters/adapters/base_adapter.rb new file mode 100644 index 0000000..cb2d734 --- /dev/null +++ b/lib/activerecord_slotted_counters/adapters/base_adapter.rb @@ -0,0 +1,79 @@ +# frozen_string_literal: true + +module ActiveRecordSlottedCounters + module Adapters + class BaseAdapter + attr_reader :klass, :current_adapter_name + + def initialize(klass, current_adapter_name) + @klass = klass + @current_adapter_name = current_adapter_name + end + + def apply? + raise NoMethodError + end + + def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) + raise NoMethodError + end + + def wrap_column_name(value) + "EXCLUDED.#{value}" + end + + private + + def build_base_sql(attributes) + keys = attributes.first.keys + klass.all_timestamp_attributes_in_model + + current_time = klass.current_time_from_proper_timezone + data = attributes.map { |attr| attr.values + [current_time, current_time] } + + columns = columns_for_attributes(keys) + + fields_str = quote_column_names(columns) + values_str = quote_many_records(columns, data) + + <<~SQL + INSERT INTO #{klass.quoted_table_name} + (#{fields_str}) + VALUES #{values_str} + SQL + end + + def unique_indexes + klass.connection.schema_cache.indexes(klass.table_name).select(&:unique) + end + + def columns_for_attributes(attributes) + attributes.map do |attribute| + klass.column_for_attribute(attribute) + end + end + + def quote_column_names(columns, table_name: false) + columns.map do |column| + column_name = klass.connection.quote_column_name(column.name) + if table_name + "#{klass.quoted_table_name}.#{column_name}" + else + column_name + end + end.join(",") + end + + def quote_record(columns, record_values) + values_str = record_values.each_with_index.map do |value, i| + type = klass.connection.lookup_cast_type_from_column(columns[i]) + klass.connection.quote(type.serialize(value)) + end.join(",") + "(#{values_str})" + end + + def quote_many_records(columns, data) + data.map { |values| quote_record(columns, values) }.join(",") + end + end + end +end diff --git a/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb b/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb new file mode 100644 index 0000000..a5d894c --- /dev/null +++ b/lib/activerecord_slotted_counters/adapters/mysql_upsert.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true + +module ActiveRecordSlottedCounters + module Adapters + class MysqlUpsert < BaseAdapter + def apply? + return false unless defined?(ActiveRecord::ConnectionAdapters::Mysql2Adapter) + + current_adapter_name == ActiveRecord::ConnectionAdapters::Mysql2Adapter::ADAPTER_NAME + end + + def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) + raise ArgumentError, "Values must not be empty" if attributes.empty? + + sql = build_base_sql(attributes) + + if on_duplicate.present? + sql += " ON DUPLICATE KEY UPDATE #{on_duplicate};" + end + + # insert/update and return amount of updated rows + klass.connection.update(sql) + end + + def wrap_column_name(value) + "VALUES(#{value})" + end + end + end +end diff --git a/lib/activerecord_slotted_counters/adapters/pg_upsert.rb b/lib/activerecord_slotted_counters/adapters/pg_upsert.rb index b16ad8f..65c4317 100644 --- a/lib/activerecord_slotted_counters/adapters/pg_upsert.rb +++ b/lib/activerecord_slotted_counters/adapters/pg_upsert.rb @@ -2,35 +2,18 @@ module ActiveRecordSlottedCounters module Adapters - class PgUpsert - attr_reader :klass - - def initialize(klass) - @klass = klass - end - + class PgUpsert < BaseAdapter def apply? - ActiveRecord::VERSION::MAJOR < 7 && klass.connection.adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME + return false if ActiveRecord::VERSION::MAJOR >= 7 + return false unless defined?(ActiveRecord::ConnectionAdapters::PostgreSQLAdapter) + + current_adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME end def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) raise ArgumentError, "Values must not be empty" if attributes.empty? - keys = attributes.first.keys + klass.all_timestamp_attributes_in_model - - current_time = klass.current_time_from_proper_timezone - data = attributes.map { |attr| attr.values + [current_time, current_time] } - - columns = columns_for_attributes(keys) - - fields_str = quote_column_names(columns) - values_str = quote_many_records(columns, data) - - sql = <<~SQL - INSERT INTO #{klass.quoted_table_name} - (#{fields_str}) - VALUES #{values_str} - SQL + sql = build_base_sql(attributes) if unique_by.present? index = unique_indexes.find { |i| i.name.to_sym == unique_by } @@ -46,42 +29,9 @@ def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) sql += " RETURNING \"id\"" - klass.connection.exec_query(sql) - end - - private - - def unique_indexes - klass.connection.schema_cache.indexes(klass.table_name).select(&:unique) - end - - def columns_for_attributes(attributes) - attributes.map do |attribute| - klass.column_for_attribute(attribute) - end - end - - def quote_column_names(columns, table_name: false) - columns.map do |column| - column_name = klass.connection.quote_column_name(column.name) - if table_name - "#{klass.quoted_table_name}.#{column_name}" - else - column_name - end - end.join(",") - end - - def quote_record(columns, record_values) - values_str = record_values.each_with_index.map do |value, i| - type = klass.connection.lookup_cast_type_from_column(columns[i]) - klass.connection.quote(type.serialize(value)) - end.join(",") - "(#{values_str})" - end + result = klass.connection.exec_query(sql) - def quote_many_records(columns, data) - data.map { |values| quote_record(columns, values) }.join(",") + result.rows.count end end end diff --git a/lib/activerecord_slotted_counters/adapters/rails_upsert.rb b/lib/activerecord_slotted_counters/adapters/rails_upsert.rb index 1dd103e..f00d276 100644 --- a/lib/activerecord_slotted_counters/adapters/rails_upsert.rb +++ b/lib/activerecord_slotted_counters/adapters/rails_upsert.rb @@ -2,19 +2,23 @@ module ActiveRecordSlottedCounters module Adapters - class RailsUpsert - attr_reader :klass - - def initialize(klass) - @klass = klass - end - + class RailsUpsert < BaseAdapter def apply? + return false if mysql_connection? + ActiveRecord::VERSION::MAJOR >= 7 end def bulk_insert(attributes, on_duplicate: nil, unique_by: nil) - klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by) + klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by).rows.count + end + + private + + def mysql_connection? + return false unless defined?(ActiveRecord::ConnectionAdapters::Mysql2Adapter) + + current_adapter_name == ActiveRecord::ConnectionAdapters::Mysql2Adapter::ADAPTER_NAME end end end diff --git a/lib/activerecord_slotted_counters/has_slotted_counter.rb b/lib/activerecord_slotted_counters/has_slotted_counter.rb index 259db8c..a31c6e4 100644 --- a/lib/activerecord_slotted_counters/has_slotted_counter.rb +++ b/lib/activerecord_slotted_counters/has_slotted_counter.rb @@ -3,8 +3,10 @@ require "active_support" require "activerecord_slotted_counters/utils" +require "activerecord_slotted_counters/adapters/base_adapter" require "activerecord_slotted_counters/adapters/rails_upsert" require "activerecord_slotted_counters/adapters/pg_upsert" +require "activerecord_slotted_counters/adapters/mysql_upsert" module ActiveRecordSlottedCounters class SlottedCounter < ::ActiveRecord::Base @@ -24,7 +26,8 @@ def message class << self def bulk_insert(attributes) - on_duplicate_clause = "count = slotted_counters.count + excluded.count" + on_duplicate_clause = + "count = slotted_counters.count + #{slotted_counter_db_adapter.wrap_column_name("count")}" slotted_counter_db_adapter.bulk_insert( attributes, @@ -42,14 +45,17 @@ def slotted_counter_db_adapter def set_slotted_counter_db_adapter available_adapters = [ ActiveRecordSlottedCounters::Adapters::RailsUpsert, - ActiveRecordSlottedCounters::Adapters::PgUpsert + ActiveRecordSlottedCounters::Adapters::PgUpsert, + ActiveRecordSlottedCounters::Adapters::MysqlUpsert ] + current_adapter_name = connection.adapter_name + adapter = available_adapters - .map { |adapter| adapter.new(self) } + .map { |adapter| adapter.new(self, current_adapter_name) } .detect { |adapter| adapter.apply? } - raise NotSupportedAdapter.new(connection.adapter_name) if adapter.nil? + raise NotSupportedAdapter.new(current_adapter_name) if adapter.nil? adapter end @@ -205,9 +211,7 @@ def reset_slotted_counters(id, *counters, touch: nil) def insert_counters_records(ids, counters) counters_params = prepare_slotted_counters_params(ids, counters) - result = ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params) - - result.rows.count + ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params) end def remove_counters_records(ids, counter_name)