Skip to content

Commit

Permalink
Add some type hints in metadata.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 17, 2023
1 parent a1d5b47 commit 95d0389
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions sql/snsql/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, TypeAlias, Union
import yaml
import io
from os import path
import warnings
if TYPE_CHECKING:
from collections.abc import Mapping
from pathlib import Path

from snsql.sql.reader.base import NameCompare

Expand All @@ -10,7 +15,7 @@
class Metadata:
"""Information about a collection of tabular data sources"""

def __init__(self, tables, engine=None, compare=None, dbname=None):
def __init__(self, tables: Iterable[Table], engine=None, compare=None, dbname=None):
"""Instantiate a metadata object with information about tabular data sources
:param tables: A list of Table descriptions
Expand All @@ -25,7 +30,7 @@ def __init__(self, tables, engine=None, compare=None, dbname=None):
self.compare = NameCompare.get_name_compare(engine) if compare is None else compare
self.dbname = dbname if dbname else None

def __getitem__(self, tablename):
def __getitem__(self, tablename: str):
schema_name = ""
dbname = ""
parts = tablename.split(".")
Expand Down Expand Up @@ -72,19 +77,19 @@ def __iter__(self):
return self.tables()

@staticmethod
def from_file(file):
def from_file(file: Union[str, io.IOBase]) -> Metadata:
"""Load the metadata about this collection from a YAML file"""
ys = CollectionYamlLoader(file)
return ys.read_file()

@staticmethod
def from_dict(schema_dict):
def from_dict(schema_dict: dict):
"""Load the metadata from a dict object"""
ys = CollectionYamlLoader("dummy")
return ys._create_metadata_object(schema_dict)

@classmethod
def from_(cls, val):
def from_(cls, val : Union[Metadata, str, io.IOBase, dict]):
if isinstance(val, Metadata):
return val
elif isinstance(val, (str, io.IOBase)):
Expand All @@ -109,9 +114,9 @@ def __init__(
self,
schema,
name,
columns,
columns: Iterable[Column],
*ignore,
rowcount=0,
rowcount:int=0,
rows_exact=None,
row_privacy=False,
max_ids=1,
Expand Down Expand Up @@ -147,8 +152,11 @@ def __init__(

if clamp_columns:
for col in self.m_columns.values():
if col.typename() in ["int", "float"] and (col.lower is None or col.upper is None):
if col.sensitivity is not None:
if (
col.typename() in ["int", "float"]
and (col.lower is None or col.upper is None) # type: ignore
and col.sensitivity is not None # type: ignore
):
raise ValueError(
f"Column {col.name} has sensitivity and no bounds, but table specifies clamp_columns. "
"clamp_columns should be False, or bounds should be provided."
Expand Down Expand Up @@ -355,11 +363,13 @@ def typename(self):
def unbounded(self):
return True

Column: TypeAlias = Union[Boolean, DateTime, Int, Float, String, Unknown]

class CollectionYamlLoader:
def __init__(self, file):
def __init__(self, file: Union[Path, str, io.IOBase]) -> None:
self.file = file

def read_file(self):
def read_file(self) -> Metadata:
if isinstance(self.file, io.IOBase):
try:
c_s = yaml.safe_load(self.file)
Expand All @@ -376,7 +386,7 @@ def read_file(self):
raise
return self._create_metadata_object(c_s)

def _create_metadata_object(self, c_s):
def _create_metadata_object(self, c_s: Mapping) -> Metadata:
if not hasattr(c_s, "keys"):
raise ValueError("Metadata must be a YAML dictionary")
keys = list(c_s.keys())
Expand Down Expand Up @@ -407,7 +417,7 @@ def _create_metadata_object(self, c_s):

return Metadata(tables, engine, dbname=collection)

def load_table(self, schema, table, t):
def load_table(self, schema, table, t) -> Table:
rowcount = int(t["rows"]) if "rows" in t else 0
rows_exact = int(t["rows_exact"]) if "rows_exact" in t else None
row_privacy = bool(t["row_privacy"]) if "row_privacy" in t else False
Expand Down Expand Up @@ -453,7 +463,7 @@ def load_table(self, schema, table, t):
censor_dims=censor_dims,
)

def load_column(self, column, c):
def load_column(self, column, c) -> Column:
lower = float(c["lower"]) if "lower" in c else None
upper = float(c["upper"]) if "upper" in c else None
is_key = False if "private_id" not in c else bool(c["private_id"])
Expand Down Expand Up @@ -492,7 +502,7 @@ def load_column(self, column, c):
else:
raise ValueError("Unknown column type for column {0}: {1}".format(column, c))

def write_file(self, collection_metadata, collection_name):
def write_file(self, collection_metadata, collection_name) -> None:

engine = collection_metadata.engine
schemas = {}
Expand Down

0 comments on commit 95d0389

Please sign in to comment.