diff --git a/IHP/IDE/CodeGen/MigrationGenerator.hs b/IHP/IDE/CodeGen/MigrationGenerator.hs index 3ee25133a..c3e4c5294 100644 --- a/IHP/IDE/CodeGen/MigrationGenerator.hs +++ b/IHP/IDE/CodeGen/MigrationGenerator.hs @@ -511,7 +511,7 @@ normalizeExpression (SelectExpression Select { columns, from, whereClause, alias resolveAlias' = resolveAlias alias (unqualifiedName from) unqualifiedName :: Expression -> Expression - unqualifiedName (DotExpression (VarExpression "public") name) = VarExpression name + unqualifiedName (DotExpression (VarExpression _) name) = VarExpression name unqualifiedName name = name normalizeExpression (DotExpression a b) = DotExpression (normalizeExpression a) b normalizeExpression (ExistsExpression a) = ExistsExpression (normalizeExpression a) @@ -522,30 +522,36 @@ normalizeExpression (ExistsExpression a) = ExistsExpression (normalizeExpression -- sql "SELECT * FROM servers WHERE is_public" -- unqualifyExpression :: Text -> Expression -> Expression -unqualifyExpression scope expression = unqualifyExpression expression +unqualifyExpression scope expression = doUnqualify expression where - unqualifyExpression e@(TextExpression {}) = e - unqualifyExpression e@(VarExpression {}) = e - unqualifyExpression (CallExpression function args) = CallExpression function (map unqualifyExpression args) - unqualifyExpression (NotEqExpression a b) = NotEqExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (EqExpression a b) = EqExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (AndExpression a b) = AndExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (IsExpression a b) = IsExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (InExpression a b) = InExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (NotExpression a) = NotExpression (unqualifyExpression a) - unqualifyExpression (OrExpression a b) = OrExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (LessThanExpression a b) = LessThanExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (LessThanOrEqualToExpression a b) = LessThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (GreaterThanExpression a b) = GreaterThanExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (GreaterThanOrEqualToExpression a b) = GreaterThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression e@(DoubleExpression {}) = e - unqualifyExpression e@(IntExpression {}) = e - unqualifyExpression (ConcatenationExpression a b) = ConcatenationExpression (unqualifyExpression a) (unqualifyExpression b) - unqualifyExpression (TypeCastExpression a b) = TypeCastExpression (unqualifyExpression a) b - unqualifyExpression (SelectExpression Select { columns, from, whereClause, alias }) = SelectExpression Select { columns = (unqualifyExpression <$> columns), from = from, whereClause = unqualifyExpression whereClause, alias } - unqualifyExpression (ExistsExpression a) = ExistsExpression (unqualifyExpression a) - unqualifyExpression (DotExpression (VarExpression scope') b) | scope == scope' = VarExpression b - unqualifyExpression (DotExpression a b) = DotExpression (unqualifyExpression a) b + doUnqualify e@(TextExpression {}) = e + doUnqualify e@(VarExpression {}) = e + doUnqualify (CallExpression function args) = CallExpression function (map doUnqualify args) + doUnqualify (NotEqExpression a b) = NotEqExpression (doUnqualify a) (doUnqualify b) + doUnqualify (EqExpression a b) = EqExpression (doUnqualify a) (doUnqualify b) + doUnqualify (AndExpression a b) = AndExpression (doUnqualify a) (doUnqualify b) + doUnqualify (IsExpression a b) = IsExpression (doUnqualify a) (doUnqualify b) + doUnqualify (InExpression a b) = InExpression (doUnqualify a) (doUnqualify b) + doUnqualify (NotExpression a) = NotExpression (doUnqualify a) + doUnqualify (OrExpression a b) = OrExpression (doUnqualify a) (doUnqualify b) + doUnqualify (LessThanExpression a b) = LessThanExpression (doUnqualify a) (doUnqualify b) + doUnqualify (LessThanOrEqualToExpression a b) = LessThanOrEqualToExpression (doUnqualify a) (doUnqualify b) + doUnqualify (GreaterThanExpression a b) = GreaterThanExpression (doUnqualify a) (doUnqualify b) + doUnqualify (GreaterThanOrEqualToExpression a b) = GreaterThanOrEqualToExpression (doUnqualify a) (doUnqualify b) + doUnqualify e@(DoubleExpression {}) = e + doUnqualify e@(IntExpression {}) = e + doUnqualify (ConcatenationExpression a b) = ConcatenationExpression (doUnqualify a) (doUnqualify b) + doUnqualify (TypeCastExpression a b) = TypeCastExpression (doUnqualify a) b + doUnqualify e@(SelectExpression Select { columns, from, whereClause, alias }) = + let recurse = case from of + VarExpression fromName -> unqualifyExpression fromName + DotExpression (VarExpression "public") fromName -> unqualifyExpression fromName + _ -> doUnqualify + in + SelectExpression Select { columns = (recurse <$> columns), from = from, whereClause = recurse whereClause, alias } + doUnqualify (ExistsExpression a) = ExistsExpression (doUnqualify a) + doUnqualify (DotExpression (VarExpression scope') b) | scope == scope' = VarExpression b + doUnqualify (DotExpression a b) = DotExpression (doUnqualify a) b resolveAlias :: Maybe Text -> Expression -> Expression -> Expression diff --git a/Test/IDE/CodeGeneration/MigrationGenerator.hs b/Test/IDE/CodeGeneration/MigrationGenerator.hs index f8690179d..4dcc5c3b6 100644 --- a/Test/IDE/CodeGeneration/MigrationGenerator.hs +++ b/Test/IDE/CodeGeneration/MigrationGenerator.hs @@ -1377,6 +1377,19 @@ CREATE POLICY "Users can read and edit their own record" ON public.users USING ( diffSchemas targetSchema actualSchema `shouldBe` migration + it "should deal with nested SELECT expressions inside a policy" do + let actualSchema = sql $ cs [plain| + CREATE POLICY "Allow users to see their own company" ON public.companies USING ((id = ( SELECT users.company_id + FROM public.users + WHERE (users.id = public.ihp_user_id())))) WITH CHECK (false); + |] + let targetSchema = sql $ cs [plain| + CREATE POLICY "Allow users to see their own company" ON companies USING (id = (SELECT company_id FROM users WHERE users.id = ihp_user_id())) WITH CHECK (false); + |] + let migration = [] + + diffSchemas targetSchema actualSchema `shouldBe` migration + sql :: Text -> [Statement] sql code = case Megaparsec.runParser Parser.parseDDL "" code of Left parsingFailed -> error (cs $ Megaparsec.errorBundlePretty parsingFailed)