diff --git a/partner_company_default/models/__init__.py b/partner_company_default/models/__init__.py index 91fed54d404..b067a53de2f 100644 --- a/partner_company_default/models/__init__.py +++ b/partner_company_default/models/__init__.py @@ -1 +1,2 @@ +from . import res_company from . import res_partner diff --git a/partner_company_default/models/res_company.py b/partner_company_default/models/res_company.py new file mode 100644 index 00000000000..0a9c1ee816c --- /dev/null +++ b/partner_company_default/models/res_company.py @@ -0,0 +1,13 @@ +# Copyright 2023 Quartile Limited +# License AGPL-3.0 or later (https://www.gnu.org/licenses/agpl). + +from odoo import api, models + + +class ResCompany(models.Model): + _inherit = "res.company" + + @api.model + def create(self, vals): + self = self.with_context(creating_from_company=True) + return super().create(vals) diff --git a/partner_company_default/models/res_partner.py b/partner_company_default/models/res_partner.py index d44ecd5538a..7223c669710 100644 --- a/partner_company_default/models/res_partner.py +++ b/partner_company_default/models/res_partner.py @@ -1,10 +1,17 @@ # Copyright 2023 Quartile Limited # License AGPL-3.0 or later (https://www.gnu.org/licenses/agpl). -from odoo import fields, models +from odoo import api, fields, models class ResPartner(models.Model): _inherit = "res.partner" company_id = fields.Many2one(default=lambda self: self.env.company) + + @api.model + def create(self, vals): + # The context value is set in the create method of res.company + if self.env.context.get("creating_from_company"): + vals["company_id"] = False + return super(ResPartner, self).create(vals) diff --git a/partner_company_default/tests/__init__.py b/partner_company_default/tests/__init__.py new file mode 100644 index 00000000000..27e48b4c23f --- /dev/null +++ b/partner_company_default/tests/__init__.py @@ -0,0 +1 @@ +from . import test_partner_company_default diff --git a/partner_company_default/tests/test_partner_company_default.py b/partner_company_default/tests/test_partner_company_default.py new file mode 100644 index 00000000000..7703de4e66c --- /dev/null +++ b/partner_company_default/tests/test_partner_company_default.py @@ -0,0 +1,44 @@ +# Copyright 2023 Quartile Limited +# License AGPL-3.0 or later (https://www.gnu.org/licenses/agpl). + +import odoo.tests.common as common + + +class TestPartnerCompanyDefault(common.TransactionCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.user = cls.env.ref("base.user_admin") + + def test_partner_company_default(self): + # Check company of newly created partner + partner = ( + self.env["res.partner"] + .with_user(self.user.id) + .create({"name": "Test Partner 1"}) + ) + self.assertEqual(partner.company_id, self.user.company_id) + + # Check company of the partner of newly created company + company_fr = ( + self.env["res.company"] + .with_user(self.user.id) + .create( + { + "name": "French company", + "currency_id": self.env.ref("base.EUR").id, + "country_id": self.env.ref("base.fr").id, + } + ) + ) + self.assertFalse(company_fr.partner_id.company_id) + + # Switch user's company and create a partner + self.user.company_ids = [(4, company_fr.id)] + self.user.company_id = company_fr.id + partner = ( + self.env["res.partner"] + .with_user(self.user.id) + .create({"name": "Test Partner 2"}) + ) + self.assertEqual(partner.company_id, company_fr)