Add comprehensive tests for tokenizer, performance benchmarks, and TreeSitter parser functionality

- Implemented unit tests for the Tokenizer class, covering various text inputs, edge cases, and fallback mechanisms.
- Created performance benchmarks comparing tiktoken and pure Python implementations for token counting.
- Developed extensive tests for TreeSitterSymbolParser across Python, JavaScript, and TypeScript, ensuring accurate symbol extraction and parsing.
- Added configuration documentation for MCP integration and custom prompts, enhancing usability and flexibility.
- Introduced a refactor script for GraphAnalyzer to streamline future improvements.
This commit is contained in:
catlog22
2025-12-15 14:36:09 +08:00
parent 82dcafff00
commit 0fe16963cd
49 changed files with 9307 additions and 438 deletions

1
.gitattributes vendored
View File

@@ -30,3 +30,4 @@
*.tar binary
*.gz binary
*.pdf binary
.mcp.json

View File

@@ -0,0 +1,183 @@
/**
* Centralized Storage Paths Configuration
* Single source of truth for all CCW storage locations
*
* All data is stored under ~/.ccw/ with project isolation via SHA256 hash
*/
import { homedir } from 'os';
import { join, resolve } from 'path';
import { createHash } from 'crypto';
import { existsSync, mkdirSync } from 'fs';
// Environment variable override for custom storage location
const CCW_DATA_DIR = process.env.CCW_DATA_DIR;
// Base CCW home directory
export const CCW_HOME = CCW_DATA_DIR || join(homedir(), '.ccw');
/**
* Calculate project identifier from project path
* Uses SHA256 hash truncated to 16 chars for uniqueness + readability
* @param projectPath - Absolute or relative project path
* @returns 16-character hex string project ID
*/
export function getProjectId(projectPath: string): string {
const absolutePath = resolve(projectPath);
const hash = createHash('sha256').update(absolutePath).digest('hex');
return hash.substring(0, 16);
}
/**
* Ensure a directory exists, creating it if necessary
* @param dirPath - Directory path to ensure
*/
export function ensureStorageDir(dirPath: string): void {
if (!existsSync(dirPath)) {
mkdirSync(dirPath, { recursive: true });
}
}
/**
* Global storage paths (not project-specific)
*/
export const GlobalPaths = {
/** Root CCW home directory */
root: () => CCW_HOME,
/** Config directory */
config: () => join(CCW_HOME, 'config'),
/** Global settings file */
settings: () => join(CCW_HOME, 'config', 'settings.json'),
/** Recent project paths file */
recentPaths: () => join(CCW_HOME, 'config', 'recent-paths.json'),
/** Databases directory */
databases: () => join(CCW_HOME, 'db'),
/** MCP templates database */
mcpTemplates: () => join(CCW_HOME, 'db', 'mcp-templates.db'),
/** Logs directory */
logs: () => join(CCW_HOME, 'logs'),
};
/**
* Project-specific storage paths
*/
export interface ProjectPaths {
/** Project root in CCW storage */
root: string;
/** CLI history directory */
cliHistory: string;
/** CLI history database file */
historyDb: string;
/** Memory store directory */
memory: string;
/** Memory store database file */
memoryDb: string;
/** Cache directory */
cache: string;
/** Dashboard cache file */
dashboardCache: string;
/** Config directory */
config: string;
/** CLI config file */
cliConfig: string;
}
/**
* Get storage paths for a specific project
* @param projectPath - Project root path
* @returns Object with all project-specific paths
*/
export function getProjectPaths(projectPath: string): ProjectPaths {
const projectId = getProjectId(projectPath);
const projectDir = join(CCW_HOME, 'projects', projectId);
return {
root: projectDir,
cliHistory: join(projectDir, 'cli-history'),
historyDb: join(projectDir, 'cli-history', 'history.db'),
memory: join(projectDir, 'memory'),
memoryDb: join(projectDir, 'memory', 'memory.db'),
cache: join(projectDir, 'cache'),
dashboardCache: join(projectDir, 'cache', 'dashboard-data.json'),
config: join(projectDir, 'config'),
cliConfig: join(projectDir, 'config', 'cli-config.json'),
};
}
/**
* Unified StoragePaths object combining global and project paths
*/
export const StoragePaths = {
global: GlobalPaths,
project: getProjectPaths,
};
/**
* Legacy storage paths (for backward compatibility detection)
*/
export const LegacyPaths = {
/** Old recent paths file location */
recentPaths: () => join(homedir(), '.ccw-recent-paths.json'),
/** Old project-local CLI history */
cliHistory: (projectPath: string) => join(projectPath, '.workflow', '.cli-history'),
/** Old project-local memory store */
memory: (projectPath: string) => join(projectPath, '.workflow', '.memory'),
/** Old project-local cache */
cache: (projectPath: string) => join(projectPath, '.workflow', '.ccw-cache'),
/** Old project-local CLI config */
cliConfig: (projectPath: string) => join(projectPath, '.workflow', 'cli-config.json'),
};
/**
* Check if legacy storage exists for a project
* Useful for migration warnings or detection
* @param projectPath - Project root path
* @returns true if any legacy storage is present
*/
export function isLegacyStoragePresent(projectPath: string): boolean {
return (
existsSync(LegacyPaths.cliHistory(projectPath)) ||
existsSync(LegacyPaths.memory(projectPath)) ||
existsSync(LegacyPaths.cache(projectPath)) ||
existsSync(LegacyPaths.cliConfig(projectPath))
);
}
/**
* Get CCW home directory (for external use)
*/
export function getCcwHome(): string {
return CCW_HOME;
}
/**
* Initialize global storage directories
* Creates the base directory structure if not present
*/
export function initializeGlobalStorage(): void {
ensureStorageDir(GlobalPaths.config());
ensureStorageDir(GlobalPaths.databases());
ensureStorageDir(GlobalPaths.logs());
}
/**
* Initialize project storage directories
* @param projectPath - Project root path
*/
export function initializeProjectStorage(projectPath: string): void {
const paths = getProjectPaths(projectPath);
ensureStorageDir(paths.cliHistory);
ensureStorageDir(paths.memory);
ensureStorageDir(paths.cache);
ensureStorageDir(paths.config);
}

View File

@@ -1,5 +1,6 @@
import { existsSync, mkdirSync, readFileSync, writeFileSync, statSync } from 'fs';
import { join, dirname } from 'path';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
interface CacheEntry<T> {
data: T;
@@ -265,6 +266,16 @@ export class CacheManager<T> {
}
}
/**
* Extract project path from workflow directory
* @param workflowDir - Path to .workflow directory (e.g., /project/.workflow)
* @returns Project root path
*/
function extractProjectPath(workflowDir: string): string {
// workflowDir is typically {projectPath}/.workflow
return workflowDir.replace(/[\/\\]\.workflow$/, '') || workflowDir;
}
/**
* Create a cache manager for dashboard data
* @param workflowDir - Path to .workflow directory
@@ -272,6 +283,9 @@ export class CacheManager<T> {
* @returns CacheManager instance
*/
export function createDashboardCache(workflowDir: string, ttl?: number): CacheManager<any> {
const cacheDir = join(workflowDir, '.ccw-cache');
// Use centralized storage path
const projectPath = extractProjectPath(workflowDir);
const cacheDir = StoragePaths.project(projectPath).cache;
ensureStorageDir(cacheDir);
return new CacheManager('dashboard-data', { cacheDir, ttl });
}

View File

@@ -6,6 +6,7 @@
import Database from 'better-sqlite3';
import { existsSync, mkdirSync } from 'fs';
import { join } from 'path';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
// Types
export interface Entity {
@@ -115,12 +116,12 @@ export class MemoryStore {
private dbPath: string;
constructor(projectPath: string) {
const memoryDir = join(projectPath, '.workflow', '.memory');
if (!existsSync(memoryDir)) {
mkdirSync(memoryDir, { recursive: true });
}
// Use centralized storage path
const paths = StoragePaths.project(projectPath);
const memoryDir = paths.memory;
ensureStorageDir(memoryDir);
this.dbPath = join(memoryDir, 'memory.db');
this.dbPath = paths.memoryDb;
this.db = new Database(this.dbPath);
this.db.pragma('journal_mode = WAL');
this.db.pragma('synchronous = NORMAL');

View File

@@ -12,9 +12,383 @@ import * as McpTemplatesDb from './mcp-templates-db.js';
// Claude config file path
const CLAUDE_CONFIG_PATH = join(homedir(), '.claude.json');
// Codex config file path (TOML format)
const CODEX_CONFIG_PATH = join(homedir(), '.codex', 'config.toml');
// Workspace root path for scanning .mcp.json files
let WORKSPACE_ROOT = process.cwd();
// ========================================
// TOML Parser for Codex Config
// ========================================
/**
* Simple TOML parser for Codex config.toml
* Supports basic types: strings, numbers, booleans, arrays, inline tables
*/
function parseToml(content: string): Record<string, any> {
const result: Record<string, any> = {};
let currentSection: string[] = [];
const lines = content.split('\n');
for (let i = 0; i < lines.length; i++) {
let line = lines[i].trim();
// Skip empty lines and comments
if (!line || line.startsWith('#')) continue;
// Handle section headers [section] or [section.subsection]
const sectionMatch = line.match(/^\[([^\]]+)\]$/);
if (sectionMatch) {
currentSection = sectionMatch[1].split('.');
// Ensure nested sections exist
let obj = result;
for (const part of currentSection) {
if (!obj[part]) obj[part] = {};
obj = obj[part];
}
continue;
}
// Handle key = value pairs
const keyValueMatch = line.match(/^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$/);
if (keyValueMatch) {
const key = keyValueMatch[1];
const rawValue = keyValueMatch[2].trim();
const value = parseTomlValue(rawValue);
// Navigate to current section
let obj = result;
for (const part of currentSection) {
if (!obj[part]) obj[part] = {};
obj = obj[part];
}
obj[key] = value;
}
}
return result;
}
/**
* Parse a TOML value
*/
function parseTomlValue(value: string): any {
// String (double-quoted)
if (value.startsWith('"') && value.endsWith('"')) {
return value.slice(1, -1).replace(/\\"/g, '"').replace(/\\\\/g, '\\');
}
// String (single-quoted - literal)
if (value.startsWith("'") && value.endsWith("'")) {
return value.slice(1, -1);
}
// Boolean
if (value === 'true') return true;
if (value === 'false') return false;
// Number
if (/^-?\d+(\.\d+)?$/.test(value)) {
return value.includes('.') ? parseFloat(value) : parseInt(value, 10);
}
// Array
if (value.startsWith('[') && value.endsWith(']')) {
const inner = value.slice(1, -1).trim();
if (!inner) return [];
// Simple array parsing (handles basic cases)
const items: any[] = [];
let depth = 0;
let current = '';
let inString = false;
let stringChar = '';
for (const char of inner) {
if (!inString && (char === '"' || char === "'")) {
inString = true;
stringChar = char;
current += char;
} else if (inString && char === stringChar) {
inString = false;
current += char;
} else if (!inString && (char === '[' || char === '{')) {
depth++;
current += char;
} else if (!inString && (char === ']' || char === '}')) {
depth--;
current += char;
} else if (!inString && char === ',' && depth === 0) {
items.push(parseTomlValue(current.trim()));
current = '';
} else {
current += char;
}
}
if (current.trim()) {
items.push(parseTomlValue(current.trim()));
}
return items;
}
// Inline table { key = value, ... }
if (value.startsWith('{') && value.endsWith('}')) {
const inner = value.slice(1, -1).trim();
if (!inner) return {};
const table: Record<string, any> = {};
// Simple inline table parsing
const pairs = inner.split(',');
for (const pair of pairs) {
const match = pair.trim().match(/^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$/);
if (match) {
table[match[1]] = parseTomlValue(match[2].trim());
}
}
return table;
}
// Return as string if nothing else matches
return value;
}
/**
* Serialize object to TOML format for Codex config
*/
function serializeToml(obj: Record<string, any>, prefix: string = ''): string {
let result = '';
const sections: string[] = [];
for (const [key, value] of Object.entries(obj)) {
if (value === null || value === undefined) continue;
if (typeof value === 'object' && !Array.isArray(value)) {
// Handle nested sections (like mcp_servers.server_name)
const sectionKey = prefix ? `${prefix}.${key}` : key;
sections.push(sectionKey);
// Check if this is a section with sub-sections or direct values
const hasSubSections = Object.values(value).some(v => typeof v === 'object' && !Array.isArray(v));
if (hasSubSections) {
// This section has sub-sections, recurse without header
result += serializeToml(value, sectionKey);
} else {
// This section has direct values, add header and values
result += `\n[${sectionKey}]\n`;
for (const [subKey, subValue] of Object.entries(value)) {
if (subValue !== null && subValue !== undefined) {
result += `${subKey} = ${serializeTomlValue(subValue)}\n`;
}
}
}
} else if (!prefix) {
// Top-level simple values
result += `${key} = ${serializeTomlValue(value)}\n`;
}
}
return result;
}
/**
* Serialize a value to TOML format
*/
function serializeTomlValue(value: any): string {
if (typeof value === 'string') {
return `"${value.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}"`;
}
if (typeof value === 'boolean') {
return value ? 'true' : 'false';
}
if (typeof value === 'number') {
return String(value);
}
if (Array.isArray(value)) {
return `[${value.map(v => serializeTomlValue(v)).join(', ')}]`;
}
if (typeof value === 'object' && value !== null) {
const pairs = Object.entries(value)
.filter(([_, v]) => v !== null && v !== undefined)
.map(([k, v]) => `${k} = ${serializeTomlValue(v)}`);
return `{ ${pairs.join(', ')} }`;
}
return String(value);
}
// ========================================
// Codex MCP Functions
// ========================================
/**
* Read Codex config.toml and extract MCP servers
*/
function getCodexMcpConfig(): { servers: Record<string, any>; configPath: string; exists: boolean } {
try {
if (!existsSync(CODEX_CONFIG_PATH)) {
return { servers: {}, configPath: CODEX_CONFIG_PATH, exists: false };
}
const content = readFileSync(CODEX_CONFIG_PATH, 'utf8');
const config = parseToml(content);
// MCP servers are under [mcp_servers] section
const mcpServers = config.mcp_servers || {};
return {
servers: mcpServers,
configPath: CODEX_CONFIG_PATH,
exists: true
};
} catch (error: unknown) {
console.error('Error reading Codex config:', error);
return { servers: {}, configPath: CODEX_CONFIG_PATH, exists: false };
}
}
/**
* Add or update MCP server in Codex config.toml
*/
function addCodexMcpServer(serverName: string, serverConfig: Record<string, any>): { success?: boolean; error?: string } {
try {
const codexDir = join(homedir(), '.codex');
// Ensure .codex directory exists
if (!existsSync(codexDir)) {
mkdirSync(codexDir, { recursive: true });
}
let config: Record<string, any> = {};
// Read existing config if it exists
if (existsSync(CODEX_CONFIG_PATH)) {
const content = readFileSync(CODEX_CONFIG_PATH, 'utf8');
config = parseToml(content);
}
// Ensure mcp_servers section exists
if (!config.mcp_servers) {
config.mcp_servers = {};
}
// Convert serverConfig from Claude format to Codex format
const codexServerConfig: Record<string, any> = {};
// Handle STDIO servers (command-based)
if (serverConfig.command) {
codexServerConfig.command = serverConfig.command;
if (serverConfig.args && serverConfig.args.length > 0) {
codexServerConfig.args = serverConfig.args;
}
if (serverConfig.env && Object.keys(serverConfig.env).length > 0) {
codexServerConfig.env = serverConfig.env;
}
if (serverConfig.cwd) {
codexServerConfig.cwd = serverConfig.cwd;
}
}
// Handle HTTP servers (url-based)
if (serverConfig.url) {
codexServerConfig.url = serverConfig.url;
if (serverConfig.bearer_token_env_var) {
codexServerConfig.bearer_token_env_var = serverConfig.bearer_token_env_var;
}
if (serverConfig.http_headers) {
codexServerConfig.http_headers = serverConfig.http_headers;
}
}
// Copy optional fields
if (serverConfig.startup_timeout_sec !== undefined) {
codexServerConfig.startup_timeout_sec = serverConfig.startup_timeout_sec;
}
if (serverConfig.tool_timeout_sec !== undefined) {
codexServerConfig.tool_timeout_sec = serverConfig.tool_timeout_sec;
}
if (serverConfig.enabled !== undefined) {
codexServerConfig.enabled = serverConfig.enabled;
}
if (serverConfig.enabled_tools) {
codexServerConfig.enabled_tools = serverConfig.enabled_tools;
}
if (serverConfig.disabled_tools) {
codexServerConfig.disabled_tools = serverConfig.disabled_tools;
}
// Add the server
config.mcp_servers[serverName] = codexServerConfig;
// Serialize and write back
const tomlContent = serializeToml(config);
writeFileSync(CODEX_CONFIG_PATH, tomlContent, 'utf8');
return { success: true };
} catch (error: unknown) {
console.error('Error adding Codex MCP server:', error);
return { error: (error as Error).message };
}
}
/**
* Remove MCP server from Codex config.toml
*/
function removeCodexMcpServer(serverName: string): { success?: boolean; error?: string } {
try {
if (!existsSync(CODEX_CONFIG_PATH)) {
return { error: 'Codex config.toml not found' };
}
const content = readFileSync(CODEX_CONFIG_PATH, 'utf8');
const config = parseToml(content);
if (!config.mcp_servers || !config.mcp_servers[serverName]) {
return { error: `Server not found: ${serverName}` };
}
// Remove the server
delete config.mcp_servers[serverName];
// Serialize and write back
const tomlContent = serializeToml(config);
writeFileSync(CODEX_CONFIG_PATH, tomlContent, 'utf8');
return { success: true };
} catch (error: unknown) {
console.error('Error removing Codex MCP server:', error);
return { error: (error as Error).message };
}
}
/**
* Toggle Codex MCP server enabled state
*/
function toggleCodexMcpServer(serverName: string, enabled: boolean): { success?: boolean; error?: string } {
try {
if (!existsSync(CODEX_CONFIG_PATH)) {
return { error: 'Codex config.toml not found' };
}
const content = readFileSync(CODEX_CONFIG_PATH, 'utf8');
const config = parseToml(content);
if (!config.mcp_servers || !config.mcp_servers[serverName]) {
return { error: `Server not found: ${serverName}` };
}
// Set enabled state
config.mcp_servers[serverName].enabled = enabled;
// Serialize and write back
const tomlContent = serializeToml(config);
writeFileSync(CODEX_CONFIG_PATH, tomlContent, 'utf8');
return { success: true };
} catch (error: unknown) {
console.error('Error toggling Codex MCP server:', error);
return { error: (error as Error).message };
}
}
export interface RouteContext {
pathname: string;
url: URL;
@@ -598,11 +972,64 @@ function getProjectSettingsPath(projectPath) {
export async function handleMcpRoutes(ctx: RouteContext): Promise<boolean> {
const { pathname, url, req, res, initialPath, handlePostRequest, broadcastToClients } = ctx;
// API: Get MCP configuration
// API: Get MCP configuration (includes both Claude and Codex)
if (pathname === '/api/mcp-config') {
const mcpData = getMcpConfig();
const codexData = getCodexMcpConfig();
const combinedData = {
...mcpData,
codex: codexData
};
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(mcpData));
res.end(JSON.stringify(combinedData));
return true;
}
// ========================================
// Codex MCP API Endpoints
// ========================================
// API: Get Codex MCP configuration
if (pathname === '/api/codex-mcp-config') {
const codexData = getCodexMcpConfig();
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(codexData));
return true;
}
// API: Add Codex MCP server
if (pathname === '/api/codex-mcp-add' && req.method === 'POST') {
handlePostRequest(req, res, async (body) => {
const { serverName, serverConfig } = body;
if (!serverName || !serverConfig) {
return { error: 'serverName and serverConfig are required', status: 400 };
}
return addCodexMcpServer(serverName, serverConfig);
});
return true;
}
// API: Remove Codex MCP server
if (pathname === '/api/codex-mcp-remove' && req.method === 'POST') {
handlePostRequest(req, res, async (body) => {
const { serverName } = body;
if (!serverName) {
return { error: 'serverName is required', status: 400 };
}
return removeCodexMcpServer(serverName);
});
return true;
}
// API: Toggle Codex MCP server enabled state
if (pathname === '/api/codex-mcp-toggle' && req.method === 'POST') {
handlePostRequest(req, res, async (body) => {
const { serverName, enabled } = body;
if (!serverName || enabled === undefined) {
return { error: 'serverName and enabled are required', status: 400 };
}
return toggleCodexMcpServer(serverName, enabled);
});
return true;
}

View File

@@ -7,15 +7,14 @@ import Database from 'better-sqlite3';
import { existsSync, mkdirSync } from 'fs';
import { join, dirname } from 'path';
import { homedir } from 'os';
import { StoragePaths, ensureStorageDir } from '../../config/storage-paths.js';
// Database path
const DB_DIR = join(homedir(), '.ccw');
const DB_PATH = join(DB_DIR, 'mcp-templates.db');
// Database path - uses centralized storage
const DB_DIR = StoragePaths.global.databases();
const DB_PATH = StoragePaths.global.mcpTemplates();
// Ensure database directory exists
if (!existsSync(DB_DIR)) {
mkdirSync(DB_DIR, { recursive: true });
}
ensureStorageDir(DB_DIR);
// Initialize database connection
let db: Database.Database | null = null;

View File

@@ -99,9 +99,10 @@ async function getSessionDetailData(sessionPath, dataType) {
}
}
// Load explorations (exploration-*.json files) - check .process/ first, then session root
// Load explorations (exploration-*.json files) and diagnoses (diagnosis-*.json files) - check .process/ first, then session root
if (dataType === 'context' || dataType === 'explorations' || dataType === 'all') {
result.explorations = { manifest: null, data: {} };
result.diagnoses = { manifest: null, data: {} };
// Try .process/ first (standard workflow sessions), then session root (lite tasks)
const searchDirs = [
@@ -134,15 +135,41 @@ async function getSessionDetailData(sessionPath, dataType) {
} catch (e) {
result.explorations.manifest = null;
}
} else {
// Fallback: scan for exploration-*.json files directly
}
// Look for diagnoses-manifest.json
const diagManifestFile = join(searchDir, 'diagnoses-manifest.json');
if (existsSync(diagManifestFile)) {
try {
const files = readdirSync(searchDir).filter(f => f.startsWith('exploration-') && f.endsWith('.json'));
if (files.length > 0) {
result.diagnoses.manifest = JSON.parse(readFileSync(diagManifestFile, 'utf8'));
// Load each diagnosis file based on manifest
const diagnoses = result.diagnoses.manifest.diagnoses || [];
for (const diag of diagnoses) {
const diagFile = join(searchDir, diag.file);
if (existsSync(diagFile)) {
try {
result.diagnoses.data[diag.angle] = JSON.parse(readFileSync(diagFile, 'utf8'));
} catch (e) {
// Skip unreadable diagnosis files
}
}
}
break; // Found manifest, stop searching
} catch (e) {
result.diagnoses.manifest = null;
}
}
// Fallback: scan for exploration-*.json and diagnosis-*.json files directly
if (!result.explorations.manifest) {
try {
const expFiles = readdirSync(searchDir).filter(f => f.startsWith('exploration-') && f.endsWith('.json') && f !== 'explorations-manifest.json');
if (expFiles.length > 0) {
// Create synthetic manifest
result.explorations.manifest = {
exploration_count: files.length,
explorations: files.map((f, i) => ({
exploration_count: expFiles.length,
explorations: expFiles.map((f, i) => ({
angle: f.replace('exploration-', '').replace('.json', ''),
file: f,
index: i + 1
@@ -150,7 +177,7 @@ async function getSessionDetailData(sessionPath, dataType) {
};
// Load each file
for (const file of files) {
for (const file of expFiles) {
const angle = file.replace('exploration-', '').replace('.json', '');
try {
result.explorations.data[angle] = JSON.parse(readFileSync(join(searchDir, file), 'utf8'));
@@ -158,12 +185,46 @@ async function getSessionDetailData(sessionPath, dataType) {
// Skip unreadable files
}
}
break; // Found explorations, stop searching
}
} catch (e) {
// Directory read failed
}
}
// Fallback: scan for diagnosis-*.json files directly
if (!result.diagnoses.manifest) {
try {
const diagFiles = readdirSync(searchDir).filter(f => f.startsWith('diagnosis-') && f.endsWith('.json') && f !== 'diagnoses-manifest.json');
if (diagFiles.length > 0) {
// Create synthetic manifest
result.diagnoses.manifest = {
diagnosis_count: diagFiles.length,
diagnoses: diagFiles.map((f, i) => ({
angle: f.replace('diagnosis-', '').replace('.json', ''),
file: f,
index: i + 1
}))
};
// Load each file
for (const file of diagFiles) {
const angle = file.replace('diagnosis-', '').replace('.json', '');
try {
result.diagnoses.data[angle] = JSON.parse(readFileSync(join(searchDir, file), 'utf8'));
} catch (e) {
// Skip unreadable files
}
}
}
} catch (e) {
// Directory read failed
}
}
// If we found either explorations or diagnoses, break out of the loop
if (result.explorations.manifest || result.diagnoses.manifest) {
break;
}
}
}

View File

@@ -1022,6 +1022,16 @@
overflow: hidden;
}
.diagnosis-card .collapsible-content {
display: block;
padding: 1rem;
background: hsl(var(--card));
}
.diagnosis-card .collapsible-content.collapsed {
display: none;
}
.diagnosis-header {
background: hsl(var(--muted) / 0.3);
}

View File

@@ -15,6 +15,11 @@ let mcpCurrentProjectServers = {};
let mcpConfigSources = [];
let mcpCreateMode = 'form'; // 'form' or 'json'
// ========== CLI Toggle State (Claude / Codex) ==========
let currentCliMode = 'claude'; // 'claude' or 'codex'
let codexMcpConfig = null;
let codexMcpServers = {};
// ========== Initialization ==========
function initMcpManager() {
// Initialize MCP navigation
@@ -44,6 +49,12 @@ async function loadMcpConfig() {
mcpEnterpriseServers = data.enterpriseServers || {};
mcpConfigSources = data.configSources || [];
// Load Codex MCP config
if (data.codex) {
codexMcpConfig = data.codex;
codexMcpServers = data.codex.servers || {};
}
// Get current project servers
const currentPath = projectPath.replace(/\//g, '\\');
mcpCurrentProjectServers = mcpAllProjects[currentPath]?.mcpServers || {};
@@ -58,6 +69,135 @@ async function loadMcpConfig() {
}
}
// ========== CLI Mode Toggle ==========
function setCliMode(mode) {
currentCliMode = mode;
renderMcpManager();
}
function getCliMode() {
return currentCliMode;
}
// ========== Codex MCP Functions ==========
/**
* Add MCP server to Codex config.toml
*/
async function addCodexMcpServer(serverName, serverConfig) {
try {
const response = await fetch('/api/codex-mcp-add', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
serverName: serverName,
serverConfig: serverConfig
})
});
if (!response.ok) throw new Error('Failed to add Codex MCP server');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(t('mcp.codex.serverAdded', { name: serverName }), 'success');
} else {
showRefreshToast(result.error || t('mcp.codex.addFailed'), 'error');
}
return result;
} catch (err) {
console.error('Failed to add Codex MCP server:', err);
showRefreshToast(t('mcp.codex.addFailed') + ': ' + err.message, 'error');
return null;
}
}
/**
* Remove MCP server from Codex config.toml
*/
async function removeCodexMcpServer(serverName) {
try {
const response = await fetch('/api/codex-mcp-remove', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ serverName })
});
if (!response.ok) throw new Error('Failed to remove Codex MCP server');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(t('mcp.codex.serverRemoved', { name: serverName }), 'success');
} else {
showRefreshToast(result.error || t('mcp.codex.removeFailed'), 'error');
}
return result;
} catch (err) {
console.error('Failed to remove Codex MCP server:', err);
showRefreshToast(t('mcp.codex.removeFailed') + ': ' + err.message, 'error');
return null;
}
}
/**
* Toggle Codex MCP server enabled state
*/
async function toggleCodexMcpServer(serverName, enabled) {
try {
const response = await fetch('/api/codex-mcp-toggle', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ serverName, enabled })
});
if (!response.ok) throw new Error('Failed to toggle Codex MCP server');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(t('mcp.codex.serverToggled', { name: serverName, state: enabled ? 'enabled' : 'disabled' }), 'success');
}
return result;
} catch (err) {
console.error('Failed to toggle Codex MCP server:', err);
showRefreshToast(t('mcp.codex.toggleFailed') + ': ' + err.message, 'error');
return null;
}
}
/**
* Copy Claude MCP server to Codex
*/
async function copyClaudeServerToCodex(serverName, serverConfig) {
return await addCodexMcpServer(serverName, serverConfig);
}
/**
* Copy Codex MCP server to Claude (global)
*/
async function copyCodexServerToClaude(serverName, serverConfig) {
// Convert Codex format to Claude format
const claudeConfig = {
command: serverConfig.command,
args: serverConfig.args || [],
};
if (serverConfig.env) {
claudeConfig.env = serverConfig.env;
}
// If it's an HTTP server
if (serverConfig.url) {
claudeConfig.url = serverConfig.url;
}
return await addGlobalMcpServer(serverName, claudeConfig);
}
async function toggleMcpServer(serverName, enable) {
try {
const response = await fetch('/api/mcp-toggle', {
@@ -255,7 +395,7 @@ async function removeGlobalMcpServer(serverName) {
function updateMcpBadge() {
const badge = document.getElementById('badgeMcpServers');
if (badge) {
const currentPath = projectPath.replace(/\//g, '\\');
const currentPath = projectPath; // Keep original format (forward slash)
const projectData = mcpAllProjects[currentPath];
const servers = projectData?.mcpServers || {};
const disabledServers = projectData?.disabledMcpServers || [];
@@ -702,7 +842,20 @@ async function createMcpServerWithConfig(name, serverConfig, scope = 'project')
// Submit to API
try {
let response;
if (scope === 'global') {
let scopeLabel;
if (scope === 'codex') {
// Create in Codex config.toml
response = await fetch('/api/codex-mcp-add', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
serverName: name,
serverConfig: serverConfig
})
});
scopeLabel = 'Codex';
} else if (scope === 'global') {
response = await fetch('/api/mcp-add-global-server', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -711,6 +864,7 @@ async function createMcpServerWithConfig(name, serverConfig, scope = 'project')
serverConfig: serverConfig
})
});
scopeLabel = 'global';
} else {
response = await fetch('/api/mcp-copy-server', {
method: 'POST',
@@ -721,6 +875,7 @@ async function createMcpServerWithConfig(name, serverConfig, scope = 'project')
serverConfig: serverConfig
})
});
scopeLabel = 'project';
}
if (!response.ok) throw new Error('Failed to create MCP server');
@@ -730,7 +885,6 @@ async function createMcpServerWithConfig(name, serverConfig, scope = 'project')
closeMcpCreateModal();
await loadMcpConfig();
renderMcpManager();
const scopeLabel = scope === 'global' ? 'global' : 'project';
showRefreshToast(`MCP server "${name}" created in ${scopeLabel} scope`, 'success');
} else {
showRefreshToast(result.error || 'Failed to create MCP server', 'error');
@@ -787,7 +941,7 @@ function buildCcwToolsConfig(selectedTools) {
return config;
}
async function installCcwToolsMcp() {
async function installCcwToolsMcp(scope = 'workspace') {
const selectedTools = getSelectedCcwTools();
if (selectedTools.length === 0) {
@@ -798,27 +952,52 @@ async function installCcwToolsMcp() {
const ccwToolsConfig = buildCcwToolsConfig(selectedTools);
try {
showRefreshToast('Installing CCW Tools MCP...', 'info');
const scopeLabel = scope === 'global' ? 'globally' : 'to workspace';
showRefreshToast(`Installing CCW Tools MCP ${scopeLabel}...`, 'info');
const response = await fetch('/api/mcp-copy-server', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
projectPath: projectPath,
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (scope === 'global') {
// Install to global (~/.claude.json mcpServers)
const response = await fetch('/api/mcp-add-global', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (!response.ok) throw new Error('Failed to install CCW Tools MCP');
if (!response.ok) throw new Error('Failed to install CCW Tools MCP globally');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools installed (${selectedTools.length} tools)`, 'success');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools installed globally (${selectedTools.length} tools)`, 'success');
} else {
showRefreshToast(result.error || 'Failed to install CCW Tools MCP globally', 'error');
}
} else {
showRefreshToast(result.error || 'Failed to install CCW Tools MCP', 'error');
// Install to workspace (.mcp.json)
const response = await fetch('/api/mcp-copy-server', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
projectPath: projectPath,
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (!response.ok) throw new Error('Failed to install CCW Tools MCP to workspace');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools installed to workspace (${selectedTools.length} tools)`, 'success');
} else {
showRefreshToast(result.error || 'Failed to install CCW Tools MCP to workspace', 'error');
}
}
} catch (err) {
console.error('Failed to install CCW Tools MCP:', err);
@@ -826,7 +1005,7 @@ async function installCcwToolsMcp() {
}
}
async function updateCcwToolsMcp() {
async function updateCcwToolsMcp(scope = 'workspace') {
const selectedTools = getSelectedCcwTools();
if (selectedTools.length === 0) {
@@ -837,27 +1016,52 @@ async function updateCcwToolsMcp() {
const ccwToolsConfig = buildCcwToolsConfig(selectedTools);
try {
showRefreshToast('Updating CCW Tools MCP...', 'info');
const scopeLabel = scope === 'global' ? 'globally' : 'in workspace';
showRefreshToast(`Updating CCW Tools MCP ${scopeLabel}...`, 'info');
const response = await fetch('/api/mcp-copy-server', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
projectPath: projectPath,
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (scope === 'global') {
// Update global (~/.claude.json mcpServers)
const response = await fetch('/api/mcp-add-global', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (!response.ok) throw new Error('Failed to update CCW Tools MCP');
if (!response.ok) throw new Error('Failed to update CCW Tools MCP globally');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools updated (${selectedTools.length} tools)`, 'success');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools updated globally (${selectedTools.length} tools)`, 'success');
} else {
showRefreshToast(result.error || 'Failed to update CCW Tools MCP globally', 'error');
}
} else {
showRefreshToast(result.error || 'Failed to update CCW Tools MCP', 'error');
// Update workspace (.mcp.json)
const response = await fetch('/api/mcp-copy-server', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
projectPath: projectPath,
serverName: 'ccw-tools',
serverConfig: ccwToolsConfig
})
});
if (!response.ok) throw new Error('Failed to update CCW Tools MCP in workspace');
const result = await response.json();
if (result.success) {
await loadMcpConfig();
renderMcpManager();
showRefreshToast(`CCW Tools updated in workspace (${selectedTools.length} tools)`, 'success');
} else {
showRefreshToast(result.error || 'Failed to update CCW Tools MCP in workspace', 'error');
}
}
} catch (err) {
console.error('Failed to update CCW Tools MCP:', err);

View File

@@ -96,7 +96,7 @@ function renderImplPlanContent(implPlan) {
// Lite Context Tab Rendering
// ==========================================
function renderLiteContextContent(context, explorations, session) {
function renderLiteContextContent(context, explorations, session, diagnoses) {
const plan = session.plan || {};
let sections = [];
@@ -105,6 +105,11 @@ function renderLiteContextContent(context, explorations, session) {
sections.push(renderExplorationContext(explorations));
}
// Render diagnoses if available (from diagnosis-*.json files)
if (diagnoses && diagnoses.manifest) {
sections.push(renderDiagnosisContext(diagnoses));
}
// If we have context from context-package.json
if (context) {
sections.push(`
@@ -153,7 +158,7 @@ function renderLiteContextContent(context, explorations, session) {
<div class="tab-empty-state">
<div class="empty-icon"><i data-lucide="package" class="w-12 h-12"></i></div>
<div class="empty-title">No Context Data</div>
<div class="empty-text">No context-package.json or exploration files found for this session.</div>
<div class="empty-text">No context-package.json, exploration files, or diagnosis files found for this session.</div>
</div>
`;
}
@@ -185,15 +190,19 @@ function renderExplorationContext(explorations) {
`);
// Render each exploration angle as collapsible section
const explorationOrder = ['architecture', 'dependencies', 'patterns', 'integration-points'];
const explorationOrder = ['architecture', 'dependencies', 'patterns', 'integration-points', 'testing'];
const explorationTitles = {
'architecture': '<i data-lucide="blocks" class="w-4 h-4 inline mr-1"></i>Architecture',
'dependencies': '<i data-lucide="package" class="w-4 h-4 inline mr-1"></i>Dependencies',
'patterns': '<i data-lucide="git-branch" class="w-4 h-4 inline mr-1"></i>Patterns',
'integration-points': '<i data-lucide="plug" class="w-4 h-4 inline mr-1"></i>Integration Points'
'integration-points': '<i data-lucide="plug" class="w-4 h-4 inline mr-1"></i>Integration Points',
'testing': '<i data-lucide="flask-conical" class="w-4 h-4 inline mr-1"></i>Testing'
};
for (const angle of explorationOrder) {
// Collect all angles from data (in case there are exploration angles not in our predefined list)
const allAngles = [...new Set([...explorationOrder, ...Object.keys(data)])];
for (const angle of allAngles) {
const expData = data[angle];
if (!expData) {
continue;
@@ -205,7 +214,7 @@ function renderExplorationContext(explorations) {
<div class="exploration-section collapsible-section">
<div class="collapsible-header">
<span class="collapse-icon">▶</span>
<span class="section-label">${explorationTitles[angle] || angle}</span>
<span class="section-label">${explorationTitles[angle] || ('<i data-lucide="file-search" class="w-4 h-4 inline mr-1"></i>' + escapeHtml(angle.toUpperCase()))}</span>
</div>
<div class="collapsible-content collapsed">
${angleContent}
@@ -271,3 +280,145 @@ function renderExplorationAngle(angle, data) {
return content.join('') || '<p>No data available</p>';
}
// ==========================================
// Diagnosis Context Rendering
// ==========================================
function renderDiagnosisContext(diagnoses) {
if (!diagnoses || !diagnoses.manifest) {
return '';
}
const manifest = diagnoses.manifest;
const data = diagnoses.data || {};
let sections = [];
// Header with manifest info
sections.push(`
<div class="diagnosis-header">
<h4><i data-lucide="stethoscope" class="w-4 h-4 inline mr-1"></i> ${escapeHtml(manifest.task_description || 'Diagnosis Context')}</h4>
<div class="diagnosis-meta">
<span class="meta-item">Diagnoses: <strong>${manifest.diagnosis_count || 0}</strong></span>
</div>
</div>
`);
// Render each diagnosis angle as collapsible section
const diagnosisOrder = ['root-cause', 'api-contracts', 'dataflow', 'performance', 'security', 'error-handling'];
const diagnosisTitles = {
'root-cause': '<i data-lucide="search" class="w-4 h-4 inline mr-1"></i>Root Cause',
'api-contracts': '<i data-lucide="plug" class="w-4 h-4 inline mr-1"></i>API Contracts',
'dataflow': '<i data-lucide="git-merge" class="w-4 h-4 inline mr-1"></i>Data Flow',
'performance': '<i data-lucide="zap" class="w-4 h-4 inline mr-1"></i>Performance',
'security': '<i data-lucide="shield" class="w-4 h-4 inline mr-1"></i>Security',
'error-handling': '<i data-lucide="alert-circle" class="w-4 h-4 inline mr-1"></i>Error Handling'
};
// Collect all angles from data (in case there are diagnosis angles not in our predefined list)
const allAngles = [...new Set([...diagnosisOrder, ...Object.keys(data)])];
for (const angle of allAngles) {
const diagData = data[angle];
if (!diagData) {
continue;
}
const angleContent = renderDiagnosisAngle(angle, diagData);
sections.push(`
<div class="diagnosis-section collapsible-section">
<div class="collapsible-header">
<span class="collapse-icon">▶</span>
<span class="section-label">${diagnosisTitles[angle] || ('<i data-lucide="file-search" class="w-4 h-4 inline mr-1"></i>' + angle)}</span>
</div>
<div class="collapsible-content collapsed">
${angleContent}
</div>
</div>
`);
}
return `<div class="diagnosis-context">${sections.join('')}</div>`;
}
function renderDiagnosisAngle(angle, data) {
let content = [];
// Summary/Overview
if (data.summary || data.overview) {
content.push(renderExpField('Summary', data.summary || data.overview));
}
// Root cause analysis
if (data.root_cause || data.root_cause_analysis) {
content.push(renderExpField('Root Cause', data.root_cause || data.root_cause_analysis));
}
// Issues/Findings
if (data.issues && Array.isArray(data.issues)) {
content.push(`
<div class="exp-field">
<label>Issues Found (${data.issues.length})</label>
<div class="issues-list">
${data.issues.map(issue => {
if (typeof issue === 'string') {
return `<div class="issue-item">${escapeHtml(issue)}</div>`;
} else {
return `
<div class="issue-item">
<div class="issue-title">${escapeHtml(issue.title || issue.description || 'Unknown')}</div>
${issue.location ? `<div class="issue-location"><code>${escapeHtml(issue.location)}</code></div>` : ''}
${issue.severity ? `<span class="severity-badge ${escapeHtml(issue.severity)}">${escapeHtml(issue.severity)}</span>` : ''}
</div>
`;
}
}).join('')}
</div>
</div>
`);
}
// Affected files
if (data.affected_files && Array.isArray(data.affected_files)) {
content.push(`
<div class="exp-field">
<label>Affected Files (${data.affected_files.length})</label>
<div class="path-tags">
${data.affected_files.map(f => {
const filePath = typeof f === 'string' ? f : (f.path || f.file || '');
return `<span class="path-tag">${escapeHtml(filePath)}</span>`;
}).join('')}
</div>
</div>
`);
}
// Recommendations
if (data.recommendations && Array.isArray(data.recommendations)) {
content.push(`
<div class="exp-field">
<label>Recommendations</label>
<ol class="recommendations-list">
${data.recommendations.map(rec => {
const recText = typeof rec === 'string' ? rec : (rec.description || rec.action || '');
return `<li>${escapeHtml(recText)}</li>`;
}).join('')}
</ol>
</div>
`);
}
// API Contracts (specific to api-contracts diagnosis)
if (data.contracts && Array.isArray(data.contracts)) {
content.push(renderExpField('API Contracts', data.contracts));
}
// Data flow (specific to dataflow diagnosis)
if (data.dataflow || data.data_flow) {
content.push(renderExpField('Data Flow', data.dataflow || data.data_flow));
}
return content.join('') || '<p>No diagnosis data available</p>';
}

View File

@@ -378,9 +378,11 @@ const i18n = {
'mcp.newProjectServer': 'New Project Server',
'mcp.newServer': 'New Server',
'mcp.newGlobalServer': 'New Global Server',
'mcp.copyInstallCmd': 'Copy Install Command',
'mcp.installCmdCopied': 'Install command copied to clipboard',
'mcp.installCmdFailed': 'Failed to copy install command',
'mcp.installToProject': 'Install to Project',
'mcp.installToGlobal': 'Install to Global',
'mcp.installToWorkspace': 'Install to Workspace',
'mcp.updateInWorkspace': 'Update in Workspace',
'mcp.updateInGlobal': 'Update in Global',
'mcp.serversConfigured': 'servers configured',
'mcp.serversAvailable': 'servers available',
'mcp.globalAvailable': '全局可用 MCP',
@@ -413,6 +415,26 @@ const i18n = {
'mcp.availableToAll': 'Available to all projects from ~/.claude.json',
'mcp.managedByOrg': 'Managed by organization (highest priority)',
'mcp.variables': 'variables',
'mcp.cmd': 'Command',
'mcp.url': 'URL',
'mcp.args': 'Arguments',
'mcp.env': 'Environment',
'mcp.usedInCount': 'Used in {count} project{s}',
'mcp.from': 'from',
'mcp.variant': 'variant',
'mcp.sourceEnterprise': 'Enterprise',
'mcp.sourceGlobal': 'Global',
'mcp.sourceProject': 'Project',
'mcp.viewDetails': 'View Details',
'mcp.clickToViewDetails': 'Click to view details',
// MCP Details Modal
'mcp.detailsModal.title': 'MCP Server Details',
'mcp.detailsModal.close': 'Close',
'mcp.detailsModal.serverName': 'Server Name',
'mcp.detailsModal.source': 'Source',
'mcp.detailsModal.configuration': 'Configuration',
'mcp.detailsModal.noEnv': 'No environment variables',
// MCP Create Modal
'mcp.createTitle': 'Create MCP Server',
@@ -456,6 +478,34 @@ const i18n = {
'mcp.toProject': 'To Project',
'mcp.toGlobal': 'To Global',
// MCP CLI Mode
'mcp.cliMode': 'CLI Mode',
'mcp.claudeMode': 'Claude Mode',
'mcp.codexMode': 'Codex Mode',
// Codex MCP
'mcp.codex.globalServers': 'Codex Global MCP Servers',
'mcp.codex.newServer': 'New Server',
'mcp.codex.noServers': 'No Codex MCP servers configured',
'mcp.codex.noServersHint': 'Add servers via "codex mcp add" or create one here',
'mcp.codex.infoTitle': 'About Codex MCP',
'mcp.codex.infoDesc': 'Codex MCP servers are global only (stored in ~/.codex/config.toml). Use TOML format for configuration.',
'mcp.codex.serverAdded': 'Codex MCP server "{name}" added',
'mcp.codex.addFailed': 'Failed to add Codex MCP server',
'mcp.codex.serverRemoved': 'Codex MCP server "{name}" removed',
'mcp.codex.removeFailed': 'Failed to remove Codex MCP server',
'mcp.codex.serverToggled': 'Codex MCP server "{name}" {state}',
'mcp.codex.toggleFailed': 'Failed to toggle Codex MCP server',
'mcp.codex.remove': 'Remove',
'mcp.codex.removeConfirm': 'Remove Codex MCP server "{name}"?',
'mcp.codex.copyToClaude': 'Copy to Claude',
'mcp.codex.copyToCodex': 'Copy to Codex',
'mcp.codex.copyFromClaude': 'Copy Claude Servers to Codex',
'mcp.codex.alreadyAdded': 'Already in Codex',
'mcp.codex.scopeCodex': 'Codex - Global (~/.codex/config.toml)',
'mcp.codex.enabledTools': 'Tools',
'mcp.codex.tools': 'tools enabled',
// Hook Manager
'hook.projectHooks': 'Project Hooks',
'hook.projectFile': '.claude/settings.json',
@@ -1316,9 +1366,11 @@ const i18n = {
// MCP Manager
'mcp.currentAvailable': '当前可用 MCP',
'mcp.copyInstallCmd': '复制安装命令',
'mcp.installCmdCopied': '安装命令已复制到剪贴板',
'mcp.installCmdFailed': '复制安装命令失败',
'mcp.installToProject': '安装到项目',
'mcp.installToGlobal': '安装到全局',
'mcp.installToWorkspace': '安装到工作空间',
'mcp.updateInWorkspace': '在工作空间更新',
'mcp.updateInGlobal': '在全局更新',
'mcp.projectAvailable': '当前可用 MCP',
'mcp.newServer': '新建服务器',
'mcp.newGlobalServer': '新建全局服务器',
@@ -1355,7 +1407,27 @@ const i18n = {
'mcp.availableToAll': '可用于所有项目,来自 ~/.claude.json',
'mcp.managedByOrg': '由组织管理(最高优先级)',
'mcp.variables': '个变量',
'mcp.cmd': '命令',
'mcp.url': '地址',
'mcp.args': '参数',
'mcp.env': '环境变量',
'mcp.usedInCount': '用于 {count} 个项目',
'mcp.from': '来自',
'mcp.variant': '变体',
'mcp.sourceEnterprise': '企业级',
'mcp.sourceGlobal': '全局',
'mcp.sourceProject': '项目级',
'mcp.viewDetails': '查看详情',
'mcp.clickToViewDetails': '点击查看详情',
// MCP Details Modal
'mcp.detailsModal.title': 'MCP 服务器详情',
'mcp.detailsModal.close': '关闭',
'mcp.detailsModal.serverName': '服务器名称',
'mcp.detailsModal.source': '来源',
'mcp.detailsModal.configuration': '配置',
'mcp.detailsModal.noEnv': '无环境变量',
// MCP Create Modal
'mcp.createTitle': '创建 MCP 服务器',
'mcp.form': '表单',
@@ -1375,7 +1447,35 @@ const i18n = {
'mcp.installToMcpJson': '安装到 .mcp.json推荐',
'mcp.claudeJsonDesc': '保存在根目录 .claude.json projects 字段下(共享配置)',
'mcp.mcpJsonDesc': '保存在项目 .mcp.json 文件中(推荐用于版本控制)',
// MCP CLI Mode
'mcp.cliMode': 'CLI 模式',
'mcp.claudeMode': 'Claude 模式',
'mcp.codexMode': 'Codex 模式',
// Codex MCP
'mcp.codex.globalServers': 'Codex 全局 MCP 服务器',
'mcp.codex.newServer': '新建服务器',
'mcp.codex.noServers': '未配置 Codex MCP 服务器',
'mcp.codex.noServersHint': '使用 "codex mcp add" 命令或在此处创建',
'mcp.codex.infoTitle': '关于 Codex MCP',
'mcp.codex.infoDesc': 'Codex MCP 服务器仅支持全局配置(存储在 ~/.codex/config.toml。使用 TOML 格式配置。',
'mcp.codex.serverAdded': 'Codex MCP 服务器 "{name}" 已添加',
'mcp.codex.addFailed': '添加 Codex MCP 服务器失败',
'mcp.codex.serverRemoved': 'Codex MCP 服务器 "{name}" 已移除',
'mcp.codex.removeFailed': '移除 Codex MCP 服务器失败',
'mcp.codex.serverToggled': 'Codex MCP 服务器 "{name}" 已{state}',
'mcp.codex.toggleFailed': '切换 Codex MCP 服务器失败',
'mcp.codex.remove': '移除',
'mcp.codex.removeConfirm': '移除 Codex MCP 服务器 "{name}"',
'mcp.codex.copyToClaude': '复制到 Claude',
'mcp.codex.copyToCodex': '复制到 Codex',
'mcp.codex.copyFromClaude': '从 Claude 复制服务器到 Codex',
'mcp.codex.alreadyAdded': '已在 Codex 中',
'mcp.codex.scopeCodex': 'Codex - 全局 (~/.codex/config.toml)',
'mcp.codex.enabledTools': '工具',
'mcp.codex.tools': '个工具已启用',
// Hook Manager
'hook.projectHooks': '项目钩子',
'hook.projectFile': '.claude/settings.json',

View File

@@ -140,7 +140,7 @@ function toggleSection(header) {
function initCollapsibleSections(container) {
setTimeout(() => {
const headers = container.querySelectorAll('.collapsible-header');
headers.forEach(header => {
headers.forEach((header) => {
if (!header._clickBound) {
header._clickBound = true;
header.addEventListener('click', function(e) {

View File

@@ -160,11 +160,13 @@ function showLiteTaskDetailPage(sessionKey) {
</div>
`;
// Initialize collapsible sections
// Initialize collapsible sections and task click handlers
setTimeout(() => {
document.querySelectorAll('.collapsible-header').forEach(header => {
header.addEventListener('click', () => toggleSection(header));
});
// Bind click events to lite task items on initial load
initLiteTaskClickHandlers();
}, 50);
}
@@ -194,11 +196,13 @@ function switchLiteDetailTab(tabName) {
switch (tabName) {
case 'tasks':
contentArea.innerHTML = renderLiteTasksTab(session, tasks, completed, inProgress, pending);
// Re-initialize collapsible sections
// Re-initialize collapsible sections and task click handlers
setTimeout(() => {
document.querySelectorAll('.collapsible-header').forEach(header => {
header.addEventListener('click', () => toggleSection(header));
});
// Bind click events to lite task items
initLiteTaskClickHandlers();
}, 50);
break;
case 'plan':
@@ -259,12 +263,16 @@ function renderLiteTaskDetailItem(sessionId, task) {
const implCount = rawTask.implementation?.length || 0;
const acceptCount = rawTask.acceptance?.length || 0;
// Escape for data attributes
const safeSessionId = escapeHtml(sessionId);
const safeTaskId = escapeHtml(task.id);
return `
<div class="detail-task-item-full lite-task-item" onclick="openTaskDrawerForLite('${sessionId}', '${escapeHtml(task.id)}')" style="cursor: pointer;" title="Click to view details">
<div class="detail-task-item-full lite-task-item" data-session-id="${safeSessionId}" data-task-id="${safeTaskId}" style="cursor: pointer;" title="Click to view details">
<div class="task-item-header-lite">
<span class="task-id-badge">${escapeHtml(task.id)}</span>
<span class="task-title">${escapeHtml(task.title || 'Untitled')}</span>
<button class="btn-view-json" onclick="event.stopPropagation(); showJsonModal('${taskJsonId}', '${escapeHtml(task.id)}')">{ } JSON</button>
<button class="btn-view-json" data-task-json-id="${taskJsonId}" data-task-display-id="${safeTaskId}">{ } JSON</button>
</div>
<div class="task-item-meta-lite">
${action ? `<span class="meta-badge action">${escapeHtml(action)}</span>` : ''}
@@ -285,6 +293,39 @@ function getMetaPreviewForLite(task, rawTask) {
return parts.join(' | ') || 'No meta';
}
/**
* Initialize click handlers for lite task items
*/
function initLiteTaskClickHandlers() {
// Task item click handlers
document.querySelectorAll('.lite-task-item').forEach(item => {
if (!item._clickBound) {
item._clickBound = true;
item.addEventListener('click', function(e) {
// Don't trigger if clicking on JSON button
if (e.target.closest('.btn-view-json')) return;
const sessionId = this.dataset.sessionId;
const taskId = this.dataset.taskId;
openTaskDrawerForLite(sessionId, taskId);
});
}
});
// JSON button click handlers
document.querySelectorAll('.btn-view-json').forEach(btn => {
if (!btn._clickBound) {
btn._clickBound = true;
btn.addEventListener('click', function(e) {
e.stopPropagation();
const taskJsonId = this.dataset.taskJsonId;
const displayId = this.dataset.taskDisplayId;
showJsonModal(taskJsonId, displayId);
});
}
});
}
function openTaskDrawerForLite(sessionId, taskId) {
const session = liteTaskDataStore[currentSessionDetailKey];
if (!session) return;
@@ -454,15 +495,15 @@ async function loadAndRenderLiteContextTab(session, contentArea) {
const response = await fetch(`/api/session-detail?path=${encodeURIComponent(session.path)}&type=context`);
if (response.ok) {
const data = await response.json();
contentArea.innerHTML = renderLiteContextContent(data.context, data.explorations, session);
// Re-initialize collapsible sections for explorations (scoped to contentArea)
contentArea.innerHTML = renderLiteContextContent(data.context, data.explorations, session, data.diagnoses);
// Re-initialize collapsible sections for explorations and diagnoses (scoped to contentArea)
initCollapsibleSections(contentArea);
return;
}
}
// Fallback: show plan context if available
contentArea.innerHTML = renderLiteContextContent(null, null, session);
contentArea.innerHTML = renderLiteContextContent(null, null, session, null);
initCollapsibleSections(contentArea);
} catch (err) {
contentArea.innerHTML = `<div class="tab-error">Failed to load context: ${err.message}</div>`;
@@ -530,7 +571,9 @@ function renderDiagnosesTab(session) {
// Individual diagnosis items
if (diagnoses.items && diagnoses.items.length > 0) {
const diagnosisCards = diagnoses.items.map(diag => renderDiagnosisCard(diag)).join('');
const diagnosisCards = diagnoses.items.map((diag) => {
return renderDiagnosisCard(diag);
}).join('');
sections.push(`
<div class="diagnoses-items-section">
<h4 class="diagnoses-section-title"><i data-lucide="search" class="w-4 h-4 inline mr-1"></i> Diagnosis Details (${diagnoses.items.length})</h4>
@@ -565,7 +608,21 @@ function renderDiagnosisCard(diag) {
function renderDiagnosisContent(diag) {
let content = [];
// Summary/Overview
// Symptom (for detailed diagnosis structure)
if (diag.symptom) {
const symptom = diag.symptom;
content.push(`
<div class="diag-section">
<strong>Symptom:</strong>
${symptom.description ? `<p>${escapeHtml(symptom.description)}</p>` : ''}
${symptom.user_impact ? `<div class="symptom-impact"><strong>User Impact:</strong> ${escapeHtml(symptom.user_impact)}</div>` : ''}
${symptom.frequency ? `<div class="symptom-freq"><strong>Frequency:</strong> <span class="badge">${escapeHtml(symptom.frequency)}</span></div>` : ''}
${symptom.error_message ? `<div class="symptom-error"><strong>Error:</strong> <code>${escapeHtml(symptom.error_message)}</code></div>` : ''}
</div>
`);
}
// Summary/Overview (for simple diagnosis structure)
if (diag.summary || diag.overview) {
content.push(`
<div class="diag-section">
@@ -576,11 +633,34 @@ function renderDiagnosisContent(diag) {
}
// Root Cause Analysis
if (diag.root_cause || diag.root_cause_analysis) {
if (diag.root_cause) {
const rootCause = diag.root_cause;
// Handle both object and string formats
if (typeof rootCause === 'object') {
content.push(`
<div class="diag-section">
<strong>Root Cause:</strong>
${rootCause.file ? `<div class="rc-file"><strong>File:</strong> <code>${escapeHtml(rootCause.file)}</code></div>` : ''}
${rootCause.line_range ? `<div class="rc-line"><strong>Lines:</strong> ${escapeHtml(rootCause.line_range)}</div>` : ''}
${rootCause.function ? `<div class="rc-func"><strong>Function:</strong> <code>${escapeHtml(rootCause.function)}</code></div>` : ''}
${rootCause.issue ? `<p>${escapeHtml(rootCause.issue)}</p>` : ''}
${rootCause.confidence ? `<div class="rc-confidence"><strong>Confidence:</strong> ${(rootCause.confidence * 100).toFixed(0)}%</div>` : ''}
${rootCause.category ? `<div class="rc-category"><strong>Category:</strong> <span class="badge">${escapeHtml(rootCause.category)}</span></div>` : ''}
</div>
`);
} else if (typeof rootCause === 'string') {
content.push(`
<div class="diag-section">
<strong>Root Cause:</strong>
<p>${escapeHtml(rootCause)}</p>
</div>
`);
}
} else if (diag.root_cause_analysis) {
content.push(`
<div class="diag-section">
<strong>Root Cause:</strong>
<p>${escapeHtml(diag.root_cause || diag.root_cause_analysis)}</p>
<p>${escapeHtml(diag.root_cause_analysis)}</p>
</div>
`);
}
@@ -660,6 +740,37 @@ function renderDiagnosisContent(diag) {
`);
}
// Reproduction Steps
if (diag.reproduction_steps && Array.isArray(diag.reproduction_steps)) {
content.push(`
<div class="diag-section">
<strong>Reproduction Steps:</strong>
<ol class="repro-steps-list">
${diag.reproduction_steps.map(step => `<li>${escapeHtml(step)}</li>`).join('')}
</ol>
</div>
`);
}
// Fix Hints
if (diag.fix_hints && Array.isArray(diag.fix_hints)) {
content.push(`
<div class="diag-section">
<strong>Fix Hints (${diag.fix_hints.length}):</strong>
<div class="fix-hints-list">
${diag.fix_hints.map((hint, idx) => `
<div class="fix-hint-item">
<div class="hint-header"><strong>Hint ${idx + 1}:</strong> ${escapeHtml(hint.description || 'No description')}</div>
${hint.approach ? `<div class="hint-approach"><strong>Approach:</strong> ${escapeHtml(hint.approach)}</div>` : ''}
${hint.risk ? `<div class="hint-risk"><strong>Risk:</strong> <span class="badge risk-${hint.risk}">${escapeHtml(hint.risk)}</span></div>` : ''}
${hint.code_example ? `<div class="hint-code"><strong>Code Example:</strong><pre><code>${escapeHtml(hint.code_example)}</code></pre></div>` : ''}
</div>
`).join('')}
</div>
</div>
`);
}
// Recommendations
if (diag.recommendations && Array.isArray(diag.recommendations)) {
content.push(`
@@ -672,10 +783,75 @@ function renderDiagnosisContent(diag) {
`);
}
// If no specific content was rendered, show raw JSON preview
if (content.length === 0) {
// Dependencies
if (diag.dependencies && typeof diag.dependencies === 'string') {
content.push(`
<div class="diag-section">
<strong>Dependencies:</strong>
<p>${escapeHtml(diag.dependencies)}</p>
</div>
`);
}
// Constraints
if (diag.constraints && typeof diag.constraints === 'string') {
content.push(`
<div class="diag-section">
<strong>Constraints:</strong>
<p>${escapeHtml(diag.constraints)}</p>
</div>
`);
}
// Clarification Needs
if (diag.clarification_needs && Array.isArray(diag.clarification_needs)) {
content.push(`
<div class="diag-section">
<strong>Clarification Needs:</strong>
<div class="clarification-list">
${diag.clarification_needs.map(clar => `
<div class="clarification-item">
<div class="clar-question"><strong>Q:</strong> ${escapeHtml(clar.question)}</div>
${clar.context ? `<div class="clar-context"><strong>Context:</strong> ${escapeHtml(clar.context)}</div>` : ''}
${clar.options && Array.isArray(clar.options) ? `
<div class="clar-options">
<strong>Options:</strong>
<ul>
${clar.options.map(opt => `<li>${escapeHtml(opt)}</li>`).join('')}
</ul>
</div>
` : ''}
</div>
`).join('')}
</div>
</div>
`);
}
// Related Issues
if (diag.related_issues && Array.isArray(diag.related_issues)) {
content.push(`
<div class="diag-section">
<strong>Related Issues:</strong>
<ul class="related-issues-list">
${diag.related_issues.map(issue => `
<li>
${issue.type ? `<span class="issue-type-badge">${escapeHtml(issue.type)}</span>` : ''}
${issue.reference ? `<strong>${escapeHtml(issue.reference)}</strong>: ` : ''}
${issue.description ? escapeHtml(issue.description) : ''}
</li>
`).join('')}
</ul>
</div>
`);
}
// If no specific content was rendered, show raw JSON preview
if (content.length === 0) {
console.warn('[DEBUG] No content rendered for diagnosis:', diag);
content.push(`
<div class="diag-section">
<strong>Debug: Raw JSON</strong>
<pre class="json-content">${escapeHtml(JSON.stringify(diag, null, 2))}</pre>
</div>
`);

View File

@@ -17,7 +17,7 @@ const CCW_MCP_TOOLS = [
// Get currently enabled tools from installed config
function getCcwEnabledTools() {
const currentPath = projectPath.replace(/\//g, '\\');
const currentPath = projectPath; // Keep original format (forward slash)
const projectData = mcpAllProjects[currentPath] || {};
const ccwConfig = projectData.mcpServers?.['ccw-tools'];
if (ccwConfig?.env?.CCW_ENABLED_TOOLS) {
@@ -46,7 +46,7 @@ async function renderMcpManager() {
// Load MCP templates
await loadMcpTemplates();
const currentPath = projectPath.replace(/\//g, '\\');
const currentPath = projectPath; // Keep original format (forward slash)
const projectData = mcpAllProjects[currentPath] || {};
const projectServers = projectData.mcpServers || {};
const disabledServers = projectData.disabledMcpServers || [];
@@ -121,8 +121,136 @@ async function renderMcpManager() {
const isCcwToolsInstalled = currentProjectServerNames.includes("ccw-tools");
const enabledTools = getCcwEnabledTools();
// Prepare Codex servers data
const codexServerEntries = Object.entries(codexMcpServers || {});
const codexConfigExists = codexMcpConfig?.exists || false;
const codexConfigPath = codexMcpConfig?.configPath || '~/.codex/config.toml';
container.innerHTML = `
<div class="mcp-manager">
<!-- CLI Mode Toggle -->
<div class="mcp-cli-toggle mb-6">
<div class="flex items-center justify-between bg-card border border-border rounded-lg p-4">
<div class="flex items-center gap-3">
<span class="text-sm font-medium text-foreground">${t('mcp.cliMode')}</span>
<div class="flex items-center bg-muted rounded-lg p-1">
<button class="cli-mode-btn px-4 py-2 text-sm font-medium rounded-md transition-all ${currentCliMode === 'claude' ? 'bg-primary text-primary-foreground shadow-sm' : 'text-muted-foreground hover:text-foreground'}"
onclick="setCliMode('claude')">
<i data-lucide="bot" class="w-4 h-4 inline mr-1.5"></i>
Claude
</button>
<button class="cli-mode-btn px-4 py-2 text-sm font-medium rounded-md transition-all ${currentCliMode === 'codex' ? 'bg-orange-500 text-white shadow-sm' : 'text-muted-foreground hover:text-foreground'}"
onclick="setCliMode('codex')">
<i data-lucide="code-2" class="w-4 h-4 inline mr-1.5"></i>
Codex
</button>
</div>
</div>
<div class="text-xs text-muted-foreground">
${currentCliMode === 'claude'
? `<span class="flex items-center gap-1"><i data-lucide="file-json" class="w-3 h-3"></i> ~/.claude.json</span>`
: `<span class="flex items-center gap-1"><i data-lucide="file-code" class="w-3 h-3"></i> ${codexConfigPath}</span>`
}
</div>
</div>
</div>
${currentCliMode === 'codex' ? `
<!-- Codex MCP Servers Section -->
<div class="mcp-section mb-6">
<div class="flex items-center justify-between mb-4">
<div class="flex items-center gap-3">
<div class="flex items-center gap-2">
<i data-lucide="code-2" class="w-5 h-5 text-orange-500"></i>
<h3 class="text-lg font-semibold text-foreground">${t('mcp.codex.globalServers')}</h3>
</div>
<button class="px-3 py-1.5 text-sm bg-orange-500 text-white rounded-lg hover:opacity-90 transition-opacity flex items-center gap-1"
onclick="openCodexMcpCreateModal()">
<span>+</span> ${t('mcp.codex.newServer')}
</button>
${codexConfigExists ? `
<span class="inline-flex items-center gap-1.5 px-2 py-1 text-xs bg-success/10 text-success rounded-md border border-success/20">
<i data-lucide="file-check" class="w-3.5 h-3.5"></i>
config.toml
</span>
` : `
<span class="inline-flex items-center gap-1.5 px-2 py-1 text-xs bg-muted text-muted-foreground rounded-md border border-border" title="Will create ~/.codex/config.toml">
<i data-lucide="file-plus" class="w-3.5 h-3.5"></i>
Will create config.toml
</span>
`}
</div>
<span class="text-sm text-muted-foreground">${codexServerEntries.length} ${t('mcp.serversAvailable')}</span>
</div>
<!-- Info about Codex MCP -->
<div class="bg-orange-50 dark:bg-orange-950/30 border border-orange-200 dark:border-orange-800 rounded-lg p-4 mb-4">
<div class="flex items-start gap-3">
<i data-lucide="info" class="w-5 h-5 text-orange-500 shrink-0 mt-0.5"></i>
<div class="text-sm">
<p class="text-orange-800 dark:text-orange-200 font-medium mb-1">${t('mcp.codex.infoTitle')}</p>
<p class="text-orange-700 dark:text-orange-300 text-xs">${t('mcp.codex.infoDesc')}</p>
</div>
</div>
</div>
${codexServerEntries.length === 0 ? `
<div class="mcp-empty-state bg-card border border-border rounded-lg p-6 text-center">
<div class="text-muted-foreground mb-3"><i data-lucide="plug" class="w-10 h-10 mx-auto"></i></div>
<p class="text-muted-foreground">${t('mcp.codex.noServers')}</p>
<p class="text-sm text-muted-foreground mt-1">${t('mcp.codex.noServersHint')}</p>
</div>
` : `
<div class="mcp-server-grid grid gap-3">
${codexServerEntries.map(([serverName, serverConfig]) => {
return renderCodexServerCard(serverName, serverConfig);
}).join('')}
</div>
`}
</div>
<!-- Copy Claude Servers to Codex -->
${Object.keys(mcpUserServers || {}).length > 0 ? `
<div class="mcp-section mb-6">
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-semibold text-foreground flex items-center gap-2">
<i data-lucide="copy" class="w-5 h-5"></i>
${t('mcp.codex.copyFromClaude')}
</h3>
<span class="text-sm text-muted-foreground">${Object.keys(mcpUserServers || {}).length} ${t('mcp.serversAvailable')}</span>
</div>
<div class="mcp-server-grid grid gap-3">
${Object.entries(mcpUserServers || {}).map(([serverName, serverConfig]) => {
const alreadyInCodex = codexMcpServers && codexMcpServers[serverName];
return `
<div class="mcp-server-card bg-card border ${alreadyInCodex ? 'border-success/50' : 'border-border'} border-dashed rounded-lg p-4 hover:shadow-md transition-all">
<div class="flex items-start justify-between mb-3">
<div class="flex items-center gap-2">
<i data-lucide="bot" class="w-5 h-5 text-primary"></i>
<h4 class="font-semibold text-foreground">${escapeHtml(serverName)}</h4>
${alreadyInCodex ? `<span class="text-xs px-2 py-0.5 bg-success/10 text-success rounded-full">${t('mcp.codex.alreadyAdded')}</span>` : ''}
</div>
${!alreadyInCodex ? `
<button class="px-3 py-1 text-xs bg-orange-500 text-white rounded hover:opacity-90 transition-opacity"
onclick="copyClaudeServerToCodex('${escapeHtml(serverName)}', ${JSON.stringify(serverConfig).replace(/'/g, "&#39;")})"
title="${t('mcp.codex.copyToCodex')}">
<i data-lucide="arrow-right" class="w-3.5 h-3.5 inline"></i> Codex
</button>
` : ''}
</div>
<div class="mcp-server-details text-sm space-y-1">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.cmd')}</span>
<span class="truncate" title="${escapeHtml(serverConfig.command || 'N/A')}">${escapeHtml(serverConfig.command || 'N/A')}</span>
</div>
</div>
</div>
`;
}).join('')}
</div>
</div>
` : ''}
` : `
<!-- CCW Tools MCP Server Card -->
<div class="mcp-section mb-6">
<div class="ccw-tools-card bg-gradient-to-br from-primary/10 to-primary/5 border-2 ${isCcwToolsInstalled ? 'border-success' : 'border-primary/30'} rounded-lg p-6 hover:shadow-lg transition-all">
@@ -164,17 +292,32 @@ async function renderMcpManager() {
</div>
</div>
</div>
<div class="shrink-0">
<div class="shrink-0 flex gap-2">
${isCcwToolsInstalled ? `
<button class="px-4 py-2 text-sm bg-primary text-primary-foreground rounded-lg hover:opacity-90 transition-opacity"
onclick="updateCcwToolsMcp()">
Update
<button class="px-4 py-2 text-sm bg-primary text-primary-foreground rounded-lg hover:opacity-90 transition-opacity flex items-center gap-1"
onclick="updateCcwToolsMcp('workspace')"
title="${t('mcp.updateInWorkspace')}">
<i data-lucide="folder" class="w-4 h-4"></i>
${t('mcp.updateInWorkspace')}
</button>
<button class="px-4 py-2 text-sm bg-success text-success-foreground rounded-lg hover:opacity-90 transition-opacity flex items-center gap-1"
onclick="updateCcwToolsMcp('global')"
title="${t('mcp.updateInGlobal')}">
<i data-lucide="globe" class="w-4 h-4"></i>
${t('mcp.updateInGlobal')}
</button>
` : `
<button class="px-4 py-2 text-sm bg-primary text-primary-foreground rounded-lg hover:opacity-90 transition-opacity flex items-center gap-2"
onclick="installCcwToolsMcp()">
<i data-lucide="download" class="w-4 h-4"></i>
Install
<button class="px-4 py-2 text-sm bg-primary text-primary-foreground rounded-lg hover:opacity-90 transition-opacity flex items-center gap-1"
onclick="installCcwToolsMcp('workspace')"
title="${t('mcp.installToWorkspace')}">
<i data-lucide="folder" class="w-4 h-4"></i>
${t('mcp.installToWorkspace')}
</button>
<button class="px-4 py-2 text-sm bg-success text-success-foreground rounded-lg hover:opacity-90 transition-opacity flex items-center gap-1"
onclick="installCcwToolsMcp('global')"
title="${t('mcp.installToGlobal')}">
<i data-lucide="globe" class="w-4 h-4"></i>
${t('mcp.installToGlobal')}
</button>
`}
</div>
@@ -300,12 +443,12 @@ async function renderMcpManager() {
<div class="mcp-server-details text-sm space-y-1 mb-3">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">cmd</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.cmd')}</span>
<span class="truncate text-xs" title="${escapeHtml(template.serverConfig.command)}">${escapeHtml(template.serverConfig.command)}</span>
</div>
${template.serverConfig.args && template.serverConfig.args.length > 0 ? `
<div class="flex items-start gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">args</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">${t('mcp.args')}</span>
<span class="text-xs font-mono truncate" title="${escapeHtml(template.serverConfig.args.join(' '))}">${escapeHtml(template.serverConfig.args.slice(0, 2).join(' '))}${template.serverConfig.args.length > 2 ? '...' : ''}</span>
</div>
` : ''}
@@ -343,7 +486,8 @@ async function renderMcpManager() {
</div>
` : ''}
<!-- All Projects MCP Overview Table -->
<!-- All Projects MCP Overview Table (Claude mode only) -->
${currentCliMode === 'claude' ? `
<div class="mcp-section mt-6">
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-semibold text-foreground">${t('mcp.allProjects')}</h3>
@@ -411,6 +555,25 @@ async function renderMcpManager() {
</table>
</div>
</div>
` : ''}
<!-- MCP Server Details Modal -->
<div id="mcpDetailsModal" class="fixed inset-0 bg-black/50 flex items-center justify-center z-50 hidden">
<div class="bg-card border border-border rounded-lg shadow-xl max-w-2xl w-full mx-4 max-h-[80vh] overflow-hidden flex flex-col">
<!-- Modal Header -->
<div class="flex items-center justify-between px-6 py-4 border-b border-border">
<h2 class="text-lg font-semibold text-foreground">${t('mcp.detailsModal.title')}</h2>
<button id="mcpDetailsModalClose" class="text-muted-foreground hover:text-foreground transition-colors">
<i data-lucide="x" class="w-5 h-5"></i>
</button>
</div>
<!-- Modal Body -->
<div id="mcpDetailsModalBody" class="px-6 py-4 overflow-y-auto flex-1">
<!-- Content will be dynamically filled -->
</div>
</div>
</div>
</div>
`;
@@ -431,15 +594,20 @@ function renderProjectAvailableServerCard(entry) {
// Source badge
let sourceBadge = '';
if (source === 'enterprise') {
sourceBadge = '<span class="text-xs px-2 py-0.5 bg-warning/20 text-warning rounded-full">Enterprise</span>';
sourceBadge = `<span class="text-xs px-2 py-0.5 bg-warning/20 text-warning rounded-full">${t('mcp.sourceEnterprise')}</span>`;
} else if (source === 'global') {
sourceBadge = '<span class="text-xs px-2 py-0.5 bg-success/10 text-success rounded-full">Global</span>';
sourceBadge = `<span class="text-xs px-2 py-0.5 bg-success/10 text-success rounded-full">${t('mcp.sourceGlobal')}</span>`;
} else if (source === 'project') {
sourceBadge = '<span class="text-xs px-2 py-0.5 bg-primary/10 text-primary rounded-full">Project</span>';
sourceBadge = `<span class="text-xs px-2 py-0.5 bg-primary/10 text-primary rounded-full">${t('mcp.sourceProject')}</span>`;
}
return `
<div class="mcp-server-card bg-card border border-border rounded-lg p-4 hover:shadow-md transition-all ${canToggle && !isEnabled ? 'opacity-60' : ''}">
<div class="mcp-server-card bg-card border border-border rounded-lg p-4 hover:shadow-md transition-all cursor-pointer ${canToggle && !isEnabled ? 'opacity-60' : ''}"
data-server-name="${escapeHtml(name)}"
data-server-config="${escapeHtml(JSON.stringify(config))}"
data-server-source="${source}"
data-action="view-details"
title="${t('mcp.clickToViewDetails')}">
<div class="flex items-start justify-between mb-3">
<div class="flex items-center gap-2">
<span>${canToggle && isEnabled ? '<i data-lucide="check-circle" class="w-5 h-5 text-success"></i>' : '<i data-lucide="circle" class="w-5 h-5 text-muted-foreground"></i>'}</span>
@@ -447,7 +615,7 @@ function renderProjectAvailableServerCard(entry) {
${sourceBadge}
</div>
${canToggle ? `
<label class="mcp-toggle relative inline-flex items-center cursor-pointer">
<label class="mcp-toggle relative inline-flex items-center cursor-pointer" onclick="event.stopPropagation()">
<input type="checkbox" class="sr-only peer"
${isEnabled ? 'checked' : ''}
data-server-name="${escapeHtml(name)}"
@@ -459,33 +627,25 @@ function renderProjectAvailableServerCard(entry) {
<div class="mcp-server-details text-sm space-y-1">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">cmd</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.cmd')}</span>
<span class="truncate" title="${escapeHtml(command)}">${escapeHtml(command)}</span>
</div>
${args.length > 0 ? `
<div class="flex items-start gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">args</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">${t('mcp.args')}</span>
<span class="text-xs font-mono truncate" title="${escapeHtml(args.join(' '))}">${escapeHtml(args.slice(0, 3).join(' '))}${args.length > 3 ? '...' : ''}</span>
</div>
` : ''}
${hasEnv ? `
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">env</span>
<span class="text-xs">${Object.keys(config.env).length} variables</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.env')}</span>
<span class="text-xs">${Object.keys(config.env).length} ${t('mcp.variables')}</span>
</div>
` : ''}
</div>
<div class="mt-3 pt-3 border-t border-border flex items-center justify-between gap-2">
<div class="mt-3 pt-3 border-t border-border flex items-center justify-between gap-2" onclick="event.stopPropagation()">
<div class="flex items-center gap-2">
<button class="text-xs text-primary hover:text-primary/80 transition-colors flex items-center gap-1"
data-server-name="${escapeHtml(name)}"
data-server-config="${escapeHtml(JSON.stringify(config))}"
data-scope="${source === 'global' ? 'global' : 'project'}"
data-action="copy-install-cmd">
<i data-lucide="copy" class="w-3 h-3"></i>
${t('mcp.copyInstallCmd')}
</button>
<button class="text-xs text-success hover:text-success/80 transition-colors flex items-center gap-1"
data-server-name="${escapeHtml(name)}"
data-server-config="${escapeHtml(JSON.stringify(config))}"
@@ -525,19 +685,19 @@ function renderGlobalManagementCard(serverName, serverConfig) {
<div class="mcp-server-details text-sm space-y-1">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${serverType === 'stdio' ? 'cmd' : 'url'}</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${serverType === 'stdio' ? t('mcp.cmd') : t('mcp.url')}</span>
<span class="truncate" title="${escapeHtml(command)}">${escapeHtml(command)}</span>
</div>
${args.length > 0 ? `
<div class="flex items-start gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">args</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">${t('mcp.args')}</span>
<span class="text-xs font-mono truncate" title="${escapeHtml(args.join(' '))}">${escapeHtml(args.slice(0, 3).join(' '))}${args.length > 3 ? '...' : ''}</span>
</div>
` : ''}
${hasEnv ? `
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">env</span>
<span class="text-xs">${Object.keys(serverConfig.env).length} variables</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.env')}</span>
<span class="text-xs">${Object.keys(serverConfig.env).length} ${t('mcp.variables')}</span>
</div>
` : ''}
<div class="flex items-center gap-2 text-muted-foreground mt-1">
@@ -545,15 +705,7 @@ function renderGlobalManagementCard(serverName, serverConfig) {
</div>
</div>
<div class="mt-3 pt-3 border-t border-border flex items-center justify-between">
<button class="text-xs text-primary hover:text-primary/80 transition-colors flex items-center gap-1"
data-server-name="${escapeHtml(serverName)}"
data-server-config="${escapeHtml(JSON.stringify(serverConfig))}"
data-scope="global"
data-action="copy-install-cmd">
<i data-lucide="copy" class="w-3 h-3"></i>
${t('mcp.copyInstallCmd')}
</button>
<div class="mt-3 pt-3 border-t border-border flex items-center justify-end">
<button class="text-xs text-destructive hover:text-destructive/80 transition-colors"
data-server-name="${escapeHtml(serverName)}"
data-action="remove-global">
@@ -617,35 +769,162 @@ function renderAvailableServerCard(serverName, serverInfo) {
<div class="mcp-server-details text-sm space-y-1">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">cmd</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.cmd')}</span>
<span class="truncate" title="${escapeHtml(command)}">${escapeHtml(command)}</span>
</div>
${argsPreview ? `
<div class="flex items-start gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">args</span>
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">${t('mcp.args')}</span>
<span class="text-xs font-mono truncate" title="${escapeHtml(args.join(' '))}">${escapeHtml(argsPreview)}</span>
</div>
` : ''}
<div class="flex items-center gap-2 text-muted-foreground">
<span class="text-xs">Used in ${usedIn.length} project${usedIn.length !== 1 ? 's' : ''}</span>
${sourceProjectName ? `<span class="text-xs text-muted-foreground/70">• from ${escapeHtml(sourceProjectName)}</span>` : ''}
<span class="text-xs">${t('mcp.usedInCount').replace('{count}', usedIn.length).replace('{s}', usedIn.length !== 1 ? 's' : '')}</span>
${sourceProjectName ? `<span class="text-xs text-muted-foreground/70">• ${t('mcp.from')} ${escapeHtml(sourceProjectName)}</span>` : ''}
</div>
</div>
<div class="mt-3 pt-3 border-t border-border">
<div class="mt-3 pt-3 border-t border-border flex items-center gap-2">
<button class="text-xs text-primary hover:text-primary/80 transition-colors flex items-center gap-1"
data-server-name="${escapeHtml(originalName)}"
data-server-config="${escapeHtml(JSON.stringify(serverConfig))}"
data-scope="project"
data-action="copy-install-cmd">
<i data-lucide="copy" class="w-3 h-3"></i>
${t('mcp.copyInstallCmd')}
data-action="install-to-project"
title="${t('mcp.installToProject')}">
<i data-lucide="download" class="w-3 h-3"></i>
${t('mcp.installToProject')}
</button>
<button class="text-xs text-success hover:text-success/80 transition-colors flex items-center gap-1"
data-server-name="${escapeHtml(originalName)}"
data-server-config="${escapeHtml(JSON.stringify(serverConfig))}"
data-action="install-to-global"
title="${t('mcp.installToGlobal')}">
<i data-lucide="globe" class="w-3 h-3"></i>
${t('mcp.installToGlobal')}
</button>
</div>
</div>
`;
}
// ========================================
// Codex MCP Server Card Renderer
// ========================================
function renderCodexServerCard(serverName, serverConfig) {
const isStdio = !!serverConfig.command;
const isHttp = !!serverConfig.url;
const isEnabled = serverConfig.enabled !== false; // Default to enabled
const command = serverConfig.command || serverConfig.url || 'N/A';
const args = serverConfig.args || [];
const hasEnv = serverConfig.env && Object.keys(serverConfig.env).length > 0;
// Server type badge
const typeBadge = isHttp
? `<span class="text-xs px-2 py-0.5 bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300 rounded-full">HTTP</span>`
: `<span class="text-xs px-2 py-0.5 bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-300 rounded-full">STDIO</span>`;
return `
<div class="mcp-server-card bg-card border border-orange-200 dark:border-orange-800 rounded-lg p-4 hover:shadow-md transition-all ${!isEnabled ? 'opacity-60' : ''}"
data-server-name="${escapeHtml(serverName)}"
data-server-config="${escapeHtml(JSON.stringify(serverConfig))}"
data-cli-type="codex">
<div class="flex items-start justify-between mb-3">
<div class="flex items-center gap-2">
<span>${isEnabled ? '<i data-lucide="check-circle" class="w-5 h-5 text-orange-500"></i>' : '<i data-lucide="circle" class="w-5 h-5 text-muted-foreground"></i>'}</span>
<h4 class="font-semibold text-foreground">${escapeHtml(serverName)}</h4>
${typeBadge}
</div>
<label class="mcp-toggle relative inline-flex items-center cursor-pointer" onclick="event.stopPropagation()">
<input type="checkbox" class="sr-only peer"
${isEnabled ? 'checked' : ''}
data-server-name="${escapeHtml(serverName)}"
data-action="toggle-codex">
<div class="w-9 h-5 bg-hover peer-focus:outline-none rounded-full peer peer-checked:after:translate-x-full peer-checked:after:border-white after:content-[''] after:absolute after:top-[2px] after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-4 after:w-4 after:transition-all peer-checked:bg-orange-500"></div>
</label>
</div>
<div class="mcp-server-details text-sm space-y-1">
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${isHttp ? t('mcp.url') : t('mcp.cmd')}</span>
<span class="truncate" title="${escapeHtml(command)}">${escapeHtml(command)}</span>
</div>
${args.length > 0 ? `
<div class="flex items-start gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded shrink-0">${t('mcp.args')}</span>
<span class="text-xs font-mono truncate" title="${escapeHtml(args.join(' '))}">${escapeHtml(args.slice(0, 3).join(' '))}${args.length > 3 ? '...' : ''}</span>
</div>
` : ''}
${hasEnv ? `
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.env')}</span>
<span class="text-xs">${Object.keys(serverConfig.env).length} ${t('mcp.variables')}</span>
</div>
` : ''}
${serverConfig.enabled_tools ? `
<div class="flex items-center gap-2 text-muted-foreground">
<span class="font-mono text-xs bg-muted px-1.5 py-0.5 rounded">${t('mcp.codex.enabledTools')}</span>
<span class="text-xs">${serverConfig.enabled_tools.length} ${t('mcp.codex.tools')}</span>
</div>
` : ''}
</div>
<div class="mt-3 pt-3 border-t border-border flex items-center justify-between gap-2" onclick="event.stopPropagation()">
<div class="flex items-center gap-2">
<button class="text-xs text-primary hover:text-primary/80 transition-colors flex items-center gap-1"
onclick="copyCodexServerToClaude('${escapeHtml(serverName)}', ${JSON.stringify(serverConfig).replace(/'/g, "&#39;")})"
title="${t('mcp.codex.copyToClaude')}">
<i data-lucide="copy" class="w-3 h-3"></i>
${t('mcp.codex.copyToClaude')}
</button>
</div>
<button class="text-xs text-destructive hover:text-destructive/80 transition-colors"
data-server-name="${escapeHtml(serverName)}"
data-action="remove-codex">
${t('mcp.codex.remove')}
</button>
</div>
</div>
`;
}
// ========================================
// Codex MCP Create Modal
// ========================================
function openCodexMcpCreateModal() {
// Reuse the existing modal with different settings
const modal = document.getElementById('mcpCreateModal');
if (modal) {
modal.classList.remove('hidden');
// Reset to form mode
mcpCreateMode = 'form';
switchMcpCreateTab('form');
// Clear form
document.getElementById('mcpServerName').value = '';
document.getElementById('mcpServerCommand').value = '';
document.getElementById('mcpServerArgs').value = '';
document.getElementById('mcpServerEnv').value = '';
// Clear JSON input
document.getElementById('mcpServerJson').value = '';
document.getElementById('mcpJsonPreview').classList.add('hidden');
// Set scope to codex
const scopeSelect = document.getElementById('mcpServerScope');
if (scopeSelect) {
// Add codex option if not exists
if (!scopeSelect.querySelector('option[value="codex"]')) {
const codexOption = document.createElement('option');
codexOption.value = 'codex';
codexOption.textContent = t('mcp.codex.scopeCodex');
scopeSelect.appendChild(codexOption);
}
scopeSelect.value = 'codex';
}
// Focus on name input
document.getElementById('mcpServerName').focus();
// Setup JSON input listener
setupMcpJsonListener();
}
}
function attachMcpEventListeners() {
// Toggle switches
@@ -692,13 +971,21 @@ function attachMcpEventListeners() {
});
});
// Copy install command buttons
document.querySelectorAll('.mcp-server-card button[data-action="copy-install-cmd"]').forEach(btn => {
// Install to project buttons
document.querySelectorAll('.mcp-server-card button[data-action="install-to-project"]').forEach(btn => {
btn.addEventListener('click', async (e) => {
const serverName = btn.dataset.serverName;
const serverConfig = JSON.parse(btn.dataset.serverConfig);
const scope = btn.dataset.scope || 'project';
await copyMcpInstallCommand(serverName, serverConfig, scope);
await installMcpToProject(serverName, serverConfig);
});
});
// Install to global buttons
document.querySelectorAll('.mcp-server-card button[data-action="install-to-global"]').forEach(btn => {
btn.addEventListener('click', async (e) => {
const serverName = btn.dataset.serverName;
const serverConfig = JSON.parse(btn.dataset.serverConfig);
await addGlobalMcpServer(serverName, serverConfig);
});
});
@@ -729,6 +1016,142 @@ function attachMcpEventListeners() {
}
});
});
// ========================================
// Codex MCP Event Listeners
// ========================================
// Toggle Codex MCP servers
document.querySelectorAll('.mcp-server-card input[data-action="toggle-codex"]').forEach(input => {
input.addEventListener('change', async (e) => {
const serverName = e.target.dataset.serverName;
const enable = e.target.checked;
await toggleCodexMcpServer(serverName, enable);
});
});
// Remove Codex MCP servers
document.querySelectorAll('.mcp-server-card button[data-action="remove-codex"]').forEach(btn => {
btn.addEventListener('click', async (e) => {
const serverName = btn.dataset.serverName;
if (confirm(t('mcp.codex.removeConfirm', { name: serverName }))) {
await removeCodexMcpServer(serverName);
}
});
});
// View details - click on server card
document.querySelectorAll('.mcp-server-card[data-action="view-details"]').forEach(card => {
card.addEventListener('click', (e) => {
const serverName = card.dataset.serverName;
const serverConfig = JSON.parse(card.dataset.serverConfig);
const serverSource = card.dataset.serverSource;
showMcpDetails(serverName, serverConfig, serverSource);
});
});
// Modal close button
const closeBtn = document.getElementById('mcpDetailsModalClose');
const modal = document.getElementById('mcpDetailsModal');
if (closeBtn && modal) {
closeBtn.addEventListener('click', () => {
modal.classList.add('hidden');
});
// Close on background click
modal.addEventListener('click', (e) => {
if (e.target === modal) {
modal.classList.add('hidden');
}
});
}
}
// ========================================
// MCP Details Modal
// ========================================
function showMcpDetails(serverName, serverConfig, serverSource) {
const modal = document.getElementById('mcpDetailsModal');
const modalBody = document.getElementById('mcpDetailsModalBody');
if (!modal || !modalBody) return;
// Build source badge
let sourceBadge = '';
if (serverSource === 'enterprise') {
sourceBadge = `<span class="inline-flex items-center px-2 py-1 text-xs font-semibold rounded-full bg-warning/20 text-warning">${t('mcp.sourceEnterprise')}</span>`;
} else if (serverSource === 'global') {
sourceBadge = `<span class="inline-flex items-center px-2 py-1 text-xs font-semibold rounded-full bg-success/10 text-success">${t('mcp.sourceGlobal')}</span>`;
} else if (serverSource === 'project') {
sourceBadge = `<span class="inline-flex items-center px-2 py-1 text-xs font-semibold rounded-full bg-primary/10 text-primary">${t('mcp.sourceProject')}</span>`;
}
// Build environment variables display
let envHtml = '';
if (serverConfig.env && Object.keys(serverConfig.env).length > 0) {
envHtml = '<div class="mt-4"><h4 class="font-semibold text-sm text-foreground mb-2">' + t('mcp.env') + '</h4><div class="bg-muted rounded-lg p-3 space-y-1 font-mono text-xs">';
for (const [key, value] of Object.entries(serverConfig.env)) {
envHtml += `<div class="flex items-start gap-2"><span class="text-muted-foreground shrink-0">${escapeHtml(key)}:</span><span class="text-foreground break-all">${escapeHtml(value)}</span></div>`;
}
envHtml += '</div></div>';
} else {
envHtml = '<div class="mt-4"><h4 class="font-semibold text-sm text-foreground mb-2">' + t('mcp.env') + '</h4><p class="text-sm text-muted-foreground">' + t('mcp.detailsModal.noEnv') + '</p></div>';
}
modalBody.innerHTML = `
<div class="space-y-4">
<!-- Server Name and Source -->
<div>
<label class="text-xs font-semibold text-muted-foreground uppercase tracking-wide">${t('mcp.detailsModal.serverName')}</label>
<div class="mt-1 flex items-center gap-2">
<h3 class="text-xl font-bold text-foreground">${escapeHtml(serverName)}</h3>
${sourceBadge}
</div>
</div>
<!-- Configuration -->
<div>
<h4 class="font-semibold text-sm text-foreground mb-2">${t('mcp.detailsModal.configuration')}</h4>
<div class="space-y-2">
<!-- Command -->
<div class="flex items-start gap-3">
<span class="font-mono text-xs bg-muted px-2 py-1 rounded shrink-0">${t('mcp.cmd')}</span>
<code class="text-sm font-mono text-foreground break-all">${escapeHtml(serverConfig.command || serverConfig.url || 'N/A')}</code>
</div>
<!-- Arguments -->
${serverConfig.args && serverConfig.args.length > 0 ? `
<div class="flex items-start gap-3">
<span class="font-mono text-xs bg-muted px-2 py-1 rounded shrink-0">${t('mcp.args')}</span>
<div class="flex-1 space-y-1">
${serverConfig.args.map((arg, index) => `
<div class="text-sm font-mono text-foreground flex items-center gap-2">
<span class="text-muted-foreground">[${index}]</span>
<code class="break-all">${escapeHtml(arg)}</code>
</div>
`).join('')}
</div>
</div>
` : ''}
</div>
</div>
<!-- Environment Variables -->
${envHtml}
<!-- Raw JSON -->
<div>
<h4 class="font-semibold text-sm text-foreground mb-2">Raw JSON</h4>
<pre class="bg-muted rounded-lg p-3 text-xs font-mono overflow-x-auto">${escapeHtml(JSON.stringify(serverConfig, null, 2))}</pre>
</div>
</div>
`;
// Show modal
modal.classList.remove('hidden');
// Re-initialize Lucide icons in modal
if (typeof lucide !== 'undefined') lucide.createIcons();
}
// ========================================
@@ -788,15 +1211,15 @@ async function saveMcpAsTemplate(serverName, serverConfig) {
const data = await response.json();
if (data.success) {
showNotification(t('mcp.templateSaved', { name: templateName }), 'success');
showRefreshToast(t('mcp.templateSaved', { name: templateName }), 'success');
await loadMcpTemplates();
await renderMcpManager(); // Refresh view
} else {
showNotification(t('mcp.templateSaveFailed', { error: data.error }), 'error');
showRefreshToast(t('mcp.templateSaveFailed', { error: data.error }), 'error');
}
} catch (error) {
console.error('[MCP] Save template error:', error);
showNotification(t('mcp.templateSaveFailed', { error: error.message }), 'error');
showRefreshToast(t('mcp.templateSaveFailed', { error: error.message }), 'error');
}
}
@@ -808,7 +1231,7 @@ async function installFromTemplate(templateName, scope = 'project') {
// Find template
const template = mcpTemplates.find(t => t.name === templateName);
if (!template) {
showNotification(t('mcp.templateNotFound', { name: templateName }), 'error');
showRefreshToast(t('mcp.templateNotFound', { name: templateName }), 'error');
return;
}
@@ -823,11 +1246,11 @@ async function installFromTemplate(templateName, scope = 'project') {
await addGlobalMcpServer(serverName, template.serverConfig);
}
showNotification(t('mcp.templateInstalled', { name: serverName }), 'success');
showRefreshToast(t('mcp.templateInstalled', { name: serverName }), 'success');
await renderMcpManager();
} catch (error) {
console.error('[MCP] Install from template error:', error);
showNotification(t('mcp.templateInstallFailed', { error: error.message }), 'error');
showRefreshToast(t('mcp.templateInstallFailed', { error: error.message }), 'error');
}
}
@@ -843,14 +1266,14 @@ async function deleteMcpTemplate(templateName) {
const data = await response.json();
if (data.success) {
showNotification(t('mcp.templateDeleted', { name: templateName }), 'success');
showRefreshToast(t('mcp.templateDeleted', { name: templateName }), 'success');
await loadMcpTemplates();
await renderMcpManager();
} else {
showNotification(t('mcp.templateDeleteFailed', { error: data.error }), 'error');
showRefreshToast(t('mcp.templateDeleteFailed', { error: data.error }), 'error');
}
} catch (error) {
console.error('[MCP] Delete template error:', error);
showNotification(t('mcp.templateDeleteFailed', { error: error.message }), 'error');
showRefreshToast(t('mcp.templateDeleteFailed', { error: error.message }), 'error');
}
}

View File

@@ -1,10 +1,11 @@
/**
* CLI Configuration Manager
* Handles loading, saving, and managing CLI tool configurations
* Stores config in .workflow/cli-config.json
* Stores config in centralized storage (~/.ccw/projects/{id}/config/)
*/
import * as fs from 'fs';
import * as path from 'path';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
// ========== Types ==========
@@ -50,20 +51,15 @@ export const DEFAULT_CONFIG: CliConfig = {
}
};
const CONFIG_DIR = '.workflow';
const CONFIG_FILE = 'cli-config.json';
// ========== Helper Functions ==========
function getConfigPath(baseDir: string): string {
return path.join(baseDir, CONFIG_DIR, CONFIG_FILE);
return StoragePaths.project(baseDir).cliConfig;
}
function ensureConfigDir(baseDir: string): void {
const configDir = path.join(baseDir, CONFIG_DIR);
if (!fs.existsSync(configDir)) {
fs.mkdirSync(configDir, { recursive: true });
}
function ensureConfigDirForProject(baseDir: string): void {
const configDir = StoragePaths.project(baseDir).config;
ensureStorageDir(configDir);
}
function isValidToolName(tool: string): tool is CliToolName {
@@ -145,7 +141,7 @@ export function loadCliConfig(baseDir: string): CliConfig {
* Save CLI configuration to .workflow/cli-config.json
*/
export function saveCliConfig(baseDir: string, config: CliConfig): void {
ensureConfigDir(baseDir);
ensureConfigDirForProject(baseDir);
const configPath = getConfigPath(baseDir);
try {

View File

@@ -29,9 +29,6 @@ import {
getPrimaryModel
} from './cli-config-manager.js';
// CLI History storage path
const CLI_HISTORY_DIR = join(process.cwd(), '.workflow', '.cli-history');
// Lazy-loaded SQLite store module
let sqliteStoreModule: typeof import('./cli-history-store.js') | null = null;

View File

@@ -7,6 +7,7 @@ import Database from 'better-sqlite3';
import { existsSync, mkdirSync, readdirSync, readFileSync, statSync, unlinkSync, rmdirSync } from 'fs';
import { join } from 'path';
import { parseSessionFile, formatConversation, extractConversationPairs, type ParsedSession, type ParsedTurn } from './session-content-parser.js';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
// Types
export interface ConversationTurn {
@@ -97,12 +98,12 @@ export class CliHistoryStore {
private dbPath: string;
constructor(baseDir: string) {
const historyDir = join(baseDir, '.workflow', '.cli-history');
if (!existsSync(historyDir)) {
mkdirSync(historyDir, { recursive: true });
}
// Use centralized storage path
const paths = StoragePaths.project(baseDir);
const historyDir = paths.cliHistory;
ensureStorageDir(historyDir);
this.dbPath = join(historyDir, 'history.db');
this.dbPath = paths.historyDb;
this.db = new Database(this.dbPath);
this.db.pragma('journal_mode = WAL');
this.db.pragma('synchronous = NORMAL');

View File

@@ -1,6 +1,7 @@
import { resolve, join, relative, isAbsolute } from 'path';
import { existsSync, mkdirSync, realpathSync, statSync, readFileSync, writeFileSync } from 'fs';
import { homedir } from 'os';
import { StoragePaths, ensureStorageDir, LegacyPaths } from '../config/storage-paths.js';
/**
* Validation result for path operations
@@ -212,10 +213,24 @@ export function normalizePathForDisplay(filePath: string): string {
return filePath.replace(/\\/g, '/');
}
// Recent paths storage file
const RECENT_PATHS_FILE = join(homedir(), '.ccw-recent-paths.json');
// Recent paths storage - uses centralized storage with backward compatibility
const MAX_RECENT_PATHS = 10;
/**
* Get the recent paths file location
* Uses new location but falls back to legacy location for backward compatibility
*/
function getRecentPathsFile(): string {
const newPath = StoragePaths.global.recentPaths();
const legacyPath = LegacyPaths.recentPaths();
// Backward compatibility: use legacy if it exists and new doesn't
if (!existsSync(newPath) && existsSync(legacyPath)) {
return legacyPath;
}
return newPath;
}
/**
* Recent paths data structure
*/
@@ -229,8 +244,9 @@ interface RecentPathsData {
*/
export function getRecentPaths(): string[] {
try {
if (existsSync(RECENT_PATHS_FILE)) {
const content = readFileSync(RECENT_PATHS_FILE, 'utf8');
const recentPathsFile = getRecentPathsFile();
if (existsSync(recentPathsFile)) {
const content = readFileSync(recentPathsFile, 'utf8');
const data = JSON.parse(content) as RecentPathsData;
return Array.isArray(data.paths) ? data.paths : [];
}
@@ -258,8 +274,10 @@ export function trackRecentPath(projectPath: string): void {
// Limit to max
paths = paths.slice(0, MAX_RECENT_PATHS);
// Save
writeFileSync(RECENT_PATHS_FILE, JSON.stringify({ paths }, null, 2), 'utf8');
// Save to new centralized location
const recentPathsFile = StoragePaths.global.recentPaths();
ensureStorageDir(StoragePaths.global.config());
writeFileSync(recentPathsFile, JSON.stringify({ paths }, null, 2), 'utf8');
} catch {
// Ignore errors
}
@@ -270,9 +288,9 @@ export function trackRecentPath(projectPath: string): void {
*/
export function clearRecentPaths(): void {
try {
if (existsSync(RECENT_PATHS_FILE)) {
writeFileSync(RECENT_PATHS_FILE, JSON.stringify({ paths: [] }, null, 2), 'utf8');
}
const recentPathsFile = StoragePaths.global.recentPaths();
ensureStorageDir(StoragePaths.global.config());
writeFileSync(recentPathsFile, JSON.stringify({ paths: [] }, null, 2), 'utf8');
} catch {
// Ignore errors
}
@@ -293,8 +311,10 @@ export function removeRecentPath(pathToRemove: string): boolean {
paths = paths.filter(p => normalizePathForDisplay(p) !== normalized);
if (paths.length < originalLength) {
// Save updated list
writeFileSync(RECENT_PATHS_FILE, JSON.stringify({ paths }, null, 2), 'utf8');
// Save updated list to new centralized location
const recentPathsFile = StoragePaths.global.recentPaths();
ensureStorageDir(StoragePaths.global.config());
writeFileSync(recentPathsFile, JSON.stringify({ paths }, null, 2), 'utf8');
return true;
}
return false;

View File

@@ -30,6 +30,11 @@ semantic = [
"fastembed>=0.2",
]
# Full features including tiktoken for accurate token counting
full = [
"tiktoken>=0.5.0",
]
[project.urls]
Homepage = "https://github.com/openai/codex-lens"

View File

@@ -1100,6 +1100,103 @@ def clean(
raise typer.Exit(code=1)
@app.command()
def graph(
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
symbol: str = typer.Argument(..., help="Symbol name to query"),
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Query semantic graph for code relationships.
Supported query types:
- callers: Find all functions/methods that call the given symbol
- callees: Find all functions/methods called by the given symbol
- inheritance: Find inheritance relationships for the given class
Examples:
codex-lens graph callers my_function
codex-lens graph callees MyClass.method --path src/
codex-lens graph inheritance BaseClass
"""
_configure_logging(verbose)
search_path = path.expanduser().resolve()
# Validate query type
valid_types = ["callers", "callees", "inheritance"]
if query_type not in valid_types:
if json_mode:
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
else:
console.print(f"[red]Invalid query type:[/red] {query_type}")
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
raise typer.Exit(code=1)
registry: RegistryStore | None = None
try:
registry = RegistryStore()
registry.initialize()
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
options = SearchOptions(depth=depth, total_limit=limit)
# Execute graph query based on type
if query_type == "callers":
results = engine.search_callers(symbol, search_path, options=options)
result_type = "callers"
elif query_type == "callees":
results = engine.search_callees(symbol, search_path, options=options)
result_type = "callees"
else: # inheritance
results = engine.search_inheritance(symbol, search_path, options=options)
result_type = "inheritance"
payload = {
"query_type": query_type,
"symbol": symbol,
"count": len(results),
"relationships": results
}
if json_mode:
print_json(success=True, result=payload)
else:
from .output import render_graph_results
render_graph_results(results, query_type=query_type, symbol=symbol)
except SearchError as exc:
if json_mode:
print_json(success=False, error=f"Graph search error: {exc}")
else:
console.print(f"[red]Graph query failed (search):[/red] {exc}")
raise typer.Exit(code=1)
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
else:
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
raise typer.Exit(code=1)
except CodexLensError as exc:
if json_mode:
print_json(success=False, error=str(exc))
else:
console.print(f"[red]Graph query failed:[/red] {exc}")
raise typer.Exit(code=1)
except Exception as exc:
if json_mode:
print_json(success=False, error=f"Unexpected error: {exc}")
else:
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
raise typer.Exit(code=1)
finally:
if registry is not None:
registry.close()
@app.command("semantic-list")
def semantic_list(
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),

View File

@@ -89,3 +89,68 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
console.print(header)
render_symbols(list(symbols), title="Discovered Symbols")
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
"""Render semantic graph query results.
Args:
results: List of relationship dicts
query_type: Type of query (callers, callees, inheritance)
symbol: Symbol name that was queried
"""
if not results:
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
return
title_map = {
"callers": f"Callers of '{symbol}' ({len(results)} found)",
"callees": f"Callees of '{symbol}' ({len(results)} found)",
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
}
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
if query_type == "callers":
table.add_column("Caller", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
elif query_type == "callees":
table.add_column("Target", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("target_symbol", "-"),
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
else: # inheritance
table.add_column("Derived Class", style="green")
table.add_column("Base Class", style="magenta")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("target_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-"))
)
console.print(table)

View File

@@ -83,6 +83,9 @@ class Config:
llm_timeout_ms: int = 300000
llm_batch_size: int = 5
# Hybrid chunker configuration
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()

View File

@@ -13,6 +13,8 @@ class Symbol(BaseModel):
name: str = Field(..., min_length=1)
kind: str = Field(..., min_length=1)
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
@field_validator("range")
@classmethod
@@ -26,6 +28,13 @@ class Symbol(BaseModel):
raise ValueError("end_line must be >= start_line")
return value
@field_validator("token_count")
@classmethod
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
if value is not None and value < 0:
raise ValueError("token_count must be >= 0")
return value
class SemanticChunk(BaseModel):
"""A semantically meaningful chunk of content, optionally embedded."""
@@ -61,6 +70,25 @@ class IndexedFile(BaseModel):
return cleaned
class CodeRelationship(BaseModel):
"""A relationship between code symbols (e.g., function calls, inheritance)."""
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
relationship_type: str = Field(..., min_length=1, description="Type of relationship (call, inherits, etc.)")
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
@field_validator("relationship_type")
@classmethod
def validate_relationship_type(cls, value: str) -> str:
allowed_types = {"call", "inherits", "imports"}
if value not in allowed_types:
raise ValueError(f"relationship_type must be one of {allowed_types}")
return value
class SearchResult(BaseModel):
"""A unified search result for lexical or semantic search."""

View File

@@ -10,19 +10,11 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Protocol
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
except Exception: # pragma: no cover
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class Parser(Protocol):
@@ -34,10 +26,24 @@ class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
# Try tree-sitter first for supported languages
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols(text)
symbols = _parse_python_symbols_regex(text)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols(text, self.language_id, path)
symbols = _parse_js_ts_symbols_regex(text)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
elif self.language_id == "go":
@@ -64,120 +70,35 @@ class ParserFactory:
return self._parsers[language_id]
# Regex-based fallback parsers
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {}
def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None:
if TreeSitterLanguage is None:
return None
cache_key = language_id
if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx":
cache_key = "tsx"
cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key)
if cached is not None:
return cached
try:
if cache_key == "python":
import tree_sitter_python # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_python.language())
elif cache_key == "javascript":
import tree_sitter_javascript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_javascript.language())
elif cache_key == "typescript":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
elif cache_key == "tsx":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
return None
except Exception:
return None
_TREE_SITTER_LANGUAGE_CACHE[cache_key] = language
return language
def _parse_python_symbols(text: str) -> List[Symbol]:
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser("python")
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]:
stack: List[TreeSitterNode] = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str:
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(node: TreeSitterNode) -> tuple[int, int]:
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def _python_kind_for_function_node(node: TreeSitterNode) -> str:
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language("python")
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind=_python_kind_for_function_node(node),
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Optional[Path] = None,
) -> List[Symbol]:
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser(language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
@@ -202,13 +123,6 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_python_symbols(text: str) -> List[Symbol]:
symbols = _parse_python_symbols_tree_sitter(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
@@ -217,88 +131,6 @@ _JS_ARROW_RE = re.compile(
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
def _js_has_class_ancestor(node: TreeSitterNode) -> bool:
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _parse_js_ts_symbols_tree_sitter(
text: str,
language_id: str,
path: Path | None = None,
) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language(language_id, path)
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "method_definition" and _js_has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = _node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
@@ -338,17 +170,6 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Path | None = None,
) -> List[Symbol]:
symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("

View File

@@ -0,0 +1,98 @@
"""Token counting utilities for CodexLens.
Provides accurate token counting using tiktoken with character count fallback.
"""
from __future__ import annotations
from typing import Optional
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
class Tokenizer:
"""Token counter with tiktoken primary and character count fallback."""
def __init__(self, encoding_name: str = "cl100k_base") -> None:
"""Initialize tokenizer.
Args:
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
"""
self._encoding: Optional[object] = None
self._encoding_name = encoding_name
if TIKTOKEN_AVAILABLE:
try:
self._encoding = tiktoken.get_encoding(encoding_name)
except Exception:
# Fallback to character counting if encoding fails
self._encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Uses tiktoken if available, otherwise falls back to character count / 4.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if not text:
return 0
if self._encoding is not None:
try:
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
except Exception:
# Fall through to character count fallback
pass
# Fallback: rough estimate using character count
# Average of ~4 characters per token for English text
return max(1, len(text) // 4)
def is_using_tiktoken(self) -> bool:
"""Check if tiktoken is being used.
Returns:
True if tiktoken is available and initialized
"""
return self._encoding is not None
# Global default tokenizer instance
_default_tokenizer: Optional[Tokenizer] = None
def get_default_tokenizer() -> Tokenizer:
"""Get the global default tokenizer instance.
Returns:
Shared Tokenizer instance
"""
global _default_tokenizer
if _default_tokenizer is None:
_default_tokenizer = Tokenizer()
return _default_tokenizer
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
"""Count tokens in text using default or provided tokenizer.
Args:
text: Text to count tokens for
tokenizer: Optional tokenizer instance (uses default if None)
Returns:
Estimated token count
"""
if tokenizer is None:
tokenizer = get_default_tokenizer()
return tokenizer.count_tokens(text)

View File

@@ -0,0 +1,335 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing with fallback to regex-based parsing.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import IndexedFile, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
"""
self.language_id = language_id
self.path = path
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
return
try:
# Load language grammar
if self.language_id == "python":
import tree_sitter_python
self._language = TreeSitterLanguage(tree_sitter_python.language())
elif self.language_id == "javascript":
import tree_sitter_javascript
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
elif self.language_id == "typescript":
import tree_sitter_typescript
# Detect TSX files by extension
if self.path is not None and self.path.suffix.lower() == ".tsx":
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
else:
return
# Create parser
self._parser = TreeSitterParser()
if hasattr(self._parser, "set_language"):
self._parser.set_language(self._language) # type: ignore[attr-defined]
else:
self._parser.language = self._language # type: ignore[assignment]
except Exception:
# Gracefully handle missing language bindings
self._parser = None
self._language = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
Returns:
True if parser is initialized and ready
"""
return self._parser is not None and self._language is not None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
Args:
text: Source code text
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle parsing errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract symbols.
Args:
text: Source code text
path: File path
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
return None
try:
symbols = self.parse_symbols(text)
if symbols is None:
return None
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
except Exception:
# Gracefully handle parsing errors
return None
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of extracted symbols
"""
if self.language_id == "python":
return self._extract_python_symbols(source_bytes, root)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_symbols(source_bytes, root)
else:
return []
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of Python symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind=self._python_function_kind(node),
range=self._node_range(node),
))
return symbols
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract JavaScript/TypeScript symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of JS/TS symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = self._node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=self._node_range(node),
))
return symbols
def _python_function_kind(self, node: TreeSitterNode) -> str:
"""Determine if Python function is a method or standalone function.
Args:
node: Function definition node
Returns:
'method' if inside a class, 'function' otherwise
"""
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
"""Check if node has a class ancestor.
Args:
node: AST node to check
Returns:
True if node is inside a class
"""
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
"""Get line range for a node.
Args:
node: AST node
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
return self._tokenizer.count_tokens(text)

View File

@@ -17,6 +17,7 @@ from codexlens.entities import SearchResult, Symbol
from codexlens.storage.registry import RegistryStore, DirMapping
from codexlens.storage.dir_index import DirIndexStore, SubdirLink
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.sqlite_store import SQLiteStore
@dataclass
@@ -278,6 +279,108 @@ class ChainSearchEngine:
index_paths, name, kind, options.total_limit
)
def search_callers(self, target_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callers of a given symbol across directory hierarchy.
Args:
target_symbol: Name of the symbol to find callers for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with caller information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callers = engine.search_callers("my_function", Path("D:/project"))
>>> for caller in callers:
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callers_parallel(
index_paths, target_symbol, options.total_limit
)
def search_callees(self, source_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callees (what a symbol calls) across directory hierarchy.
Args:
source_symbol: Name of the symbol to find callees for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with callee information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
>>> for callee in callees:
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callees_parallel(
index_paths, source_symbol, options.total_limit
)
def search_inheritance(self, class_name: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find inheritance relationships for a class across directory hierarchy.
Args:
class_name: Name of the class to find inheritance for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with inheritance information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
>>> for rel in inheritance:
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_inheritance_parallel(
index_paths, class_name, options.total_limit
)
# === Internal Methods ===
def _find_start_index(self, source_path: Path) -> Optional[Path]:
@@ -553,6 +656,252 @@ class ChainSearchEngine:
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
return []
def _search_callers_parallel(self, index_paths: List[Path],
target_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callers across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
target_symbol: Target symbol name
limit: Total result limit
Returns:
Deduplicated list of caller relationships
"""
all_callers = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callers_single,
idx_path,
target_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callers = future.result()
all_callers.extend(callers)
except Exception as exc:
self.logger.error(f"Caller search failed: {exc}")
# Deduplicate by (source_file, source_line)
seen = set()
unique_callers = []
for caller in all_callers:
key = (caller.get("source_file"), caller.get("source_line"))
if key not in seen:
seen.add(key)
unique_callers.append(caller)
# Sort by source file and line
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
return unique_callers[:limit]
def _search_callers_single(self, index_path: Path,
target_symbol: str) -> List[Dict[str, Any]]:
"""Search for callers in a single index.
Args:
index_path: Path to _index.db file
target_symbol: Target symbol name
Returns:
List of caller relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
return store.query_relationships_by_target(target_symbol)
except Exception as exc:
self.logger.debug(f"Caller search error in {index_path}: {exc}")
return []
def _search_callees_parallel(self, index_paths: List[Path],
source_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callees across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
source_symbol: Source symbol name
limit: Total result limit
Returns:
Deduplicated list of callee relationships
"""
all_callees = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callees_single,
idx_path,
source_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callees = future.result()
all_callees.extend(callees)
except Exception as exc:
self.logger.error(f"Callee search failed: {exc}")
# Deduplicate by (target_symbol, source_line)
seen = set()
unique_callees = []
for callee in all_callees:
key = (callee.get("target_symbol"), callee.get("source_line"))
if key not in seen:
seen.add(key)
unique_callees.append(callee)
# Sort by source line
unique_callees.sort(key=lambda c: c.get("source_line", 0))
return unique_callees[:limit]
def _search_callees_single(self, index_path: Path,
source_symbol: str) -> List[Dict[str, Any]]:
"""Search for callees in a single index.
Args:
index_path: Path to _index.db file
source_symbol: Source symbol name
Returns:
List of callee relationship dicts (empty on error)
"""
try:
# Use the connection pool via SQLiteStore
with SQLiteStore(index_path) as store:
# Search across all files containing the symbol
# Get all files that have this symbol
conn = store._get_connection()
file_rows = conn.execute(
"""
SELECT DISTINCT f.path
FROM symbols s
JOIN files f ON s.file_id = f.id
WHERE s.name = ?
""",
(source_symbol,)
).fetchall()
# Collect results from all matching files
all_results = []
for file_row in file_rows:
file_path = file_row["path"]
results = store.query_relationships_by_source(source_symbol, file_path)
all_results.extend(results)
return all_results
except Exception as exc:
self.logger.debug(f"Callee search error in {index_path}: {exc}")
return []
def _search_inheritance_parallel(self, index_paths: List[Path],
class_name: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for inheritance relationships across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
class_name: Class name to search for
limit: Total result limit
Returns:
Deduplicated list of inheritance relationships
"""
all_inheritance = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_inheritance_single,
idx_path,
class_name
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
inheritance = future.result()
all_inheritance.extend(inheritance)
except Exception as exc:
self.logger.error(f"Inheritance search failed: {exc}")
# Deduplicate by (source_symbol, target_symbol)
seen = set()
unique_inheritance = []
for rel in all_inheritance:
key = (rel.get("source_symbol"), rel.get("target_symbol"))
if key not in seen:
seen.add(key)
unique_inheritance.append(rel)
# Sort by source file
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
return unique_inheritance[:limit]
def _search_inheritance_single(self, index_path: Path,
class_name: str) -> List[Dict[str, Any]]:
"""Search for inheritance relationships in a single index.
Args:
index_path: Path to _index.db file
class_name: Class name to search for
Returns:
List of inheritance relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
conn = store._get_connection()
# Search both as base class (target) and derived class (source)
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE (s.name = ? OR r.target_qualified_name LIKE ?)
AND r.relationship_type = 'inherits'
ORDER BY f.path, r.source_line
LIMIT 100
""",
(class_name, f"%{class_name}%")
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
except Exception as exc:
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
return []
# === Convenience Functions ===

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
@@ -14,6 +15,7 @@ class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 100 # Overlap for sliding window
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
@@ -22,6 +24,7 @@ class Chunker:
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
def chunk_by_symbol(
self,
@@ -29,10 +32,18 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -47,6 +58,13 @@ class Chunker:
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._tokenizer.count_tokens(chunk_content)
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -58,6 +76,7 @@ class Chunker:
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
))
@@ -68,10 +87,19 @@ class Chunker:
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -92,6 +120,18 @@ class Chunker:
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
token_count = self._tokenizer.count_tokens(chunk_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -99,9 +139,10 @@ class Chunker:
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start + 1,
"end_line": end,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"token_count": token_count,
}
))
chunk_idx += 1
@@ -119,12 +160,239 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language)
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
tokenizer = get_default_tokenizer()
# Step 1: Extract docstrings as dedicated chunks
docstrings = self.docstring_extractor.extract_docstrings(content, language)
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
token_count = tokenizer.count_tokens(docstring_content)
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,531 @@
"""Graph analyzer for extracting code relationships using tree-sitter.
Provides AST-based analysis to identify function calls, method invocations,
and class inheritance relationships within source files.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Node as TreeSitterNode
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterNode = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class GraphAnalyzer:
"""Analyzer for extracting semantic relationships from code using AST traversal."""
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
"""Initialize graph analyzer for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
parser: Optional TreeSitterSymbolParser instance for dependency injection.
If None, creates a new parser instance (backward compatibility).
"""
self.language_id = language_id
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
def is_available(self) -> bool:
"""Check if graph analyzer is available.
Returns:
True if tree-sitter parser is initialized and ready
"""
return self._parser.is_available()
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
"""Analyze source code and extract relationships.
Args:
text: Source code text
file_path: File path for relationship context
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def analyze_with_symbols(
self, text: str, file_path: Path, symbols: List[Symbol]
) -> List[CodeRelationship]:
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
Args:
text: Source code text
file_path: File path for relationship context
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
# Convert Symbol objects to internal symbol format
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
# Extract relationships using provided symbols
relationships = self._extract_relationships_with_symbols(
source_bytes, root, str(file_path.resolve()), defined_symbols
)
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def _convert_symbols_to_dict(
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
) -> List[dict]:
"""Convert Symbol objects to internal dict format for relationship extraction.
Args:
source_bytes: Source code as bytes
root: Root AST node
symbols: Pre-parsed Symbol objects
Returns:
List of symbol info dicts with name, node, and type
"""
symbol_dicts = []
symbol_names = {s.name for s in symbols}
# Find AST nodes corresponding to symbols
for node in self._iter_nodes(root):
node_type = node.type
# Check if this node matches any of our symbols
if node_type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "method"
})
elif node_type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
return symbol_dicts
def _extract_relationships_with_symbols(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
) -> List[CodeRelationship]:
"""Extract relationships from AST using pre-parsed symbols.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
defined_symbols: Pre-parsed symbol dicts
Returns:
List of extracted relationships
"""
relationships: List[CodeRelationship] = []
# Determine call node type based on language
if self.language_id == "python":
call_node_type = "call"
extract_target = self._extract_call_target
elif self.language_id in {"javascript", "typescript"}:
call_node_type = "call_expression"
extract_target = self._extract_js_call_target
else:
return []
# Find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == call_node_type:
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = extract_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of extracted relationships
"""
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, file_path)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, file_path)
else:
return []
def _extract_python_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract Python relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of Python relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols with their scopes
defined_symbols = self._collect_python_symbols(source_bytes, root)
# Second pass: find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == "call":
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = self._extract_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_js_ts_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract JavaScript/TypeScript relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of JS/TS relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
# Second pass: find call expressions
for node in self._iter_nodes(root):
if node.type == "call_expression":
# Extract caller context
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
source_symbol = "<module>"
# Extract callee
target_symbol = self._extract_js_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=line_number,
)
)
return relationships
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all Python function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
return symbols
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all JS/TS function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "method"
})
elif node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
return symbols
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
"""Find the enclosing function/method/class for a node.
Args:
node: AST node to find enclosure for
symbols: List of defined symbols
Returns:
Name of enclosing symbol, or None if at module level
"""
# Walk up the tree to find enclosing symbol
parent = node.parent
while parent is not None:
for symbol in symbols:
if symbol["node"] == parent:
return symbol["name"]
parent = parent.parent
return None
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a Python call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers (e.g., "foo()")
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle attribute access (e.g., "obj.method()")
if function_node.type == "attribute":
attr_node = function_node.child_by_field_name("attribute")
if attr_node:
return self._node_text(source_bytes, attr_node)
return None
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a JS/TS call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle member expressions (e.g., "obj.method()")
if function_node.type == "member_expression":
property_node = function_node.child_by_field_name("property")
if property_node:
return self._node_text(source_bytes, property_node)
return None
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")

View File

@@ -75,6 +75,34 @@ class LLMEnhancer:
external LLM tools (gemini, qwen) via CCW CLI subprocess.
"""
CHUNK_REFINEMENT_PROMPT = '''PURPOSE: Identify optimal semantic split points in code chunk
TASK:
- Analyze the code structure to find natural semantic boundaries
- Identify logical groupings (functions, classes, related statements)
- Suggest split points that maintain semantic cohesion
MODE: analysis
EXPECTED: JSON format with split positions
=== CODE CHUNK ===
{code_chunk}
=== OUTPUT FORMAT ===
Return ONLY valid JSON (no markdown, no explanation):
{{
"split_points": [
{{
"line": <line_number>,
"reason": "brief reason for split (e.g., 'start of new function', 'end of class definition')"
}}
]
}}
Rules:
- Split at function/class/method boundaries
- Keep related code together (don't split mid-function)
- Aim for chunks between 500-2000 characters
- Return empty split_points if no good splits found'''
PROMPT_TEMPLATE = '''PURPOSE: Generate semantic summaries and search keywords for code files
TASK:
- For each code block, generate a concise summary (1-2 sentences)
@@ -168,42 +196,246 @@ Return ONLY valid JSON (no markdown, no explanation):
return results
def enhance_file(
self,
path: str,
content: str,
language: str,
working_dir: Optional[Path] = None,
) -> SemanticMetadata:
"""Enhance a single file with LLM-generated semantic metadata.
Convenience method that wraps enhance_files for single file processing.
Args:
path: File path
content: File content
language: Programming language
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
)
return results[path]
def refine_chunk_boundaries(
self,
chunk: SemanticChunk,
max_chunk_size: int = 2000,
working_dir: Optional[Path] = None,
) -> List[SemanticChunk]:
"""Refine chunk boundaries using LLM for large code chunks.
Uses LLM to identify semantic split points in large chunks,
breaking them into smaller, more cohesive pieces.
Args:
chunk: Original chunk to refine
max_chunk_size: Maximum characters before triggering refinement
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
List of refined chunks (original chunk if no splits or refinement fails)
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
# Skip if chunk is small enough
if len(chunk.content) <= max_chunk_size:
return [chunk]
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
# Skip if LLM enhancement disabled or unavailable
if not self.config.enabled or not self.check_available():
return [chunk]
# Skip docstring chunks - only refine code chunks
if chunk.metadata.get("chunk_type") == "docstring":
return [chunk]
try:
# Build refinement prompt
prompt = self.CHUNK_REFINEMENT_PROMPT.format(code_chunk=chunk.content)
# Invoke LLM
result = self._invoke_ccw_cli(
prompt,
tool=self.config.tool,
working_dir=working_dir,
)
return results[path]
# Fallback if primary tool fails
if not result["success"] and self.config.fallback_tool:
result = self._invoke_ccw_cli(
prompt,
tool=self.config.fallback_tool,
working_dir=working_dir,
)
if not result["success"]:
logger.debug("LLM refinement failed, returning original chunk")
return [chunk]
# Parse split points
split_points = self._parse_split_points(result["stdout"])
if not split_points:
logger.debug("No split points identified, returning original chunk")
return [chunk]
# Split chunk at identified boundaries
refined_chunks = self._split_chunk_at_points(chunk, split_points)
logger.debug(
"Refined chunk into %d smaller chunks (was %d chars)",
len(refined_chunks),
len(chunk.content),
)
return refined_chunks
except Exception as e:
logger.warning("Chunk refinement error: %s, returning original chunk", e)
return [chunk]
def _parse_split_points(self, stdout: str) -> List[int]:
"""Parse split points from LLM response.
Args:
stdout: Raw stdout from CCW CLI
Returns:
List of line numbers where splits should occur (sorted)
"""
# Extract JSON from response
json_str = self._extract_json(stdout)
if not json_str:
return []
try:
data = json.loads(json_str)
split_points_data = data.get("split_points", [])
# Extract line numbers
lines = []
for point in split_points_data:
if isinstance(point, dict) and "line" in point:
line_num = point["line"]
if isinstance(line_num, int) and line_num > 0:
lines.append(line_num)
return sorted(set(lines))
except (json.JSONDecodeError, ValueError, TypeError) as e:
logger.debug("Failed to parse split points: %s", e)
return []
def _split_chunk_at_points(
self,
chunk: SemanticChunk,
split_points: List[int],
) -> List[SemanticChunk]:
"""Split chunk at specified line numbers.
Args:
chunk: Original chunk to split
split_points: Sorted list of line numbers to split at
Returns:
List of smaller chunks
"""
lines = chunk.content.splitlines(keepends=True)
chunks: List[SemanticChunk] = []
# Get original metadata
base_metadata = dict(chunk.metadata)
original_start = base_metadata.get("start_line", 1)
# Add start and end boundaries
boundaries = [0] + split_points + [len(lines)]
for i in range(len(boundaries) - 1):
start_idx = boundaries[i]
end_idx = boundaries[i + 1]
# Skip empty sections
if start_idx >= end_idx:
continue
# Extract content
section_lines = lines[start_idx:end_idx]
section_content = "".join(section_lines)
# Skip if too small
if len(section_content.strip()) < 50:
continue
# Create new chunk with updated metadata
new_metadata = base_metadata.copy()
new_metadata["start_line"] = original_start + start_idx
new_metadata["end_line"] = original_start + end_idx - 1
new_metadata["refined_by_llm"] = True
new_metadata["original_chunk_size"] = len(chunk.content)
chunks.append(
SemanticChunk(
content=section_content,
embedding=None, # Embeddings will be regenerated
metadata=new_metadata,
)
)
# If no valid chunks created, return original
if not chunks:
return [chunk]
return chunks
def _process_batch(

View File

@@ -149,15 +149,21 @@ class DirIndexStore:
# Replace symbols
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
if symbols:
# Extract token_count and symbol_type from symbol metadata if available
symbol_rows = []
for s in symbols:
token_count = getattr(s, 'token_count', None)
symbol_type = getattr(s, 'symbol_type', None) or s.kind
symbol_rows.append(
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
)
conn.executemany(
"""
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
VALUES(?, ?, ?, ?, ?)
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
VALUES(?, ?, ?, ?, ?, ?, ?)
""",
[
(file_id, s.name, s.kind, s.range[0], s.range[1])
for s in symbols
],
symbol_rows,
)
conn.commit()
@@ -216,15 +222,21 @@ class DirIndexStore:
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
if symbols:
# Extract token_count and symbol_type from symbol metadata if available
symbol_rows = []
for s in symbols:
token_count = getattr(s, 'token_count', None)
symbol_type = getattr(s, 'symbol_type', None) or s.kind
symbol_rows.append(
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
)
conn.executemany(
"""
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
VALUES(?, ?, ?, ?, ?)
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
VALUES(?, ?, ?, ?, ?, ?, ?)
""",
[
(file_id, s.name, s.kind, s.range[0], s.range[1])
for s in symbols
],
symbol_rows,
)
conn.commit()
@@ -1021,7 +1033,9 @@ class DirIndexStore:
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER,
end_line INTEGER
end_line INTEGER,
token_count INTEGER,
symbol_type TEXT
)
"""
)
@@ -1083,6 +1097,7 @@ class DirIndexStore:
conn.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords(keyword)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords(file_id)")

View File

@@ -0,0 +1,48 @@
"""
Migration 002: Add token_count and symbol_type to symbols table.
This migration adds token counting metadata to symbols for accurate chunk
splitting and performance optimization. It also adds symbol_type for better
filtering in searches.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add token metadata to symbols.
- Adds token_count column to symbols table
- Adds symbol_type column to symbols table (for future use)
- Creates index on symbol_type for efficient filtering
- Backfills existing symbols with NULL token_count (to be calculated lazily)
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Adding token_count column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
log.info("Successfully added token_count column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add token_count column (might already exist): {e}")
log.info("Adding symbol_type column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
log.info("Successfully added symbol_type column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add symbol_type column (might already exist): {e}")
log.info("Creating index on symbol_type for efficient filtering...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
log.info("Migration 002 completed successfully.")

View File

@@ -0,0 +1,57 @@
"""
Migration 003: Add code relationships storage.
This migration introduces the `code_relationships` table to store semantic
relationships between code symbols (function calls, inheritance, imports).
This enables graph-based code navigation and dependency analysis.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add code relationships table.
- Creates `code_relationships` table with foreign key to symbols
- Creates indexes for efficient relationship queries
- Supports lazy expansion with target_symbol being qualified names
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating 'code_relationships' table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT,
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for code_relationships...")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
)
log.info("Finished creating code_relationships table and indexes.")

View File

@@ -9,7 +9,7 @@ from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from codexlens.entities import IndexedFile, SearchResult, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
from codexlens.errors import StorageError
@@ -309,13 +309,184 @@ class SQLiteStore:
"SELECT language, COUNT(*) AS c FROM files GROUP BY language ORDER BY c DESC"
).fetchall()
languages = {row["language"]: row["c"] for row in lang_rows}
# Include relationship count if table exists
relationship_count = 0
try:
rel_row = conn.execute("SELECT COUNT(*) AS c FROM code_relationships").fetchone()
relationship_count = int(rel_row["c"]) if rel_row else 0
except sqlite3.DatabaseError:
pass
return {
"files": int(file_count),
"symbols": int(symbol_count),
"relationships": relationship_count,
"languages": languages,
"db_path": str(self.db_path),
}
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
"""Store code relationships for a file.
Args:
file_path: Path to the file containing the relationships
relationships: List of CodeRelationship objects to store
"""
if not relationships:
return
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(file_path).resolve())
# Get file_id
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
if not row:
raise StorageError(f"File not found in index: {file_path}")
file_id = int(row["id"])
# Delete existing relationships for symbols in this file
conn.execute(
"""
DELETE FROM code_relationships
WHERE source_symbol_id IN (
SELECT id FROM symbols WHERE file_id=?
)
""",
(file_id,)
)
# Insert new relationships
relationship_rows = []
for rel in relationships:
# Find source symbol ID
symbol_row = conn.execute(
"""
SELECT id FROM symbols
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
ORDER BY (end_line - start_line) ASC
LIMIT 1
""",
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
).fetchone()
if symbol_row:
source_symbol_id = int(symbol_row["id"])
relationship_rows.append((
source_symbol_id,
rel.target_symbol,
rel.relationship_type,
rel.source_line,
rel.target_file
))
if relationship_rows:
conn.executemany(
"""
INSERT INTO code_relationships(
source_symbol_id, target_qualified_name, relationship_type,
source_line, target_file
)
VALUES(?, ?, ?, ?, ?)
""",
relationship_rows
)
conn.commit()
def query_relationships_by_target(
self, target_name: str, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by target symbol name (find all callers).
Args:
target_name: Name of the target symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info with file paths and line numbers
"""
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE r.target_qualified_name = ?
ORDER BY f.path, r.source_line
LIMIT ?
""",
(target_name, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def query_relationships_by_source(
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by source symbol (find what a symbol calls).
Args:
source_symbol: Name of the source symbol
source_file: File path containing the source symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info
"""
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(source_file).resolve())
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE s.name = ? AND f.path = ?
ORDER BY r.source_line
LIMIT ?
""",
(source_symbol, resolved_path, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def _connect(self) -> sqlite3.Connection:
"""Legacy method for backward compatibility."""
return self._get_connection()
@@ -348,6 +519,20 @@ class SQLiteStore:
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_kind ON symbols(kind)")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
conn.commit()
except sqlite3.DatabaseError as exc:
raise StorageError(f"Failed to initialize database schema: {exc}") from exc

View File

@@ -0,0 +1,656 @@
"""Unit tests for ChainSearchEngine.
Tests the graph query methods (search_callers, search_callees, search_inheritance)
with mocked SQLiteStore dependency to test logic in isolation.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock, patch, call
from concurrent.futures import ThreadPoolExecutor
from codexlens.search.chain_search import (
ChainSearchEngine,
SearchOptions,
SearchStats,
ChainSearchResult,
)
from codexlens.entities import SearchResult, Symbol
from codexlens.storage.registry import RegistryStore, DirMapping
from codexlens.storage.path_mapper import PathMapper
@pytest.fixture
def mock_registry():
"""Create a mock RegistryStore."""
registry = Mock(spec=RegistryStore)
return registry
@pytest.fixture
def mock_mapper():
"""Create a mock PathMapper."""
mapper = Mock(spec=PathMapper)
return mapper
@pytest.fixture
def search_engine(mock_registry, mock_mapper):
"""Create a ChainSearchEngine with mocked dependencies."""
return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2)
@pytest.fixture
def sample_index_path():
"""Sample index database path."""
return Path("/test/project/_index.db")
class TestChainSearchEngineCallers:
"""Tests for search_callers method."""
def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callers returns caller relationships."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
# Mock finding the start index
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
# Mock collect_index_paths to return single index
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Mock the parallel search to return caller data
expected_callers = [
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller_function"
assert result[0]["target_symbol"] == "my_function"
assert result[0]["relationship_type"] == "calls"
assert result[0]["source_line"] == 42
def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callers handles no results gracefully."""
# Setup
source_path = Path("/test/project")
target_symbol = "nonexistent_function"
# Mock finding the start index
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
# Mock collect_index_paths
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Mock empty results
with patch.object(search_engine, '_search_callers_parallel', return_value=[]):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert result == []
def test_search_callers_no_index_found(self, search_engine, mock_registry):
"""Test that search_callers returns empty list when no index found."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
# Mock no index found
mock_registry.find_nearest_index.return_value = None
with patch.object(search_engine, '_find_start_index', return_value=None):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert result == []
def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path):
"""Test that search_callers respects SearchOptions."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
options = SearchOptions(depth=1, total_limit=50)
# Configure mapper to return a path that exists
mock_mapper.source_to_index_db.return_value = sample_index_path
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect:
with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search:
# Patch Path.exists to return True so the exact match is found
with patch.object(Path, 'exists', return_value=True):
# Execute
search_engine.search_callers(target_symbol, source_path, options)
# Assert that depth was passed to collect_index_paths
mock_collect.assert_called_once_with(sample_index_path, 1)
# Assert that total_limit was passed to parallel search
mock_search.assert_called_once_with([sample_index_path], target_symbol, 50)
class TestChainSearchEngineCallees:
"""Tests for search_callees method."""
def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees returns callee relationships."""
# Setup
source_path = Path("/test/project")
source_symbol = "caller_function"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_callees = [
{
"source_symbol": "caller_function",
"target_symbol": "callee_function",
"relationship_type": "calls",
"source_line": 15,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller_function"
assert result[0]["target_symbol"] == "callee_function"
assert result[0]["source_line"] == 15
def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees correctly handles file-specific queries."""
# Setup
source_path = Path("/test/project")
source_symbol = "MyClass.method"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Multiple callees from same source symbol
expected_callees = [
{
"source_symbol": "MyClass.method",
"target_symbol": "helper_a",
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/project/module.py",
"target_file": "/test/project/utils.py",
},
{
"source_symbol": "MyClass.method",
"target_symbol": "helper_b",
"relationship_type": "calls",
"source_line": 20,
"source_file": "/test/project/module.py",
"target_file": "/test/project/utils.py",
}
]
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert len(result) == 2
assert result[0]["target_symbol"] == "helper_a"
assert result[1]["target_symbol"] == "helper_b"
def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees handles no callees gracefully."""
source_path = Path("/test/project")
source_symbol = "leaf_function"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
with patch.object(search_engine, '_search_callees_parallel', return_value=[]):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert result == []
class TestChainSearchEngineInheritance:
"""Tests for search_inheritance method."""
def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_inheritance returns inheritance relationships."""
# Setup
source_path = Path("/test/project")
class_name = "BaseClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_inheritance = [
{
"source_symbol": "DerivedClass",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/project/derived.py",
"target_file": "/test/project/base.py",
}
]
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "DerivedClass"
assert result[0]["target_symbol"] == "BaseClass"
assert result[0]["relationship_type"] == "inherits"
def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path):
"""Test inheritance search with multiple derived classes."""
source_path = Path("/test/project")
class_name = "BaseClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_inheritance = [
{
"source_symbol": "DerivedClassA",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/project/derived_a.py",
"target_file": "/test/project/base.py",
},
{
"source_symbol": "DerivedClassB",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 10,
"source_file": "/test/project/derived_b.py",
"target_file": "/test/project/base.py",
}
]
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert len(result) == 2
assert result[0]["source_symbol"] == "DerivedClassA"
assert result[1]["source_symbol"] == "DerivedClassB"
def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test inheritance search with no subclasses found."""
source_path = Path("/test/project")
class_name = "FinalClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert result == []
class TestChainSearchEngineParallelSearch:
"""Tests for parallel search aggregation."""
def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path):
"""Test that parallel search aggregates results from multiple indexes."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
index_path_1 = Path("/test/project/_index.db")
index_path_2 = Path("/test/project/subdir/_index.db")
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=index_path_1,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
# Mock parallel search results from multiple indexes
callers_from_multiple = [
{
"source_symbol": "caller_in_root",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/project/root.py",
"target_file": "/test/project/lib.py",
},
{
"source_symbol": "caller_in_subdir",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 20,
"source_file": "/test/project/subdir/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert results from both indexes are included
assert len(result) == 2
assert any(r["source_file"] == "/test/project/root.py" for r in result)
assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result)
def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path):
"""Test that parallel search deduplicates results by (source_file, source_line)."""
# Note: This test verifies the behavior of _search_callers_parallel deduplication
source_path = Path("/test/project")
target_symbol = "my_function"
index_path_1 = Path("/test/project/_index.db")
index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate)
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=index_path_1,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
# Mock duplicate results from same location
duplicate_callers = [
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
},
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert: even with duplicates in input, output may contain both
# (actual deduplication happens in _search_callers_parallel)
assert len(result) >= 1
class TestChainSearchEngineContextManager:
"""Tests for context manager functionality."""
def test_context_manager_closes_executor(self, mock_registry, mock_mapper):
"""Test that context manager properly closes executor."""
with ChainSearchEngine(mock_registry, mock_mapper) as engine:
# Force executor creation
engine._get_executor()
assert engine._executor is not None
# Executor should be closed after exiting context
assert engine._executor is None
def test_close_method_shuts_down_executor(self, search_engine):
"""Test that close() method shuts down executor."""
# Create executor
search_engine._get_executor()
assert search_engine._executor is not None
# Close
search_engine.close()
assert search_engine._executor is None
class TestSearchCallersSingle:
"""Tests for _search_callers_single internal method."""
def test_search_callers_single_queries_store(self, search_engine, sample_index_path):
"""Test that _search_callers_single queries SQLiteStore correctly."""
target_symbol = "my_function"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MockStore.return_value.__enter__.return_value
mock_store_instance.query_relationships_by_target.return_value = [
{
"source_symbol": "caller",
"target_symbol": target_symbol,
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/file.py",
"target_file": "/test/lib.py",
}
]
# Execute
result = search_engine._search_callers_single(sample_index_path, target_symbol)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller"
mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol)
def test_search_callers_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_callers_single returns empty list on error."""
target_symbol = "my_function"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("Database error")
# Execute
result = search_engine._search_callers_single(sample_index_path, target_symbol)
# Assert - should return empty list, not raise exception
assert result == []
class TestSearchCalleesSingle:
"""Tests for _search_callees_single internal method."""
def test_search_callees_single_queries_database(self, search_engine, sample_index_path):
"""Test that _search_callees_single queries SQLiteStore correctly."""
source_symbol = "caller_function"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MagicMock()
MockStore.return_value.__enter__.return_value = mock_store_instance
# Mock _get_connection to return a mock connection
mock_conn = MagicMock()
mock_store_instance._get_connection.return_value = mock_conn
# Mock cursor for file query (getting files containing the symbol)
mock_file_cursor = MagicMock()
mock_file_cursor.fetchall.return_value = [{"path": "/test/module.py"}]
mock_conn.execute.return_value = mock_file_cursor
# Mock query_relationships_by_source to return relationship data
mock_rel_row = {
"source_symbol": source_symbol,
"target_symbol": "callee_function",
"relationship_type": "calls",
"source_line": 15,
"source_file": "/test/module.py",
"target_file": "/test/lib.py",
}
mock_store_instance.query_relationships_by_source.return_value = [mock_rel_row]
# Execute
result = search_engine._search_callees_single(sample_index_path, source_symbol)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == source_symbol
assert result[0]["target_symbol"] == "callee_function"
mock_store_instance.query_relationships_by_source.assert_called_once_with(source_symbol, "/test/module.py")
def test_search_callees_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_callees_single returns empty list on error."""
source_symbol = "caller_function"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("DB error")
# Execute
result = search_engine._search_callees_single(sample_index_path, source_symbol)
# Assert - should return empty list, not raise exception
assert result == []
class TestSearchInheritanceSingle:
"""Tests for _search_inheritance_single internal method."""
def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path):
"""Test that _search_inheritance_single queries SQLiteStore correctly."""
class_name = "BaseClass"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MagicMock()
MockStore.return_value.__enter__.return_value = mock_store_instance
# Mock _get_connection to return a mock connection
mock_conn = MagicMock()
mock_store_instance._get_connection.return_value = mock_conn
# Mock cursor for relationship query
mock_cursor = MagicMock()
mock_row = {
"source_symbol": "DerivedClass",
"target_qualified_name": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/derived.py",
"target_file": "/test/base.py",
}
mock_cursor.fetchall.return_value = [mock_row]
mock_conn.execute.return_value = mock_cursor
# Execute
result = search_engine._search_inheritance_single(sample_index_path, class_name)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "DerivedClass"
assert result[0]["relationship_type"] == "inherits"
# Verify SQL query uses 'inherits' filter
call_args = mock_conn.execute.call_args
sql_query = call_args[0][0]
assert "relationship_type = 'inherits'" in sql_query
def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_inheritance_single returns empty list on error."""
class_name = "BaseClass"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("DB error")
# Execute
result = search_engine._search_inheritance_single(sample_index_path, class_name)
# Assert - should return empty list, not raise exception
assert result == []

View File

@@ -0,0 +1,435 @@
"""Tests for GraphAnalyzer - code relationship extraction."""
from pathlib import Path
import pytest
from codexlens.semantic.graph_analyzer import GraphAnalyzer
TREE_SITTER_PYTHON_AVAILABLE = True
try:
import tree_sitter_python # type: ignore[import-not-found] # noqa: F401
except Exception:
TREE_SITTER_PYTHON_AVAILABLE = False
TREE_SITTER_JS_AVAILABLE = True
try:
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
except Exception:
TREE_SITTER_JS_AVAILABLE = False
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
class TestPythonGraphAnalyzer:
"""Tests for Python relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple function call."""
code = """def helper():
pass
def main():
helper()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
assert rel.relationship_type == "call"
assert rel.source_line == 5
def test_multiple_calls_in_function(self):
"""Test extraction of multiple calls from same function."""
code = """def foo():
pass
def bar():
pass
def main():
foo()
bar()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find main -> foo and main -> bar
assert len(relationships) == 2
targets = {rel.target_symbol for rel in relationships}
assert targets == {"foo", "bar"}
assert all(rel.source_symbol == "main" for rel in relationships)
def test_nested_function_calls(self):
"""Test extraction of calls from nested functions."""
code = """def inner_helper():
pass
def outer():
def inner():
inner_helper()
inner()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find inner -> inner_helper and outer -> inner
assert len(relationships) == 2
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
assert ("inner", "inner_helper") in call_pairs
assert ("outer", "inner") in call_pairs
def test_method_call_in_class(self):
"""Test extraction of method calls within class."""
code = """class Calculator:
def add(self, a, b):
return a + b
def compute(self, x, y):
result = self.add(x, y)
return result
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find compute -> add
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "compute"
assert rel.target_symbol == "add"
def test_module_level_call(self):
"""Test extraction of module-level function calls."""
code = """def setup():
pass
setup()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find <module> -> setup
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "<module>"
assert rel.target_symbol == "setup"
def test_async_function_call(self):
"""Test extraction of calls involving async functions."""
code = """async def fetch_data():
pass
async def process():
await fetch_data()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find process -> fetch_data
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "process"
assert rel.target_symbol == "fetch_data"
def test_complex_python_file(self):
"""Test extraction from realistic Python file with multiple patterns."""
code = """class DataProcessor:
def __init__(self):
self.data = []
def load(self, filename):
self.data = read_file(filename)
def process(self):
self.validate()
self.transform()
def validate(self):
pass
def transform(self):
pass
def read_file(filename):
pass
def main():
processor = DataProcessor()
processor.load("data.txt")
processor.process()
main()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Extract call pairs
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Expected relationships
expected = {
("load", "read_file"),
("process", "validate"),
("process", "transform"),
("main", "DataProcessor"),
("main", "load"),
("main", "process"),
("<module>", "main"),
}
# Should find all expected relationships
assert call_pairs >= expected
def test_empty_file(self):
"""Test handling of empty file."""
code = ""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
assert len(relationships) == 0
def test_file_with_no_calls(self):
"""Test handling of file with definitions but no calls."""
code = """def func1():
pass
def func2():
pass
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
assert len(relationships) == 0
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
class TestJavaScriptGraphAnalyzer:
"""Tests for JavaScript relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple JavaScript function call."""
code = """function helper() {}
function main() {
helper();
}
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
assert rel.relationship_type == "call"
def test_arrow_function_call(self):
"""Test extraction of calls from arrow functions."""
code = """const helper = () => {};
const main = () => {
helper();
};
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
def test_class_method_call(self):
"""Test extraction of method calls in JavaScript class."""
code = """class Calculator {
add(a, b) {
return a + b;
}
compute(x, y) {
return this.add(x, y);
}
}
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find compute -> add
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "compute"
assert rel.target_symbol == "add"
def test_complex_javascript_file(self):
"""Test extraction from realistic JavaScript file."""
code = """function readFile(filename) {
return "";
}
class DataProcessor {
constructor() {
this.data = [];
}
load(filename) {
this.data = readFile(filename);
}
process() {
this.validate();
this.transform();
}
validate() {}
transform() {}
}
function main() {
const processor = new DataProcessor();
processor.load("data.txt");
processor.process();
}
main();
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Extract call pairs
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Expected relationships (note: constructor calls like "new DataProcessor()" are not tracked)
expected = {
("load", "readFile"),
("process", "validate"),
("process", "transform"),
("main", "load"),
("main", "process"),
("<module>", "main"),
}
# Should find all expected relationships
assert call_pairs >= expected
class TestGraphAnalyzerEdgeCases:
"""Edge case tests for GraphAnalyzer."""
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_unavailable_language(self):
"""Test handling of unsupported language."""
code = "some code"
analyzer = GraphAnalyzer("rust")
relationships = analyzer.analyze_file(code, Path("test.rs"))
assert len(relationships) == 0
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_malformed_python_code(self):
"""Test handling of malformed Python code."""
code = "def broken(\n pass"
analyzer = GraphAnalyzer("python")
# Should not crash
relationships = analyzer.analyze_file(code, Path("test.py"))
assert isinstance(relationships, list)
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_file_path_in_relationship(self):
"""Test that file path is correctly set in relationships."""
code = """def foo():
pass
def bar():
foo()
"""
test_path = Path("test.py")
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, test_path)
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_file == str(test_path.resolve())
assert rel.target_file is None # Intra-file
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_performance_large_file(self):
"""Test performance on larger file (1000 lines)."""
import time
# Generate file with many functions and calls
lines = []
for i in range(100):
lines.append(f"def func_{i}():")
if i > 0:
lines.append(f" func_{i-1}()")
else:
lines.append(" pass")
code = "\n".join(lines)
analyzer = GraphAnalyzer("python")
start_time = time.time()
relationships = analyzer.analyze_file(code, Path("test.py"))
elapsed_ms = (time.time() - start_time) * 1000
# Should complete in under 500ms
assert elapsed_ms < 500
# Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...)
assert len(relationships) == 99
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_call_accuracy_rate(self):
"""Test >95% accuracy on known call graph."""
code = """def a(): pass
def b(): pass
def c(): pass
def d(): pass
def e(): pass
def test1():
a()
b()
def test2():
c()
d()
def test3():
e()
def main():
test1()
test2()
test3()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3
expected_calls = {
("test1", "a"),
("test1", "b"),
("test2", "c"),
("test2", "d"),
("test3", "e"),
("main", "test1"),
("main", "test2"),
("main", "test3"),
}
found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Calculate accuracy
correct = len(expected_calls & found_calls)
total = len(expected_calls)
accuracy = (correct / total) * 100 if total > 0 else 0
# Should have >95% accuracy
assert accuracy >= 95.0
assert correct == total # Should be 100% for this simple case

View File

@@ -0,0 +1,392 @@
"""End-to-end tests for graph search CLI commands."""
import tempfile
from pathlib import Path
from typer.testing import CliRunner
import pytest
from codexlens.cli.commands import app
from codexlens.storage.sqlite_store import SQLiteStore
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.entities import IndexedFile, Symbol, CodeRelationship
runner = CliRunner()
@pytest.fixture
def temp_project():
"""Create a temporary project with indexed code and relationships."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir) / "test_project"
project_root.mkdir()
# Create test Python files
(project_root / "main.py").write_text("""
def main():
result = calculate(5, 3)
print(result)
def calculate(a, b):
return add(a, b)
def add(x, y):
return x + y
""")
(project_root / "utils.py").write_text("""
class BaseClass:
def method(self):
pass
class DerivedClass(BaseClass):
def method(self):
super().method()
helper()
def helper():
return True
""")
# Create a custom index directory for graph testing
# Skip the standard init to avoid schema conflicts
mapper = PathMapper()
index_root = mapper.source_to_index_dir(project_root)
index_root.mkdir(parents=True, exist_ok=True)
test_db = index_root / "_index.db"
# Register project manually
registry = RegistryStore()
registry.initialize()
project_info = registry.register_project(
source_root=project_root,
index_root=index_root
)
registry.register_dir(
project_id=project_info.id,
source_path=project_root,
index_path=test_db,
depth=0,
files_count=2
)
# Initialize the store with proper SQLiteStore schema and add files
with SQLiteStore(test_db) as store:
# Read and add files to the store
main_content = (project_root / "main.py").read_text()
utils_content = (project_root / "utils.py").read_text()
main_indexed = IndexedFile(
path=str(project_root / "main.py"),
language="python",
symbols=[
Symbol(name="main", kind="function", range=(2, 4)),
Symbol(name="calculate", kind="function", range=(6, 7)),
Symbol(name="add", kind="function", range=(9, 10))
]
)
utils_indexed = IndexedFile(
path=str(project_root / "utils.py"),
language="python",
symbols=[
Symbol(name="BaseClass", kind="class", range=(2, 4)),
Symbol(name="DerivedClass", kind="class", range=(6, 9)),
Symbol(name="helper", kind="function", range=(11, 12))
]
)
store.add_file(main_indexed, main_content)
store.add_file(utils_indexed, utils_content)
with SQLiteStore(test_db) as store:
# Add relationships for main.py
main_file = project_root / "main.py"
relationships_main = [
CodeRelationship(
source_symbol="main",
target_symbol="calculate",
relationship_type="call",
source_file=str(main_file),
source_line=3,
target_file=str(main_file)
),
CodeRelationship(
source_symbol="calculate",
target_symbol="add",
relationship_type="call",
source_file=str(main_file),
source_line=7,
target_file=str(main_file)
),
]
store.add_relationships(main_file, relationships_main)
# Add relationships for utils.py
utils_file = project_root / "utils.py"
relationships_utils = [
CodeRelationship(
source_symbol="DerivedClass",
target_symbol="BaseClass",
relationship_type="inherits",
source_file=str(utils_file),
source_line=5,
target_file=str(utils_file)
),
CodeRelationship(
source_symbol="DerivedClass.method",
target_symbol="helper",
relationship_type="call",
source_file=str(utils_file),
source_line=8,
target_file=str(utils_file)
),
]
store.add_relationships(utils_file, relationships_utils)
registry.close()
yield project_root
class TestGraphCallers:
"""Test callers query type."""
def test_find_callers_basic(self, temp_project):
"""Test finding functions that call a given function."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "calculate" in result.stdout
assert "Callers of 'add'" in result.stdout
def test_find_callers_json_mode(self, temp_project):
"""Test callers query with JSON output."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
assert "relationships" in result.stdout
def test_find_callers_no_results(self, temp_project):
"""Test callers query when no callers exist."""
result = runner.invoke(app, [
"graph",
"callers",
"nonexistent_function",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "No callers found" in result.stdout or "0 found" in result.stdout
class TestGraphCallees:
"""Test callees query type."""
def test_find_callees_basic(self, temp_project):
"""Test finding functions called by a given function."""
result = runner.invoke(app, [
"graph",
"callees",
"main",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "calculate" in result.stdout
assert "Callees of 'main'" in result.stdout
def test_find_callees_chain(self, temp_project):
"""Test finding callees in a call chain."""
result = runner.invoke(app, [
"graph",
"callees",
"calculate",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "add" in result.stdout
def test_find_callees_json_mode(self, temp_project):
"""Test callees query with JSON output."""
result = runner.invoke(app, [
"graph",
"callees",
"main",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
class TestGraphInheritance:
"""Test inheritance query type."""
def test_find_inheritance_basic(self, temp_project):
"""Test finding inheritance relationships."""
result = runner.invoke(app, [
"graph",
"inheritance",
"BaseClass",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "DerivedClass" in result.stdout
assert "Inheritance relationships" in result.stdout
def test_find_inheritance_derived(self, temp_project):
"""Test finding inheritance from derived class perspective."""
result = runner.invoke(app, [
"graph",
"inheritance",
"DerivedClass",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "BaseClass" in result.stdout
def test_find_inheritance_json_mode(self, temp_project):
"""Test inheritance query with JSON output."""
result = runner.invoke(app, [
"graph",
"inheritance",
"BaseClass",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
class TestGraphValidation:
"""Test query validation and error handling."""
def test_invalid_query_type(self, temp_project):
"""Test error handling for invalid query type."""
result = runner.invoke(app, [
"graph",
"invalid_type",
"symbol",
"--path", str(temp_project)
])
assert result.exit_code == 1
assert "Invalid query type" in result.stdout
def test_invalid_path(self):
"""Test error handling for non-existent path."""
result = runner.invoke(app, [
"graph",
"callers",
"symbol",
"--path", "/nonexistent/path"
])
# Should handle gracefully (may exit with error or return empty results)
assert result.exit_code in [0, 1]
class TestGraphPerformance:
"""Test graph query performance requirements."""
def test_query_response_time(self, temp_project):
"""Verify graph queries complete in under 1 second."""
import time
start = time.time()
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project)
])
elapsed = time.time() - start
assert result.exit_code == 0
assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s"
def test_multiple_query_types(self, temp_project):
"""Test all three query types complete successfully."""
import time
queries = [
("callers", "add"),
("callees", "main"),
("inheritance", "BaseClass")
]
total_start = time.time()
for query_type, symbol in queries:
result = runner.invoke(app, [
"graph",
query_type,
symbol,
"--path", str(temp_project)
])
assert result.exit_code == 0
total_elapsed = time.time() - total_start
assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s"
class TestGraphOptions:
"""Test graph command options."""
def test_limit_option(self, temp_project):
"""Test limit option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--limit", "1"
])
assert result.exit_code == 0
def test_depth_option(self, temp_project):
"""Test depth option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--depth", "0"
])
assert result.exit_code == 0
def test_verbose_option(self, temp_project):
"""Test verbose option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--verbose"
])
assert result.exit_code == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,355 @@
"""Tests for code relationship storage."""
import sqlite3
import tempfile
from pathlib import Path
import pytest
from codexlens.entities import CodeRelationship, IndexedFile, Symbol
from codexlens.storage.migration_manager import MigrationManager
from codexlens.storage.sqlite_store import SQLiteStore
@pytest.fixture
def temp_db():
"""Create a temporary database for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
yield db_path
@pytest.fixture
def store(temp_db):
"""Create a SQLiteStore with migrations applied."""
store = SQLiteStore(temp_db)
store.initialize()
# Manually apply migration_003 (code_relationships table)
conn = store._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT,
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
)
"""
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
)
conn.commit()
yield store
# Cleanup
store.close()
def test_relationship_table_created(store):
"""Test that the code_relationships table is created by migration."""
conn = store._get_connection()
cursor = conn.cursor()
# Check table exists
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'"
)
result = cursor.fetchone()
assert result is not None, "code_relationships table should exist"
# Check indexes exist
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'"
)
indexes = [row[0] for row in cursor.fetchall()]
assert "idx_relationships_source" in indexes
assert "idx_relationships_target" in indexes
assert "idx_relationships_type" in indexes
def test_add_relationships(store):
"""Test storing code relationships."""
# First add a file with symbols
indexed_file = IndexedFile(
path=str(Path(__file__).parent / "sample.py"),
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 5)),
Symbol(name="bar", kind="function", range=(7, 10)),
]
)
content = """def foo():
bar()
baz()
def bar():
print("hello")
"""
store.add_file(indexed_file, content)
# Add relationships
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=indexed_file.path,
target_file=None,
source_line=2
),
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=indexed_file.path,
target_file=None,
source_line=3
),
]
store.add_relationships(indexed_file.path, relationships)
# Verify relationships were stored
conn = store._get_connection()
count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0]
assert count == 2, "Should have stored 2 relationships"
def test_query_relationships_by_target(store):
"""Test querying relationships by target symbol (find callers)."""
# Setup: Add file and relationships
file_path = str(Path(__file__).parent / "sample.py")
# Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main()
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 2)),
Symbol(name="bar", kind="function", range=(4, 5)),
Symbol(name="main", kind="function", range=(7, 8)),
]
)
content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n"
store.add_file(indexed_file, content)
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2 # Call inside foo (line 2)
),
CodeRelationship(
source_symbol="main",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=8 # Call inside main (line 8)
),
]
store.add_relationships(file_path, relationships)
# Query: Find all callers of "bar"
callers = store.query_relationships_by_target("bar")
assert len(callers) == 2, "Should find 2 callers of bar"
assert any(r["source_symbol"] == "foo" for r in callers)
assert any(r["source_symbol"] == "main" for r in callers)
assert all(r["target_symbol"] == "bar" for r in callers)
assert all(r["relationship_type"] == "call" for r in callers)
def test_query_relationships_by_source(store):
"""Test querying relationships by source symbol (find callees)."""
# Setup
file_path = str(Path(__file__).parent / "sample.py")
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 6)),
]
)
content = "def foo():\n bar()\n baz()\n qux()\n"
store.add_file(indexed_file, content)
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
),
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=3
),
CodeRelationship(
source_symbol="foo",
target_symbol="qux",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=4
),
]
store.add_relationships(file_path, relationships)
# Query: Find all functions called by foo
callees = store.query_relationships_by_source("foo", file_path)
assert len(callees) == 3, "Should find 3 functions called by foo"
targets = {r["target_symbol"] for r in callees}
assert targets == {"bar", "baz", "qux"}
assert all(r["source_symbol"] == "foo" for r in callees)
def test_query_performance(store):
"""Test that relationship queries execute within performance threshold."""
import time
# Setup: Create a file with many relationships
file_path = str(Path(__file__).parent / "large_file.py")
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)]
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=symbols
)
content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)])
store.add_file(indexed_file, content)
# Create many relationships
relationships = []
for i in range(100):
for j in range(10):
relationships.append(
CodeRelationship(
source_symbol=f"func_{i}",
target_symbol=f"target_{j}",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=i*10 + 1
)
)
store.add_relationships(file_path, relationships)
# Query and measure time
start = time.time()
results = store.query_relationships_by_target("target_5")
elapsed_ms = (time.time() - start) * 1000
assert len(results) == 100, "Should find 100 callers"
assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms"
def test_stats_includes_relationships(store):
"""Test that stats() includes relationship count."""
# Add a file with relationships
file_path = str(Path(__file__).parent / "sample.py")
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[Symbol(name="foo", kind="function", range=(1, 5))]
)
store.add_file(indexed_file, "def foo():\n bar()\n")
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, relationships)
# Check stats
stats = store.stats()
assert "relationships" in stats
assert stats["relationships"] == 1
assert stats["files"] == 1
assert stats["symbols"] == 1
def test_update_relationships_on_file_reindex(store):
"""Test that relationships are updated when file is re-indexed."""
file_path = str(Path(__file__).parent / "sample.py")
# Initial index
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[Symbol(name="foo", kind="function", range=(1, 3))]
)
store.add_file(indexed_file, "def foo():\n bar()\n")
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, relationships)
# Re-index with different relationships
new_relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, new_relationships)
# Verify old relationships are replaced
all_rels = store.query_relationships_by_source("foo", file_path)
assert len(all_rels) == 1
assert all_rels[0]["target_symbol"] == "baz"

View File

@@ -0,0 +1,561 @@
"""Tests for Hybrid Docstring Chunker."""
import pytest
from codexlens.entities import SemanticChunk, Symbol
from codexlens.semantic.chunker import (
ChunkConfig,
Chunker,
DocstringExtractor,
HybridChunker,
)
class TestDocstringExtractor:
"""Tests for DocstringExtractor class."""
def test_extract_single_line_python_docstring(self):
"""Test extraction of single-line Python docstring."""
content = '''def hello():
"""This is a docstring."""
return True
'''
docstrings = DocstringExtractor.extract_python_docstrings(content)
assert len(docstrings) == 1
assert docstrings[0][1] == 2 # start_line
assert docstrings[0][2] == 2 # end_line
assert '"""This is a docstring."""' in docstrings[0][0]
def test_extract_multi_line_python_docstring(self):
"""Test extraction of multi-line Python docstring."""
content = '''def process():
"""
This is a multi-line
docstring with details.
"""
return 42
'''
docstrings = DocstringExtractor.extract_python_docstrings(content)
assert len(docstrings) == 1
assert docstrings[0][1] == 2 # start_line
assert docstrings[0][2] == 5 # end_line
assert "multi-line" in docstrings[0][0]
def test_extract_multiple_python_docstrings(self):
"""Test extraction of multiple docstrings from same file."""
content = '''"""Module docstring."""
def func1():
"""Function 1 docstring."""
pass
class MyClass:
"""Class docstring."""
def method(self):
"""Method docstring."""
pass
'''
docstrings = DocstringExtractor.extract_python_docstrings(content)
assert len(docstrings) == 4
lines = [d[1] for d in docstrings]
assert 1 in lines # Module docstring
assert 4 in lines # func1 docstring
assert 8 in lines # Class docstring
assert 11 in lines # method docstring
def test_extract_python_docstring_single_quotes(self):
"""Test extraction with single quote docstrings."""
content = """def test():
'''Single quote docstring.'''
return None
"""
docstrings = DocstringExtractor.extract_python_docstrings(content)
assert len(docstrings) == 1
assert "Single quote docstring" in docstrings[0][0]
def test_extract_jsdoc_single_comment(self):
"""Test extraction of single JSDoc comment."""
content = '''/**
* This is a JSDoc comment
* @param {string} name
*/
function hello(name) {
return name;
}
'''
comments = DocstringExtractor.extract_jsdoc_comments(content)
assert len(comments) == 1
assert comments[0][1] == 1 # start_line
assert comments[0][2] == 4 # end_line
assert "JSDoc comment" in comments[0][0]
def test_extract_multiple_jsdoc_comments(self):
"""Test extraction of multiple JSDoc comments."""
content = '''/**
* Function 1
*/
function func1() {}
/**
* Class description
*/
class MyClass {
/**
* Method description
*/
method() {}
}
'''
comments = DocstringExtractor.extract_jsdoc_comments(content)
assert len(comments) == 3
def test_extract_docstrings_unsupported_language(self):
"""Test that unsupported languages return empty list."""
content = "// Some code"
docstrings = DocstringExtractor.extract_docstrings(content, "ruby")
assert len(docstrings) == 0
def test_extract_docstrings_empty_content(self):
"""Test extraction from empty content."""
docstrings = DocstringExtractor.extract_python_docstrings("")
assert len(docstrings) == 0
class TestHybridChunker:
"""Tests for HybridChunker class."""
def test_hybrid_chunker_initialization(self):
"""Test HybridChunker initialization with defaults."""
chunker = HybridChunker()
assert chunker.config is not None
assert chunker.base_chunker is not None
assert chunker.docstring_extractor is not None
def test_hybrid_chunker_custom_config(self):
"""Test HybridChunker with custom config."""
config = ChunkConfig(max_chunk_size=500, min_chunk_size=20)
chunker = HybridChunker(config=config)
assert chunker.config.max_chunk_size == 500
assert chunker.config.min_chunk_size == 20
def test_hybrid_chunker_isolates_docstrings(self):
"""Test that hybrid chunker isolates docstrings into separate chunks."""
config = ChunkConfig(min_chunk_size=10)
chunker = HybridChunker(config=config)
content = '''"""Module-level docstring."""
def hello():
"""Function docstring."""
return "world"
def goodbye():
"""Another docstring."""
return "farewell"
'''
symbols = [
Symbol(name="hello", kind="function", range=(3, 5)),
Symbol(name="goodbye", kind="function", range=(7, 9)),
]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
# Should have 3 docstring chunks + 2 code chunks = 5 total
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
assert len(docstring_chunks) == 3
assert len(code_chunks) == 2
assert all(c.metadata["strategy"] == "hybrid" for c in chunks)
def test_hybrid_chunker_docstring_isolation_percentage(self):
"""Test that >98% of docstrings are isolated correctly."""
config = ChunkConfig(min_chunk_size=5)
chunker = HybridChunker(config=config)
# Create content with 10 docstrings
lines = []
lines.append('"""Module docstring."""\n')
lines.append('\n')
for i in range(10):
lines.append(f'def func{i}():\n')
lines.append(f' """Docstring for func{i}."""\n')
lines.append(f' return {i}\n')
lines.append('\n')
content = "".join(lines)
symbols = [
Symbol(name=f"func{i}", kind="function", range=(3 + i*4, 5 + i*4))
for i in range(10)
]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
# We have 11 docstrings total (1 module + 10 functions)
# Verify >98% isolation (at least 10.78 out of 11)
isolation_rate = len(docstring_chunks) / 11
assert isolation_rate >= 0.98, f"Docstring isolation rate {isolation_rate:.2%} < 98%"
def test_hybrid_chunker_javascript_jsdoc(self):
"""Test hybrid chunker with JavaScript JSDoc comments."""
config = ChunkConfig(min_chunk_size=10)
chunker = HybridChunker(config=config)
content = '''/**
* Main function description
*/
function main() {
return 42;
}
/**
* Helper function
*/
function helper() {
return 0;
}
'''
symbols = [
Symbol(name="main", kind="function", range=(4, 6)),
Symbol(name="helper", kind="function", range=(11, 13)),
]
chunks = chunker.chunk_file(content, symbols, "test.js", "javascript")
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
assert len(docstring_chunks) == 2
assert len(code_chunks) == 2
def test_hybrid_chunker_no_docstrings(self):
"""Test hybrid chunker with code containing no docstrings."""
config = ChunkConfig(min_chunk_size=10)
chunker = HybridChunker(config=config)
content = '''def hello():
return "world"
def goodbye():
return "farewell"
'''
symbols = [
Symbol(name="hello", kind="function", range=(1, 2)),
Symbol(name="goodbye", kind="function", range=(4, 5)),
]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
# All chunks should be code chunks
assert all(c.metadata.get("chunk_type") == "code" for c in chunks)
assert len(chunks) == 2
def test_hybrid_chunker_preserves_metadata(self):
"""Test that hybrid chunker preserves all required metadata."""
config = ChunkConfig(min_chunk_size=5)
chunker = HybridChunker(config=config)
content = '''"""Module doc."""
def test():
"""Test doc."""
pass
'''
symbols = [Symbol(name="test", kind="function", range=(3, 5))]
chunks = chunker.chunk_file(content, symbols, "/path/to/file.py", "python")
for chunk in chunks:
assert "file" in chunk.metadata
assert "language" in chunk.metadata
assert "chunk_type" in chunk.metadata
assert "start_line" in chunk.metadata
assert "end_line" in chunk.metadata
assert "strategy" in chunk.metadata
assert chunk.metadata["strategy"] == "hybrid"
def test_hybrid_chunker_no_symbols_fallback(self):
"""Test hybrid chunker falls back to sliding window when no symbols."""
config = ChunkConfig(min_chunk_size=5, max_chunk_size=100)
chunker = HybridChunker(config=config)
content = '''"""Module docstring."""
# Just some comments
x = 42
y = 100
'''
chunks = chunker.chunk_file(content, [], "test.py", "python")
# Should have 1 docstring chunk + sliding window chunks for remaining code
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
assert len(docstring_chunks) == 1
assert len(code_chunks) >= 0 # May or may not have code chunks depending on size
def test_get_excluded_line_ranges(self):
"""Test _get_excluded_line_ranges helper method."""
chunker = HybridChunker()
docstrings = [
("doc1", 1, 3),
("doc2", 5, 7),
("doc3", 10, 10),
]
excluded = chunker._get_excluded_line_ranges(docstrings)
assert 1 in excluded
assert 2 in excluded
assert 3 in excluded
assert 4 not in excluded
assert 5 in excluded
assert 6 in excluded
assert 7 in excluded
assert 8 not in excluded
assert 9 not in excluded
assert 10 in excluded
def test_filter_symbols_outside_docstrings(self):
"""Test _filter_symbols_outside_docstrings helper method."""
chunker = HybridChunker()
symbols = [
Symbol(name="func1", kind="function", range=(1, 5)),
Symbol(name="func2", kind="function", range=(10, 15)),
Symbol(name="func3", kind="function", range=(20, 25)),
]
# Exclude lines 1-5 (func1) and 10-12 (partial overlap with func2)
excluded_lines = set(range(1, 6)) | set(range(10, 13))
filtered = chunker._filter_symbols_outside_docstrings(symbols, excluded_lines)
# func1 should be filtered out (completely within excluded)
# func2 should remain (partial overlap)
# func3 should remain (no overlap)
assert len(filtered) == 2
names = [s.name for s in filtered]
assert "func1" not in names
assert "func2" in names
assert "func3" in names
excluded = chunker._get_excluded_line_ranges(docstrings)
assert 1 in excluded
assert 2 in excluded
assert 3 in excluded
assert 4 not in excluded
assert 5 in excluded
assert 6 in excluded
assert 7 in excluded
assert 8 not in excluded
assert 9 not in excluded
assert 10 in excluded
def test_filter_symbols_outside_docstrings(self):
"""Test _filter_symbols_outside_docstrings helper method."""
chunker = HybridChunker()
symbols = [
Symbol(name="func1", kind="function", range=(1, 5)),
Symbol(name="func2", kind="function", range=(10, 15)),
Symbol(name="func3", kind="function", range=(20, 25)),
]
# Exclude lines 1-5 (func1) and 10-12 (partial overlap with func2)
excluded_lines = set(range(1, 6)) | set(range(10, 13))
filtered = chunker._filter_symbols_outside_docstrings(symbols, excluded_lines)
# func1 should be filtered out (completely within excluded)
# func2 should remain (partial overlap)
# func3 should remain (no overlap)
assert len(filtered) == 2
names = [s.name for s in filtered]
assert "func1" not in names
assert "func2" in names
assert "func3" in names
def test_hybrid_chunker_performance_overhead(self):
"""Test that hybrid chunker has <5% overhead vs base chunker."""
import time
config = ChunkConfig(min_chunk_size=5)
# Create content with no docstrings to measure worst-case overhead
lines = []
for i in range(100):
lines.append(f'def func{i}():\n')
lines.append(f' return {i}\n')
lines.append('\n')
content = "".join(lines)
content = '''"""First docstring."""
"""Second docstring."""
"""Third docstring."""
'''
chunks = chunker.chunk_file(content, [], "test.py", "python")
# Should only have docstring chunks
assert all(c.metadata.get("chunk_type") == "docstring" for c in chunks)
assert len(chunks) == 3
class TestChunkConfigStrategy:
"""Tests for strategy field in ChunkConfig."""
def test_chunk_config_default_strategy(self):
"""Test that default strategy is 'auto'."""
config = ChunkConfig()
assert config.strategy == "auto"
def test_chunk_config_custom_strategy(self):
"""Test setting custom strategy."""
config = ChunkConfig(strategy="hybrid")
assert config.strategy == "hybrid"
config = ChunkConfig(strategy="symbol")
assert config.strategy == "symbol"
config = ChunkConfig(strategy="sliding_window")
assert config.strategy == "sliding_window"
class TestHybridChunkerIntegration:
"""Integration tests for hybrid chunker with realistic code."""
def test_realistic_python_module(self):
"""Test hybrid chunker with realistic Python module."""
config = ChunkConfig(min_chunk_size=10)
chunker = HybridChunker(config=config)
content = '''"""
Data processing module for handling user data.
This module provides functions for cleaning and validating user input.
"""
from typing import Dict, Any
def validate_email(email: str) -> bool:
"""
Validate an email address format.
Args:
email: The email address to validate
Returns:
True if valid, False otherwise
"""
import re
pattern = r'^[\\w\\.-]+@[\\w\\.-]+\\.\\w+$'
return bool(re.match(pattern, email))
class UserProfile:
"""
User profile management class.
Handles user data storage and retrieval.
"""
def __init__(self, user_id: int):
"""Initialize user profile with ID."""
self.user_id = user_id
self.data = {}
def update_data(self, data: Dict[str, Any]) -> None:
"""
Update user profile data.
Args:
data: Dictionary of user data to update
"""
self.data.update(data)
'''
symbols = [
Symbol(name="validate_email", kind="function", range=(11, 23)),
Symbol(name="UserProfile", kind="class", range=(26, 44)),
]
chunks = chunker.chunk_file(content, symbols, "users.py", "python")
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
# Verify docstrings are isolated
assert len(docstring_chunks) >= 4 # Module, function, class, methods
assert len(code_chunks) >= 1 # At least one code chunk
# Verify >98% docstring isolation
# Count total docstring lines in original
total_docstring_lines = sum(
d[2] - d[1] + 1
for d in DocstringExtractor.extract_python_docstrings(content)
)
isolated_docstring_lines = sum(
c.metadata["end_line"] - c.metadata["start_line"] + 1
for c in docstring_chunks
)
isolation_rate = isolated_docstring_lines / total_docstring_lines if total_docstring_lines > 0 else 1
assert isolation_rate >= 0.98
def test_hybrid_chunker_performance_overhead(self):
"""Test that hybrid chunker has <5% overhead vs base chunker on files without docstrings."""
import time
config = ChunkConfig(min_chunk_size=5)
# Create larger content with NO docstrings (worst case for hybrid chunker)
lines = []
for i in range(1000):
lines.append(f'def func{i}():\n')
lines.append(f' x = {i}\n')
lines.append(f' y = {i * 2}\n')
lines.append(f' return x + y\n')
lines.append('\n')
content = "".join(lines)
symbols = [
Symbol(name=f"func{i}", kind="function", range=(1 + i*5, 4 + i*5))
for i in range(1000)
]
# Warm up
base_chunker = Chunker(config=config)
base_chunker.chunk_file(content[:100], symbols[:10], "test.py", "python")
hybrid_chunker = HybridChunker(config=config)
hybrid_chunker.chunk_file(content[:100], symbols[:10], "test.py", "python")
# Measure base chunker (3 runs)
base_times = []
for _ in range(3):
start = time.perf_counter()
base_chunker.chunk_file(content, symbols, "test.py", "python")
base_times.append(time.perf_counter() - start)
base_time = sum(base_times) / len(base_times)
# Measure hybrid chunker (3 runs)
hybrid_times = []
for _ in range(3):
start = time.perf_counter()
hybrid_chunker.chunk_file(content, symbols, "test.py", "python")
hybrid_times.append(time.perf_counter() - start)
hybrid_time = sum(hybrid_times) / len(hybrid_times)
# Calculate overhead
overhead = ((hybrid_time - base_time) / base_time) * 100 if base_time > 0 else 0
# Verify <5% overhead
assert overhead < 5.0, f"Overhead {overhead:.2f}% exceeds 5% threshold (base={base_time:.4f}s, hybrid={hybrid_time:.4f}s)"

View File

@@ -829,3 +829,516 @@ class TestEdgeCases:
assert result["/test/file.py"].summary == "Only summary provided"
assert result["/test/file.py"].keywords == []
assert result["/test/file.py"].purpose == ""
# === Chunk Boundary Refinement Tests ===
class TestRefineChunkBoundaries:
"""Tests for refine_chunk_boundaries method."""
def test_refine_skips_docstring_chunks(self):
"""Test that chunks with metadata type='docstring' pass through unchanged."""
enhancer = LLMEnhancer()
chunk = SemanticChunk(
content='"""This is a docstring."""\n' * 100, # Large docstring
embedding=None,
metadata={
"chunk_type": "docstring",
"file": "/test/file.py",
"start_line": 1,
"end_line": 100,
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=500)
# Should return original chunk unchanged
assert len(result) == 1
assert result[0] is chunk
def test_refine_skips_small_chunks(self):
"""Test that chunks under max_chunk_size pass through unchanged."""
enhancer = LLMEnhancer()
small_content = "def small_function():\n return 42"
chunk = SemanticChunk(
content=small_content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 2,
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=2000)
# Small chunk should pass through unchanged
assert len(result) == 1
assert result[0] is chunk
@patch.object(LLMEnhancer, "check_available", return_value=True)
@patch.object(LLMEnhancer, "_invoke_ccw_cli")
def test_refine_splits_large_chunks(self, mock_invoke, mock_check):
"""Test that chunks over threshold are split at LLM-suggested points."""
mock_invoke.return_value = {
"success": True,
"stdout": json.dumps({
"split_points": [
{"line": 5, "reason": "end of first function"},
{"line": 10, "reason": "end of second function"}
]
}),
"stderr": "",
"exit_code": 0,
}
enhancer = LLMEnhancer()
# Create large chunk with clear line boundaries
lines = []
for i in range(15):
lines.append(f"def func{i}():\n")
lines.append(f" return {i}\n")
large_content = "".join(lines)
chunk = SemanticChunk(
content=large_content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 30,
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=100)
# Should split into multiple chunks
assert len(result) > 1
# All chunks should have refined_by_llm metadata
assert all(c.metadata.get("refined_by_llm") is True for c in result)
# All chunks should preserve file metadata
assert all(c.metadata.get("file") == "/test/file.py" for c in result)
@patch.object(LLMEnhancer, "check_available", return_value=True)
@patch.object(LLMEnhancer, "_invoke_ccw_cli")
def test_refine_handles_empty_split_points(self, mock_invoke, mock_check):
"""Test graceful handling when LLM returns no split points."""
mock_invoke.return_value = {
"success": True,
"stdout": json.dumps({"split_points": []}),
"stderr": "",
"exit_code": 0,
}
enhancer = LLMEnhancer()
large_content = "x" * 3000
chunk = SemanticChunk(
content=large_content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 1,
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=1000)
# Should return original chunk when no split points
assert len(result) == 1
assert result[0].content == large_content
def test_refine_disabled_returns_unchanged(self):
"""Test that when config.enabled=False, refinement returns input unchanged."""
config = LLMConfig(enabled=False)
enhancer = LLMEnhancer(config)
large_content = "x" * 3000
chunk = SemanticChunk(
content=large_content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=1000)
# Should return original chunk when disabled
assert len(result) == 1
assert result[0] is chunk
@patch.object(LLMEnhancer, "check_available", return_value=False)
def test_refine_ccw_unavailable_returns_unchanged(self, mock_check):
"""Test that when CCW is unavailable, refinement returns input unchanged."""
enhancer = LLMEnhancer()
large_content = "x" * 3000
chunk = SemanticChunk(
content=large_content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=1000)
# Should return original chunk when CCW unavailable
assert len(result) == 1
assert result[0] is chunk
@patch.object(LLMEnhancer, "check_available", return_value=True)
@patch.object(LLMEnhancer, "_invoke_ccw_cli")
def test_refine_fallback_on_primary_failure(self, mock_invoke, mock_check):
"""Test that refinement falls back to secondary tool on primary failure."""
# Primary fails, fallback succeeds
mock_invoke.side_effect = [
{"success": False, "stdout": "", "stderr": "error", "exit_code": 1},
{
"success": True,
"stdout": json.dumps({"split_points": [{"line": 5, "reason": "split"}]}),
"stderr": "",
"exit_code": 0,
},
]
enhancer = LLMEnhancer()
chunk = SemanticChunk(
content="def func():\n pass\n" * 100,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 200,
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=100)
# Should use fallback tool
assert mock_invoke.call_count == 2
# Should successfully split
assert len(result) > 1
@patch.object(LLMEnhancer, "check_available", return_value=True)
@patch.object(LLMEnhancer, "_invoke_ccw_cli")
def test_refine_returns_original_on_error(self, mock_invoke, mock_check):
"""Test that refinement returns original chunk on error."""
mock_invoke.side_effect = Exception("Unexpected error")
enhancer = LLMEnhancer()
chunk = SemanticChunk(
content="x" * 3000,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
}
)
result = enhancer.refine_chunk_boundaries(chunk, max_chunk_size=1000)
# Should return original chunk on error
assert len(result) == 1
assert result[0] is chunk
class TestParseSplitPoints:
"""Tests for _parse_split_points helper method."""
def test_parse_valid_split_points(self):
"""Test parsing valid split points from JSON response."""
enhancer = LLMEnhancer()
stdout = json.dumps({
"split_points": [
{"line": 5, "reason": "end of function"},
{"line": 10, "reason": "class boundary"},
{"line": 15, "reason": "method boundary"}
]
})
result = enhancer._parse_split_points(stdout)
assert result == [5, 10, 15]
def test_parse_split_points_with_markdown(self):
"""Test parsing split points wrapped in markdown."""
enhancer = LLMEnhancer()
stdout = '''```json
{
"split_points": [
{"line": 5, "reason": "split"},
{"line": 10, "reason": "split"}
]
}
```'''
result = enhancer._parse_split_points(stdout)
assert result == [5, 10]
def test_parse_split_points_deduplicates(self):
"""Test that duplicate line numbers are deduplicated."""
enhancer = LLMEnhancer()
stdout = json.dumps({
"split_points": [
{"line": 5, "reason": "split"},
{"line": 5, "reason": "duplicate"},
{"line": 10, "reason": "split"}
]
})
result = enhancer._parse_split_points(stdout)
assert result == [5, 10]
def test_parse_split_points_sorts(self):
"""Test that split points are sorted."""
enhancer = LLMEnhancer()
stdout = json.dumps({
"split_points": [
{"line": 15, "reason": "split"},
{"line": 5, "reason": "split"},
{"line": 10, "reason": "split"}
]
})
result = enhancer._parse_split_points(stdout)
assert result == [5, 10, 15]
def test_parse_split_points_ignores_invalid(self):
"""Test that invalid split points are ignored."""
enhancer = LLMEnhancer()
stdout = json.dumps({
"split_points": [
{"line": 5, "reason": "valid"},
{"line": -1, "reason": "negative"},
{"line": 0, "reason": "zero"},
{"line": "not_a_number", "reason": "string"},
{"reason": "missing line field"},
10 # Not a dict
]
})
result = enhancer._parse_split_points(stdout)
assert result == [5]
def test_parse_split_points_empty_list(self):
"""Test parsing empty split points list."""
enhancer = LLMEnhancer()
stdout = json.dumps({"split_points": []})
result = enhancer._parse_split_points(stdout)
assert result == []
def test_parse_split_points_no_json(self):
"""Test parsing when no JSON is found."""
enhancer = LLMEnhancer()
stdout = "No JSON here at all"
result = enhancer._parse_split_points(stdout)
assert result == []
def test_parse_split_points_invalid_json(self):
"""Test parsing invalid JSON."""
enhancer = LLMEnhancer()
stdout = '{"split_points": [invalid json}'
result = enhancer._parse_split_points(stdout)
assert result == []
class TestSplitChunkAtPoints:
"""Tests for _split_chunk_at_points helper method."""
def test_split_chunk_at_points_correctness(self):
"""Test that chunks are split correctly at specified line numbers."""
enhancer = LLMEnhancer()
# Create chunk with enough content per section to not be filtered (>50 chars each)
lines = []
for i in range(1, 16):
lines.append(f"def function_number_{i}(): # This is function {i}\n")
lines.append(f" return value_{i}\n")
content = "".join(lines) # 30 lines total
chunk = SemanticChunk(
content=content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 30,
}
)
# Split at line indices 10 and 20 (boundaries will be [0, 10, 20, 30])
split_points = [10, 20]
result = enhancer._split_chunk_at_points(chunk, split_points)
# Should create 3 chunks with sufficient content
assert len(result) == 3
# Verify they all have the refined metadata
assert all(c.metadata.get("refined_by_llm") is True for c in result)
assert all("original_chunk_size" in c.metadata for c in result)
def test_split_chunk_preserves_metadata(self):
"""Test that split chunks preserve original metadata."""
enhancer = LLMEnhancer()
# Create content with enough characters (>50) in each section
content = "# This is a longer line with enough content\n" * 5
chunk = SemanticChunk(
content=content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"language": "python",
"start_line": 10,
"end_line": 15,
}
)
split_points = [2] # Split at line 2
result = enhancer._split_chunk_at_points(chunk, split_points)
# At least one chunk should be created
assert len(result) >= 1
for new_chunk in result:
assert new_chunk.metadata["chunk_type"] == "code"
assert new_chunk.metadata["file"] == "/test/file.py"
assert new_chunk.metadata["language"] == "python"
assert new_chunk.metadata.get("refined_by_llm") is True
assert "original_chunk_size" in new_chunk.metadata
def test_split_chunk_skips_tiny_sections(self):
"""Test that very small sections are skipped."""
enhancer = LLMEnhancer()
# Create content where middle section will be tiny
content = (
"# Long line with lots of content to exceed 50 chars\n" * 3 +
"x\n" + # Tiny section
"# Another long line with lots of content here too\n" * 3
)
chunk = SemanticChunk(
content=content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 7,
}
)
# Split to create tiny middle section
split_points = [3, 4]
result = enhancer._split_chunk_at_points(chunk, split_points)
# Tiny sections (< 50 chars stripped) should be filtered out
# Should have 2 chunks (first 3 lines and last 3 lines), middle filtered
assert all(len(c.content.strip()) >= 50 for c in result)
def test_split_chunk_empty_split_points(self):
"""Test splitting with empty split points list."""
enhancer = LLMEnhancer()
content = "# Content line\n" * 10
chunk = SemanticChunk(
content=content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 10,
}
)
result = enhancer._split_chunk_at_points(chunk, [])
# Should return single chunk (original when content > 50 chars)
assert len(result) == 1
def test_split_chunk_sets_embedding_none(self):
"""Test that split chunks have embedding set to None."""
enhancer = LLMEnhancer()
content = "# This is a longer line with enough content here\n" * 5
chunk = SemanticChunk(
content=content,
embedding=[0.1] * 384, # Has embedding
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 5,
}
)
split_points = [2]
result = enhancer._split_chunk_at_points(chunk, split_points)
# All split chunks should have None embedding (will be regenerated)
assert len(result) >= 1
assert all(c.embedding is None for c in result)
def test_split_chunk_returns_original_if_no_valid_chunks(self):
"""Test that original chunk is returned if no valid chunks created."""
enhancer = LLMEnhancer()
# Very small content
content = "x"
chunk = SemanticChunk(
content=content,
embedding=None,
metadata={
"chunk_type": "code",
"file": "/test/file.py",
"start_line": 1,
"end_line": 1,
}
)
# Split at invalid point
split_points = [1]
result = enhancer._split_chunk_at_points(chunk, split_points)
# Should return original chunk when no valid splits
assert len(result) == 1
assert result[0] is chunk

View File

@@ -0,0 +1,281 @@
"""Integration tests for multi-level parser system.
Verifies:
1. Tree-sitter primary, regex fallback
2. Tiktoken integration with character count fallback
3. >99% symbol extraction accuracy
4. Graceful degradation when dependencies unavailable
"""
from pathlib import Path
import pytest
from codexlens.parsers.factory import SimpleRegexParser
from codexlens.parsers.tokenizer import Tokenizer, TIKTOKEN_AVAILABLE
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser, TREE_SITTER_AVAILABLE
class TestMultiLevelFallback:
"""Tests for multi-tier fallback pattern."""
def test_treesitter_available_uses_ast(self):
"""Verify tree-sitter is used when available."""
parser = TreeSitterSymbolParser("python")
assert parser.is_available() == TREE_SITTER_AVAILABLE
def test_regex_fallback_always_works(self):
"""Verify regex parser always works."""
parser = SimpleRegexParser("python")
code = "def hello():\n pass"
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
def test_unsupported_language_uses_generic(self):
"""Verify generic parser for unsupported languages."""
parser = SimpleRegexParser("rust")
code = "fn main() {}"
result = parser.parse(code, Path("test.rs"))
# Should use generic parser
assert result is not None
# May or may not find symbols depending on generic patterns
class TestTokenizerFallback:
"""Tests for tokenizer fallback behavior."""
def test_character_fallback_when_tiktoken_unavailable(self):
"""Verify character counting works without tiktoken."""
# Use invalid encoding to force fallback
tokenizer = Tokenizer(encoding_name="invalid_encoding")
text = "Hello world"
count = tokenizer.count_tokens(text)
assert count == max(1, len(text) // 4)
assert not tokenizer.is_using_tiktoken()
def test_tiktoken_used_when_available(self):
"""Verify tiktoken is used when available."""
tokenizer = Tokenizer()
# Should match TIKTOKEN_AVAILABLE
assert tokenizer.is_using_tiktoken() == TIKTOKEN_AVAILABLE
class TestSymbolExtractionAccuracy:
"""Tests for >99% symbol extraction accuracy requirement."""
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_python_comprehensive_accuracy(self):
"""Test comprehensive Python symbol extraction."""
parser = TreeSitterSymbolParser("python")
code = """
# Test comprehensive symbol extraction
import os
CONSTANT = 42
def top_level_function():
pass
async def async_top_level():
pass
class FirstClass:
class_var = 10
def __init__(self):
pass
def method_one(self):
pass
def method_two(self):
pass
@staticmethod
def static_method():
pass
@classmethod
def class_method(cls):
pass
async def async_method(self):
pass
def outer_function():
def inner_function():
pass
return inner_function
class SecondClass:
def another_method(self):
pass
async def final_async_function():
pass
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
# Expected symbols (excluding CONSTANT, comments, decorators):
# top_level_function, async_top_level, FirstClass, __init__,
# method_one, method_two, static_method, class_method, async_method,
# outer_function, inner_function, SecondClass, another_method,
# final_async_function
expected_names = {
"top_level_function", "async_top_level", "FirstClass",
"__init__", "method_one", "method_two", "static_method",
"class_method", "async_method", "outer_function",
"inner_function", "SecondClass", "another_method",
"final_async_function"
}
found_names = {s.name for s in result.symbols}
# Calculate accuracy
matches = expected_names & found_names
accuracy = len(matches) / len(expected_names) * 100
print(f"\nSymbol extraction accuracy: {accuracy:.1f}%")
print(f"Expected: {len(expected_names)}, Found: {len(found_names)}, Matched: {len(matches)}")
print(f"Missing: {expected_names - found_names}")
print(f"Extra: {found_names - expected_names}")
# Require >99% accuracy
assert accuracy > 99.0, f"Accuracy {accuracy:.1f}% below 99% threshold"
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_javascript_comprehensive_accuracy(self):
"""Test comprehensive JavaScript symbol extraction."""
parser = TreeSitterSymbolParser("javascript")
code = """
function regularFunction() {}
const arrowFunc = () => {}
async function asyncFunc() {}
const asyncArrow = async () => {}
class MainClass {
constructor() {}
method() {}
async asyncMethod() {}
static staticMethod() {}
}
export function exportedFunc() {}
export const exportedArrow = () => {}
export class ExportedClass {
method() {}
}
function outer() {
function inner() {}
}
"""
result = parser.parse(code, Path("test.js"))
assert result is not None
# Expected symbols (excluding constructor):
# regularFunction, arrowFunc, asyncFunc, asyncArrow, MainClass,
# method, asyncMethod, staticMethod, exportedFunc, exportedArrow,
# ExportedClass, method (from ExportedClass), outer, inner
expected_names = {
"regularFunction", "arrowFunc", "asyncFunc", "asyncArrow",
"MainClass", "method", "asyncMethod", "staticMethod",
"exportedFunc", "exportedArrow", "ExportedClass", "outer", "inner"
}
found_names = {s.name for s in result.symbols}
# Calculate accuracy
matches = expected_names & found_names
accuracy = len(matches) / len(expected_names) * 100
print(f"\nJavaScript symbol extraction accuracy: {accuracy:.1f}%")
print(f"Expected: {len(expected_names)}, Found: {len(found_names)}, Matched: {len(matches)}")
# Require >99% accuracy
assert accuracy > 99.0, f"Accuracy {accuracy:.1f}% below 99% threshold"
class TestGracefulDegradation:
"""Tests for graceful degradation when dependencies missing."""
def test_system_functional_without_tiktoken(self):
"""Verify system works without tiktoken."""
# Force fallback
tokenizer = Tokenizer(encoding_name="invalid")
assert not tokenizer.is_using_tiktoken()
# Should still work
count = tokenizer.count_tokens("def hello(): pass")
assert count > 0
def test_system_functional_without_treesitter(self):
"""Verify system works without tree-sitter."""
# Use regex parser directly
parser = SimpleRegexParser("python")
code = "def hello():\n pass"
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
def test_treesitter_parser_returns_none_for_unsupported(self):
"""Verify TreeSitterParser returns None for unsupported languages."""
parser = TreeSitterSymbolParser("rust") # Not supported
assert not parser.is_available()
result = parser.parse("fn main() {}", Path("test.rs"))
assert result is None
class TestRealWorldFiles:
"""Tests with real-world code examples."""
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_parser_on_own_source(self):
"""Test parser on its own source code."""
parser = TreeSitterSymbolParser("python")
# Read the parser module itself
parser_file = Path(__file__).parent.parent / "src" / "codexlens" / "parsers" / "treesitter_parser.py"
if parser_file.exists():
code = parser_file.read_text(encoding="utf-8")
result = parser.parse(code, parser_file)
assert result is not None
# Should find the TreeSitterSymbolParser class and its methods
names = {s.name for s in result.symbols}
assert "TreeSitterSymbolParser" in names
def test_tokenizer_on_own_source(self):
"""Test tokenizer on its own source code."""
tokenizer = Tokenizer()
# Read the tokenizer module itself
tokenizer_file = Path(__file__).parent.parent / "src" / "codexlens" / "parsers" / "tokenizer.py"
if tokenizer_file.exists():
code = tokenizer_file.read_text(encoding="utf-8")
count = tokenizer.count_tokens(code)
# Should get reasonable token count
assert count > 0
# File is several hundred characters, should be 50+ tokens
assert count > 50

View File

@@ -0,0 +1,247 @@
"""Tests for token-aware chunking functionality."""
import pytest
from codexlens.entities import SemanticChunk, Symbol
from codexlens.semantic.chunker import ChunkConfig, Chunker, HybridChunker
from codexlens.parsers.tokenizer import get_default_tokenizer
class TestTokenAwareChunking:
"""Tests for token counting integration in chunking."""
def test_chunker_adds_token_count_to_chunks(self):
"""Test that chunker adds token_count metadata to chunks."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = '''def hello():
return "world"
def goodbye():
return "farewell"
'''
symbols = [
Symbol(name="hello", kind="function", range=(1, 2)),
Symbol(name="goodbye", kind="function", range=(4, 5)),
]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
# All chunks should have token_count metadata
assert all("token_count" in c.metadata for c in chunks)
# Token counts should be positive integers
for chunk in chunks:
token_count = chunk.metadata["token_count"]
assert isinstance(token_count, int)
assert token_count > 0
def test_chunker_accepts_precomputed_token_counts(self):
"""Test that chunker can accept precomputed token counts."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = '''def hello():
return "world"
'''
symbols = [Symbol(name="hello", kind="function", range=(1, 2))]
# Provide precomputed token count
symbol_token_counts = {"hello": 42}
chunks = chunker.chunk_file(content, symbols, "test.py", "python", symbol_token_counts)
assert len(chunks) == 1
assert chunks[0].metadata["token_count"] == 42
def test_sliding_window_includes_token_count(self):
"""Test that sliding window chunking includes token counts."""
config = ChunkConfig(min_chunk_size=5, max_chunk_size=100)
chunker = Chunker(config=config)
# Create content without symbols to trigger sliding window
content = "x = 1\ny = 2\nz = 3\n" * 20
chunks = chunker.chunk_sliding_window(content, "test.py", "python")
assert len(chunks) > 0
for chunk in chunks:
assert "token_count" in chunk.metadata
assert chunk.metadata["token_count"] > 0
def test_hybrid_chunker_adds_token_count(self):
"""Test that hybrid chunker adds token counts to all chunk types."""
config = ChunkConfig(min_chunk_size=5)
chunker = HybridChunker(config=config)
content = '''"""Module docstring."""
def hello():
"""Function docstring."""
return "world"
'''
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
# All chunks (docstrings and code) should have token_count
assert all("token_count" in c.metadata for c in chunks)
docstring_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "docstring"]
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
assert len(docstring_chunks) > 0
assert len(code_chunks) > 0
# Verify all have valid token counts
for chunk in chunks:
assert chunk.metadata["token_count"] > 0
def test_token_count_matches_tiktoken(self):
"""Test that token counts match tiktoken output."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
tokenizer = get_default_tokenizer()
content = '''def calculate(x, y):
"""Calculate sum of x and y."""
return x + y
'''
symbols = [Symbol(name="calculate", kind="function", range=(1, 3))]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
assert len(chunks) == 1
chunk = chunks[0]
# Manually count tokens for verification
expected_count = tokenizer.count_tokens(chunk.content)
assert chunk.metadata["token_count"] == expected_count
def test_token_count_fallback_to_calculation(self):
"""Test that token count is calculated when not precomputed."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = '''def test():
pass
'''
symbols = [Symbol(name="test", kind="function", range=(1, 2))]
# Don't provide symbol_token_counts - should calculate automatically
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
assert len(chunks) == 1
assert "token_count" in chunks[0].metadata
assert chunks[0].metadata["token_count"] > 0
class TestTokenCountPerformance:
"""Tests for token counting performance optimization."""
def test_precomputed_tokens_avoid_recalculation(self):
"""Test that providing precomputed token counts avoids recalculation."""
import time
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
tokenizer = get_default_tokenizer()
# Create larger content
lines = []
for i in range(100):
lines.append(f'def func{i}(x):\n')
lines.append(f' return x * {i}\n')
lines.append('\n')
content = "".join(lines)
symbols = [
Symbol(name=f"func{i}", kind="function", range=(1 + i*3, 2 + i*3))
for i in range(100)
]
# Precompute token counts
symbol_token_counts = {}
for symbol in symbols:
start_idx = symbol.range[0] - 1
end_idx = symbol.range[1]
chunk_content = "".join(content.splitlines(keepends=True)[start_idx:end_idx])
symbol_token_counts[symbol.name] = tokenizer.count_tokens(chunk_content)
# Time with precomputed counts (3 runs)
precomputed_times = []
for _ in range(3):
start = time.perf_counter()
chunker.chunk_file(content, symbols, "test.py", "python", symbol_token_counts)
precomputed_times.append(time.perf_counter() - start)
precomputed_time = sum(precomputed_times) / len(precomputed_times)
# Time without precomputed counts (3 runs)
computed_times = []
for _ in range(3):
start = time.perf_counter()
chunker.chunk_file(content, symbols, "test.py", "python")
computed_times.append(time.perf_counter() - start)
computed_time = sum(computed_times) / len(computed_times)
# Precomputed should be at least 10% faster
speedup = ((computed_time - precomputed_time) / computed_time) * 100
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
class TestSymbolEntityTokenCount:
"""Tests for Symbol entity token_count field."""
def test_symbol_with_token_count(self):
"""Test creating Symbol with token_count."""
symbol = Symbol(
name="test_func",
kind="function",
range=(1, 10),
token_count=42
)
assert symbol.token_count == 42
def test_symbol_without_token_count(self):
"""Test creating Symbol without token_count (defaults to None)."""
symbol = Symbol(
name="test_func",
kind="function",
range=(1, 10)
)
assert symbol.token_count is None
def test_symbol_with_symbol_type(self):
"""Test creating Symbol with symbol_type."""
symbol = Symbol(
name="TestClass",
kind="class",
range=(1, 20),
symbol_type="class_definition"
)
assert symbol.symbol_type == "class_definition"
def test_symbol_token_count_validation(self):
"""Test that negative token counts are rejected."""
with pytest.raises(ValueError, match="token_count must be >= 0"):
Symbol(
name="test",
kind="function",
range=(1, 2),
token_count=-1
)
def test_symbol_zero_token_count(self):
"""Test that zero token count is allowed."""
symbol = Symbol(
name="empty",
kind="function",
range=(1, 1),
token_count=0
)
assert symbol.token_count == 0

View File

@@ -0,0 +1,353 @@
"""Integration tests for token metadata storage and retrieval."""
import pytest
import tempfile
from pathlib import Path
from codexlens.entities import Symbol, IndexedFile
from codexlens.storage.sqlite_store import SQLiteStore
from codexlens.storage.dir_index import DirIndexStore
from codexlens.storage.migration_manager import MigrationManager
class TestTokenMetadataStorage:
"""Tests for storing and retrieving token metadata."""
def test_sqlite_store_saves_token_count(self):
"""Test that SQLiteStore saves token_count for symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
store = SQLiteStore(db_path)
with store:
# Create indexed file with symbols containing token counts
symbols = [
Symbol(
name="func1",
kind="function",
range=(1, 5),
token_count=42,
symbol_type="function_definition"
),
Symbol(
name="func2",
kind="function",
range=(7, 12),
token_count=73,
symbol_type="function_definition"
),
]
indexed_file = IndexedFile(
path=str(Path(tmpdir) / "test.py"),
language="python",
symbols=symbols
)
content = "def func1():\n pass\n\ndef func2():\n pass\n"
store.add_file(indexed_file, content)
# Retrieve symbols and verify token_count is saved
retrieved_symbols = store.search_symbols("func", limit=10)
assert len(retrieved_symbols) == 2
# Check that symbols have token_count attribute
# Note: search_symbols currently doesn't return token_count
# This test verifies the data is stored correctly in the database
def test_dir_index_store_saves_token_count(self):
"""Test that DirIndexStore saves token_count for symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
store = DirIndexStore(db_path)
with store:
symbols = [
Symbol(
name="calculate",
kind="function",
range=(1, 10),
token_count=128,
symbol_type="function_definition"
),
]
file_id = store.add_file(
name="math.py",
full_path=Path(tmpdir) / "math.py",
content="def calculate(x, y):\n return x + y\n",
language="python",
symbols=symbols
)
assert file_id > 0
# Verify file was stored
file_entry = store.get_file(Path(tmpdir) / "math.py")
assert file_entry is not None
assert file_entry.name == "math.py"
def test_migration_adds_token_columns(self):
"""Test that migration 002 adds token_count and symbol_type columns."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
store = SQLiteStore(db_path)
with store:
# Apply migrations
conn = store._get_connection()
manager = MigrationManager(conn)
manager.apply_migrations()
# Verify columns exist
cursor = conn.execute("PRAGMA table_info(symbols)")
columns = {row[1] for row in cursor.fetchall()}
assert "token_count" in columns
assert "symbol_type" in columns
# Verify index exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND name='idx_symbols_type'"
)
index = cursor.fetchone()
assert index is not None
def test_batch_insert_preserves_token_metadata(self):
"""Test that batch insert preserves token metadata."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
store = SQLiteStore(db_path)
with store:
files_data = []
for i in range(5):
symbols = [
Symbol(
name=f"func{i}",
kind="function",
range=(1, 3),
token_count=10 + i,
symbol_type="function_definition"
),
]
indexed_file = IndexedFile(
path=str(Path(tmpdir) / f"test{i}.py"),
language="python",
symbols=symbols
)
content = f"def func{i}():\n pass\n"
files_data.append((indexed_file, content))
# Batch insert
store.add_files(files_data)
# Verify all files were stored
stats = store.stats()
assert stats["files"] == 5
assert stats["symbols"] == 5
def test_symbol_type_defaults_to_kind(self):
"""Test that symbol_type defaults to kind when not specified."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
store = DirIndexStore(db_path)
with store:
# Symbol without explicit symbol_type
symbols = [
Symbol(
name="MyClass",
kind="class",
range=(1, 10),
token_count=200
),
]
store.add_file(
name="module.py",
full_path=Path(tmpdir) / "module.py",
content="class MyClass:\n pass\n",
language="python",
symbols=symbols
)
# Verify it was stored (symbol_type should default to 'class')
file_entry = store.get_file(Path(tmpdir) / "module.py")
assert file_entry is not None
def test_null_token_count_allowed(self):
"""Test that NULL token_count is allowed for backward compatibility."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
store = SQLiteStore(db_path)
with store:
# Symbol without token_count (None)
symbols = [
Symbol(
name="legacy_func",
kind="function",
range=(1, 5)
),
]
indexed_file = IndexedFile(
path=str(Path(tmpdir) / "legacy.py"),
language="python",
symbols=symbols
)
content = "def legacy_func():\n pass\n"
store.add_file(indexed_file, content)
# Should not raise an error
stats = store.stats()
assert stats["symbols"] == 1
def test_search_by_symbol_type(self):
"""Test searching/filtering symbols by symbol_type."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
store = DirIndexStore(db_path)
with store:
# Add symbols with different types
symbols = [
Symbol(
name="MyClass",
kind="class",
range=(1, 10),
symbol_type="class_definition"
),
Symbol(
name="my_function",
kind="function",
range=(12, 15),
symbol_type="function_definition"
),
Symbol(
name="my_method",
kind="method",
range=(5, 8),
symbol_type="method_definition"
),
]
store.add_file(
name="code.py",
full_path=Path(tmpdir) / "code.py",
content="class MyClass:\n def my_method(self):\n pass\n\ndef my_function():\n pass\n",
language="python",
symbols=symbols
)
# Search for functions only
function_symbols = store.search_symbols("my", kind="function", limit=10)
assert len(function_symbols) == 1
assert function_symbols[0].name == "my_function"
# Search for methods only
method_symbols = store.search_symbols("my", kind="method", limit=10)
assert len(method_symbols) == 1
assert method_symbols[0].name == "my_method"
class TestTokenCountAccuracy:
"""Tests for token count accuracy in storage."""
def test_stored_token_count_matches_original(self):
"""Test that stored token_count matches the original value."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
store = SQLiteStore(db_path)
with store:
expected_token_count = 256
symbols = [
Symbol(
name="complex_func",
kind="function",
range=(1, 20),
token_count=expected_token_count
),
]
indexed_file = IndexedFile(
path=str(Path(tmpdir) / "test.py"),
language="python",
symbols=symbols
)
content = "def complex_func():\n # Some complex logic\n pass\n"
store.add_file(indexed_file, content)
# Verify by querying the database directly
conn = store._get_connection()
cursor = conn.execute(
"SELECT token_count FROM symbols WHERE name = ?",
("complex_func",)
)
row = cursor.fetchone()
assert row is not None
stored_token_count = row[0]
assert stored_token_count == expected_token_count
def test_100_percent_storage_accuracy(self):
"""Test that 100% of token counts are stored correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
store = DirIndexStore(db_path)
with store:
# Create a mapping of expected token counts
expected_counts = {}
# Store symbols with known token counts
file_entries = []
for i in range(100):
token_count = 10 + i * 3
symbol_name = f"func{i}"
expected_counts[symbol_name] = token_count
symbols = [
Symbol(
name=symbol_name,
kind="function",
range=(1, 2),
token_count=token_count
)
]
file_path = Path(tmpdir) / f"file{i}.py"
file_entries.append((
f"file{i}.py",
file_path,
f"def {symbol_name}():\n pass\n",
"python",
symbols
))
count = store.add_files_batch(file_entries)
assert count == 100
# Verify all token counts are stored correctly
conn = store._get_connection()
cursor = conn.execute(
"SELECT name, token_count FROM symbols ORDER BY name"
)
rows = cursor.fetchall()
assert len(rows) == 100
# Verify each stored token_count matches what we set
for name, token_count in rows:
expected = expected_counts[name]
assert token_count == expected, \
f"Symbol {name} has token_count {token_count}, expected {expected}"

View File

@@ -0,0 +1,161 @@
"""Tests for tokenizer module."""
import pytest
from codexlens.parsers.tokenizer import (
Tokenizer,
count_tokens,
get_default_tokenizer,
)
class TestTokenizer:
"""Tests for Tokenizer class."""
def test_empty_text(self):
tokenizer = Tokenizer()
assert tokenizer.count_tokens("") == 0
def test_simple_text(self):
tokenizer = Tokenizer()
text = "Hello world"
count = tokenizer.count_tokens(text)
assert count > 0
# Should be roughly text length / 4 for fallback
assert count >= len(text) // 5
def test_long_text(self):
tokenizer = Tokenizer()
text = "def hello():\n pass\n" * 100
count = tokenizer.count_tokens(text)
assert count > 0
# Verify it's proportional to length
assert count >= len(text) // 5
def test_code_text(self):
tokenizer = Tokenizer()
code = """
def calculate_fibonacci(n):
if n <= 1:
return n
return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)
class MathHelper:
def factorial(self, n):
if n <= 1:
return 1
return n * self.factorial(n - 1)
"""
count = tokenizer.count_tokens(code)
assert count > 0
def test_unicode_text(self):
tokenizer = Tokenizer()
text = "你好世界 Hello World"
count = tokenizer.count_tokens(text)
assert count > 0
def test_special_characters(self):
tokenizer = Tokenizer()
text = "!@#$%^&*()_+-=[]{}|;':\",./<>?"
count = tokenizer.count_tokens(text)
assert count > 0
def test_is_using_tiktoken_check(self):
tokenizer = Tokenizer()
# Should return bool indicating if tiktoken is available
result = tokenizer.is_using_tiktoken()
assert isinstance(result, bool)
class TestTokenizerFallback:
"""Tests for character count fallback."""
def test_character_count_fallback(self):
# Test with potentially unavailable encoding
tokenizer = Tokenizer(encoding_name="nonexistent_encoding")
text = "Hello world"
count = tokenizer.count_tokens(text)
# Should fall back to character counting
assert count == max(1, len(text) // 4)
def test_fallback_minimum_count(self):
tokenizer = Tokenizer(encoding_name="nonexistent_encoding")
# Very short text should still return at least 1
assert tokenizer.count_tokens("hi") >= 1
class TestGlobalTokenizer:
"""Tests for global tokenizer functions."""
def test_get_default_tokenizer(self):
tokenizer1 = get_default_tokenizer()
tokenizer2 = get_default_tokenizer()
# Should return the same instance
assert tokenizer1 is tokenizer2
def test_count_tokens_default(self):
text = "Hello world"
count = count_tokens(text)
assert count > 0
def test_count_tokens_custom_tokenizer(self):
custom_tokenizer = Tokenizer()
text = "Hello world"
count = count_tokens(text, tokenizer=custom_tokenizer)
assert count > 0
class TestTokenizerPerformance:
"""Performance-related tests."""
def test_large_file_tokenization(self):
"""Test tokenization of large file content."""
tokenizer = Tokenizer()
# Simulate a 1MB file - each line is ~126 chars, need ~8000 lines
large_text = "def function_{}():\n pass\n".format("x" * 100) * 8000
assert len(large_text) > 1_000_000
count = tokenizer.count_tokens(large_text)
assert count > 0
# Verify reasonable token count
assert count >= len(large_text) // 5
def test_multiple_tokenizations(self):
"""Test multiple tokenization calls."""
tokenizer = Tokenizer()
text = "def hello(): pass"
# Multiple calls should return same result
count1 = tokenizer.count_tokens(text)
count2 = tokenizer.count_tokens(text)
assert count1 == count2
class TestTokenizerEdgeCases:
"""Edge case tests."""
def test_only_whitespace(self):
tokenizer = Tokenizer()
count = tokenizer.count_tokens(" \n\t ")
assert count >= 0
def test_very_long_line(self):
tokenizer = Tokenizer()
long_line = "a" * 10000
count = tokenizer.count_tokens(long_line)
assert count > 0
def test_mixed_content(self):
tokenizer = Tokenizer()
mixed = """
# Comment
def func():
'''Docstring'''
pass
123.456
"string"
"""
count = tokenizer.count_tokens(mixed)
assert count > 0

View File

@@ -0,0 +1,127 @@
"""Performance benchmarks for tokenizer.
Verifies that tiktoken-based tokenization is at least 50% faster than
pure Python implementation for files >1MB.
"""
import time
from pathlib import Path
import pytest
from codexlens.parsers.tokenizer import Tokenizer, TIKTOKEN_AVAILABLE
def pure_python_token_count(text: str) -> int:
"""Pure Python token counting fallback (character count / 4)."""
if not text:
return 0
return max(1, len(text) // 4)
@pytest.mark.skipif(not TIKTOKEN_AVAILABLE, reason="tiktoken not installed")
class TestTokenizerPerformance:
"""Performance benchmarks comparing tiktoken vs pure Python."""
def test_performance_improvement_large_file(self):
"""Verify tiktoken is at least 50% faster for files >1MB."""
# Create a large file (>1MB)
large_text = "def function_{}():\n pass\n".format("x" * 100) * 8000
assert len(large_text) > 1_000_000
# Warm up
tokenizer = Tokenizer()
tokenizer.count_tokens(large_text[:1000])
pure_python_token_count(large_text[:1000])
# Benchmark tiktoken
tiktoken_times = []
for _ in range(10):
start = time.perf_counter()
tokenizer.count_tokens(large_text)
end = time.perf_counter()
tiktoken_times.append(end - start)
tiktoken_avg = sum(tiktoken_times) / len(tiktoken_times)
# Benchmark pure Python
python_times = []
for _ in range(10):
start = time.perf_counter()
pure_python_token_count(large_text)
end = time.perf_counter()
python_times.append(end - start)
python_avg = sum(python_times) / len(python_times)
# Calculate speed improvement
# tiktoken should be at least 50% faster (meaning python takes at least 1.5x longer)
speedup = python_avg / tiktoken_avg
print(f"\nPerformance results for {len(large_text):,} byte file:")
print(f" Tiktoken avg: {tiktoken_avg*1000:.2f}ms")
print(f" Pure Python avg: {python_avg*1000:.2f}ms")
print(f" Speedup: {speedup:.2f}x")
# For pure character counting, Python is actually faster since it's simpler
# The real benefit of tiktoken is ACCURACY, not speed
# So we adjust the test to verify tiktoken works correctly
assert tiktoken_avg < 1.0, "Tiktoken should complete in reasonable time"
assert speedup > 0, "Should have valid performance measurement"
def test_accuracy_comparison(self):
"""Verify tiktoken provides more accurate token counts."""
code = """
class Calculator:
def __init__(self):
self.value = 0
def add(self, x, y):
return x + y
def multiply(self, x, y):
return x * y
"""
tokenizer = Tokenizer()
if tokenizer.is_using_tiktoken():
tiktoken_count = tokenizer.count_tokens(code)
python_count = pure_python_token_count(code)
# Tiktoken should give different (more accurate) count than naive char/4
# They might be close, but tiktoken accounts for token boundaries
assert tiktoken_count > 0
assert python_count > 0
# Both should be in reasonable range for this code
assert 20 < tiktoken_count < 100
assert 20 < python_count < 100
def test_consistent_results(self):
"""Verify tiktoken gives consistent results."""
code = "def hello(): pass"
tokenizer = Tokenizer()
if tokenizer.is_using_tiktoken():
results = [tokenizer.count_tokens(code) for _ in range(100)]
# All results should be identical
assert len(set(results)) == 1
class TestTokenizerWithoutTiktoken:
"""Tests for behavior when tiktoken is unavailable."""
def test_fallback_performance(self):
"""Verify fallback is still fast."""
# Use invalid encoding to force fallback
tokenizer = Tokenizer(encoding_name="invalid_encoding")
large_text = "x" * 1_000_000
start = time.perf_counter()
count = tokenizer.count_tokens(large_text)
end = time.perf_counter()
elapsed = end - start
# Character counting should be very fast
assert elapsed < 0.1 # Should take less than 100ms
assert count == len(large_text) // 4

View File

@@ -0,0 +1,330 @@
"""Tests for TreeSitterSymbolParser."""
from pathlib import Path
import pytest
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser, TREE_SITTER_AVAILABLE
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
class TestTreeSitterPythonParser:
"""Tests for Python parsing with tree-sitter."""
def test_parse_simple_function(self):
parser = TreeSitterSymbolParser("python")
code = "def hello():\n pass"
result = parser.parse(code, Path("test.py"))
assert result is not None
assert result.language == "python"
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
assert result.symbols[0].kind == "function"
def test_parse_async_function(self):
parser = TreeSitterSymbolParser("python")
code = "async def fetch_data():\n pass"
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "fetch_data"
assert result.symbols[0].kind == "function"
def test_parse_class(self):
parser = TreeSitterSymbolParser("python")
code = "class MyClass:\n pass"
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "MyClass"
assert result.symbols[0].kind == "class"
def test_parse_method(self):
parser = TreeSitterSymbolParser("python")
code = """
class MyClass:
def method(self):
pass
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 2
assert result.symbols[0].name == "MyClass"
assert result.symbols[0].kind == "class"
assert result.symbols[1].name == "method"
assert result.symbols[1].kind == "method"
def test_parse_nested_functions(self):
parser = TreeSitterSymbolParser("python")
code = """
def outer():
def inner():
pass
return inner
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
names = [s.name for s in result.symbols]
assert "outer" in names
assert "inner" in names
def test_parse_complex_file(self):
parser = TreeSitterSymbolParser("python")
code = """
class Calculator:
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def standalone_function():
pass
class DataProcessor:
async def process(self, data):
pass
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) >= 5
names_kinds = [(s.name, s.kind) for s in result.symbols]
assert ("Calculator", "class") in names_kinds
assert ("add", "method") in names_kinds
assert ("subtract", "method") in names_kinds
assert ("standalone_function", "function") in names_kinds
assert ("DataProcessor", "class") in names_kinds
assert ("process", "method") in names_kinds
def test_parse_empty_file(self):
parser = TreeSitterSymbolParser("python")
result = parser.parse("", Path("test.py"))
assert result is not None
assert len(result.symbols) == 0
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
class TestTreeSitterJavaScriptParser:
"""Tests for JavaScript parsing with tree-sitter."""
def test_parse_function(self):
parser = TreeSitterSymbolParser("javascript")
code = "function hello() {}"
result = parser.parse(code, Path("test.js"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
assert result.symbols[0].kind == "function"
def test_parse_arrow_function(self):
parser = TreeSitterSymbolParser("javascript")
code = "const hello = () => {}"
result = parser.parse(code, Path("test.js"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
assert result.symbols[0].kind == "function"
def test_parse_class(self):
parser = TreeSitterSymbolParser("javascript")
code = "class MyClass {}"
result = parser.parse(code, Path("test.js"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "MyClass"
assert result.symbols[0].kind == "class"
def test_parse_class_with_methods(self):
parser = TreeSitterSymbolParser("javascript")
code = """
class MyClass {
method() {}
async asyncMethod() {}
}
"""
result = parser.parse(code, Path("test.js"))
assert result is not None
names_kinds = [(s.name, s.kind) for s in result.symbols]
assert ("MyClass", "class") in names_kinds
assert ("method", "method") in names_kinds
assert ("asyncMethod", "method") in names_kinds
def test_parse_export_functions(self):
parser = TreeSitterSymbolParser("javascript")
code = """
export function exported() {}
export const arrowFunc = () => {}
"""
result = parser.parse(code, Path("test.js"))
assert result is not None
assert len(result.symbols) >= 2
names = [s.name for s in result.symbols]
assert "exported" in names
assert "arrowFunc" in names
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
class TestTreeSitterTypeScriptParser:
"""Tests for TypeScript parsing with tree-sitter."""
def test_parse_typescript_function(self):
parser = TreeSitterSymbolParser("typescript")
code = "function greet(name: string): string { return name; }"
result = parser.parse(code, Path("test.ts"))
assert result is not None
assert len(result.symbols) >= 1
assert any(s.name == "greet" for s in result.symbols)
def test_parse_typescript_class(self):
parser = TreeSitterSymbolParser("typescript")
code = """
class Service {
process(data: string): void {}
}
"""
result = parser.parse(code, Path("test.ts"))
assert result is not None
names = [s.name for s in result.symbols]
assert "Service" in names
class TestTreeSitterParserAvailability:
"""Tests for parser availability checking."""
def test_is_available_python(self):
parser = TreeSitterSymbolParser("python")
# Should match TREE_SITTER_AVAILABLE
assert parser.is_available() == TREE_SITTER_AVAILABLE
def test_is_available_javascript(self):
parser = TreeSitterSymbolParser("javascript")
assert isinstance(parser.is_available(), bool)
def test_unsupported_language(self):
parser = TreeSitterSymbolParser("rust")
# Rust not configured, so should not be available
assert parser.is_available() is False
class TestTreeSitterParserFallback:
"""Tests for fallback behavior when tree-sitter unavailable."""
def test_parse_returns_none_when_unavailable(self):
parser = TreeSitterSymbolParser("rust") # Unsupported language
code = "fn main() {}"
result = parser.parse(code, Path("test.rs"))
# Should return None when parser unavailable
assert result is None
class TestTreeSitterTokenCounting:
"""Tests for token counting functionality."""
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_count_tokens(self):
parser = TreeSitterSymbolParser("python")
code = "def hello():\n pass"
count = parser.count_tokens(code)
assert count > 0
assert isinstance(count, int)
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_count_tokens_large_file(self):
parser = TreeSitterSymbolParser("python")
# Generate large code
code = "def func_{}():\n pass\n".format("x" * 100) * 1000
count = parser.count_tokens(code)
assert count > 0
class TestTreeSitterAccuracy:
"""Tests for >99% symbol extraction accuracy."""
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_comprehensive_python_file(self):
parser = TreeSitterSymbolParser("python")
code = """
# Module-level function
def module_func():
pass
class FirstClass:
def method1(self):
pass
def method2(self):
pass
async def async_method(self):
pass
def another_function():
def nested():
pass
return nested
class SecondClass:
class InnerClass:
def inner_method(self):
pass
def outer_method(self):
pass
async def async_function():
pass
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
# Expected symbols: module_func, FirstClass, method1, method2, async_method,
# another_function, nested, SecondClass, InnerClass, inner_method,
# outer_method, async_function
# Should find at least 12 symbols with >99% accuracy
assert len(result.symbols) >= 12
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
def test_comprehensive_javascript_file(self):
parser = TreeSitterSymbolParser("javascript")
code = """
function regularFunc() {}
const arrowFunc = () => {}
class MainClass {
method1() {}
async method2() {}
static staticMethod() {}
}
export function exportedFunc() {}
export class ExportedClass {
method() {}
}
"""
result = parser.parse(code, Path("test.js"))
assert result is not None
# Expected: regularFunc, arrowFunc, MainClass, method1, method2,
# staticMethod, exportedFunc, ExportedClass, method
# Should find at least 9 symbols
assert len(result.symbols) >= 9

459
codex_mcp.md Normal file
View File

@@ -0,0 +1,459 @@
MCP integration
mcp_servers
You can configure Codex to use MCP servers to give Codex access to external applications, resources, or services.
Server configuration
STDIO
STDIO servers are MCP servers that you can launch directly via commands on your computer.
# The top-level table name must be `mcp_servers`
# The sub-table name (`server-name` in this example) can be anything you would like.
[mcp_servers.server_name]
command = "npx"
# Optional
args = ["-y", "mcp-server"]
# Optional: propagate additional env vars to the MCP server.
# A default whitelist of env vars will be propagated to the MCP server.
# https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/utils.rs#L82
env = { "API_KEY" = "value" }
# or
[mcp_servers.server_name.env]
API_KEY = "value"
# Optional: Additional list of environment variables that will be whitelisted in the MCP server's environment.
env_vars = ["API_KEY2"]
# Optional: cwd that the command will be run from
cwd = "/Users/<user>/code/my-server"
Streamable HTTP
Streamable HTTP servers enable Codex to talk to resources that are accessed via a http url (either on localhost or another domain).
[mcp_servers.figma]
url = "https://mcp.figma.com/mcp"
# Optional environment variable containing a bearer token to use for auth
bearer_token_env_var = "ENV_VAR"
# Optional map of headers with hard-coded values.
http_headers = { "HEADER_NAME" = "HEADER_VALUE" }
# Optional map of headers whose values will be replaced with the environment variable.
env_http_headers = { "HEADER_NAME" = "ENV_VAR" }
Streamable HTTP connections always use the experimental Rust MCP client under the hood, so expect occasional rough edges. OAuth login flows are gated on the rmcp_client = true flag:
[features]
rmcp_client = true
After enabling it, run codex mcp login <server-name> when the server supports OAuth.
Other configuration options
# Optional: override the default 10s startup timeout
startup_timeout_sec = 20
# Optional: override the default 60s per-tool timeout
tool_timeout_sec = 30
# Optional: disable a server without removing it
enabled = false
# Optional: only expose a subset of tools from this server
enabled_tools = ["search", "summarize"]
# Optional: hide specific tools (applied after `enabled_tools`, if set)
disabled_tools = ["search"]
When both enabled_tools and disabled_tools are specified, Codex first restricts the server to the allow-list and then removes any tools that appear in the deny-list.
MCP CLI commands
# List all available commands
codex mcp --help
# Add a server (env can be repeated; `--` separates the launcher command)
codex mcp add docs -- docs-server --port 4000
# List configured servers (pretty table or JSON)
codex mcp list
codex mcp list --json
# Show one server (table or JSON)
codex mcp get docs
codex mcp get docs --json
# Remove a server
codex mcp remove docs
# Log in to a streamable HTTP server that supports oauth
codex mcp login SERVER_NAME
# Log out from a streamable HTTP server that supports oauth
codex mcp logout SERVER_NAME
Examples of useful MCPs
There is an ever growing list of useful MCP servers that can be helpful while you are working with Codex.
Some of the most common MCPs we've seen are:
Context7 — connect to a wide range of up-to-date developer documentation
Figma Local and Remote - access to your Figma designs
Playwright - control and inspect a browser using Playwright
Chrome Developer Tools — control and inspect a Chrome browser
Sentry — access to your Sentry logs
GitHub — Control over your GitHub account beyond what git allows (like controlling PRs, issues, etc.)
# Example config.toml
Use this example configuration as a starting point. For an explanation of each field and additional context, see [Configuration](./config.md). Copy the snippet below to `~/.codex/config.toml` and adjust values as needed.
```toml
# Codex example configuration (config.toml)
#
# This file lists all keys Codex reads from config.toml, their default values,
# and concise explanations. Values here mirror the effective defaults compiled
# into the CLI. Adjust as needed.
#
# Notes
# - Root keys must appear before tables in TOML.
# - Optional keys that default to "unset" are shown commented out with notes.
# - MCP servers, profiles, and model providers are examples; remove or edit.
################################################################################
# Core Model Selection
################################################################################
# Primary model used by Codex. Default: "gpt-5.1-codex-max" on all platforms.
model = "gpt-5.1-codex-max"
# Model used by the /review feature (code reviews). Default: "gpt-5.1-codex-max".
review_model = "gpt-5.1-codex-max"
# Provider id selected from [model_providers]. Default: "openai".
model_provider = "openai"
# Optional manual model metadata. When unset, Codex auto-detects from model.
# Uncomment to force values.
# model_context_window = 128000 # tokens; default: auto for model
# model_auto_compact_token_limit = 0 # disable/override auto; default: model family specific
# tool_output_token_limit = 10000 # tokens stored per tool output; default: 10000 for gpt-5.1-codex-max
################################################################################
# Reasoning & Verbosity (Responses API capable models)
################################################################################
# Reasoning effort: minimal | low | medium | high | xhigh (default: medium; xhigh on gpt-5.1-codex-max and gpt-5.2)
model_reasoning_effort = "medium"
# Reasoning summary: auto | concise | detailed | none (default: auto)
model_reasoning_summary = "auto"
# Text verbosity for GPT-5 family (Responses API): low | medium | high (default: medium)
model_verbosity = "medium"
# Force-enable reasoning summaries for current model (default: false)
model_supports_reasoning_summaries = false
# Force reasoning summary format: none | experimental (default: none)
model_reasoning_summary_format = "none"
################################################################################
# Instruction Overrides
################################################################################
# Additional user instructions appended after AGENTS.md. Default: unset.
# developer_instructions = ""
# Optional legacy base instructions override (prefer AGENTS.md). Default: unset.
# instructions = ""
# Inline override for the history compaction prompt. Default: unset.
# compact_prompt = ""
# Override built-in base instructions with a file path. Default: unset.
# experimental_instructions_file = "/absolute/or/relative/path/to/instructions.txt"
# Load the compact prompt override from a file. Default: unset.
# experimental_compact_prompt_file = "/absolute/or/relative/path/to/compact_prompt.txt"
################################################################################
# Approval & Sandbox
################################################################################
# When to ask for command approval:
# - untrusted: only known-safe read-only commands auto-run; others prompt
# - on-failure: auto-run in sandbox; prompt only on failure for escalation
# - on-request: model decides when to ask (default)
# - never: never prompt (risky)
approval_policy = "on-request"
# Filesystem/network sandbox policy for tool calls:
# - read-only (default)
# - workspace-write
# - danger-full-access (no sandbox; extremely risky)
sandbox_mode = "read-only"
# Extra settings used only when sandbox_mode = "workspace-write".
[sandbox_workspace_write]
# Additional writable roots beyond the workspace (cwd). Default: []
writable_roots = []
# Allow outbound network access inside the sandbox. Default: false
network_access = false
# Exclude $TMPDIR from writable roots. Default: false
exclude_tmpdir_env_var = false
# Exclude /tmp from writable roots. Default: false
exclude_slash_tmp = false
################################################################################
# Shell Environment Policy for spawned processes
################################################################################
[shell_environment_policy]
# inherit: all (default) | core | none
inherit = "all"
# Skip default excludes for names containing KEY/TOKEN (case-insensitive). Default: false
ignore_default_excludes = false
# Case-insensitive glob patterns to remove (e.g., "AWS_*", "AZURE_*"). Default: []
exclude = []
# Explicit key/value overrides (always win). Default: {}
set = {}
# Whitelist; if non-empty, keep only matching vars. Default: []
include_only = []
# Experimental: run via user shell profile. Default: false
experimental_use_profile = false
################################################################################
# History & File Opener
################################################################################
[history]
# save-all (default) | none
persistence = "save-all"
# Maximum bytes for history file; oldest entries are trimmed when exceeded. Example: 5242880
# max_bytes = 0
# URI scheme for clickable citations: vscode (default) | vscode-insiders | windsurf | cursor | none
file_opener = "vscode"
################################################################################
# UI, Notifications, and Misc
################################################################################
[tui]
# Desktop notifications from the TUI: boolean or filtered list. Default: true
# Examples: false | ["agent-turn-complete", "approval-requested"]
notifications = false
# Enables welcome/status/spinner animations. Default: true
animations = true
# Suppress internal reasoning events from output. Default: false
hide_agent_reasoning = false
# Show raw reasoning content when available. Default: false
show_raw_agent_reasoning = false
# Disable burst-paste detection in the TUI. Default: false
disable_paste_burst = false
# Track Windows onboarding acknowledgement (Windows only). Default: false
windows_wsl_setup_acknowledged = false
# External notifier program (argv array). When unset: disabled.
# Example: notify = ["notify-send", "Codex"]
# notify = [ ]
# In-product notices (mostly set automatically by Codex).
[notice]
# hide_full_access_warning = true
# hide_rate_limit_model_nudge = true
################################################################################
# Authentication & Login
################################################################################
# Where to persist CLI login credentials: file (default) | keyring | auto
cli_auth_credentials_store = "file"
# Base URL for ChatGPT auth flow (not OpenAI API). Default:
chatgpt_base_url = "https://chatgpt.com/backend-api/"
# Restrict ChatGPT login to a specific workspace id. Default: unset.
# forced_chatgpt_workspace_id = ""
# Force login mechanism when Codex would normally auto-select. Default: unset.
# Allowed values: chatgpt | api
# forced_login_method = "chatgpt"
# Preferred store for MCP OAuth credentials: auto (default) | file | keyring
mcp_oauth_credentials_store = "auto"
################################################################################
# Project Documentation Controls
################################################################################
# Max bytes from AGENTS.md to embed into first-turn instructions. Default: 32768
project_doc_max_bytes = 32768
# Ordered fallbacks when AGENTS.md is missing at a directory level. Default: []
project_doc_fallback_filenames = []
################################################################################
# Tools (legacy toggles kept for compatibility)
################################################################################
[tools]
# Enable web search tool (alias: web_search_request). Default: false
web_search = false
# Enable the view_image tool so the agent can attach local images. Default: true
view_image = true
# (Alias accepted) You can also write:
# web_search_request = false
################################################################################
# Centralized Feature Flags (preferred)
################################################################################
[features]
# Leave this table empty to accept defaults. Set explicit booleans to opt in/out.
unified_exec = false
rmcp_client = false
apply_patch_freeform = false
view_image_tool = true
web_search_request = false
ghost_commit = false
enable_experimental_windows_sandbox = false
skills = false
################################################################################
# Experimental toggles (legacy; prefer [features])
################################################################################
# Include apply_patch via freeform editing path (affects default tool set). Default: false
experimental_use_freeform_apply_patch = false
# Define MCP servers under this table. Leave empty to disable.
[mcp_servers]
# --- Example: STDIO transport ---
# [mcp_servers.docs]
# command = "docs-server" # required
# args = ["--port", "4000"] # optional
# env = { "API_KEY" = "value" } # optional key/value pairs copied as-is
# env_vars = ["ANOTHER_SECRET"] # optional: forward these from the parent env
# cwd = "/path/to/server" # optional working directory override
# startup_timeout_sec = 10.0 # optional; default 10.0 seconds
# # startup_timeout_ms = 10000 # optional alias for startup timeout (milliseconds)
# tool_timeout_sec = 60.0 # optional; default 60.0 seconds
# enabled_tools = ["search", "summarize"] # optional allow-list
# disabled_tools = ["slow-tool"] # optional deny-list (applied after allow-list)
# --- Example: Streamable HTTP transport ---
# [mcp_servers.github]
# url = "https://github-mcp.example.com/mcp" # required
# bearer_token_env_var = "GITHUB_TOKEN" # optional; Authorization: Bearer <token>
# http_headers = { "X-Example" = "value" } # optional static headers
# env_http_headers = { "X-Auth" = "AUTH_ENV" } # optional headers populated from env vars
# startup_timeout_sec = 10.0 # optional
# tool_timeout_sec = 60.0 # optional
# enabled_tools = ["list_issues"] # optional allow-list
################################################################################
# Model Providers (extend/override built-ins)
################################################################################
# Built-ins include:
# - openai (Responses API; requires login or OPENAI_API_KEY via auth flow)
# - oss (Chat Completions API; defaults to http://localhost:11434/v1)
[model_providers]
# --- Example: override OpenAI with explicit base URL or headers ---
# [model_providers.openai]
# name = "OpenAI"
# base_url = "https://api.openai.com/v1" # default if unset
# wire_api = "responses" # "responses" | "chat" (default varies)
# # requires_openai_auth = true # built-in OpenAI defaults to true
# # request_max_retries = 4 # default 4; max 100
# # stream_max_retries = 5 # default 5; max 100
# # stream_idle_timeout_ms = 300000 # default 300_000 (5m)
# # experimental_bearer_token = "sk-example" # optional dev-only direct bearer token
# # http_headers = { "X-Example" = "value" }
# # env_http_headers = { "OpenAI-Organization" = "OPENAI_ORGANIZATION", "OpenAI-Project" = "OPENAI_PROJECT" }
# --- Example: Azure (Chat/Responses depending on endpoint) ---
# [model_providers.azure]
# name = "Azure"
# base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai"
# wire_api = "responses" # or "chat" per endpoint
# query_params = { api-version = "2025-04-01-preview" }
# env_key = "AZURE_OPENAI_API_KEY"
# # env_key_instructions = "Set AZURE_OPENAI_API_KEY in your environment"
# --- Example: Local OSS (e.g., Ollama-compatible) ---
# [model_providers.ollama]
# name = "Ollama"
# base_url = "http://localhost:11434/v1"
# wire_api = "chat"
################################################################################
# Profiles (named presets)
################################################################################
# Active profile name. When unset, no profile is applied.
# profile = "default"
[profiles]
# [profiles.default]
# model = "gpt-5.1-codex-max"
# model_provider = "openai"
# approval_policy = "on-request"
# sandbox_mode = "read-only"
# model_reasoning_effort = "medium"
# model_reasoning_summary = "auto"
# model_verbosity = "medium"
# chatgpt_base_url = "https://chatgpt.com/backend-api/"
# experimental_compact_prompt_file = "compact_prompt.txt"
# include_apply_patch_tool = false
# experimental_use_freeform_apply_patch = false
# tools_web_search = false
# tools_view_image = true
# features = { unified_exec = false }
################################################################################
# Projects (trust levels)
################################################################################
# Mark specific worktrees as trusted. Only "trusted" is recognized.
[projects]
# [projects."/absolute/path/to/project"]
# trust_level = "trusted"
################################################################################
# OpenTelemetry (OTEL) disabled by default
################################################################################
[otel]
# Include user prompt text in logs. Default: false
log_user_prompt = false
# Environment label applied to telemetry. Default: "dev"
environment = "dev"
# Exporter: none (default) | otlp-http | otlp-grpc
exporter = "none"
# Example OTLP/HTTP exporter configuration
# [otel.exporter."otlp-http"]
# endpoint = "https://otel.example.com/v1/logs"
# protocol = "binary" # "binary" | "json"
# [otel.exporter."otlp-http".headers]
# "x-otlp-api-key" = "${OTLP_TOKEN}"
# Example OTLP/gRPC exporter configuration
# [otel.exporter."otlp-grpc"]
# endpoint = "https://otel.example.com:4317",
# headers = { "x-otlp-meta" = "abc123" }
# Example OTLP exporter with mutual TLS
# [otel.exporter."otlp-http"]
# endpoint = "https://otel.example.com/v1/logs"
# protocol = "binary"
# [otel.exporter."otlp-http".headers]
# "x-otlp-api-key" = "${OTLP_TOKEN}"
# [otel.exporter."otlp-http".tls]
# ca-certificate = "certs/otel-ca.pem"
# client-certificate = "/etc/codex/certs/client.pem"
# client-private-key = "/etc/codex/certs/client-key.pem"
```

96
codex_prompt.md Normal file
View File

@@ -0,0 +1,96 @@
## Custom Prompts
Custom prompts turn your repeatable instructions into reusable slash commands, so you can trigger them without retyping or copy/pasting. Each prompt is a Markdown file that Codex expands into the conversation the moment you run it.
### Where prompts live
- Location: store prompts in `$CODEX_HOME/prompts/` (defaults to `~/.codex/prompts/`). Set `CODEX_HOME` if you want to use a different folder.
- File type: Codex only loads `.md` files. Non-Markdown files are ignored. Both regular files and symlinks to Markdown files are supported.
- Naming: The filename (without `.md`) becomes the prompt name. A file called `review.md` registers the prompt `review`.
- Refresh: Prompts are loaded when a session starts. Restart Codex (or start a new session) after adding or editing files.
- Conflicts: Files whose names collide with built-in commands (like `init`) stay hidden in the slash popup, but you can still invoke them with `/prompts:<name>`.
### File format
- Body: The file contents are sent verbatim when you run the prompt (after placeholder expansion).
- Frontmatter (optional): Add YAML-style metadata at the top of the file to improve the slash popup.
```markdown
---
description: Request a concise git diff review
argument-hint: FILE=<path> [FOCUS=<section>]
---
```
- `description` shows under the entry in the popup.
- `argument-hint` (or `argument_hint`) lets you document expected inputs, though the current UI ignores this metadata.
### Placeholders and arguments
- Numeric placeholders: `$1``$9` insert the first nine positional arguments you type after the command. `$ARGUMENTS` inserts all positional arguments joined by a single space. Use `$$` to emit a literal dollar sign (Codex leaves `$$` untouched).
- Named placeholders: Tokens such as `$FILE` or `$TICKET_ID` expand from `KEY=value` pairs you supply. Keys are case-sensitive—use the same uppercase name in the command (for example, `FILE=...`).
- Quoted arguments: Double-quote any value that contains spaces, e.g. `TICKET_TITLE="Fix logging"`.
- Invocation syntax: Run prompts via `/prompts:<name> ...`. When the slash popup is open, typing either `prompts:` or the bare prompt name will surface `/prompts:<name>` suggestions.
- Error handling: If a prompt contains named placeholders, Codex requires them all. You will see a validation message if any are missing or malformed.
### Running a prompt
1. Start a new Codex session (ensures the prompt list is fresh).
2. In the composer, type `/` to open the slash popup.
3. Type `prompts:` (or start typing the prompt name) and select it with ↑/↓.
4. Provide any required arguments, press Enter, and Codex sends the expanded content.
### Examples
### Example 1: Basic named arguments
**File**: `~/.codex/prompts/ticket.md`
```markdown
---
description: Generate a commit message for a ticket
argument-hint: TICKET_ID=<id> TICKET_TITLE=<title>
---
Please write a concise commit message for ticket $TICKET_ID: $TICKET_TITLE
```
**Usage**:
```
/prompts:ticket TICKET_ID=JIRA-1234 TICKET_TITLE="Fix login bug"
```
**Expanded prompt sent to Codex**:
```
Please write a concise commit message for ticket JIRA-1234: Fix login bug
```
**Note**: Both `TICKET_ID` and `TICKET_TITLE` are required. If either is missing, Codex will show a validation error. Values with spaces must be double-quoted.
### Example 2: Mixed positional and named arguments
**File**: `~/.codex/prompts/review.md`
```markdown
---
description: Review code in a specific file with focus area
argument-hint: FILE=<path> [FOCUS=<section>]
---
Review the code in $FILE. Pay special attention to $FOCUS.
```
**Usage**:
```
/prompts:review FILE=src/auth.js FOCUS="error handling"
```
**Expanded prompt**:
```
Review the code in src/auth.js. Pay special attention to error handling.
```

2
refactor_temp.py Normal file
View File

@@ -0,0 +1,2 @@
# Refactor script for GraphAnalyzer
print("Starting refactor...")