From 6e037658ef751c25d2175e7e5e82af9d18ec6654 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20P=C3=B6ppel?= <4759892+jpoeppel@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:41:51 +0200 Subject: [PATCH] Datafiltering fix after V1 Update (#263) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed sql translation by handling internal variable creation for 'some' and slight adaption of the rego file Signed-off-by: Dr. Jan Pöppel --- .../data_filter_example/opa.py | 135 ++++++++++++------ .../data_filter_example/server.py | 5 +- .../data_filter_example/sql.py | 10 +- data_filter_example/example.rego | 3 +- data_filter_example/tests/test_opa.py | 3 +- 5 files changed, 107 insertions(+), 49 deletions(-) diff --git a/data_filter_example/data_filter_example/opa.py b/data_filter_example/data_filter_example/opa.py index e5e9bc06..33682f9b 100644 --- a/data_filter_example/data_filter_example/opa.py +++ b/data_filter_example/data_filter_example/opa.py @@ -76,6 +76,7 @@ class TranslationError(Exception): """Raised if an error occurs during the Rego to SQL translation.""" + pass @@ -85,13 +86,14 @@ class Result(object): Attributes: defined (bool): If the query is NEVER defined, defined is False. In - this case, the app can intrepret the result as denying the request. + this case, the app can interpret the result as denying the request. sql (:class:`sql.Union`): If the query is ALWAYS defined, sql is None. - In this case the app can interpet the result as allowing the request + In this case the app can interpret the result as allowing the request unconditionally. If sql is not None, the app should apply the SQL clauses to the query it is about to run. """ + def __init__(self, defined, sql): self.defined = defined self.sql = sql @@ -100,44 +102,51 @@ def __init__(self, defined, sql): def compile_http(query, input, unknowns): """Returns a set of compiled queries.""" response = requests.post( - 'http://localhost:8181/v1/compile', - data=json.dumps({ - 'query': query, - 'input': input, - 'unknowns': unknowns, - })) + "http://localhost:8181/v1/compile", + data=json.dumps( + { + "query": query, + "input": input, + "unknowns": unknowns, + } + ), + ) body = response.json() if response.status_code != 200: - raise Exception('%s: %s' % (body.code, body.message)) - return body.get('result', {}).get('queries', []) + raise Exception("%s: %s" % (body.code, body.message)) + return body.get("result", {}).get("queries", []) def compile_command_line(data_files): """Returns a function that can be called to compile a query using OPA's eval subcommand.""" + def wrapped(query, input, unknowns): - args = ['opa', 'eval', '--partial', '--format', 'json'] + args = ["opa", "eval", "--partial", "--format", "json"] for u in unknowns: - args.extend(['--unknowns', u]) + args.extend(["--unknowns", u]) dirpath = tempfile.mkdtemp() try: - data_dirpath = os.path.join(dirpath, 'data') + data_dirpath = os.path.join(dirpath, "data") os.makedirs(data_dirpath) for filename, content in data_files.items(): - with open(os.path.join(data_dirpath, filename), 'w') as f: + with open(os.path.join(data_dirpath, filename), "w") as f: f.write(content) - args.extend(['--data', data_dirpath]) + args.extend(["--data", data_dirpath]) if input is not None: - input_path = os.path.join(dirpath, 'input.json') - with open(input_path, 'w') as f: + input_path = os.path.join(dirpath, "input.json") + with open(input_path, "w") as f: json.dump(input, f) - args.extend(['--input', input_path]) + args.extend(["--input", input_path]) args.append(query) output = subprocess.check_output(args, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - raise Exception("exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output)) + raise Exception( + "exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output) + ) finally: shutil.rmtree(dirpath) - return json.loads(output).get('partial', {}).get('queries', []) + return json.loads(output).get("partial", {}).get("queries", []) + return wrapped @@ -148,7 +157,9 @@ def compile(q, input, unknowns, from_table=None, compile_func=None): if compile_func is None: compile_func = compile_http - queries = compile_func(query=q, input=input, unknowns=['data.' + u for u in unknowns]) + queries = compile_func( + query=q, input=input, unknowns=["data." + u for u in unknowns] + ) # Check if query is never or always defined. if len(queries) == 0: @@ -164,19 +175,19 @@ def compile(q, input, unknowns, from_table=None, compile_func=None): return Result(True, clauses) -def splice(SELECT, FROM, WHERE='', decision=None, sql_kwargs=None): +def splice(SELECT, FROM, WHERE="", decision=None, sql_kwargs=None): """Returns a SQL query as a string constructed from the caller's provided values and the decision returned by compile.""" - sql = 'SELECT ' + SELECT + ' FROM ' + FROM + sql = "SELECT " + SELECT + " FROM " + FROM if decision is not None and decision.sql is not None: queries = [sql] * len(decision.sql.clauses) for i, clause in enumerate(decision.sql.clauses): if sql_kwargs is None: sql_kwargs = {} - queries[i] = queries[i] + ' ' + clause.sql(**sql_kwargs) + queries[i] = queries[i] + " " + clause.sql(**sql_kwargs) if WHERE: - queries[i] = queries[i] + ' AND (' + WHERE + ')' - return ' UNION '.join(queries) + queries[i] = queries[i] + " AND (" + WHERE + ")" + return " UNION ".join(queries) class queryTranslator(object): @@ -185,18 +196,19 @@ class queryTranslator(object): # Maps supported Rego relational operators to SQL relational operators. _sql_relation_operators = { - 'eq': '=', - 'equal': '=', - 'neq': '!=', - 'lt': '<', - 'gt': '>', - 'lte': '<=', - 'gte': '>=', + "internal.member_2": "in", + "eq": "=", + "equal": "=", + "neq": "!=", + "lt": "<", + "gt": ">", + "lte": "<=", + "gte": ">=", } # Maps supported Rego call operators to SQL call operators. _sql_call_operators = { - 'abs': 'abs', + "abs": "abs", } def __init__(self, from_table): @@ -213,8 +225,10 @@ def translate(self, query_set): walk.walk(query_set, self) clauses = [] if len(self._conjunctions) > 0: - clauses = [sql.Where(sql.Disjunction([conj for conj in self._conjunctions]))] - for (tables, conj) in self._joins: + clauses = [ + sql.Where(sql.Disjunction([conj for conj in self._conjunctions])) + ] + for tables, conj in self._joins: pred = sql.InnerJoin(tables, conj) clauses.append(pred) return sql.Union(clauses) @@ -223,7 +237,8 @@ def __call__(self, node): if isinstance(node, ast.Query): self._translate_query(node) elif isinstance(node, ast.Expr): - self._translate_expr(node) + if not node.ignore: + self._translate_expr(node) elif isinstance(node, ast.Term): self._translate_term(node) else: @@ -248,12 +263,14 @@ def _translate_expr(self, node): if not node.is_call(): return if len(node.operands) != 2: - raise TranslationError('invalid expression: too many arguments') + raise TranslationError("invalid expression: too many arguments") try: op = node.op() sql_op = sql.RelationOp(self._sql_relation_operators[op]) except KeyError: - raise TranslationError('invalid expression: operator not supported: %s' % op) + raise TranslationError( + "invalid expression: operator not supported: %s" % op + ) self._operands.append([]) for term in node.operands: walk.walk(term, self) @@ -275,21 +292,27 @@ def _translate_term(self, node): op = v.op() sql_op = self._sql_call_operators[op] except KeyError: - raise TranslationError('invalid call: operator not supported: %s' % op) + raise TranslationError("invalid call: operator not supported: %s" % op) self._operands.append([]) for term in v.operands: walk.walk(term, self) sql_operands = self._operands.pop() self._operands[-1].append(sql.Call(sql_op, sql_operands)) + elif isinstance(v, ast.Array): + self._operands[-1].append( + sql.Array([sql.Constant(t.value.value) for t in v.terms]) + ) else: - raise TranslationError('invalid term: type not supported: %s' % v.__class__.__name__) + raise TranslationError( + "invalid term: type not supported: %s" % v.__class__.__name__ + ) class queryPreprocessor(object): """Implements the visitor pattern to preprocess refs in the Rego query set. Preprocessing the Rego query set simplifies the translation process. - Refs are rewritten to correspond directly to SQL tables aand columns. + Refs are rewritten to correspond directly to SQL tables and columns. Specifically, refs of the form data.foo[var].bar are rewritten as data.foo.bar. Similarly, if var is dereferenced later in the query, e.g., var.baz, that will be rewritten as data.foo.baz.""" @@ -306,10 +329,31 @@ def __call__(self, node): self._table_names.append({}) self._table_vars = {} elif isinstance(node, ast.Expr): + node.ignore = False if node.is_call(): + if node.op() == "eq": + # Filter out temp variable expressions, resulting from multiple `some` statements + # TODO make cleaner + all_variables = True + for o in node.operands: + if not isinstance(o.value, ast.Ref): + all_variables = False + else: + if isinstance(o.value.operand(-1).value, ast.Scalar): + all_variables = False + if ( + isinstance(o.value.operand(-1).value, ast.Var) + and "__local" not in o.value.operand(-1).value.value + ): + all_variables = False + if all_variables: + # Throw out expression + node.ignore = True + return None # Skip the built-in call operator. for o in node.operands: walk.walk(o, self) + self.cur_operator = None return elif isinstance(node, ast.Call): # Skip the call operator. @@ -330,7 +374,9 @@ def __call__(self, node): # Refs must be of the form data.[].. if not isinstance(row_id, ast.Var): raise TranslationError( - 'invalid reference: row identifier type not supported: %s' % row_id.__class__.__name__) + "invalid reference: row identifier type not supported: %s" + % row_id.__class__.__name__ + ) prefix = node.terms[:2] @@ -341,9 +387,10 @@ def __call__(self, node): # Keep track of iterators used for each table. We do not support # self-joins currently. Self-joins require namespacing in the SQL # query. + exist = self._table_names[-1].get(table_name, row_id.value) if exist != row_id.value: - raise TranslationError('invalid reference: self-joins not supported') + raise TranslationError("invalid reference: self-joins not supported") else: self._table_names[-1][table_name] = row_id.value diff --git a/data_filter_example/data_filter_example/server.py b/data_filter_example/data_filter_example/server.py index 67b8d97e..3865cbff 100644 --- a/data_filter_example/data_filter_example/server.py +++ b/data_filter_example/data_filter_example/server.py @@ -99,7 +99,7 @@ def login(): if user in USERS: for c in COOKIES: if c in USERS[user]: - response.set_cookie(c, base64.b64encode(json.dumps(USERS[user][c]))) + response.set_cookie(c, json.dumps(USERS[user][c])) return response @@ -120,7 +120,7 @@ def make_subject(): for c in COOKIES: v = flask.request.cookies.get(c, '') if v: - subject[c] = json.loads(base64.b64decode(v)) + subject[c] = json.loads(v) return subject @@ -164,6 +164,7 @@ def init_db(): def query_db(query, args=(), one=False): + print("Resulting query: ", query) cur = get_db().execute(query, args) rv = cur.fetchall() cur.close() diff --git a/data_filter_example/data_filter_example/sql.py b/data_filter_example/data_filter_example/sql.py index 4c6de3c3..61689d57 100644 --- a/data_filter_example/data_filter_example/sql.py +++ b/data_filter_example/data_filter_example/sql.py @@ -12,7 +12,7 @@ def __init__(self, tables, expr): self.expr = expr def sql(self, **kwargs): - return ' '.join(['INNER JOIN ' + t for t in self.tables]) + ' ON ' + self.expr.sql(**kwargs) + return ' '.join(['INNER JOIN ' + t for t in sorted(self.tables)]) + ' ON ' + self.expr.sql(**kwargs) class Where(object): @@ -80,6 +80,14 @@ def sql(self, **kwargs): if isinstance(self.value, str): return "'" + self.value + "'" return json.dumps(self.value) + + +class Array(object): + def __init__(self, values): + self.values = values + + def sql(self, **kwargs): + return '(' + ', '.join(e.sql(**kwargs) for e in self.values) + ')' class RelationOp(object): diff --git a/data_filter_example/example.rego b/data_filter_example/example.rego index 43201b9e..6fd1bcf1 100644 --- a/data_filter_example/example.rego +++ b/data_filter_example/example.rego @@ -24,7 +24,8 @@ allow if { allow if { input.method == "GET" input.path == ["posts"] - count(allowed) > 0 + some post in data.posts + allowed[post] } allowed contains post if { diff --git a/data_filter_example/tests/test_opa.py b/data_filter_example/tests/test_opa.py index 7e385472..b76f9d80 100644 --- a/data_filter_example/tests/test_opa.py +++ b/data_filter_example/tests/test_opa.py @@ -268,6 +268,7 @@ 'simple join', {}, '''package test + p { data.q[x].a = data.r[y].b }''', @@ -334,7 +335,7 @@ def test_compile_multi_table(note, input, policy, exp_defined, exp_sql): if isinstance(clause, str): clauses.append('WHERE ' + clause) else: - joins = ' '.join('INNER JOIN ' + t for t in clause[0]) + joins = ' '.join('INNER JOIN ' + t for t in sorted(clause[0])) clauses.append(joins + ' ON ' + clause[1]) crunch( 'data.test.p = true',