Skip to content

Commit

Permalink
Fixed bug in Migration generator with a nested SELECT expression
Browse files Browse the repository at this point in the history
  • Loading branch information
mpscholten committed Aug 10, 2023
1 parent 29ec434 commit 535d4f5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
54 changes: 30 additions & 24 deletions IHP/IDE/CodeGen/MigrationGenerator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ normalizeExpression (SelectExpression Select { columns, from, whereClause, alias
resolveAlias' = resolveAlias alias (unqualifiedName from)

unqualifiedName :: Expression -> Expression
unqualifiedName (DotExpression (VarExpression "public") name) = VarExpression name
unqualifiedName (DotExpression (VarExpression _) name) = VarExpression name
unqualifiedName name = name
normalizeExpression (DotExpression a b) = DotExpression (normalizeExpression a) b
normalizeExpression (ExistsExpression a) = ExistsExpression (normalizeExpression a)
Expand All @@ -522,30 +522,36 @@ normalizeExpression (ExistsExpression a) = ExistsExpression (normalizeExpression
-- sql "SELECT * FROM servers WHERE is_public"
--
unqualifyExpression :: Text -> Expression -> Expression
unqualifyExpression scope expression = unqualifyExpression expression
unqualifyExpression scope expression = doUnqualify expression
where
unqualifyExpression e@(TextExpression {}) = e
unqualifyExpression e@(VarExpression {}) = e
unqualifyExpression (CallExpression function args) = CallExpression function (map unqualifyExpression args)
unqualifyExpression (NotEqExpression a b) = NotEqExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (EqExpression a b) = EqExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (AndExpression a b) = AndExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (IsExpression a b) = IsExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (InExpression a b) = InExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (NotExpression a) = NotExpression (unqualifyExpression a)
unqualifyExpression (OrExpression a b) = OrExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (LessThanExpression a b) = LessThanExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (LessThanOrEqualToExpression a b) = LessThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (GreaterThanExpression a b) = GreaterThanExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (GreaterThanOrEqualToExpression a b) = GreaterThanOrEqualToExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression e@(DoubleExpression {}) = e
unqualifyExpression e@(IntExpression {}) = e
unqualifyExpression (ConcatenationExpression a b) = ConcatenationExpression (unqualifyExpression a) (unqualifyExpression b)
unqualifyExpression (TypeCastExpression a b) = TypeCastExpression (unqualifyExpression a) b
unqualifyExpression (SelectExpression Select { columns, from, whereClause, alias }) = SelectExpression Select { columns = (unqualifyExpression <$> columns), from = from, whereClause = unqualifyExpression whereClause, alias }
unqualifyExpression (ExistsExpression a) = ExistsExpression (unqualifyExpression a)
unqualifyExpression (DotExpression (VarExpression scope') b) | scope == scope' = VarExpression b
unqualifyExpression (DotExpression a b) = DotExpression (unqualifyExpression a) b
doUnqualify e@(TextExpression {}) = e
doUnqualify e@(VarExpression {}) = e
doUnqualify (CallExpression function args) = CallExpression function (map doUnqualify args)
doUnqualify (NotEqExpression a b) = NotEqExpression (doUnqualify a) (doUnqualify b)
doUnqualify (EqExpression a b) = EqExpression (doUnqualify a) (doUnqualify b)
doUnqualify (AndExpression a b) = AndExpression (doUnqualify a) (doUnqualify b)
doUnqualify (IsExpression a b) = IsExpression (doUnqualify a) (doUnqualify b)
doUnqualify (InExpression a b) = InExpression (doUnqualify a) (doUnqualify b)
doUnqualify (NotExpression a) = NotExpression (doUnqualify a)
doUnqualify (OrExpression a b) = OrExpression (doUnqualify a) (doUnqualify b)
doUnqualify (LessThanExpression a b) = LessThanExpression (doUnqualify a) (doUnqualify b)
doUnqualify (LessThanOrEqualToExpression a b) = LessThanOrEqualToExpression (doUnqualify a) (doUnqualify b)
doUnqualify (GreaterThanExpression a b) = GreaterThanExpression (doUnqualify a) (doUnqualify b)
doUnqualify (GreaterThanOrEqualToExpression a b) = GreaterThanOrEqualToExpression (doUnqualify a) (doUnqualify b)
doUnqualify e@(DoubleExpression {}) = e
doUnqualify e@(IntExpression {}) = e
doUnqualify (ConcatenationExpression a b) = ConcatenationExpression (doUnqualify a) (doUnqualify b)
doUnqualify (TypeCastExpression a b) = TypeCastExpression (doUnqualify a) b
doUnqualify e@(SelectExpression Select { columns, from, whereClause, alias }) =
let recurse = case from of
VarExpression fromName -> unqualifyExpression fromName
DotExpression (VarExpression "public") fromName -> unqualifyExpression fromName
_ -> doUnqualify
in
SelectExpression Select { columns = (recurse <$> columns), from = from, whereClause = recurse whereClause, alias }
doUnqualify (ExistsExpression a) = ExistsExpression (doUnqualify a)
doUnqualify (DotExpression (VarExpression scope') b) | scope == scope' = VarExpression b
doUnqualify (DotExpression a b) = DotExpression (doUnqualify a) b


resolveAlias :: Maybe Text -> Expression -> Expression -> Expression
Expand Down
13 changes: 13 additions & 0 deletions Test/IDE/CodeGeneration/MigrationGenerator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,19 @@ CREATE POLICY "Users can read and edit their own record" ON public.users USING (

diffSchemas targetSchema actualSchema `shouldBe` migration

it "should deal with nested SELECT expressions inside a policy" do
let actualSchema = sql $ cs [plain|
CREATE POLICY "Allow users to see their own company" ON public.companies USING ((id = ( SELECT users.company_id
FROM public.users
WHERE (users.id = public.ihp_user_id())))) WITH CHECK (false);
|]
let targetSchema = sql $ cs [plain|
CREATE POLICY "Allow users to see their own company" ON companies USING (id = (SELECT company_id FROM users WHERE users.id = ihp_user_id())) WITH CHECK (false);
|]
let migration = []

diffSchemas targetSchema actualSchema `shouldBe` migration

sql :: Text -> [Statement]
sql code = case Megaparsec.runParser Parser.parseDDL "" code of
Left parsingFailed -> error (cs $ Megaparsec.errorBundlePretty parsingFailed)
Expand Down

0 comments on commit 535d4f5

Please sign in to comment.