Skip to content

Commit

Permalink
typing fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
cmekik committed Nov 2, 2022
1 parent 3de7347 commit 61bf9ec
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 24 deletions.
14 changes: 8 additions & 6 deletions pyClarion/base/constructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import MappingProxyType
from contextvars import ContextVar
from typing import (Union, Tuple, Callable, Any, Sequence, Iterator, ClassVar,
List, OrderedDict)
List, OrderedDict, Generic, TypeVar)
from functools import partial
import logging

Expand Down Expand Up @@ -81,20 +81,22 @@ def _log_init(self, name: str) -> None:
logging.debug(f"Initializing {tname} '{name}' in '{context}'.")


class Module(Construct):
P = TypeVar("P", bound=Process)

class Module(Construct, Generic[P]):
"""An elementary module."""

_constant: ClassVar[float] = 0.0

_process: Process
_process: P
_inputs: List[Tuple[str, Callable]]
_i_uris: Tuple[str, ...]
_fs_uris: Tuple[str, ...]

def __init__(
self,
name: str,
process: Process,
process: P,
i_uris: Sequence[str] = (),
fs_uris: Sequence[str] = ()
) -> None:
Expand Down Expand Up @@ -124,12 +126,12 @@ def fs_uris(self) -> Tuple[str, ...]:
return self._fs_uris

@property
def process(self) -> Process:
def process(self) -> P:
"""Module process; issues updates and emits activations."""
return self._process

@process.setter
def process(self, process: Process) -> None:
def process(self, process: P) -> None:
self._process = process
process.prefix = uris.split_head(self.path.lstrip(uris.SEP))[1]

Expand Down
12 changes: 8 additions & 4 deletions pyClarion/components/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class Receptors(cld.Process):
initial = nd.NumDict()

def __init__(
self, reprs: Union[List[str], Dict[str, List[Union[str, int]]]]
self,
reprs: Union[List[str], Dict[str, List[str]], Dict[str, List[int]]]
) -> None:
self._reprs = reprs
self._data: nd.NumDict[feature] = nd.NumDict()
Expand Down Expand Up @@ -95,7 +96,10 @@ class Actions(cld.Process):
initial = nd.NumDict()
_cmd_pre = "cmd"

def __init__(self, actions: Dict[str, List[Union[str, int]]]) -> None:
def __init__(
self,
actions: Union[Dict[str, List[str]], Dict[str, List[int]]]
) -> None:
self.actions = OrderedDict(actions)

def call(self, c: nd.NumDict[feature]) -> nd.NumDict[feature]:
Expand Down Expand Up @@ -357,11 +361,11 @@ def call(
s_r = (cr
.mul_from(d, kf=cld.first, strict=True)
.div(norm)
.sum_by(kf=cld.first)
.sum_by(kf=cld.second)
.squeeze())
s_c = (rc
.mul_from(s_r, kf=cld.first, strict=True)
.max_by(kf=cld.first)
.max_by(kf=cld.second)
.squeeze())

return s_c, s_r
Expand Down
2 changes: 1 addition & 1 deletion pyClarion/components/ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def call(
.mul_from(gains, kf=cld.eye)))

@property
def reprs(self) -> Tuple[NumDict[feature], ...]:
def reprs(self) -> Tuple[feature, ...]:
return tuple(feature(cld.prefix(d, self.prefix)) for d in self.dspec)
2 changes: 1 addition & 1 deletion pyClarion/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def first(pair: Tuple[T, Any]) -> T:
return pair[0]


def second(pair: Tuple[T, Any]) -> T:
def second(pair: Tuple[Any, T]) -> T:
"""Return the second element in a pair."""
return pair[1]

Expand Down
1 change: 1 addition & 0 deletions pyClarion/numdicts/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def boltzmann(d: nd.NumDict, t: nd.NumDict) -> nd.NumDict:
vmax, _t = max(vs), t._c
# v - vmax is a stability trick; softmax(x) = softmax(x + c)
exp_v = [_exp((v - vmax) / _t) for v in vs]
assert len(exp_v)
sum_exp_v = sum(exp_v)
vals = [v / sum_exp_v for v in exp_v]
return nd.NumDict._new(m={k: v for k, v in zip(ks, vals)})
Expand Down
16 changes: 4 additions & 12 deletions pyClarion/numdicts/vec_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,8 @@ def _grad_reduce_sum(grads, result, d, *, key):
return (dops.isolate(grads, key=key) * dops.mask(d),)


@overload
def matmul(d1: nd.NumDict[T], d2: nd.NumDict[T]) -> nd.NumDict[Any]:
...

@overload
def matmul(d1: nd.NumDict[T], d2: nd.NumDict[T], *, key: T) -> nd.NumDict[T]:
...

def matmul(d1, d2, *, key=None):
return reduce_sum(d1 * d2, key=key)
return reduce_sum(d1 * d2)


@overload
Expand Down Expand Up @@ -242,7 +234,7 @@ def _grad_max_by(
grads: nd.NumDict[T2], result: nd.NumDict[T2], d: nd.NumDict[T1], *,
kf: Callable[[T1], T2]
) -> Tuple[nd.NumDict]:
return (put(grads, d, kf=kf) * bops.isclose(d, put(d, result, kf=kf)),)
return (put(d, grads, kf=kf) * bops.isclose(d, put(d, result, kf=kf)),)


@gt.GradientTape.op()
Expand All @@ -258,9 +250,9 @@ def min_by(d: nd.NumDict[T1], *, kf: Callable[[T1], T2]) -> nd.NumDict[T2]:
@gt.GradientTape.grad(min_by)
def _grad_min_by(
grads: nd.NumDict[T2], result: nd.NumDict[T2], d: nd.NumDict[T1], *,
kf: Callable[[T2], T1]
kf: Callable[[T1], T2]
) -> Tuple[nd.NumDict]:
return (put(grads, d, kf=kf) * bops.isclose(d, put(d, result, kf=kf)),)
return (put(d, grads, kf=kf) * bops.isclose(d, put(d, result, kf=kf)),)


@gt.GradientTape.op()
Expand Down

0 comments on commit 61bf9ec

Please sign in to comment.