Skip to content

Commit

Permalink
Merge pull request #162 from ImogenBits/graph_util
Browse files Browse the repository at this point in the history
Graph utility methods
  • Loading branch information
Benezivas authored Jan 8, 2024
2 parents 42d2160 + cb16884 commit 0d15f5d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 5 deletions.
81 changes: 77 additions & 4 deletions algobattle/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility types used to easily define Problems."""
from dataclasses import dataclass
from functools import cache, cached_property
from sys import float_info
from typing import (
Annotated,
Expand All @@ -24,6 +25,7 @@
SupportsLt,
SupportsMod,
)
from itertools import pairwise

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
Expand Down Expand Up @@ -62,6 +64,7 @@
"DirectedGraph",
"UndirectedGraph",
"Edge",
"Path",
"EdgeLen",
"EdgeWeights",
"VertexWeights",
Expand Down Expand Up @@ -411,6 +414,24 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa
# * Graph classes


Vertex = SizeIndex
"""Type for vertices, encoded as numbers `0 <= v < instance.num_vertices`."""


Edge = Annotated[int, IndexInto[InstanceRef.edges]]
"""Type for edges, encoded as indices into `instance.edges`."""


def path_in_graph(path: list[Vertex], edge_set: set[tuple[Vertex, Vertex]]):
"""Checks that a path actually exists in the graph."""
for edge in pairwise(path):
if edge not in edge_set:
raise ValueError(f"The edge {edge} does not exist in the graph.")


Path = Annotated[list[Vertex], AttributeReferenceValidator(path_in_graph, InstanceRef.edge_set)]


class DirectedGraph(InstanceModel):
"""Base instance class for problems on directed graphs."""

Expand All @@ -422,6 +443,21 @@ def size(self) -> int:
"""A graph's size is the number of vertices in it."""
return self.num_vertices

@cached_property
def edge_set(self) -> set[tuple[Vertex, Vertex]]:
"""The set of edges in this graph."""
return set(self.edges)

@cache
def neighbors(self, vertex: Vertex, direction: Literal["all", "outgoing", "incoming"] = "all") -> set[Vertex]:
"""The neighbors of a vertex."""
res = set[Vertex]()
if direction in {"all", "outgoing"}:
res |= set(v for (u, v) in self.edges if u == vertex)
if direction in {"all", "incoming"}:
res |= set(v for (v, u) in self.edges if u == vertex)
return res


class UndirectedGraph(DirectedGraph):
"""Base instance class for problems on undirected graphs."""
Expand All @@ -440,13 +476,20 @@ def validate_instance(self):
if any(edge[::-1] in edge_set for edge in self.edges):
raise ValidationError("Undirected graph contains back and forth edges between two vertices.")

@cached_property
def edge_set(self) -> set[tuple[Vertex, Vertex]]:
"""The set of edges in this graph.
Vertex = SizeIndex
"""Type for vertices, encoded as numbers `0 <= v < instance.num_vertices`."""
Normalized to contain every edge in both directions.
"""
return set(self.edges) | set((v, u) for (u, v) in self.edges)

@cache
def neighbors(self, vertex: Vertex, direction: Literal["all", "outgoing", "incoming"] = "all") -> set[Vertex]:
"""The neighbors of a vertex."""
# more efficient specialization

Edge = IndexInto[InstanceRef.edges]
"""Type for edges, encoded as indices into `instance.edges`."""
return set(v for (u, v) in self.edge_set if u == vertex)


class EdgeLen:
Expand Down Expand Up @@ -477,12 +520,42 @@ class EdgeWeights(DirectedGraph, BaseModel, Generic[Weight]):

edge_weights: Annotated[list[Weight], EdgeLen]

@cached_property
def edges_with_weights(self) -> Iterator[tuple[tuple[Vertex, Vertex], Weight]]:
"""Iterate over all edges and their weights."""
return zip(self.edges, self.edge_weights)

@cache
def weight(self, edge: Edge | tuple[Vertex, Vertex]) -> Weight:
"""Returns the weight of an edge.
Raises KeyError if the given edge does not exist.
"""
if isinstance(edge, tuple):
try:
edge = self.edges.index(edge)
except ValueError:
if isinstance(self, UndirectedGraph):
try:
edge = self.edges.index((edge[1], edge[0]))
except ValueError:
raise KeyError
else:
raise KeyError

return self.edge_weights[edge]


class VertexWeights(DirectedGraph, BaseModel, Generic[Weight]):
"""Mixin for graphs with weighted vertices."""

vertex_weights: Annotated[list[Weight], SizeLen]

@cached_property
def vertices_with_weights(self) -> Iterator[tuple[Vertex, Weight]]:
"""Iterate over all edges and their weights."""
return enumerate(self.vertex_weights)


@dataclass(frozen=True, slots=True)
class LaxComp:
Expand Down
7 changes: 6 additions & 1 deletion docs/instructor/problem/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ directionless. Both graph's size is the number of vertices in it.

!!! tip "Associated Annotation Types"
As you can see in the example above, we also provide several types that are useful in type annotations of graph
problems such as `Vertex` or `Edge`. These are documented in more detail in the
problems such as `Vertex`, `Edge`, or `Path`. How these function is explained in more detail in the
[advanced annotations](annotations.md) section.

If you want the problem instance to also contain additional information associated with each vertex and/or each edge
Expand Down Expand Up @@ -201,6 +201,11 @@ indexed with the type of the weights you want to use.
...
```

!!! tip
These classes also contain some utility methods to easily perform common graph operations. For example,
`UndirectedGraph.edge_set` contains all edges in both directions, and the `neighbors` methods lets you quickly
access a vertex's neighbours.

## Comparing Floats

!!! abstract
Expand Down

0 comments on commit 0d15f5d

Please sign in to comment.