Skip to content

Commit

Permalink
Add missing format method and load method implementation (#197)
Browse files Browse the repository at this point in the history
* feat: added missing format and load implementations for json and parquet

* fix: moved implementation from base to concrete impl

* fix: resolve review comments
  • Loading branch information
eredzik authored Nov 4, 2024
1 parent d2553fe commit 8292241
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sqlframe/base/mixins/readwriter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load(

assert path is not None, "path is required"
assert isinstance(path, str), "path must be a string"
format = format or _infer_format(path)
format = format or self.state_format_to_read or _infer_format(path)
kwargs = {k: v for k, v in options.items() if v is not None}
if format == "json":
df = pd.read_json(path, lines=True, **kwargs) # type: ignore
Expand Down
39 changes: 39 additions & 0 deletions sqlframe/base/readerwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
class _BaseDataFrameReader(t.Generic[SESSION, DF]):
def __init__(self, spark: SESSION):
self._session = spark
self.state_format_to_read: t.Optional[str] = None

@property
def session(self) -> SESSION:
Expand Down Expand Up @@ -67,6 +68,44 @@ def _to_casted_columns(self, column_mapping: t.Dict) -> t.List[Column]:
for k, v in column_mapping.items()
]

def format(self, source: str) -> "Self":
"""Specifies the input data source format.
.. versionadded:: 1.4.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
source : str
string, name of the data source, e.g. 'json', 'parquet'.
Examples
--------
>>> spark.read.format('json')
<...readwriter.DataFrameReader object ...>
Write a DataFrame into a JSON file and read it back.
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... # Write a DataFrame into a JSON file
... spark.createDataFrame(
... [{"age": 100, "name": "Hyukjin Kwon"}]
... ).write.mode("overwrite").format("json").save(d)
...
... # Read the JSON file as a DataFrame.
... spark.read.format('json').load(d).show()
+---+------------+
|age| name|
+---+------------+
|100|Hyukjin Kwon|
+---+------------+
"""
self.state_format_to_read = source
return self

def load(
self,
path: t.Optional[PathOrPaths] = None,
Expand Down
1 change: 1 addition & 0 deletions sqlframe/duckdb/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def load(
|100|NULL|
+---+----+
"""
format = format or self.state_format_to_read
if schema:
column_mapping = ensure_column_mapping(schema)
select_column_mapping = column_mapping.copy()
Expand Down

0 comments on commit 8292241

Please sign in to comment.