diff --git a/src/geminimcp/server.py b/src/geminimcp/server.py index e9db523..b2c7123 100644 --- a/src/geminimcp/server.py +++ b/src/geminimcp/server.py @@ -18,13 +18,6 @@ import shutil mcp = FastMCP("Gemini MCP Server-from guda.studio") -def _empty_str_to_none(value: str | None) -> str | None: - """Convert empty strings to None for optional UUID parameters.""" - if isinstance(value, str) and not value.strip(): - return None - return value - - def run_shell_command(cmd: list[str]) -> Generator[str, None, None]: """Execute a command and stream its output line-by-line. @@ -45,40 +38,64 @@ def run_shell_command(cmd: list[str]) -> Generator[str, None, None]: process = subprocess.Popen( popen_cmd, - shell=False, # Safer: no shell injection - stdin=subprocess.PIPE, # Prevent process from waiting for input + shell=False, + stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, - encoding="utf-8", + encoding='utf-8', ) - output_queue: queue.Queue[str] = queue.Queue() + output_queue: queue.Queue[str | None] = queue.Queue() + GRACEFUL_SHUTDOWN_DELAY = 0.3 + + def is_turn_completed(line: str) -> bool: + """Check if the line indicates turn completion via JSON parsing.""" + try: + data = json.loads(line) + return data.get("type") == "turn.completed" + except (json.JSONDecodeError, AttributeError, TypeError): + return False def read_output() -> None: """Read process output in a separate thread.""" if process.stdout: for line in iter(process.stdout.readline, ""): - output_queue.put(line.strip()) + stripped = line.strip() + output_queue.put(stripped) + if is_turn_completed(stripped): + time.sleep(GRACEFUL_SHUTDOWN_DELAY) + process.terminate() + break process.stdout.close() + output_queue.put(None) thread = threading.Thread(target=read_output) - thread.daemon = True thread.start() # Yield lines while process is running - while process.poll() is None: + while True: try: - yield output_queue.get(timeout=0.1) + line = output_queue.get(timeout=0.5) + if line is None: + break + yield line except queue.Empty: - continue + if process.poll() is not None and not thread.is_alive(): + break - process.wait() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + thread.join(timeout=5) - # Drain remaining output from queue while not output_queue.empty(): try: - yield output_queue.get_nowait() + line = output_queue.get_nowait() + if line is not None: + yield line except queue.Empty: break @@ -194,9 +211,11 @@ async def gemini( # err_message = "gemini error: " + line_dict.get("message", "") except json.JSONDecodeError as error: # Improved error handling: include problematic line - err_message = line + err_message += "\n\n[json decode error] " + line + continue except Exception as error: - err_message = f"Unexpected error: {error}. Line: {line!r}" + err_message += "\n\n[unexpected error] " + f"Unexpected error: {error}. Line: {line!r}" + break if thread_id is None: @@ -205,10 +224,10 @@ async def gemini( "Failed to get `SESSION_ID` from the gemini session. \n\n" + err_message ) - if len(agent_messages) == 0: + if success and len(agent_messages) == 0: success = False err_message = ( - "Failed to get `agent_messages` from the gemini session. \n\n You can try to set `return_all_messages` to `True` to get the full information. \n\n " + "Failed to retrieve `agent_messages` data from the Gemini session. This might be due to Gemini performing a tool call. You can continue using the `SESSION_ID` to proceed with the conversation. \n\n " + err_message ) @@ -219,10 +238,11 @@ async def gemini( "agent_messages": agent_messages, # "PROMPT": PROMPT, } - if return_all_messages: - result["all_messages"] = all_messages else: result = {"success": False, "error": err_message} + + if return_all_messages: + result["all_messages"] = all_messages return result