Skip to content

Commit

Permalink
Merge pull request #892 from circulon/fix/query_true_should_return_qu…
Browse files Browse the repository at this point in the history
…ery_builder

Tests query=True should return query builder
  • Loading branch information
josephmancuso authored Oct 28, 2024
2 parents 55ff6dd + d7a8895 commit 15894aa
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 25 deletions.
10 changes: 5 additions & 5 deletions src/masoniteorm/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def find(cls, record_id, query=False):

if query:
return builder

if isinstance(record_id, (list, tuple)):
return builder.get()
else:
if isinstance(record_id, (list, tuple)):
return builder.get()

return builder.first()

Expand Down Expand Up @@ -562,7 +562,7 @@ def create(
if query:
return cls.builder.create(
dictionary, query=True, cast=cast, **kwargs
).to_sql()
)

return cls.builder.create(dictionary, cast=cast, **kwargs)

Expand Down Expand Up @@ -897,7 +897,7 @@ def save(self, query=False):
if self.is_loaded():
result = builder.update(
self.__dirty_attributes__, dry=query, ignore_mass_assignment=True
).to_sql()
)
else:
result = self.create(self.__dirty_attributes__, query=query)

Expand Down
19 changes: 11 additions & 8 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,14 +1710,13 @@ def first(self, fields=None, query=False):
if not fields:
fields = []

if fields:
self.select(fields)
self.select(fields).limit(1)

if query:
return self.limit(1)
return self

result = self.new_connection().query(
self.limit(1).to_qmark(), self._bindings, results=1
self.to_qmark(), self._bindings, results=1
)

return self.prepare_result(result)
Expand Down Expand Up @@ -1778,11 +1777,13 @@ def last(self, column=None, query=False):
dictionary -- Returns a dictionary of results.
"""
_column = column if column else self._model.get_primary_key()
self.limit(1).order_by(_column, direction="DESC")

if query:
return self.limit(1).order_by(_column, direction="DESC")
return self

result = self.new_connection().query(
self.limit(1).order_by(_column, direction="DESC").to_qmark(),
self.to_qmark(),
self._bindings,
results=1,
)
Expand Down Expand Up @@ -1872,7 +1873,7 @@ def first_or_fail(self, query=False):
"""

if query:
return self.limit(1)
return self.first(query=True)

result = self.first()

Expand Down Expand Up @@ -1980,9 +1981,11 @@ def all(self, selects=[], query=False):
Returns:
dictionary -- Returns a dictionary of results.
"""

self.select(*selects)

if query:
return self.to_sql()
return self

result = self.new_connection().query(self.to_qmark(), self._bindings) or []

Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def test_model_creates_when_new(self):
model = ModelTest.hydrate({"id": 1, "username": "joe", "admin": True})

model.name = "Bill"
sql = model.save(query=True)
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))

model = ModelTest()

model.name = "Bill"
sql = model.save(query=True)
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("INSERT"))

def test_model_can_cast_attributes(self):
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_model_update_without_changes(self):

model.username = "joe"
model.name = "Bill"
sql = model.save(query=True)
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
self.assertNotIn("username", sql)

Expand All @@ -181,7 +181,7 @@ def test_force_update_on_model_class(self):

model.username = "joe"
model.name = "Bill"
sql = model.save(query=True)
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
self.assertIn("username", sql)
self.assertIn("name", sql)
Expand All @@ -201,7 +201,7 @@ def test_model_update_without_changes_at_all(self):

model.username = "joe"
model.name = "Joe"
sql = model.save(query=True)
sql = model.save(query=True).to_sql()
self.assertFalse(sql.startswith("UPDATE"))

def test_model_using_or_where(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/mysql/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class TestModel(unittest.TestCase):
def test_create_can_use_fillable(self):
sql = ProfileFillable.create(
{"name": "Joe", "email": "user@example.com"}, query=True
)
).to_sql()

self.assertEqual(
sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')"
Expand All @@ -90,7 +90,7 @@ def test_create_can_use_fillable(self):
def test_create_can_use_fillable_asterisk(self):
sql = ProfileFillAsterisk.create(
{"name": "Joe", "email": "user@example.com"}, query=True
)
).to_sql()

self.assertEqual(
sql,
Expand All @@ -100,7 +100,7 @@ def test_create_can_use_fillable_asterisk(self):
def test_create_can_use_guarded(self):
sql = ProfileGuarded.create(
{"name": "Joe", "email": "user@example.com"}, query=True
)
).to_sql()

self.assertEqual(
sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')"
Expand All @@ -109,7 +109,7 @@ def test_create_can_use_guarded(self):
def test_create_can_use_guarded_asterisk(self):
sql = ProfileGuardedAsterisk.create(
{"name": "Joe", "email": "user@example.com"}, query=True
)
).to_sql()

# An asterisk guarded attribute excludes all fields from mass-assignment.
# This would raise a DB error if there are any required fields.
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_can_find_first(self):
def test_can_touch(self):
profile = ProfileFillTimeStamped.hydrate({"name": "Joe", "id": 1})

sql = profile.touch("now", query=True)
sql = profile.touch("now", query=True).to_sql()

self.assertEqual(
sql,
Expand Down
2 changes: 1 addition & 1 deletion tests/mysql/scopes/test_can_use_global_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_can_use_global_scopes_on_select(self):

def test_can_use_global_scopes_on_time(self):
sql = "INSERT INTO `users` (`users`.`name`, `users`.`updated_at`, `users`.`created_at`) VALUES ('Joe'"
self.assertTrue(User.create({"name": "Joe"}, query=True).startswith(sql))
self.assertTrue(User.create({"name": "Joe"}, query=True).to_sql().startswith(sql))

# def test_can_use_global_scopes_on_inherit(self):
# sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NULL"
Expand Down
4 changes: 3 additions & 1 deletion tests/sqlite/models/test_sqlite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def test_update_all_records(self):

def test_can_find_list(self):
sql = User.find(1, query=True).to_sql()

self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""")

sql = User.find([1, 2, 3], query=True).to_sql()

self.assertEqual(
sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')"""
)
Expand Down Expand Up @@ -106,7 +108,7 @@ def test_model_can_use_selects(self):

def test_model_can_use_selects_from_methods(self):
self.assertEqual(
SelectPass.all(["username"], query=True),
SelectPass.all(["username"], query=True).to_sql(),
'SELECT "select_passes"."username" FROM "select_passes"',
)

Expand Down

0 comments on commit 15894aa

Please sign in to comment.