Skip to content

Commit

Permalink
add abillity to work with mysql2 adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
prog-supdex committed Oct 20, 2023
1 parent 7fd17b0 commit 452edc4
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 73 deletions.
79 changes: 79 additions & 0 deletions lib/activerecord_slotted_counters/adapters/base_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# frozen_string_literal: true

module ActiveRecordSlottedCounters
module Adapters
class BaseAdapter
attr_reader :klass, :current_adapter_name

def initialize(klass, current_adapter_name)
@klass = klass
@current_adapter_name = current_adapter_name
end

def apply?
raise NoMethodError
end

def bulk_insert(attributes, on_duplicate: nil, unique_by: nil)
raise NoMethodError
end

def wrap_column_name(value)
"EXCLUDED.#{value}"
end

private

def build_base_sql(attributes)
keys = attributes.first.keys + klass.all_timestamp_attributes_in_model

current_time = klass.current_time_from_proper_timezone
data = attributes.map { |attr| attr.values + [current_time, current_time] }

columns = columns_for_attributes(keys)

fields_str = quote_column_names(columns)
values_str = quote_many_records(columns, data)

<<~SQL
INSERT INTO #{klass.quoted_table_name}
(#{fields_str})
VALUES #{values_str}
SQL
end

def unique_indexes
klass.connection.schema_cache.indexes(klass.table_name).select(&:unique)
end

def columns_for_attributes(attributes)
attributes.map do |attribute|
klass.column_for_attribute(attribute)
end
end

def quote_column_names(columns, table_name: false)
columns.map do |column|
column_name = klass.connection.quote_column_name(column.name)
if table_name
"#{klass.quoted_table_name}.#{column_name}"
else
column_name
end
end.join(",")
end

def quote_record(columns, record_values)
values_str = record_values.each_with_index.map do |value, i|
type = klass.connection.lookup_cast_type_from_column(columns[i])
klass.connection.quote(type.serialize(value))
end.join(",")
"(#{values_str})"
end

def quote_many_records(columns, data)
data.map { |values| quote_record(columns, values) }.join(",")
end
end
end
end
30 changes: 30 additions & 0 deletions lib/activerecord_slotted_counters/adapters/mysql_upsert.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# frozen_string_literal: true

module ActiveRecordSlottedCounters
module Adapters
class MysqlUpsert < BaseAdapter
def apply?
return false unless defined?(ActiveRecord::ConnectionAdapters::Mysql2Adapter)

current_adapter_name == ActiveRecord::ConnectionAdapters::Mysql2Adapter::ADAPTER_NAME
end

def bulk_insert(attributes, on_duplicate: nil, unique_by: nil)
raise ArgumentError, "Values must not be empty" if attributes.empty?

sql = build_base_sql(attributes)

if on_duplicate.present?
sql += " ON DUPLICATE KEY UPDATE #{on_duplicate};"
end

# insert/update and return amount of updated rows
klass.connection.update(sql)
end

def wrap_column_name(value)
"VALUES(#{value})"
end
end
end
end
66 changes: 8 additions & 58 deletions lib/activerecord_slotted_counters/adapters/pg_upsert.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,18 @@

module ActiveRecordSlottedCounters
module Adapters
class PgUpsert
attr_reader :klass

def initialize(klass)
@klass = klass
end

class PgUpsert < BaseAdapter
def apply?
ActiveRecord::VERSION::MAJOR < 7 && klass.connection.adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME
return false if ActiveRecord::VERSION::MAJOR >= 7
return false unless defined?(ActiveRecord::ConnectionAdapters::PostgreSQLAdapter)

current_adapter_name == ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::ADAPTER_NAME
end

def bulk_insert(attributes, on_duplicate: nil, unique_by: nil)
raise ArgumentError, "Values must not be empty" if attributes.empty?

keys = attributes.first.keys + klass.all_timestamp_attributes_in_model

current_time = klass.current_time_from_proper_timezone
data = attributes.map { |attr| attr.values + [current_time, current_time] }

columns = columns_for_attributes(keys)

fields_str = quote_column_names(columns)
values_str = quote_many_records(columns, data)

sql = <<~SQL
INSERT INTO #{klass.quoted_table_name}
(#{fields_str})
VALUES #{values_str}
SQL
sql = build_base_sql(attributes)

if unique_by.present?
index = unique_indexes.find { |i| i.name.to_sym == unique_by }
Expand All @@ -46,42 +29,9 @@ def bulk_insert(attributes, on_duplicate: nil, unique_by: nil)

sql += " RETURNING \"id\""

klass.connection.exec_query(sql)
end

private

def unique_indexes
klass.connection.schema_cache.indexes(klass.table_name).select(&:unique)
end

def columns_for_attributes(attributes)
attributes.map do |attribute|
klass.column_for_attribute(attribute)
end
end

def quote_column_names(columns, table_name: false)
columns.map do |column|
column_name = klass.connection.quote_column_name(column.name)
if table_name
"#{klass.quoted_table_name}.#{column_name}"
else
column_name
end
end.join(",")
end

def quote_record(columns, record_values)
values_str = record_values.each_with_index.map do |value, i|
type = klass.connection.lookup_cast_type_from_column(columns[i])
klass.connection.quote(type.serialize(value))
end.join(",")
"(#{values_str})"
end
result = klass.connection.exec_query(sql)

def quote_many_records(columns, data)
data.map { |values| quote_record(columns, values) }.join(",")
result.rows.count
end
end
end
Expand Down
20 changes: 12 additions & 8 deletions lib/activerecord_slotted_counters/adapters/rails_upsert.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@

module ActiveRecordSlottedCounters
module Adapters
class RailsUpsert
attr_reader :klass

def initialize(klass)
@klass = klass
end

class RailsUpsert < BaseAdapter
def apply?
return false if mysql_connection?

ActiveRecord::VERSION::MAJOR >= 7
end

def bulk_insert(attributes, on_duplicate: nil, unique_by: nil)
klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by)
klass.upsert_all(attributes, on_duplicate: on_duplicate, unique_by: unique_by).rows.count
end

private

def mysql_connection?
return false unless defined?(ActiveRecord::ConnectionAdapters::Mysql2Adapter)

current_adapter_name == ActiveRecord::ConnectionAdapters::Mysql2Adapter::ADAPTER_NAME
end
end
end
Expand Down
18 changes: 11 additions & 7 deletions lib/activerecord_slotted_counters/has_slotted_counter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
require "active_support"
require "activerecord_slotted_counters/utils"

require "activerecord_slotted_counters/adapters/base_adapter"
require "activerecord_slotted_counters/adapters/rails_upsert"
require "activerecord_slotted_counters/adapters/pg_upsert"
require "activerecord_slotted_counters/adapters/mysql_upsert"

module ActiveRecordSlottedCounters
class SlottedCounter < ::ActiveRecord::Base
Expand All @@ -24,7 +26,8 @@ def message

class << self
def bulk_insert(attributes)
on_duplicate_clause = "count = slotted_counters.count + excluded.count"
on_duplicate_clause =
"count = slotted_counters.count + #{slotted_counter_db_adapter.wrap_column_name("count")}"

slotted_counter_db_adapter.bulk_insert(
attributes,
Expand All @@ -42,14 +45,17 @@ def slotted_counter_db_adapter
def set_slotted_counter_db_adapter
available_adapters = [
ActiveRecordSlottedCounters::Adapters::RailsUpsert,
ActiveRecordSlottedCounters::Adapters::PgUpsert
ActiveRecordSlottedCounters::Adapters::PgUpsert,
ActiveRecordSlottedCounters::Adapters::MysqlUpsert
]

current_adapter_name = connection.adapter_name

adapter = available_adapters
.map { |adapter| adapter.new(self) }
.map { |adapter| adapter.new(self, current_adapter_name) }
.detect { |adapter| adapter.apply? }

raise NotSupportedAdapter.new(connection.adapter_name) if adapter.nil?
raise NotSupportedAdapter.new(current_adapter_name) if adapter.nil?

adapter
end
Expand Down Expand Up @@ -205,9 +211,7 @@ def reset_slotted_counters(id, *counters, touch: nil)
def insert_counters_records(ids, counters)
counters_params = prepare_slotted_counters_params(ids, counters)

result = ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params)

result.rows.count
ActiveRecordSlottedCounters::SlottedCounter.bulk_insert(counters_params)
end

def remove_counters_records(ids, counter_name)
Expand Down

0 comments on commit 452edc4

Please sign in to comment.