Skip to content

Commit

Permalink
Merge pull request #5 from SNEWS2/publish-to-pypi
Browse files Browse the repository at this point in the history
Lightly refactor message classes
  • Loading branch information
justinvasel authored Aug 16, 2024
2 parents 73a0227 + 745dc88 commit 23cdcd1
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 463 deletions.
5 changes: 4 additions & 1 deletion snews/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# -*- coding: utf-8 -*-
from .data import detectors
from .models import messages, timing
from .schema import SNEWSJsonSchema

__all__ = ["data", "models", "schemas"]
__all__ = ["detectors", "messages", "timing", "SNEWSJsonSchema"]
3 changes: 3 additions & 0 deletions snews/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

# Standard modules
import inspect
import json
import logging
from pathlib import Path
Expand Down Expand Up @@ -32,6 +33,8 @@ def generate_model_schemas(outdir: str = None, models_module: list = models, dry
model_class = getattr(models, model_class_name)
for model_name in model_class.__all__:
model = getattr(model_class, model_name)
if not inspect.isclass(model):
continue

if not issubclass(model, BaseModel):
continue
Expand Down
141 changes: 83 additions & 58 deletions snews/examples/tutorial.ipynb

Large diffs are not rendered by default.

116 changes: 89 additions & 27 deletions snews/models/messages.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# -*- coding: utf-8 -*-
__all__ = [
"HeartBeat",
"Retraction",
"CoincidenceTierMessage",
"SignificanceTierMessage",
"TimingTierMessage"
]

# Standard library modules
from datetime import UTC, datetime, timedelta
Expand All @@ -15,15 +8,24 @@

# Third-party modules
import numpy as np
from pydantic import (UUID4, BaseModel, Field, NonNegativeFloat,
field_validator, model_validator, root_validator,
validator)
from pydantic import (BaseModel, Field, NonNegativeFloat, ValidationError,
field_validator, model_validator)

# Local modules
from ..__version__ import schema_version
from ..data import detectors
from ..models.timing import PrecisionTimestamp

__all__ = [
"HeartbeatMessage",
"RetractionMessage",
"CoincidenceTierMessage",
"SignificanceTierMessage",
"TimingTierMessage",
"compatible_message_types",
"create_messages",
]


# .................................................................................................
def convert_timestamp_to_ns_precision(timestamp: Union[str, datetime, np.datetime64]) -> str:
Expand Down Expand Up @@ -68,10 +70,11 @@ class Config:
description="Textual identifier for the message"
)

uid: UUID4 = Field(
uuid: str = Field(
title="Unique message ID",
default_factory=uuid4,
description="Unique identifier for the message"
description="Unique identifier for the message",
validate_default=True
)

tier: Tier = Field(
Expand All @@ -83,13 +86,15 @@ class Config:
sent_time_utc: Optional[str] = Field(
default=None,
title="Sent time (UTC)",
description="Time the message was sent in ISO 8601-1:2019 format"
description="Time the message was sent in ISO 8601-1:2019 format",
validate_default=True
)

machine_time_utc: Optional[str] = Field(
default=None,
title="Machine time (UTC)",
description="Time of the event at the detector in ISO 8601-1:2019 format"
description="Time of the event at the detector in ISO 8601-1:2019 format",
validate_default=True
)

is_pre_sn: Optional[bool] = Field(
Expand Down Expand Up @@ -123,14 +128,21 @@ class Config:
frozen=True,
)

@validator("sent_time_utc", "machine_time_utc", pre=True, always=True)
@field_validator("sent_time_utc", "machine_time_utc", mode="before")
def _convert_timestamp_to_ns_precision(cls, v):
"""
Convert to nanosecond precision (before running Pydantic validators).
"""
if v is not None:
return convert_timestamp_to_ns_precision(timestamp=v)

@field_validator("uuid", mode="before")
def _cast_uuid_to_string(cls, v):
"""
Cast UUID to string (before running Pydantic validators).
"""
return str(v)

@model_validator(mode="after")
def _format_id(self):
"""
Expand All @@ -143,6 +155,18 @@ def _format_id(self):

return self

def fields(self):
"""
Return a list of fields for the message.
"""
return list(self.model_fields.keys())

def required_fields(self):
"""
Return a list of required fields for the message.
"""
return [k for k, v in self.model_fields.items() if v.is_required()]


# .................................................................................................
class DetectorMessageBase(MessageBase):
Expand Down Expand Up @@ -171,7 +195,7 @@ def _validate_detector_name(self) -> str:


# .................................................................................................
class HeartBeat(DetectorMessageBase):
class HeartbeatMessage(DetectorMessageBase):
"""
Heartbeat detector message.
"""
Expand All @@ -186,7 +210,7 @@ class Config:
examples=["ON", "OFF"]
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.HEART_BEAT
return values
Expand All @@ -204,15 +228,15 @@ def _validate_model(self):


# .................................................................................................
class Retraction(DetectorMessageBase):
class RetractionMessage(DetectorMessageBase):
"""
Retraction detector message.
"""

class Config:
validate_assignment = True

retract_message_uid: Optional[UUID4] = Field(
retract_message_uid: Optional[str] = Field(
default=None,
title="Unique message ID",
description="Unique identifier for the message to retract"
Expand All @@ -230,25 +254,25 @@ class Config:
description="Reason for retraction",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.RETRACTION
return values

@model_validator(mode="after")
def _validate_model(self):
if self.retract_latest and self.retract_message_uid is not None:
raise ValueError("retract_message_uid cannot be specified when retract_latest=True")
raise ValueError("retract_message_uuid cannot be specified when retract_latest=True")

if not self.retract_latest and self.retract_message_uid is None:
raise ValueError("Must specify either retract_message_uid or retract_latest=True")
raise ValueError("Must specify either retract_message_uuid or retract_latest=True")
return self


# .................................................................................................
class TierMessageBase(DetectorMessageBase):
"""
Tier detector base message
Tier base message
"""

class Config:
Expand Down Expand Up @@ -276,13 +300,13 @@ class TimingTierMessage(TierMessageBase):
class Config:
validate_assignment = True

timing_series: List[str] = Field(
timing_series: List[Union[str, int]] = Field(
...,
title="Timing Series",
description="Timing series of the event",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.TIMING_TIER
return values
Expand Down Expand Up @@ -322,7 +346,7 @@ class Config:
description="Time bin width of the event",
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.SIGNIFICANCE_TIER
return values
Expand Down Expand Up @@ -358,7 +382,7 @@ class Config:
description="Time of the first neutrino in the event in ISO 8601-1:2019 format"
)

@root_validator(pre=True)
@model_validator(mode="before")
def _set_tier(cls, values):
values['tier'] = Tier.COINCIDENCE_TIER
return values
Expand All @@ -384,3 +408,41 @@ def _validate_neutrino_time(self):
raise ValueError("neutrino_time_utc must be in the past")

return self


# .................................................................................................
def compatible_message_types(**kwargs) -> list:
"""
Return a list of message types that are compatible with the given keyword arguments.
"""

message_types = [
HeartbeatMessage,
RetractionMessage,
CoincidenceTierMessage,
SignificanceTierMessage,
TimingTierMessage,
]

compatible_message_types = []
for message_type in message_types:
try:
message_type(**kwargs)
compatible_message_types.append(message_type)
except ValidationError:
pass

return compatible_message_types


# .................................................................................................
def create_messages(**kwargs) -> list:
"""
Return a list of messages initialized with the given keyword arguments.
"""

messages = []
for message_type in compatible_message_types(**kwargs):
messages.append(message_type(**kwargs))

return messages
7 changes: 3 additions & 4 deletions snews/schema/CoincidenceTierMessage.schema.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"Tier": {
"enum": [
Expand Down Expand Up @@ -30,9 +30,8 @@
"description": "Textual identifier for the message",
"title": "Human-readable message ID"
},
"uid": {
"uuid": {
"description": "Unique identifier for the message",
"format": "uuid4",
"title": "Unique message ID",
"type": "string"
},
Expand Down Expand Up @@ -132,7 +131,7 @@
"type": "null"
}
],
"default": "1a",
"default": "0.1",
"description": "Schema version of the message",
"title": "Schema Version"
},
Expand Down
2 changes: 1 addition & 1 deletion snews/schema/Detector.schema.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"DetectorType": {
"enum": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"schema_author": "Supernova Early Warning System (SNEWS)",
"schema_version": "1a",
"schema_version": "0.1",
"$defs": {
"Tier": {
"enum": [
Expand Down Expand Up @@ -30,9 +30,8 @@
"description": "Textual identifier for the message",
"title": "Human-readable message ID"
},
"uid": {
"uuid": {
"description": "Unique identifier for the message",
"format": "uuid4",
"title": "Unique message ID",
"type": "string"
},
Expand Down Expand Up @@ -132,7 +131,7 @@
"type": "null"
}
],
"default": "1a",
"default": "0.1",
"description": "Schema version of the message",
"title": "Schema Version"
},
Expand All @@ -156,6 +155,6 @@
"detector_name",
"detector_status"
],
"title": "HeartBeat",
"title": "HeartbeatMessage",
"type": "object"
}
Loading

0 comments on commit 23cdcd1

Please sign in to comment.