diff --git a/src/codexmcp/server.py b/src/codexmcp/server.py index 93cfc93..aa05efd 100644 --- a/src/codexmcp/server.py +++ b/src/codexmcp/server.py @@ -5,9 +5,9 @@ from __future__ import annotations import json import os import queue -import re import subprocess import threading +import time import uuid from pathlib import Path from typing import Annotated, Any, Dict, Generator, List, Literal, Optional @@ -37,53 +37,73 @@ def run_shell_command(cmd: list[str]) -> Generator[str, None, None]: """ # On Windows, codex is exposed via a *.cmd shim. Use cmd.exe with /s so # user prompts containing quotes/newlines aren't reinterpreted as shell syntax. - popen_cmd = cmd - + popen_cmd = cmd.copy() codex_path = shutil.which('codex') or cmd[0] popen_cmd[0] = codex_path - # if os.name == "nt" and codex_path.lower().endswith((".cmd", ".bat")): - # from subprocess import list2cmdline - # popen_cmd = ["cmd.exe", "/s", "/c", list2cmdline(cmd)] - process = subprocess.Popen( popen_cmd, - shell=False, # Safer: no shell injection - stdin=subprocess.PIPE, # Prevent process from waiting for input + shell=False, + stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding='utf-8', ) - output_queue: queue.Queue[str] = queue.Queue() + output_queue: queue.Queue[str | None] = queue.Queue() + GRACEFUL_SHUTDOWN_DELAY = 0.3 + + def is_turn_completed(line: str) -> bool: + """Check if the line indicates turn completion via JSON parsing.""" + try: + data = json.loads(line) + return data.get("type") == "turn.completed" + except (json.JSONDecodeError, AttributeError, TypeError): + return False def read_output() -> None: """Read process output in a separate thread.""" if process.stdout: for line in iter(process.stdout.readline, ""): - output_queue.put(line.strip()) + stripped = line.strip() + output_queue.put(stripped) + if is_turn_completed(stripped): + time.sleep(GRACEFUL_SHUTDOWN_DELAY) + process.terminate() + break process.stdout.close() + output_queue.put(None) thread = threading.Thread(target=read_output) - thread.daemon = True thread.start() # Yield lines while process is running - while process.poll() is None: + while True: try: - yield output_queue.get(timeout=0.1) + line = output_queue.get(timeout=0.5) + if line is None: + break + yield line except queue.Empty: - continue + if process.poll() is not None and not thread.is_alive(): + break - process.wait() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + thread.join(timeout=5) - # Drain remaining output from queue while not output_queue.empty(): try: - yield output_queue.get_nowait() + line = output_queue.get_nowait() + if line is not None: + yield line except queue.Empty: break + def windows_escape(prompt): """ Windows 风格的字符串转义函数。