From 023fdb5e13fbdb4dd6cee34b3c7fb8cd35398940 Mon Sep 17 00:00:00 2001 From: Jost Migenda Date: Fri, 16 Aug 2024 14:55:22 -0500 Subject: [PATCH] code cleanup * remove unused imports * give __init__ argument a cleaner name * move type check to top of __init__ --- python/snewpy/models/extended.py | 42 ++++++++++++-------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/python/snewpy/models/extended.py b/python/snewpy/models/extended.py index b8dc622e..59237189 100644 --- a/python/snewpy/models/extended.py +++ b/python/snewpy/models/extended.py @@ -1,45 +1,33 @@ -import itertools as it -import os -from abc import ABC, abstractmethod from warnings import warn import numpy as np from astropy import units as u -from astropy.table import Table, join -from astropy.units import UnitTypeError, get_physical_type -from astropy.units.quantity import Quantity -from scipy.special import loggamma -from snewpy import _model_downloader from snewpy.neutrino import Flavor -from snewpy.flavor_transformation import NoTransformation -from functools import wraps - -from snewpy.flux import Flux from snewpy.models.base import SupernovaModel class ExtendedModel(SupernovaModel): """Class defining a supernova model with a cooling tail extension.""" - def __init__(self, *args): + def __init__(self, base_model): """Initialize extended supernova model class.""" - if isinstance(args[0],SupernovaModel): - self.__dict__ = args[0].__dict__.copy() - for method_name in dir(args[0]): - if callable(getattr(args[0], method_name)) and method_name[0] != '_': - if method_name == 'get_initial_spectra': - self._get_initial_spectra = getattr(args[0], method_name) - else: - self.method_name = getattr(args[0], method_name) - self.t_final = self.time[-1] - self.L_final = {Flavor.NU_E: self.luminosity[Flavor.NU_E][-1], - Flavor.NU_X: self.luminosity[Flavor.NU_X][-1], - Flavor.NU_E_BAR: self.luminosity[Flavor.NU_E_BAR][-1], - Flavor.NU_X_BAR: self.luminosity[Flavor.NU_X_BAR][-1]} - else: + if not isinstance(base_model, SupernovaModel): raise TypeError("ExtendedModel.__init__ requires a SupernovaModel object") + self.__dict__ = base_model.__dict__.copy() + for method_name in dir(base_model): + if callable(getattr(base_model, method_name)) and method_name[0] != '_': + if method_name == 'get_initial_spectra': + self._get_initial_spectra = getattr(base_model, method_name) + else: + self.method_name = getattr(base_model, method_name) + self.t_final = self.time[-1] + self.L_final = {Flavor.NU_E: self.luminosity[Flavor.NU_E][-1], + Flavor.NU_X: self.luminosity[Flavor.NU_X][-1], + Flavor.NU_E_BAR: self.luminosity[Flavor.NU_E_BAR][-1], + Flavor.NU_X_BAR: self.luminosity[Flavor.NU_X_BAR][-1]} + def get_initial_spectra(self, *args): """Get neutrino spectra/luminosity curves before oscillation""" return self._get_initial_spectra(*args)