Skip to content

Commit

Permalink
Datafiltering fix after V1 Update (#263)
Browse files Browse the repository at this point in the history
Fixed sql translation by handling internal variable creation for 'some' and slight adaption of the rego file

Signed-off-by: Dr. Jan Pöppel <poeppja@imes-solutions.com>
  • Loading branch information
jpoeppel authored Oct 21, 2024
1 parent b791362 commit 6e03765
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 49 deletions.
135 changes: 91 additions & 44 deletions data_filter_example/data_filter_example/opa.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

class TranslationError(Exception):
"""Raised if an error occurs during the Rego to SQL translation."""

pass


Expand All @@ -85,13 +86,14 @@ class Result(object):
Attributes:
defined (bool): If the query is NEVER defined, defined is False. In
this case, the app can intrepret the result as denying the request.
this case, the app can interpret the result as denying the request.
sql (:class:`sql.Union`): If the query is ALWAYS defined, sql is None.
In this case the app can interpet the result as allowing the request
In this case the app can interpret the result as allowing the request
unconditionally. If sql is not None, the app should apply the SQL
clauses to the query it is about to run.
"""

def __init__(self, defined, sql):
self.defined = defined
self.sql = sql
Expand All @@ -100,44 +102,51 @@ def __init__(self, defined, sql):
def compile_http(query, input, unknowns):
"""Returns a set of compiled queries."""
response = requests.post(
'http://localhost:8181/v1/compile',
data=json.dumps({
'query': query,
'input': input,
'unknowns': unknowns,
}))
"http://localhost:8181/v1/compile",
data=json.dumps(
{
"query": query,
"input": input,
"unknowns": unknowns,
}
),
)
body = response.json()
if response.status_code != 200:
raise Exception('%s: %s' % (body.code, body.message))
return body.get('result', {}).get('queries', [])
raise Exception("%s: %s" % (body.code, body.message))
return body.get("result", {}).get("queries", [])


def compile_command_line(data_files):
"""Returns a function that can be called to compile a query using OPA's eval subcommand."""

def wrapped(query, input, unknowns):
args = ['opa', 'eval', '--partial', '--format', 'json']
args = ["opa", "eval", "--partial", "--format", "json"]
for u in unknowns:
args.extend(['--unknowns', u])
args.extend(["--unknowns", u])
dirpath = tempfile.mkdtemp()
try:
data_dirpath = os.path.join(dirpath, 'data')
data_dirpath = os.path.join(dirpath, "data")
os.makedirs(data_dirpath)
for filename, content in data_files.items():
with open(os.path.join(data_dirpath, filename), 'w') as f:
with open(os.path.join(data_dirpath, filename), "w") as f:
f.write(content)
args.extend(['--data', data_dirpath])
args.extend(["--data", data_dirpath])
if input is not None:
input_path = os.path.join(dirpath, 'input.json')
with open(input_path, 'w') as f:
input_path = os.path.join(dirpath, "input.json")
with open(input_path, "w") as f:
json.dump(input, f)
args.extend(['--input', input_path])
args.extend(["--input", input_path])
args.append(query)
output = subprocess.check_output(args, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise Exception("exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output))
raise Exception(
"exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output)
)
finally:
shutil.rmtree(dirpath)
return json.loads(output).get('partial', {}).get('queries', [])
return json.loads(output).get("partial", {}).get("queries", [])

return wrapped


Expand All @@ -148,7 +157,9 @@ def compile(q, input, unknowns, from_table=None, compile_func=None):
if compile_func is None:
compile_func = compile_http

queries = compile_func(query=q, input=input, unknowns=['data.' + u for u in unknowns])
queries = compile_func(
query=q, input=input, unknowns=["data." + u for u in unknowns]
)

# Check if query is never or always defined.
if len(queries) == 0:
Expand All @@ -164,19 +175,19 @@ def compile(q, input, unknowns, from_table=None, compile_func=None):
return Result(True, clauses)


def splice(SELECT, FROM, WHERE='', decision=None, sql_kwargs=None):
def splice(SELECT, FROM, WHERE="", decision=None, sql_kwargs=None):
"""Returns a SQL query as a string constructed from the caller's provided
values and the decision returned by compile."""
sql = 'SELECT ' + SELECT + ' FROM ' + FROM
sql = "SELECT " + SELECT + " FROM " + FROM
if decision is not None and decision.sql is not None:
queries = [sql] * len(decision.sql.clauses)
for i, clause in enumerate(decision.sql.clauses):
if sql_kwargs is None:
sql_kwargs = {}
queries[i] = queries[i] + ' ' + clause.sql(**sql_kwargs)
queries[i] = queries[i] + " " + clause.sql(**sql_kwargs)
if WHERE:
queries[i] = queries[i] + ' AND (' + WHERE + ')'
return ' UNION '.join(queries)
queries[i] = queries[i] + " AND (" + WHERE + ")"
return " UNION ".join(queries)


class queryTranslator(object):
Expand All @@ -185,18 +196,19 @@ class queryTranslator(object):

# Maps supported Rego relational operators to SQL relational operators.
_sql_relation_operators = {
'eq': '=',
'equal': '=',
'neq': '!=',
'lt': '<',
'gt': '>',
'lte': '<=',
'gte': '>=',
"internal.member_2": "in",
"eq": "=",
"equal": "=",
"neq": "!=",
"lt": "<",
"gt": ">",
"lte": "<=",
"gte": ">=",
}

# Maps supported Rego call operators to SQL call operators.
_sql_call_operators = {
'abs': 'abs',
"abs": "abs",
}

def __init__(self, from_table):
Expand All @@ -213,8 +225,10 @@ def translate(self, query_set):
walk.walk(query_set, self)
clauses = []
if len(self._conjunctions) > 0:
clauses = [sql.Where(sql.Disjunction([conj for conj in self._conjunctions]))]
for (tables, conj) in self._joins:
clauses = [
sql.Where(sql.Disjunction([conj for conj in self._conjunctions]))
]
for tables, conj in self._joins:
pred = sql.InnerJoin(tables, conj)
clauses.append(pred)
return sql.Union(clauses)
Expand All @@ -223,7 +237,8 @@ def __call__(self, node):
if isinstance(node, ast.Query):
self._translate_query(node)
elif isinstance(node, ast.Expr):
self._translate_expr(node)
if not node.ignore:
self._translate_expr(node)
elif isinstance(node, ast.Term):
self._translate_term(node)
else:
Expand All @@ -248,12 +263,14 @@ def _translate_expr(self, node):
if not node.is_call():
return
if len(node.operands) != 2:
raise TranslationError('invalid expression: too many arguments')
raise TranslationError("invalid expression: too many arguments")
try:
op = node.op()
sql_op = sql.RelationOp(self._sql_relation_operators[op])
except KeyError:
raise TranslationError('invalid expression: operator not supported: %s' % op)
raise TranslationError(
"invalid expression: operator not supported: %s" % op
)
self._operands.append([])
for term in node.operands:
walk.walk(term, self)
Expand All @@ -275,21 +292,27 @@ def _translate_term(self, node):
op = v.op()
sql_op = self._sql_call_operators[op]
except KeyError:
raise TranslationError('invalid call: operator not supported: %s' % op)
raise TranslationError("invalid call: operator not supported: %s" % op)
self._operands.append([])
for term in v.operands:
walk.walk(term, self)
sql_operands = self._operands.pop()
self._operands[-1].append(sql.Call(sql_op, sql_operands))
elif isinstance(v, ast.Array):
self._operands[-1].append(
sql.Array([sql.Constant(t.value.value) for t in v.terms])
)
else:
raise TranslationError('invalid term: type not supported: %s' % v.__class__.__name__)
raise TranslationError(
"invalid term: type not supported: %s" % v.__class__.__name__
)


class queryPreprocessor(object):
"""Implements the visitor pattern to preprocess refs in the Rego query set.
Preprocessing the Rego query set simplifies the translation process.
Refs are rewritten to correspond directly to SQL tables aand columns.
Refs are rewritten to correspond directly to SQL tables and columns.
Specifically, refs of the form data.foo[var].bar are rewritten as
data.foo.bar. Similarly, if var is dereferenced later in the query, e.g.,
var.baz, that will be rewritten as data.foo.baz."""
Expand All @@ -306,10 +329,31 @@ def __call__(self, node):
self._table_names.append({})
self._table_vars = {}
elif isinstance(node, ast.Expr):
node.ignore = False
if node.is_call():
if node.op() == "eq":
# Filter out temp variable expressions, resulting from multiple `some` statements
# TODO make cleaner
all_variables = True
for o in node.operands:
if not isinstance(o.value, ast.Ref):
all_variables = False
else:
if isinstance(o.value.operand(-1).value, ast.Scalar):
all_variables = False
if (
isinstance(o.value.operand(-1).value, ast.Var)
and "__local" not in o.value.operand(-1).value.value
):
all_variables = False
if all_variables:
# Throw out expression
node.ignore = True
return None
# Skip the built-in call operator.
for o in node.operands:
walk.walk(o, self)
self.cur_operator = None
return
elif isinstance(node, ast.Call):
# Skip the call operator.
Expand All @@ -330,7 +374,9 @@ def __call__(self, node):
# Refs must be of the form data.<table>[<iterator>].<column>.
if not isinstance(row_id, ast.Var):
raise TranslationError(
'invalid reference: row identifier type not supported: %s' % row_id.__class__.__name__)
"invalid reference: row identifier type not supported: %s"
% row_id.__class__.__name__
)

prefix = node.terms[:2]

Expand All @@ -341,9 +387,10 @@ def __call__(self, node):
# Keep track of iterators used for each table. We do not support
# self-joins currently. Self-joins require namespacing in the SQL
# query.

exist = self._table_names[-1].get(table_name, row_id.value)
if exist != row_id.value:
raise TranslationError('invalid reference: self-joins not supported')
raise TranslationError("invalid reference: self-joins not supported")
else:
self._table_names[-1][table_name] = row_id.value

Expand Down
5 changes: 3 additions & 2 deletions data_filter_example/data_filter_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def login():
if user in USERS:
for c in COOKIES:
if c in USERS[user]:
response.set_cookie(c, base64.b64encode(json.dumps(USERS[user][c])))
response.set_cookie(c, json.dumps(USERS[user][c]))
return response


Expand All @@ -120,7 +120,7 @@ def make_subject():
for c in COOKIES:
v = flask.request.cookies.get(c, '')
if v:
subject[c] = json.loads(base64.b64decode(v))
subject[c] = json.loads(v)
return subject


Expand Down Expand Up @@ -164,6 +164,7 @@ def init_db():


def query_db(query, args=(), one=False):
print("Resulting query: ", query)
cur = get_db().execute(query, args)
rv = cur.fetchall()
cur.close()
Expand Down
10 changes: 9 additions & 1 deletion data_filter_example/data_filter_example/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, tables, expr):
self.expr = expr

def sql(self, **kwargs):
return ' '.join(['INNER JOIN ' + t for t in self.tables]) + ' ON ' + self.expr.sql(**kwargs)
return ' '.join(['INNER JOIN ' + t for t in sorted(self.tables)]) + ' ON ' + self.expr.sql(**kwargs)


class Where(object):
Expand Down Expand Up @@ -80,6 +80,14 @@ def sql(self, **kwargs):
if isinstance(self.value, str):
return "'" + self.value + "'"
return json.dumps(self.value)


class Array(object):
def __init__(self, values):
self.values = values

def sql(self, **kwargs):
return '(' + ', '.join(e.sql(**kwargs) for e in self.values) + ')'


class RelationOp(object):
Expand Down
3 changes: 2 additions & 1 deletion data_filter_example/example.rego
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ allow if {
allow if {
input.method == "GET"
input.path == ["posts"]
count(allowed) > 0
some post in data.posts
allowed[post]
}

allowed contains post if {
Expand Down
3 changes: 2 additions & 1 deletion data_filter_example/tests/test_opa.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@
'simple join',
{},
'''package test
p {
data.q[x].a = data.r[y].b
}''',
Expand Down Expand Up @@ -334,7 +335,7 @@ def test_compile_multi_table(note, input, policy, exp_defined, exp_sql):
if isinstance(clause, str):
clauses.append('WHERE ' + clause)
else:
joins = ' '.join('INNER JOIN ' + t for t in clause[0])
joins = ' '.join('INNER JOIN ' + t for t in sorted(clause[0]))
clauses.append(joins + ' ON ' + clause[1])
crunch(
'data.test.p = true',
Expand Down

0 comments on commit 6e03765

Please sign in to comment.