Skip to content

Commit

Permalink
Updated read_database to use column type if present - closes #34
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 30, 2023
1 parent bd73b10 commit c5abf3e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
7 changes: 6 additions & 1 deletion Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ gem "rake"
gem "rake-compiler"
gem "minitest"
gem "activerecord"
gem "sqlite3"
gem "numo-narray"
gem "vega"

if ENV["ADAPTER"] == "postgresql"
gem "pg"
else
gem "sqlite3"
end

# https://github.com/lsegal/yard/issues/1321
gem "yard", require: false
3 changes: 1 addition & 2 deletions lib/polars/data_frame.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class DataFrame
# this does not yield conclusive results, column orientation is used.
def initialize(data = nil, schema: nil, columns: nil, schema_overrides: nil, orient: nil, infer_schema_length: 100, nan_to_null: false)
schema ||= columns
raise Todo if schema_overrides

# TODO deprecate in favor of read_sql
if defined?(ActiveRecord) && (data.is_a?(ActiveRecord::Relation) || data.is_a?(ActiveRecord::Result))
Expand Down Expand Up @@ -4946,7 +4945,7 @@ def self._unpack_schema(schema, schema_overrides: nil, n_expected: nil, lookup_n
end

if schema_overrides && schema_overrides.any?
raise Todo
column_dtypes.merge!(schema_overrides)
end

column_dtypes.each do |col, dtype|
Expand Down
14 changes: 13 additions & 1 deletion lib/polars/io.rb
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,22 @@ def read_database(query)
raise ArgumentError, "Expected ActiveRecord::Relation, ActiveRecord::Result, or String"
end
data = {}
schema_overrides = {}
result.columns.each_with_index do |k, i|
data[k] = result.rows.map { |r| r[i] }
column_type = result.column_types[i]&.type
polars_type =
case column_type
when :datetime
Datetime
when :string
Utf8
when :integer
Int64
end
schema_overrides[k] = polars_type if polars_type
end
DataFrame.new(data)
DataFrame.new(data, schema_overrides: schema_overrides)
end
alias_method :read_sql, :read_database

Expand Down
18 changes: 13 additions & 5 deletions test/database_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def setup
def test_relation
users = create_users
df = Polars::DataFrame.new(User.order(:id))
assert_equal ["id", "name", "number"], df.columns
assert_equal ["id", "name", "number", "joined_at"], df.columns
assert_series users.map(&:id), df["id"]
assert_series users.map(&:name), df["name"]
assert_equal Polars::Int64, df["number"].dtype
Expand All @@ -17,7 +17,7 @@ def test_relation
def test_result
users = create_users
df = Polars::DataFrame.new(User.connection.select_all("SELECT * FROM users ORDER BY id"))
assert_equal ["id", "name", "number"], df.columns
assert_equal ["id", "name", "number", "joined_at"], df.columns
assert_series users.map(&:id), df["id"]
assert_series users.map(&:name), df["name"]
assert_equal Polars::Int64, df["number"].dtype
Expand All @@ -26,7 +26,7 @@ def test_result
def test_read_database_relation
users = create_users
df = Polars.read_database(User.order(:id))
assert_equal ["id", "name", "number"], df.columns
assert_equal ["id", "name", "number", "joined_at"], df.columns
assert_series users.map(&:id), df["id"]
assert_series users.map(&:name), df["name"]
assert_equal Polars::Int64, df["number"].dtype
Expand All @@ -35,7 +35,7 @@ def test_read_database_relation
def test_read_database_result
users = create_users
df = Polars.read_database(User.connection.select_all("SELECT * FROM users ORDER BY id"))
assert_equal ["id", "name", "number"], df.columns
assert_equal ["id", "name", "number", "joined_at"], df.columns
assert_series users.map(&:id), df["id"]
assert_series users.map(&:name), df["name"]
assert_equal Polars::Int64, df["number"].dtype
Expand All @@ -44,12 +44,20 @@ def test_read_database_result
def test_read_database_string
users = create_users
df = Polars.read_database("SELECT * FROM users ORDER BY id")
assert_equal ["id", "name", "number"], df.columns
assert_equal ["id", "name", "number", "joined_at"], df.columns
assert_series users.map(&:id), df["id"]
assert_series users.map(&:name), df["name"]
assert_equal Polars::Int64, df["number"].dtype
end

def test_read_database_null
skip unless ENV["ADAPTER"] == "postgresql"

User.create!
df = Polars.read_database("SELECT * FROM users ORDER BY id")
assert_equal [Polars::Int64, Polars::Utf8, Polars::Int64, Polars::Datetime], df.dtypes
end

def test_read_database_unsupported
error = assert_raises(ArgumentError) do
Polars.read_database(Object.new)
Expand Down
9 changes: 7 additions & 2 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
ActiveRecord::Base.logger = logger
ActiveRecord::Migration.verbose = ENV["VERBOSE"]

ActiveRecord::Base.establish_connection adapter: "sqlite3", database: ":memory:"
if ENV["ADAPTER"] == "postgresql"
ActiveRecord::Base.establish_connection adapter: "postgresql", database: "polars_ruby_test"
else
ActiveRecord::Base.establish_connection adapter: "sqlite3", database: ":memory:"
end

ActiveRecord::Schema.define do
create_table :users do |t|
create_table :users, force: true do |t|
t.string :name
t.integer :number
t.datetime :joined_at
end
end

Expand Down

0 comments on commit c5abf3e

Please sign in to comment.