"""Tests for bin/amc-hook functions. These are unit tests for the pure functions in the hook script. Edge cases are prioritized over happy paths. """ import json import os import sys import tempfile import types import unittest from pathlib import Path from unittest.mock import patch # Import the hook module (no .py extension, so use compile+exec pattern) hook_path = Path(__file__).parent.parent / "bin" / "amc-hook" amc_hook = types.ModuleType("amc_hook") amc_hook.__file__ = str(hook_path) # Load module code - this is safe, we're loading our own source file code = compile(hook_path.read_text(), hook_path, "exec") exec(code, amc_hook.__dict__) # noqa: S102 - loading local module class TestDetectProseQuestion(unittest.TestCase): """Tests for _detect_prose_question edge cases.""" def test_none_input_returns_none(self): self.assertIsNone(amc_hook._detect_prose_question(None)) def test_empty_string_returns_none(self): self.assertIsNone(amc_hook._detect_prose_question("")) def test_whitespace_only_returns_none(self): self.assertIsNone(amc_hook._detect_prose_question(" \n\t ")) def test_no_question_mark_returns_none(self): self.assertIsNone(amc_hook._detect_prose_question("This is a statement.")) def test_question_mark_in_middle_not_at_end_returns_none(self): # Question mark exists but message doesn't END with one self.assertIsNone(amc_hook._detect_prose_question("What? I said hello.")) def test_trailing_whitespace_after_question_still_detects(self): result = amc_hook._detect_prose_question("Is this a question? \n\t") self.assertEqual(result, "Is this a question?") def test_question_in_last_paragraph_only(self): msg = "First paragraph here.\n\nSecond paragraph is the question?" result = amc_hook._detect_prose_question(msg) self.assertEqual(result, "Second paragraph is the question?") def test_multiple_paragraphs_question_not_in_last_returns_none(self): # Question in first paragraph, statement in last msg = "Is this a question?\n\nNo, this is the last paragraph." self.assertIsNone(amc_hook._detect_prose_question(msg)) def test_truncates_long_question_to_max_length(self): long_question = "x" * 600 + "?" result = amc_hook._detect_prose_question(long_question) self.assertLessEqual(len(result), amc_hook.MAX_QUESTION_LEN + 1) # +1 for ? def test_long_question_tries_sentence_boundary(self): # Create a message longer than MAX_QUESTION_LEN (500) with a sentence boundary # The truncation takes the LAST MAX_QUESTION_LEN chars, then finds FIRST ". " within that prefix = "a" * 500 + ". Sentence start. " suffix = "Is this the question?" msg = prefix + suffix self.assertGreater(len(msg), amc_hook.MAX_QUESTION_LEN) result = amc_hook._detect_prose_question(msg) # Code finds FIRST ". " in truncated portion, so starts at "Sentence start" self.assertTrue( result.startswith("Sentence start"), f"Expected to start with 'Sentence start', got: {result[:50]}" ) def test_long_question_no_sentence_boundary_truncates_from_end(self): # No period in the long text long_msg = "a" * 600 + "?" result = amc_hook._detect_prose_question(long_msg) self.assertTrue(result.endswith("?")) self.assertLessEqual(len(result), amc_hook.MAX_QUESTION_LEN + 1) def test_single_character_question(self): result = amc_hook._detect_prose_question("?") self.assertEqual(result, "?") def test_newlines_within_last_paragraph_preserved(self): msg = "Intro.\n\nLine one\nLine two?" result = amc_hook._detect_prose_question(msg) self.assertIn("\n", result) class TestExtractQuestions(unittest.TestCase): """Tests for _extract_questions edge cases.""" def test_empty_hook_returns_empty_list(self): self.assertEqual(amc_hook._extract_questions({}), []) def test_missing_tool_input_returns_empty_list(self): self.assertEqual(amc_hook._extract_questions({"other": "data"}), []) def test_tool_input_is_none_returns_empty_list(self): self.assertEqual(amc_hook._extract_questions({"tool_input": None}), []) def test_tool_input_is_list_returns_empty_list(self): # tool_input should be dict, not list self.assertEqual(amc_hook._extract_questions({"tool_input": []}), []) def test_tool_input_is_string_json_parsed(self): tool_input = json.dumps({"questions": [{"question": "Test?", "options": []}]}) result = amc_hook._extract_questions({"tool_input": tool_input}) self.assertEqual(len(result), 1) self.assertEqual(result[0]["question"], "Test?") def test_tool_input_invalid_json_string_returns_empty(self): result = amc_hook._extract_questions({"tool_input": "not valid json"}) self.assertEqual(result, []) def test_questions_key_is_none_returns_empty(self): result = amc_hook._extract_questions({"tool_input": {"questions": None}}) self.assertEqual(result, []) def test_questions_key_missing_returns_empty(self): result = amc_hook._extract_questions({"tool_input": {"other": "data"}}) self.assertEqual(result, []) def test_option_without_markdown_excluded_from_output(self): hook = { "tool_input": { "questions": [{ "question": "Pick one", "options": [{"label": "A", "description": "Desc A"}], }] } } result = amc_hook._extract_questions(hook) self.assertNotIn("markdown", result[0]["options"][0]) def test_option_with_markdown_included(self): hook = { "tool_input": { "questions": [{ "question": "Pick one", "options": [{"label": "A", "description": "Desc", "markdown": "```code```"}], }] } } result = amc_hook._extract_questions(hook) self.assertEqual(result[0]["options"][0]["markdown"], "```code```") def test_missing_question_fields_default_to_empty(self): hook = {"tool_input": {"questions": [{}]}} result = amc_hook._extract_questions(hook) self.assertEqual(result[0]["question"], "") self.assertEqual(result[0]["header"], "") self.assertEqual(result[0]["options"], []) def test_option_missing_fields_default_to_empty(self): hook = {"tool_input": {"questions": [{"options": [{}]}]}} result = amc_hook._extract_questions(hook) self.assertEqual(result[0]["options"][0]["label"], "") self.assertEqual(result[0]["options"][0]["description"], "") class TestAtomicWrite(unittest.TestCase): """Tests for _atomic_write edge cases.""" def test_writes_to_nonexistent_file(self): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "new_file.json" amc_hook._atomic_write(path, {"key": "value"}) self.assertTrue(path.exists()) self.assertEqual(json.loads(path.read_text()), {"key": "value"}) def test_overwrites_existing_file(self): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "existing.json" path.write_text('{"old": "data"}') amc_hook._atomic_write(path, {"new": "data"}) self.assertEqual(json.loads(path.read_text()), {"new": "data"}) def test_cleans_up_temp_file_on_replace_failure(self): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "subdir" / "file.json" # Parent doesn't exist, so mkstemp will fail with self.assertRaises(FileNotFoundError): amc_hook._atomic_write(path, {"data": "test"}) def test_no_partial_writes_on_failure(self): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "file.json" path.write_text('{"original": "data"}') # Mock os.replace to fail after the temp file is written original_replace = os.replace def failing_replace(src, dst): raise PermissionError("Simulated failure") with patch("os.replace", side_effect=failing_replace): with self.assertRaises(PermissionError): amc_hook._atomic_write(path, {"new": "data"}) # Original file should be unchanged self.assertEqual(json.loads(path.read_text()), {"original": "data"}) class TestReadSession(unittest.TestCase): """Tests for _read_session edge cases.""" def test_nonexistent_file_returns_empty_dict(self): result = amc_hook._read_session(Path("/nonexistent/path/file.json")) self.assertEqual(result, {}) def test_empty_file_returns_empty_dict(self): with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write("") path = Path(f.name) try: result = amc_hook._read_session(path) self.assertEqual(result, {}) finally: path.unlink() def test_invalid_json_returns_empty_dict(self): with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write("not valid json {{{") path = Path(f.name) try: result = amc_hook._read_session(path) self.assertEqual(result, {}) finally: path.unlink() def test_valid_json_returned(self): with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump({"session_id": "abc"}, f) path = Path(f.name) try: result = amc_hook._read_session(path) self.assertEqual(result, {"session_id": "abc"}) finally: path.unlink() class TestAppendEvent(unittest.TestCase): """Tests for _append_event edge cases.""" def test_creates_file_if_missing(self): with tempfile.TemporaryDirectory() as tmpdir: with patch.object(amc_hook, "EVENTS_DIR", Path(tmpdir)): amc_hook._append_event("session123", {"event": "test"}) event_file = Path(tmpdir) / "session123.jsonl" self.assertTrue(event_file.exists()) def test_appends_to_existing_file(self): with tempfile.TemporaryDirectory() as tmpdir: event_file = Path(tmpdir) / "session123.jsonl" event_file.write_text('{"event": "first"}\n') with patch.object(amc_hook, "EVENTS_DIR", Path(tmpdir)): amc_hook._append_event("session123", {"event": "second"}) lines = event_file.read_text().strip().split("\n") self.assertEqual(len(lines), 2) self.assertEqual(json.loads(lines[1])["event"], "second") def test_oserror_silently_ignored(self): with patch.object(amc_hook, "EVENTS_DIR", Path("/nonexistent/path")): # Should not raise amc_hook._append_event("session123", {"event": "test"}) class TestMainHookPathTraversal(unittest.TestCase): """Tests for path traversal protection in main().""" def test_session_id_with_path_traversal_sanitized(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() # Create a legitimate session file to test that traversal doesn't reach it legit_file = Path(tmpdir) / "secret.json" legit_file.write_text('{"secret": "data"}') hook_input = json.dumps({ "hook_event_name": "SessionStart", "session_id": "../secret", "cwd": "/test/project", }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): amc_hook.main() # The sanitized session ID should be "secret" (basename of "../secret") # and should NOT have modified the legit_file in parent dir self.assertEqual(json.loads(legit_file.read_text()), {"secret": "data"}) class TestMainHookEmptyInput(unittest.TestCase): """Tests for main() with various empty/invalid inputs.""" def test_empty_stdin_returns_silently(self): with patch("sys.stdin.read", return_value=""): # Should not raise amc_hook.main() def test_whitespace_only_stdin_returns_silently(self): with patch("sys.stdin.read", return_value=" \n\t "): amc_hook.main() def test_invalid_json_stdin_fails_silently(self): with patch("sys.stdin.read", return_value="not json"): amc_hook.main() def test_missing_session_id_returns_silently(self): with patch("sys.stdin.read", return_value='{"hook_event_name": "SessionStart"}'): amc_hook.main() def test_missing_event_name_returns_silently(self): with patch("sys.stdin.read", return_value='{"session_id": "abc123"}'): amc_hook.main() def test_empty_session_id_after_sanitization_returns_silently(self): # Edge case: session_id that becomes empty after basename() with patch("sys.stdin.read", return_value='{"hook_event_name": "SessionStart", "session_id": "/"}'): amc_hook.main() class TestMainSessionEndDeletesFile(unittest.TestCase): """Tests for SessionEnd hook behavior.""" def test_session_end_deletes_existing_session_file(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() session_file = sessions_dir / "abc123.json" session_file.write_text('{"session_id": "abc123"}') hook_input = json.dumps({ "hook_event_name": "SessionEnd", "session_id": "abc123", }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): amc_hook.main() self.assertFalse(session_file.exists()) def test_session_end_missing_file_no_error(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() hook_input = json.dumps({ "hook_event_name": "SessionEnd", "session_id": "nonexistent", }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): # Should not raise amc_hook.main() class TestMainPreToolUseWithoutExistingSession(unittest.TestCase): """Edge case: PreToolUse arrives but session file doesn't exist.""" def test_pre_tool_use_no_existing_session_returns_silently(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() hook_input = json.dumps({ "hook_event_name": "PreToolUse", "tool_name": "AskUserQuestion", "session_id": "nonexistent", "tool_input": {"questions": []}, }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): amc_hook.main() # No session file should be created self.assertFalse((sessions_dir / "nonexistent.json").exists()) class TestMainStopWithProseQuestion(unittest.TestCase): """Tests for Stop hook detecting prose questions.""" def test_stop_with_prose_question_sets_needs_attention(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() # Create existing session session_file = sessions_dir / "abc123.json" session_file.write_text(json.dumps({ "session_id": "abc123", "status": "active", })) hook_input = json.dumps({ "hook_event_name": "Stop", "session_id": "abc123", "last_assistant_message": "What do you think about this approach?", "cwd": "/test/project", }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): amc_hook.main() data = json.loads(session_file.read_text()) self.assertEqual(data["status"], "needs_attention") self.assertEqual(len(data["pending_questions"]), 1) self.assertIn("approach?", data["pending_questions"][0]["question"]) class TestMainTurnTimingAccumulation(unittest.TestCase): """Tests for turn timing accumulation across pause/resume cycles.""" def test_post_tool_use_accumulates_paused_time(self): with tempfile.TemporaryDirectory() as tmpdir: sessions_dir = Path(tmpdir) / "sessions" events_dir = Path(tmpdir) / "events" sessions_dir.mkdir() events_dir.mkdir() # Create session with existing paused state session_file = sessions_dir / "abc123.json" session_file.write_text(json.dumps({ "session_id": "abc123", "status": "needs_attention", "turn_paused_at": "2024-01-01T00:00:00+00:00", "turn_paused_ms": 5000, # Already had 5 seconds paused })) hook_input = json.dumps({ "hook_event_name": "PostToolUse", "tool_name": "AskUserQuestion", "session_id": "abc123", }) with patch.object(amc_hook, "SESSIONS_DIR", sessions_dir), \ patch.object(amc_hook, "EVENTS_DIR", events_dir), \ patch("sys.stdin.read", return_value=hook_input): amc_hook.main() data = json.loads(session_file.read_text()) # Should have accumulated more paused time self.assertGreater(data["turn_paused_ms"], 5000) # turn_paused_at should be removed after resuming self.assertNotIn("turn_paused_at", data) if __name__ == "__main__": unittest.main()