Skip to content

Commit

Permalink
Add some type hints in sql/reader/base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 17, 2023
1 parent 95d0389 commit a1c2d15
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions sql/snsql/sql/reader/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Type
from snsql.reader.base import Reader
from snsql.sql.reader.engine import Engine
import importlib
Expand All @@ -6,7 +8,7 @@

class SqlReader(Reader):
@classmethod
def get_reader_class(cls, engine):
def get_reader_class(cls, engine) -> Type[Reader]:
prefix = ""
for eng in Engine.known_engines:
if str(eng).lower() == engine.lower():
Expand Down Expand Up @@ -72,10 +74,10 @@ def __init__(self, search_path=None):
path. Pass in only the schema part.
"""

def reserved(self):
def reserved(self) -> list[str]:
return ["select", "group", "on"]

def schema_match(self, from_query, from_meta):
def schema_match(self, from_query: str, from_meta: str) -> bool:
if from_query.strip() == "" and from_meta in self.search_path:
return True
if from_meta.strip() == "" and from_query in self.search_path:
Expand All @@ -87,25 +89,25 @@ def schema_match(self, from_query, from_meta):
if identifier used in query matches identifier
of metadata object. Pass in one part at a time.
"""
def identifier_match(self, from_query, from_meta):
def identifier_match(self, from_query: str, from_meta: str) -> bool:
return from_query == from_meta

"""
Removes all escaping characters, keeping identifiers unchanged
"""
def strip_escapes(self, value):
def strip_escapes(self, value: str) -> str:
return value.replace('"', "").replace("`", "").replace("[", "").replace("]", "")

"""
True if any part of identifier is escaped
"""
def is_escaped(self, identifier):
def is_escaped(self, identifier: str) -> bool:
return any([p[0] in ['"', "[", "`"] for p in identifier.split(".") if p != ""])

"""
Converts proprietary escaping to SQL-92. Supports multi-part identifiers
"""
def clean_escape(self, identifier):
def clean_escape(self, identifier: str) -> str:
escaped = []
for p in identifier.split("."):
if self.is_escaped(p):
Expand All @@ -118,7 +120,7 @@ def clean_escape(self, identifier):
Returns true if an identifier should
be escaped. Checks only one part per call.
"""
def should_escape(self, identifier):
def should_escape(self, identifier: str) -> bool:
if self.is_escaped(identifier):
return False
if identifier.lower() in self.reserved():
Expand Down

0 comments on commit a1c2d15

Please sign in to comment.