Created
April 7, 2026 13:30
-
-
Save bbelderbos/9f543590707298556a194aff80c85a22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Finite State Machine — GitHub PR workflow | |
| Full implementation + tests from: https://belderbos.dev/blog/build-finite-state-machine-python | |
| Run: uvx pytest fsm.py -v | |
| """ | |
| import pytest | |
| # --- Implementation --- | |
| TRANSITIONS = { | |
| "draft": {"open", "closed"}, | |
| "open": {"changes_requested", "approved", "closed"}, | |
| "changes_requested": {"open", "closed"}, | |
| "approved": {"merged", "open", "closed"}, | |
| "merged": set(), | |
| "closed": {"open"}, | |
| } | |
| class InvalidTransition(Exception): | |
| pass | |
| class StateMachine: | |
| def __init__( | |
| self, | |
| initial, | |
| transitions: dict[str, set[str]], | |
| *, | |
| guards=None, | |
| hooks=None, | |
| context=None, | |
| ): | |
| self._state = initial | |
| self._transitions = transitions | |
| self._guards = guards if guards is not None else [] | |
| self._hooks = hooks if hooks is not None else {} | |
| self._context = context if context is not None else {} | |
| self._history: list[tuple[str, str]] = [] | |
| @property | |
| def current(self): | |
| return self._state | |
| @property | |
| def context(self): | |
| return self._context | |
| @property | |
| def history(self): | |
| return list(self._history) | |
| def can_transition(self, to): | |
| return to in self._transitions.get(self._state, set()) | |
| def transition(self, to): | |
| if not self.can_transition(to): | |
| raise InvalidTransition(f"Cannot go from {self._state} to {to}") | |
| errors = [] | |
| for guard in self._guards: | |
| if not guard(self, to): | |
| errors.append(guard.__name__) | |
| if errors: | |
| raise InvalidTransition( | |
| f"Cannot go from {self._state} to {to} " | |
| f"due to failed guards: {', '.join(errors)}" | |
| ) | |
| previous = self._state | |
| self._state = to | |
| self._history.append((previous, to)) | |
| for hook in self._hooks.get(to, []): | |
| hook(self, previous, to) | |
| # --- Tests --- | |
| def test_initial_state(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| assert sm.current == "draft" | |
| def test_valid_transition(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| sm.transition("open") | |
| assert sm.current == "open" | |
| def test_invalid_transition_raises(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| with pytest.raises(InvalidTransition): | |
| sm.transition("merged") | |
| def test_can_transition_returns_bool(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| assert sm.can_transition("open") is True | |
| assert sm.can_transition("merged") is False | |
| def test_terminal_state_blocks_transitions(): | |
| sm = StateMachine("approved", TRANSITIONS) | |
| sm.transition("merged") | |
| assert sm.current == "merged" | |
| with pytest.raises(InvalidTransition): | |
| sm.transition("open") | |
| def test_closed_can_reopen(): | |
| sm = StateMachine("open", TRANSITIONS) | |
| sm.transition("closed") | |
| sm.transition("open") | |
| assert sm.current == "open" | |
| def test_history_tracks_transitions(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| sm.transition("open") | |
| sm.transition("approved") | |
| assert len(sm.history) == 2 | |
| assert sm.history[0] == ("draft", "open") | |
| assert sm.history[1] == ("open", "approved") | |
| def test_history_empty_initially(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| assert sm.history == [] | |
| def test_guard_blocks_transition_with_name_in_error(): | |
| def always_fails(sm, to): | |
| return False | |
| sm = StateMachine("draft", TRANSITIONS, guards=[always_fails]) | |
| with pytest.raises(InvalidTransition, match="always_fails"): | |
| sm.transition("open") | |
| def test_guard_blocks_transition(): | |
| def needs_approval(sm, to): | |
| if to == "merged": | |
| return sm.context.get("approvals", 0) >= 1 | |
| return True | |
| sm = StateMachine("approved", TRANSITIONS, guards=[needs_approval]) | |
| with pytest.raises(InvalidTransition): | |
| sm.transition("merged") | |
| def test_guard_allows_when_satisfied(): | |
| def needs_approval(sm, to): | |
| if to == "merged": | |
| return sm.context.get("approvals", 0) >= 1 | |
| return True | |
| sm = StateMachine( | |
| "approved", TRANSITIONS, guards=[needs_approval], context={"approvals": 1} | |
| ) | |
| sm.transition("merged") | |
| assert sm.current == "merged" | |
| def test_multiple_guards_all_must_pass(): | |
| def guard_one(sm, to): | |
| return sm.context.get("one", False) | |
| def guard_two(sm, to): | |
| return sm.context.get("two", False) | |
| sm = StateMachine( | |
| "approved", | |
| TRANSITIONS, | |
| guards=[guard_one, guard_two], | |
| context={"one": True, "two": False}, | |
| ) | |
| with pytest.raises(InvalidTransition): | |
| sm.transition("merged") | |
| sm.context["two"] = True | |
| sm.transition("merged") | |
| assert sm.current == "merged" | |
| def test_hook_called_on_transition(): | |
| called = [] | |
| def on_open(sm, from_state, to_state): | |
| called.append((from_state, to_state)) | |
| sm = StateMachine("draft", TRANSITIONS, hooks={"open": [on_open]}) | |
| sm.transition("open") | |
| assert called == [("draft", "open")] | |
| def test_hook_not_called_on_failed_transition(): | |
| called = [] | |
| def on_merged(sm, from_state, to_state): | |
| called.append("merged") | |
| sm = StateMachine("draft", TRANSITIONS, hooks={"merged": [on_merged]}) | |
| with pytest.raises(InvalidTransition): | |
| sm.transition("merged") | |
| assert called == [] | |
| def test_multiple_hooks_for_same_state(): | |
| results = [] | |
| def hook_a(sm, from_state, to_state): | |
| results.append("a") | |
| def hook_b(sm, from_state, to_state): | |
| results.append("b") | |
| sm = StateMachine("draft", TRANSITIONS, hooks={"open": [hook_a, hook_b]}) | |
| sm.transition("open") | |
| assert results == ["a", "b"] | |
| def test_hook_can_access_context(): | |
| def track_reviewer(sm, from_state, to_state): | |
| sm.context["notified"] = True | |
| sm = StateMachine("draft", TRANSITIONS, hooks={"open": [track_reviewer]}) | |
| sm.transition("open") | |
| assert sm.context.get("notified") is True | |
| def test_context_default_empty(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| assert sm.context == {} | |
| def test_context_can_be_initialized(): | |
| sm = StateMachine("draft", TRANSITIONS, context={"author": "alice"}) | |
| assert sm.context["author"] == "alice" | |
| def test_context_mutable_during_lifecycle(): | |
| sm = StateMachine("draft", TRANSITIONS) | |
| sm.context["reviewers"] = ["bob"] | |
| sm.transition("open") | |
| assert sm.context["reviewers"] == ["bob"] | |
| def test_full_pr_happy_path(): | |
| notifications = [] | |
| def notify(sm, from_state, to_state): | |
| notifications.append(f"{from_state}->{to_state}") | |
| def needs_approval(sm, to): | |
| if to == "merged": | |
| return sm.context.get("approvals", 0) >= 1 | |
| return True | |
| sm = StateMachine( | |
| "draft", | |
| TRANSITIONS, | |
| guards=[needs_approval], | |
| hooks={"open": [notify], "approved": [notify], "merged": [notify]}, | |
| context={"approvals": 0}, | |
| ) | |
| sm.transition("open") | |
| sm.transition("approved") | |
| sm.context["approvals"] = 1 | |
| sm.transition("merged") | |
| assert sm.current == "merged" | |
| assert len(sm.history) == 3 | |
| assert notifications == ["draft->open", "open->approved", "approved->merged"] | |
| def test_pr_with_changes_requested(): | |
| sm = StateMachine("open", TRANSITIONS) | |
| sm.transition("changes_requested") | |
| sm.transition("open") | |
| sm.transition("approved") | |
| assert sm.current == "approved" | |
| assert len(sm.history) == 3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment