Skip to content

Commit

Permalink
Allow adding and removing areas
Browse files Browse the repository at this point in the history
  • Loading branch information
wernerhp committed Mar 26, 2024
1 parent 966fa6a commit 0a9acc2
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 94 deletions.
26 changes: 25 additions & 1 deletion custom_components/load_shedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
area_coordinator.update_interval = timedelta(
seconds=config_entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
)
for conf in config_entry.options.get(CONF_AREAS, {}).values():
for conf in config_entry.options.get(CONF_AREAS, []):
area = Area(
id=conf.get(CONF_ID),
name=conf.get(CONF_NAME),
Expand Down Expand Up @@ -147,6 +147,30 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
config_entry, data=new_data, options=new_options
)

if config_entry.version == 4:
old_data = {**config_entry.data}
old_options = {**config_entry.options}
new_data = {}
new_options = {
CONF_API_KEY: old_options.get(CONF_API_KEY),
CONF_AREAS: [],
}
for field in old_options:
if field == CONF_AREAS:
areas = old_options.get(CONF_AREAS, {})
for area_id in areas:
new_options[CONF_AREAS].append(areas[area_id])
continue

value = old_options.get(field)
if value is not None:
new_options[field] = value

config_entry.version = 5
hass.config_entries.async_update_entry(
config_entry, data=new_data, options=new_options
)

_LOGGER.info("Migration to version %s successful", config_entry.version)
return True

Expand Down
158 changes: 73 additions & 85 deletions custom_components/load_shedding/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
OptionsFlow,
OptionsFlowWithConfigEntry,
)
from homeassistant.const import CONF_API_KEY, CONF_DESCRIPTION, CONF_ID, CONF_NAME
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult, FlowHandler
Expand All @@ -17,7 +22,6 @@
from .const import (
CONF_AREA_ID,
CONF_AREAS,
CONF_PROVIDER,
CONF_SEARCH,
DOMAIN,
NAME,
Expand All @@ -36,13 +40,12 @@
class LoadSheddingFlowHandler(ConfigFlow, domain=DOMAIN):
"""Config flow for LoadShedding."""

VERSION = 4
VERSION = 5

def __init__(self):
self.provider: Provider = None
self.api_key: str = ""
self.areas: dict = {}
# self.device_unique_id = f"{DOMAIN}"

@staticmethod
@callback
Expand Down Expand Up @@ -86,7 +89,7 @@ async def async_step_sepush(
# Validate the token by checking the allowance.
sepush = SePush(token=self.api_key)
await self.hass.async_add_executor_job(sepush.check_allowance)
except (SePushError) as err:
except SePushError as err:
status_code = err.__cause__.args[0]
if status_code == 400:
errors["base"] = "sepush_400"
Expand Down Expand Up @@ -212,27 +215,15 @@ async def async_step_select_area(
data = {}
options = {
CONF_API_KEY: self.api_key,
CONF_AREAS: {
area.id: {
CONF_AREAS: [
{
CONF_DESCRIPTION: description,
CONF_NAME: area.name,
CONF_ID: area.id,
},
},
],
}

# entry = await self.async_set_unique_id(DOMAIN)
# if entry:
# try:
# _LOGGER.debug("Entry exists: %s", entry)
# if self.hass.config_entries.async_update_entry(entry, data=data):
# await self.hass.config_entries.async_reload(entry.entry_id)
# except Exception:
# _LOGGER.debug("Unknown error", exc_info=True)
# raise
# else:
# return self.async_abort(reason=FlowResultType.SHOW_PROGRESS_DONE)

return self.async_create_entry(
title=NAME,
data=data,
Expand All @@ -241,14 +232,12 @@ async def async_step_select_area(
)


class LoadSheddingOptionsFlowHandler(OptionsFlow):
class LoadSheddingOptionsFlowHandler(OptionsFlowWithConfigEntry):
"""Load Shedding config flow options handler."""

def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
# self.config_entry: ConfigEntry = config_entry
self.opts = dict(config_entry.options)

super().__init__(config_entry)
self.provider = Provider.SE_PUSH
self.api_key = config_entry.options.get(CONF_API_KEY)
self.areas = {}
Expand All @@ -260,36 +249,38 @@ async def async_step_init(

CONF_ACTIONS = {
CONF_SETUP_API: "Configure API",
# CONF_ADD_AREA: "Add area",
# CONF_DELETE_AREA: "Remove area",
# CONF_MULTI_STAGE_EVENTS: ""
CONF_ADD_AREA: "Add area",
CONF_DELETE_AREA: "Remove area",
}

if user_input is not None:
if user_input.get(CONF_ACTION) == CONF_SETUP_API:
return await self.async_step_sepush()
if user_input.get(CONF_ACTION) == CONF_ADD_AREA:
return await self.async_step_add_area()
if user_input.get(CONF_ACTION) == CONF_DELETE_AREA:
return await self.async_step_delete_area()
self.options[CONF_MULTI_STAGE_EVENTS] = user_input.get(
CONF_MULTI_STAGE_EVENTS
)
self.options[CONF_MIN_EVENT_DURATION] = user_input.get(
CONF_MIN_EVENT_DURATION
)
return self.async_create_entry(title=NAME, data=self.options)

OPTIONS_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ACTION): vol.In(CONF_ACTIONS),
vol.Optional(
CONF_MULTI_STAGE_EVENTS,
default=self.opts.get(CONF_MULTI_STAGE_EVENTS, False),
default=self.options.get(CONF_MULTI_STAGE_EVENTS, True),
): bool,
vol.Optional(
CONF_MIN_EVENT_DURATION,
default=self.opts.get(CONF_MIN_EVENT_DURATION, 30),
default=self.options.get(CONF_MIN_EVENT_DURATION, 31),
): int,
}
)

if user_input is not None:
if user_input.get(CONF_ACTION) == CONF_SETUP_API:
return await self.async_step_sepush()
if user_input.get(CONF_ACTION) == CONF_ADD_AREA:
return await self.async_step_add_area()
# if user_input.get(CONF_ACTION) == CONF_DELETE_AREA:
# return await self.async_step_delete_area()
self.opts[CONF_MULTI_STAGE_EVENTS] = user_input.get(CONF_MULTI_STAGE_EVENTS)
self.opts[CONF_MIN_EVENT_DURATION] = user_input.get(CONF_MIN_EVENT_DURATION)
return self.async_create_entry(title=NAME, data=self.opts)

return self.async_show_form(
step_id="init",
data_schema=OPTIONS_SCHEMA,
Expand All @@ -312,7 +303,7 @@ async def async_step_sepush(
sepush = SePush(token=api_key)
esp = await self.hass.async_add_executor_job(sepush.check_allowance)
_LOGGER.debug("Validate API Key Response: %s", esp)
except (SePushError) as err:
except SePushError as err:
status_code = err.__cause__.args[0]
if status_code == 400:
errors["base"] = "sepush_400"
Expand All @@ -326,8 +317,8 @@ async def async_step_sepush(
errors["base"] = "provider_error"
else:
self.api_key = api_key
self.opts[CONF_API_KEY] = api_key
return self.async_create_entry(title=NAME, data=self.opts)
self.options[CONF_API_KEY] = api_key
return self.async_create_entry(title=NAME, data=self.options)

data_schema = vol.Schema(
{
Expand All @@ -340,37 +331,6 @@ async def async_step_sepush(
errors=errors,
)

# async def async_step_init(
# self, user_input: dict[str, Any] | None = None
# ) -> FlowResult: # pylint: disable=unused-argument
# """Manage the options."""

# CONF_ACTIONS = {
# CONF_ADD_DEVICE: "Add Area",
# CONF_EDIT_DEVICE: "Remove Area",
# }

# CONFIGURE_SCHEMA = vol.Schema(
# {
# vol.Required(CONF_ACTION, default=CONF_ADD_DEVICE): vol.In(
# CONF_ACTIONS
# ),
# }
# )

# return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))

# schema: dict[vol.Marker, type] = {}
# areas = self.opts.get(CONF_AREAS, {})
# for area_id, area in areas.items():
# schema[vol.Required(area_id, default=True)] = vol.In(
# {area_id: area.get(CONF_DESCRIPTION)}
# )

# return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))

# return await self.async_step_lookup_areas()

async def async_step_add_area(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
Expand Down Expand Up @@ -469,7 +429,6 @@ async def async_step_select_area(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the flow step to create a area."""
areas = self.opts.get(CONF_AREAS, {})
area = self.areas.get(user_input.get(CONF_AREA_ID))

description = f"{area.name}"
Expand All @@ -478,16 +437,45 @@ async def async_step_select_area(
if area.province is not Province.UNKNOWN:
description += f", {area.province}"

areas[area.id] = {
CONF_DESCRIPTION: description,
CONF_NAME: area.name,
CONF_ID: area.id,
}

self.opts.update(
self.options[CONF_AREAS].append(
{
CONF_AREAS: areas,
CONF_DESCRIPTION: description,
CONF_NAME: area.name,
CONF_ID: area.id,
}
)
result = self.async_create_entry(title=NAME, data=self.opts)

result = self.async_create_entry(title=NAME, data=self.options)
return result

async def async_step_delete_area(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the flow step to delete an area."""
errors = None
if user_input is None:
area_idx = {}
for idx, area in enumerate(self.options.get(CONF_AREAS, [])):
area_idx[idx] = area.get(CONF_NAME)

if not errors:
data_schema = vol.Schema(
{
vol.Optional(CONF_AREA_ID): vol.In(area_idx),
}
)

return self.async_show_form(
step_id="delete_area",
data_schema=data_schema,
errors=errors,
)
else:
new_areas = []
for idx, area in enumerate(self.options.get(CONF_AREAS, [])):
if idx == user_input.get(CONF_AREA_ID):
continue
new_areas.append(area)

self.options[CONF_AREAS] = new_areas
return self.async_create_entry(title=NAME, data=self.options)
3 changes: 2 additions & 1 deletion custom_components/load_shedding/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MAX_FORECAST_DAYS: Final = 7
NAME: Final = "Load Shedding"
MANUFACTURER: Final = "@wernerhp"
VERSION: Final = "1.1.0"
VERSION: Final = "2.0.0"
DEFAULT_SCAN_INTERVAL: Final = 60
AREA_UPDATE_INTERVAL: Final = 86400 # 60sec * 60min * 24h / daily
QUOTA_UPDATE_INTERVAL: Final = 1800 # 60sec * 30min
Expand Down Expand Up @@ -43,6 +43,7 @@

ATTR_AREA: Final = "area"
ATTR_AREAS: Final = "areas"
ATTR_AREA_ID: Final = "area_id"
ATTR_CURRENT: Final = "current"
ATTR_END_IN: Final = "ends_in"
ATTR_END_TIME: Final = "end_time"
Expand Down
2 changes: 1 addition & 1 deletion custom_components/load_shedding/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
"dependencies": [],
"codeowners": ["@wernerhp"],
"iot_class": "cloud_polling",
"version": "1.2.0"
"version": "2.0.0"
}
15 changes: 9 additions & 6 deletions custom_components/load_shedding/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from . import LoadSheddingDevice
from .const import (
ATTR_AREA,
ATTR_AREA_ID,
ATTR_END_IN,
ATTR_END_TIME,
ATTR_FORECAST,
Expand Down Expand Up @@ -79,8 +80,8 @@ async def async_setup_entry(
stage_entity = LoadSheddingStageSensorEntity(stage_coordinator, idx)
entities.append(stage_entity)

for area in area_coordinator.areas:
area_entity = LoadSheddingAreaSensorEntity(area_coordinator, area)
for idx, area in enumerate(area_coordinator.areas):
area_entity = LoadSheddingAreaSensorEntity(area_coordinator, area, idx + 1)
entities.append(area_entity)

quota_entity = LoadSheddingQuotaSensorEntity(quota_coordinator)
Expand Down Expand Up @@ -205,7 +206,9 @@ class LoadSheddingAreaSensorEntity(

coordinator: CoordinatorEntity

def __init__(self, coordinator: CoordinatorEntity, area: Area) -> None:
def __init__(
self, coordinator: CoordinatorEntity, area: Area, area_idx: int
) -> None:
"""Initialize."""
super().__init__(coordinator)
self.area = area
Expand All @@ -218,7 +221,7 @@ def __init__(self, coordinator: CoordinatorEntity, area: Area) -> None:
entity_registry_enabled_default=True,
)
self._attr_unique_id = (
f"{self.coordinator.config_entry.entry_id}_sensor_{area.id}"
f"{self.coordinator.config_entry.entry_id}_sensor_area_{area.id}"
)
self.entity_id = f"{DOMAIN}.{DOMAIN}_area_{area.id}"

Expand Down Expand Up @@ -293,10 +296,10 @@ def extra_state_attributes(self) -> dict[str, list, Any]:
forecast = data[ATTR_FORECAST]

attrs = get_sensor_attrs(forecast)
attrs[ATTR_AREA_ID] = self.area.id
attrs[ATTR_FORECAST] = forecast
attrs[ATTR_LAST_UPDATE] = self.coordinator.last_update
attrs = clean(attrs)

self._attr_extra_state_attributes.update(attrs)
return self._attr_extra_state_attributes

Expand Down Expand Up @@ -435,7 +438,7 @@ def get_sensor_attrs(forecast: list, stage: Stage = Stage.NO_LOAD_SHEDDING) -> d

def clean(data: dict) -> dict:
"""Remove default values from dict"""
for (key, value) in CLEAN_DATA.items():
for key, value in CLEAN_DATA.items():
if key not in data:
continue
if data[key] == value:
Expand Down

0 comments on commit 0a9acc2

Please sign in to comment.