diff --git a/data_filter_example/data_filter_example/opa.py b/data_filter_example/data_filter_example/opa.py index e5e9bc06..33682f9b 100644 --- a/data_filter_example/data_filter_example/opa.py +++ b/data_filter_example/data_filter_example/opa.py @@ -76,6 +76,7 @@ class TranslationError(Exception): """Raised if an error occurs during the Rego to SQL translation.""" + pass @@ -85,13 +86,14 @@ class Result(object): Attributes: defined (bool): If the query is NEVER defined, defined is False. In - this case, the app can intrepret the result as denying the request. + this case, the app can interpret the result as denying the request. sql (:class:`sql.Union`): If the query is ALWAYS defined, sql is None. - In this case the app can interpet the result as allowing the request + In this case the app can interpret the result as allowing the request unconditionally. If sql is not None, the app should apply the SQL clauses to the query it is about to run. """ + def __init__(self, defined, sql): self.defined = defined self.sql = sql @@ -100,44 +102,51 @@ def __init__(self, defined, sql): def compile_http(query, input, unknowns): """Returns a set of compiled queries.""" response = requests.post( - 'http://localhost:8181/v1/compile', - data=json.dumps({ - 'query': query, - 'input': input, - 'unknowns': unknowns, - })) + "http://localhost:8181/v1/compile", + data=json.dumps( + { + "query": query, + "input": input, + "unknowns": unknowns, + } + ), + ) body = response.json() if response.status_code != 200: - raise Exception('%s: %s' % (body.code, body.message)) - return body.get('result', {}).get('queries', []) + raise Exception("%s: %s" % (body.code, body.message)) + return body.get("result", {}).get("queries", []) def compile_command_line(data_files): """Returns a function that can be called to compile a query using OPA's eval subcommand.""" + def wrapped(query, input, unknowns): - args = ['opa', 'eval', '--partial', '--format', 'json'] + args = ["opa", "eval", "--partial", "--format", "json"] for u in unknowns: - args.extend(['--unknowns', u]) + args.extend(["--unknowns", u]) dirpath = tempfile.mkdtemp() try: - data_dirpath = os.path.join(dirpath, 'data') + data_dirpath = os.path.join(dirpath, "data") os.makedirs(data_dirpath) for filename, content in data_files.items(): - with open(os.path.join(data_dirpath, filename), 'w') as f: + with open(os.path.join(data_dirpath, filename), "w") as f: f.write(content) - args.extend(['--data', data_dirpath]) + args.extend(["--data", data_dirpath]) if input is not None: - input_path = os.path.join(dirpath, 'input.json') - with open(input_path, 'w') as f: + input_path = os.path.join(dirpath, "input.json") + with open(input_path, "w") as f: json.dump(input, f) - args.extend(['--input', input_path]) + args.extend(["--input", input_path]) args.append(query) output = subprocess.check_output(args, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - raise Exception("exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output)) + raise Exception( + "exit code %d: command: %s: %s" % (e.returncode, e.cmd, e.output) + ) finally: shutil.rmtree(dirpath) - return json.loads(output).get('partial', {}).get('queries', []) + return json.loads(output).get("partial", {}).get("queries", []) + return wrapped @@ -148,7 +157,9 @@ def compile(q, input, unknowns, from_table=None, compile_func=None): if compile_func is None: compile_func = compile_http - queries = compile_func(query=q, input=input, unknowns=['data.' + u for u in unknowns]) + queries = compile_func( + query=q, input=input, unknowns=["data." + u for u in unknowns] + ) # Check if query is never or always defined. if len(queries) == 0: @@ -164,19 +175,19 @@ def compile(q, input, unknowns, from_table=None, compile_func=None): return Result(True, clauses) -def splice(SELECT, FROM, WHERE='', decision=None, sql_kwargs=None): +def splice(SELECT, FROM, WHERE="", decision=None, sql_kwargs=None): """Returns a SQL query as a string constructed from the caller's provided values and the decision returned by compile.""" - sql = 'SELECT ' + SELECT + ' FROM ' + FROM + sql = "SELECT " + SELECT + " FROM " + FROM if decision is not None and decision.sql is not None: queries = [sql] * len(decision.sql.clauses) for i, clause in enumerate(decision.sql.clauses): if sql_kwargs is None: sql_kwargs = {} - queries[i] = queries[i] + ' ' + clause.sql(**sql_kwargs) + queries[i] = queries[i] + " " + clause.sql(**sql_kwargs) if WHERE: - queries[i] = queries[i] + ' AND (' + WHERE + ')' - return ' UNION '.join(queries) + queries[i] = queries[i] + " AND (" + WHERE + ")" + return " UNION ".join(queries) class queryTranslator(object): @@ -185,18 +196,19 @@ class queryTranslator(object): # Maps supported Rego relational operators to SQL relational operators. _sql_relation_operators = { - 'eq': '=', - 'equal': '=', - 'neq': '!=', - 'lt': '<', - 'gt': '>', - 'lte': '<=', - 'gte': '>=', + "internal.member_2": "in", + "eq": "=", + "equal": "=", + "neq": "!=", + "lt": "<", + "gt": ">", + "lte": "<=", + "gte": ">=", } # Maps supported Rego call operators to SQL call operators. _sql_call_operators = { - 'abs': 'abs', + "abs": "abs", } def __init__(self, from_table): @@ -213,8 +225,10 @@ def translate(self, query_set): walk.walk(query_set, self) clauses = [] if len(self._conjunctions) > 0: - clauses = [sql.Where(sql.Disjunction([conj for conj in self._conjunctions]))] - for (tables, conj) in self._joins: + clauses = [ + sql.Where(sql.Disjunction([conj for conj in self._conjunctions])) + ] + for tables, conj in self._joins: pred = sql.InnerJoin(tables, conj) clauses.append(pred) return sql.Union(clauses) @@ -223,7 +237,8 @@ def __call__(self, node): if isinstance(node, ast.Query): self._translate_query(node) elif isinstance(node, ast.Expr): - self._translate_expr(node) + if not node.ignore: + self._translate_expr(node) elif isinstance(node, ast.Term): self._translate_term(node) else: @@ -248,12 +263,14 @@ def _translate_expr(self, node): if not node.is_call(): return if len(node.operands) != 2: - raise TranslationError('invalid expression: too many arguments') + raise TranslationError("invalid expression: too many arguments") try: op = node.op() sql_op = sql.RelationOp(self._sql_relation_operators[op]) except KeyError: - raise TranslationError('invalid expression: operator not supported: %s' % op) + raise TranslationError( + "invalid expression: operator not supported: %s" % op + ) self._operands.append([]) for term in node.operands: walk.walk(term, self) @@ -275,21 +292,27 @@ def _translate_term(self, node): op = v.op() sql_op = self._sql_call_operators[op] except KeyError: - raise TranslationError('invalid call: operator not supported: %s' % op) + raise TranslationError("invalid call: operator not supported: %s" % op) self._operands.append([]) for term in v.operands: walk.walk(term, self) sql_operands = self._operands.pop() self._operands[-1].append(sql.Call(sql_op, sql_operands)) + elif isinstance(v, ast.Array): + self._operands[-1].append( + sql.Array([sql.Constant(t.value.value) for t in v.terms]) + ) else: - raise TranslationError('invalid term: type not supported: %s' % v.__class__.__name__) + raise TranslationError( + "invalid term: type not supported: %s" % v.__class__.__name__ + ) class queryPreprocessor(object): """Implements the visitor pattern to preprocess refs in the Rego query set. Preprocessing the Rego query set simplifies the translation process. - Refs are rewritten to correspond directly to SQL tables aand columns. + Refs are rewritten to correspond directly to SQL tables and columns. Specifically, refs of the form data.foo[var].bar are rewritten as data.foo.bar. Similarly, if var is dereferenced later in the query, e.g., var.baz, that will be rewritten as data.foo.baz.""" @@ -306,10 +329,31 @@ def __call__(self, node): self._table_names.append({}) self._table_vars = {} elif isinstance(node, ast.Expr): + node.ignore = False if node.is_call(): + if node.op() == "eq": + # Filter out temp variable expressions, resulting from multiple `some` statements + # TODO make cleaner + all_variables = True + for o in node.operands: + if not isinstance(o.value, ast.Ref): + all_variables = False + else: + if isinstance(o.value.operand(-1).value, ast.Scalar): + all_variables = False + if ( + isinstance(o.value.operand(-1).value, ast.Var) + and "__local" not in o.value.operand(-1).value.value + ): + all_variables = False + if all_variables: + # Throw out expression + node.ignore = True + return None # Skip the built-in call operator. for o in node.operands: walk.walk(o, self) + self.cur_operator = None return elif isinstance(node, ast.Call): # Skip the call operator. @@ -330,7 +374,9 @@ def __call__(self, node): # Refs must be of the form data.