Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
* remove unused imports
* give __init__ argument a cleaner name
* move type check to top of __init__
  • Loading branch information
JostMigenda committed Aug 16, 2024
1 parent fcb1c06 commit 023fdb5
Showing 1 changed file with 15 additions and 27 deletions.
42 changes: 15 additions & 27 deletions python/snewpy/models/extended.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
import itertools as it
import os
from abc import ABC, abstractmethod
from warnings import warn

import numpy as np
from astropy import units as u
from astropy.table import Table, join
from astropy.units import UnitTypeError, get_physical_type
from astropy.units.quantity import Quantity
from scipy.special import loggamma
from snewpy import _model_downloader

from snewpy.neutrino import Flavor
from snewpy.flavor_transformation import NoTransformation
from functools import wraps

from snewpy.flux import Flux
from snewpy.models.base import SupernovaModel


class ExtendedModel(SupernovaModel):
"""Class defining a supernova model with a cooling tail extension."""

def __init__(self, *args):
def __init__(self, base_model):
"""Initialize extended supernova model class."""
if isinstance(args[0],SupernovaModel):
self.__dict__ = args[0].__dict__.copy()
for method_name in dir(args[0]):
if callable(getattr(args[0], method_name)) and method_name[0] != '_':
if method_name == 'get_initial_spectra':
self._get_initial_spectra = getattr(args[0], method_name)
else:
self.method_name = getattr(args[0], method_name)
self.t_final = self.time[-1]
self.L_final = {Flavor.NU_E: self.luminosity[Flavor.NU_E][-1],
Flavor.NU_X: self.luminosity[Flavor.NU_X][-1],
Flavor.NU_E_BAR: self.luminosity[Flavor.NU_E_BAR][-1],
Flavor.NU_X_BAR: self.luminosity[Flavor.NU_X_BAR][-1]}
else:
if not isinstance(base_model, SupernovaModel):
raise TypeError("ExtendedModel.__init__ requires a SupernovaModel object")

self.__dict__ = base_model.__dict__.copy()
for method_name in dir(base_model):
if callable(getattr(base_model, method_name)) and method_name[0] != '_':
if method_name == 'get_initial_spectra':
self._get_initial_spectra = getattr(base_model, method_name)
else:
self.method_name = getattr(base_model, method_name)
self.t_final = self.time[-1]
self.L_final = {Flavor.NU_E: self.luminosity[Flavor.NU_E][-1],
Flavor.NU_X: self.luminosity[Flavor.NU_X][-1],
Flavor.NU_E_BAR: self.luminosity[Flavor.NU_E_BAR][-1],
Flavor.NU_X_BAR: self.luminosity[Flavor.NU_X_BAR][-1]}

def get_initial_spectra(self, *args):
"""Get neutrino spectra/luminosity curves before oscillation"""
return self._get_initial_spectra(*args)
Expand Down

0 comments on commit 023fdb5

Please sign in to comment.