Add comprehensive tests for ast-grep and tree-sitter relationship extraction

- Introduced test suite for AstGrepPythonProcessor covering pattern definitions, parsing, and relationship extraction.
- Added comparison tests between tree-sitter and ast-grep for consistency in relationship extraction.
- Implemented tests for ast-grep binding module to verify functionality and availability.
- Ensured tests cover various scenarios including inheritance, function calls, and imports.
This commit is contained in:
catlog22
2026-02-15 21:14:14 +08:00
parent 126a357aa2
commit 48a6a1f2aa
56 changed files with 10622 additions and 374 deletions

View File

@@ -184,11 +184,11 @@ Execution:
├─ Step 1: Initialize result tracking (previousExecutionResults = []) ├─ Step 1: Initialize result tracking (previousExecutionResults = [])
├─ Step 2: Task grouping & batch creation ├─ Step 2: Task grouping & batch creation
│ ├─ Extract explicit depends_on (no file/keyword inference) │ ├─ Extract explicit depends_on (no file/keyword inference)
│ ├─ Group: independent tasks → single parallel batch (maximize utilization) │ ├─ Group: independent tasks → per-executor parallel batches (one CLI per batch)
│ ├─ Group: dependent tasks → sequential phases (respect dependencies) │ ├─ Group: dependent tasks → sequential phases (respect dependencies)
│ └─ Create TodoWrite list for batches │ └─ Create TodoWrite list for batches
├─ Step 3: Launch execution ├─ Step 3: Launch execution
│ ├─ Phase 1: All independent tasks (⚡ single batch, concurrent) │ ├─ Phase 1: Independent tasks (⚡ per-executor batches, multi-CLI concurrent)
│ └─ Phase 2+: Dependent tasks by dependency order │ └─ Phase 2+: Dependent tasks by dependency order
├─ Step 4: Track progress (TodoWrite updates per batch) ├─ Step 4: Track progress (TodoWrite updates per batch)
└─ Step 5: Code review (if codeReviewTool ≠ "Skip") └─ Step 5: Code review (if codeReviewTool ≠ "Skip")
@@ -241,26 +241,58 @@ function extractDependencies(tasks) {
}) })
} }
// Group into batches: maximize parallel execution // Executor Resolution (used by task grouping below)
// 获取任务的 executor优先使用 executorAssignmentsfallback 到全局 executionMethod
function getTaskExecutor(task) {
const assignments = executionContext?.executorAssignments || {}
if (assignments[task.id]) {
return assignments[task.id].executor // 'gemini' | 'codex' | 'agent'
}
// Fallback: 全局 executionMethod 映射
const method = executionContext?.executionMethod || 'Auto'
if (method === 'Agent') return 'agent'
if (method === 'Codex') return 'codex'
// Auto: 根据复杂度
return planObject.complexity === 'Low' ? 'agent' : 'codex'
}
// 按 executor 分组任务(核心分组组件)
function groupTasksByExecutor(tasks) {
const groups = { gemini: [], codex: [], agent: [] }
tasks.forEach(task => {
const executor = getTaskExecutor(task)
groups[executor].push(task)
})
return groups
}
// Group into batches: per-executor parallel batches (one CLI per batch)
function createExecutionCalls(tasks, executionMethod) { function createExecutionCalls(tasks, executionMethod) {
const tasksWithDeps = extractDependencies(tasks) const tasksWithDeps = extractDependencies(tasks)
const processed = new Set() const processed = new Set()
const calls = [] const calls = []
// Phase 1: All independent tasks → single parallel batch (maximize utilization) // Phase 1: Independent tasks → per-executor batches (multi-CLI concurrent)
const independentTasks = tasksWithDeps.filter(t => t.dependencies.length === 0) const independentTasks = tasksWithDeps.filter(t => t.dependencies.length === 0)
if (independentTasks.length > 0) { if (independentTasks.length > 0) {
independentTasks.forEach(t => processed.add(t.taskIndex)) const executorGroups = groupTasksByExecutor(independentTasks)
calls.push({ let parallelIndex = 1
method: executionMethod,
executionType: "parallel", for (const [executor, tasks] of Object.entries(executorGroups)) {
groupId: "P1", if (tasks.length === 0) continue
taskSummary: independentTasks.map(t => t.title).join(' | '), tasks.forEach(t => processed.add(t.taskIndex))
tasks: independentTasks calls.push({
}) method: executionMethod,
executor: executor, // 明确指定 executor
executionType: "parallel",
groupId: `P${parallelIndex++}`,
taskSummary: tasks.map(t => t.title).join(' | '),
tasks: tasks
})
}
} }
// Phase 2: Dependent tasks → sequential batches (respect dependencies) // Phase 2: Dependent tasks → sequential/parallel batches (respect dependencies)
let sequentialIndex = 1 let sequentialIndex = 1
let remaining = tasksWithDeps.filter(t => !processed.has(t.taskIndex)) let remaining = tasksWithDeps.filter(t => !processed.has(t.taskIndex))
@@ -275,15 +307,33 @@ function createExecutionCalls(tasks, executionMethod) {
ready.push(...remaining) ready.push(...remaining)
} }
// Group ready tasks (can run in parallel within this phase) if (ready.length > 1) {
ready.forEach(t => processed.add(t.taskIndex)) // Multiple ready tasks → per-executor batches (parallel within this phase)
calls.push({ const executorGroups = groupTasksByExecutor(ready)
method: executionMethod, for (const [executor, tasks] of Object.entries(executorGroups)) {
executionType: ready.length > 1 ? "parallel" : "sequential", if (tasks.length === 0) continue
groupId: ready.length > 1 ? `P${calls.length + 1}` : `S${sequentialIndex++}`, tasks.forEach(t => processed.add(t.taskIndex))
taskSummary: ready.map(t => t.title).join(ready.length > 1 ? ' | ' : ' → '), calls.push({
tasks: ready method: executionMethod,
}) executor: executor,
executionType: "parallel",
groupId: `P${calls.length + 1}`,
taskSummary: tasks.map(t => t.title).join(' | '),
tasks: tasks
})
}
} else {
// Single ready task → sequential batch
ready.forEach(t => processed.add(t.taskIndex))
calls.push({
method: executionMethod,
executor: getTaskExecutor(ready[0]),
executionType: "sequential",
groupId: `S${sequentialIndex++}`,
taskSummary: ready[0].title,
tasks: ready
})
}
remaining = remaining.filter(t => !processed.has(t.taskIndex)) remaining = remaining.filter(t => !processed.has(t.taskIndex))
} }
@@ -304,33 +354,40 @@ TodoWrite({
### Step 3: Launch Execution ### Step 3: Launch Execution
**Executor Resolution** (任务级 executor 优先于全局设置): **Executor Resolution**: `getTaskExecutor()` and `groupTasksByExecutor()` defined in Step 2 (Task Grouping).
```javascript
// 获取任务的 executor优先使用 executorAssignmentsfallback 到全局 executionMethod
function getTaskExecutor(task) {
const assignments = executionContext?.executorAssignments || {}
if (assignments[task.id]) {
return assignments[task.id].executor // 'gemini' | 'codex' | 'agent'
}
// Fallback: 全局 executionMethod 映射
const method = executionContext?.executionMethod || 'Auto'
if (method === 'Agent') return 'agent'
if (method === 'Codex') return 'codex'
// Auto: 根据复杂度
return planObject.complexity === 'Low' ? 'agent' : 'codex'
}
// 按 executor 分组任务 **Batch Execution Routing** (根据 batch.executor 字段路由):
function groupTasksByExecutor(tasks) { ```javascript
const groups = { gemini: [], codex: [], agent: [] } // executeBatch 根据 batch 自身的 executor 字段决定调用哪个 CLI
tasks.forEach(task => { function executeBatch(batch) {
const executor = getTaskExecutor(task) const executor = batch.executor || getTaskExecutor(batch.tasks[0])
groups[executor].push(task) const sessionId = executionContext?.session?.id || 'standalone'
}) const fixedId = `${sessionId}-${batch.groupId}`
return groups
if (executor === 'agent') {
// Agent execution (synchronous)
return Task({
subagent_type: "code-developer",
run_in_background: false,
description: batch.taskSummary,
prompt: buildExecutionPrompt(batch)
})
} else if (executor === 'codex') {
// Codex CLI (background)
return Bash(`ccw cli -p "${buildExecutionPrompt(batch)}" --tool codex --mode write --id ${fixedId}`, { run_in_background: true })
} else if (executor === 'gemini') {
// Gemini CLI (background)
return Bash(`ccw cli -p "${buildExecutionPrompt(batch)}" --tool gemini --mode write --id ${fixedId}`, { run_in_background: true })
}
} }
``` ```
**并行执行原则**:
- 每个 batch 对应一个独立的 CLI 实例或 Agent 调用
- 并行 = 多个 Bash(run_in_background=true) 或多个 Task() 同时发出
- 绝不将多个独立任务合并到同一个 CLI prompt 中
- Agent 任务不可后台执行run_in_background=false但多个 Agent 任务可通过单条消息中的多个 Task() 调用并发
**Execution Flow**: Parallel batches concurrently → Sequential batches in order **Execution Flow**: Parallel batches concurrently → Sequential batches in order
```javascript ```javascript
const parallel = executionCalls.filter(c => c.executionType === "parallel") const parallel = executionCalls.filter(c => c.executionType === "parallel")
@@ -659,8 +716,8 @@ console.log(`✓ Development index: [${category}] ${entry.title}`)
## Best Practices ## Best Practices
**Input Modes**: In-memory (lite-plan), prompt (standalone), file (JSON/text) **Input Modes**: In-memory (lite-plan), prompt (standalone), file (JSON/text)
**Task Grouping**: Based on explicit depends_on only; independent tasks run in single parallel batch **Task Grouping**: Based on explicit depends_on only; independent tasks split by executor, each batch runs as separate CLI instance
**Execution**: All independent tasks launch concurrently via single Claude message with multiple tool calls **Execution**: Independent task batches launch concurrently via single Claude message with multiple tool calls (one tool call per batch)
## Error Handling ## Error Handling

View File

@@ -191,11 +191,11 @@ Execution:
├─ Step 1: Initialize result tracking (previousExecutionResults = []) ├─ Step 1: Initialize result tracking (previousExecutionResults = [])
├─ Step 2: Task grouping & batch creation ├─ Step 2: Task grouping & batch creation
│ ├─ Extract explicit depends_on (no file/keyword inference) │ ├─ Extract explicit depends_on (no file/keyword inference)
│ ├─ Group: independent tasks → single parallel batch (maximize utilization) │ ├─ Group: independent tasks → per-executor parallel batches (one CLI per batch)
│ ├─ Group: dependent tasks → sequential phases (respect dependencies) │ ├─ Group: dependent tasks → sequential phases (respect dependencies)
│ └─ Create TodoWrite list for batches │ └─ Create TodoWrite list for batches
├─ Step 3: Launch execution ├─ Step 3: Launch execution
│ ├─ Phase 1: All independent tasks (⚡ single batch, concurrent) │ ├─ Phase 1: Independent tasks (⚡ per-executor batches, multi-CLI concurrent)
│ └─ Phase 2+: Dependent tasks by dependency order │ └─ Phase 2+: Dependent tasks by dependency order
├─ Step 4: Track progress (TodoWrite updates per batch) ├─ Step 4: Track progress (TodoWrite updates per batch)
└─ Step 5: Code review (if codeReviewTool ≠ "Skip") └─ Step 5: Code review (if codeReviewTool ≠ "Skip")
@@ -248,26 +248,58 @@ function extractDependencies(tasks) {
}) })
} }
// Group into batches: maximize parallel execution // Executor Resolution (used by task grouping below)
// 获取任务的 executor优先使用 executorAssignmentsfallback 到全局 executionMethod
function getTaskExecutor(task) {
const assignments = executionContext?.executorAssignments || {}
if (assignments[task.id]) {
return assignments[task.id].executor // 'gemini' | 'codex' | 'agent'
}
// Fallback: 全局 executionMethod 映射
const method = executionContext?.executionMethod || 'Auto'
if (method === 'Agent') return 'agent'
if (method === 'Codex') return 'codex'
// Auto: 根据复杂度
return planObject.complexity === 'Low' ? 'agent' : 'codex'
}
// 按 executor 分组任务(核心分组组件)
function groupTasksByExecutor(tasks) {
const groups = { gemini: [], codex: [], agent: [] }
tasks.forEach(task => {
const executor = getTaskExecutor(task)
groups[executor].push(task)
})
return groups
}
// Group into batches: per-executor parallel batches (one CLI per batch)
function createExecutionCalls(tasks, executionMethod) { function createExecutionCalls(tasks, executionMethod) {
const tasksWithDeps = extractDependencies(tasks) const tasksWithDeps = extractDependencies(tasks)
const processed = new Set() const processed = new Set()
const calls = [] const calls = []
// Phase 1: All independent tasks → single parallel batch (maximize utilization) // Phase 1: Independent tasks → per-executor batches (multi-CLI concurrent)
const independentTasks = tasksWithDeps.filter(t => t.dependencies.length === 0) const independentTasks = tasksWithDeps.filter(t => t.dependencies.length === 0)
if (independentTasks.length > 0) { if (independentTasks.length > 0) {
independentTasks.forEach(t => processed.add(t.taskIndex)) const executorGroups = groupTasksByExecutor(independentTasks)
calls.push({ let parallelIndex = 1
method: executionMethod,
executionType: "parallel", for (const [executor, tasks] of Object.entries(executorGroups)) {
groupId: "P1", if (tasks.length === 0) continue
taskSummary: independentTasks.map(t => t.title).join(' | '), tasks.forEach(t => processed.add(t.taskIndex))
tasks: independentTasks calls.push({
}) method: executionMethod,
executor: executor, // 明确指定 executor
executionType: "parallel",
groupId: `P${parallelIndex++}`,
taskSummary: tasks.map(t => t.title).join(' | '),
tasks: tasks
})
}
} }
// Phase 2: Dependent tasks → sequential batches (respect dependencies) // Phase 2: Dependent tasks → sequential/parallel batches (respect dependencies)
let sequentialIndex = 1 let sequentialIndex = 1
let remaining = tasksWithDeps.filter(t => !processed.has(t.taskIndex)) let remaining = tasksWithDeps.filter(t => !processed.has(t.taskIndex))
@@ -282,15 +314,33 @@ function createExecutionCalls(tasks, executionMethod) {
ready.push(...remaining) ready.push(...remaining)
} }
// Group ready tasks (can run in parallel within this phase) if (ready.length > 1) {
ready.forEach(t => processed.add(t.taskIndex)) // Multiple ready tasks → per-executor batches (parallel within this phase)
calls.push({ const executorGroups = groupTasksByExecutor(ready)
method: executionMethod, for (const [executor, tasks] of Object.entries(executorGroups)) {
executionType: ready.length > 1 ? "parallel" : "sequential", if (tasks.length === 0) continue
groupId: ready.length > 1 ? `P${calls.length + 1}` : `S${sequentialIndex++}`, tasks.forEach(t => processed.add(t.taskIndex))
taskSummary: ready.map(t => t.title).join(ready.length > 1 ? ' | ' : ' → '), calls.push({
tasks: ready method: executionMethod,
}) executor: executor,
executionType: "parallel",
groupId: `P${calls.length + 1}`,
taskSummary: tasks.map(t => t.title).join(' | '),
tasks: tasks
})
}
} else {
// Single ready task → sequential batch
ready.forEach(t => processed.add(t.taskIndex))
calls.push({
method: executionMethod,
executor: getTaskExecutor(ready[0]),
executionType: "sequential",
groupId: `S${sequentialIndex++}`,
taskSummary: ready[0].title,
tasks: ready
})
}
remaining = remaining.filter(t => !processed.has(t.taskIndex)) remaining = remaining.filter(t => !processed.has(t.taskIndex))
} }
@@ -311,33 +361,40 @@ TodoWrite({
### Step 3: Launch Execution ### Step 3: Launch Execution
**Executor Resolution** (任务级 executor 优先于全局设置): **Executor Resolution**: `getTaskExecutor()` and `groupTasksByExecutor()` defined in Step 2 (Task Grouping).
```javascript
// 获取任务的 executor优先使用 executorAssignmentsfallback 到全局 executionMethod
function getTaskExecutor(task) {
const assignments = executionContext?.executorAssignments || {}
if (assignments[task.id]) {
return assignments[task.id].executor // 'gemini' | 'codex' | 'agent'
}
// Fallback: 全局 executionMethod 映射
const method = executionContext?.executionMethod || 'Auto'
if (method === 'Agent') return 'agent'
if (method === 'Codex') return 'codex'
// Auto: 根据复杂度
return planObject.complexity === 'Low' ? 'agent' : 'codex'
}
// 按 executor 分组任务 **Batch Execution Routing** (根据 batch.executor 字段路由):
function groupTasksByExecutor(tasks) { ```javascript
const groups = { gemini: [], codex: [], agent: [] } // executeBatch 根据 batch 自身的 executor 字段决定调用哪个 CLI
tasks.forEach(task => { function executeBatch(batch) {
const executor = getTaskExecutor(task) const executor = batch.executor || getTaskExecutor(batch.tasks[0])
groups[executor].push(task) const sessionId = executionContext?.session?.id || 'standalone'
}) const fixedId = `${sessionId}-${batch.groupId}`
return groups
if (executor === 'agent') {
// Agent execution (synchronous)
return Task({
subagent_type: "code-developer",
run_in_background: false,
description: batch.taskSummary,
prompt: buildExecutionPrompt(batch)
})
} else if (executor === 'codex') {
// Codex CLI (background)
return Bash(`ccw cli -p "${buildExecutionPrompt(batch)}" --tool codex --mode write --id ${fixedId}`, { run_in_background: true })
} else if (executor === 'gemini') {
// Gemini CLI (background)
return Bash(`ccw cli -p "${buildExecutionPrompt(batch)}" --tool gemini --mode write --id ${fixedId}`, { run_in_background: true })
}
} }
``` ```
**并行执行原则**:
- 每个 batch 对应一个独立的 CLI 实例或 Agent 调用
- 并行 = 多个 Bash(run_in_background=true) 或多个 Task() 同时发出
- 绝不将多个独立任务合并到同一个 CLI prompt 中
- Agent 任务不可后台执行run_in_background=false但多个 Agent 任务可通过单条消息中的多个 Task() 调用并发
**Execution Flow**: Parallel batches concurrently → Sequential batches in order **Execution Flow**: Parallel batches concurrently → Sequential batches in order
```javascript ```javascript
const parallel = executionCalls.filter(c => c.executionType === "parallel") const parallel = executionCalls.filter(c => c.executionType === "parallel")
@@ -666,8 +723,8 @@ console.log(`✓ Development index: [${category}] ${entry.title}`)
## Best Practices ## Best Practices
**Input Modes**: In-memory (lite-plan), prompt (standalone), file (JSON/text) **Input Modes**: In-memory (lite-plan), prompt (standalone), file (JSON/text)
**Task Grouping**: Based on explicit depends_on only; independent tasks run in single parallel batch **Task Grouping**: Based on explicit depends_on only; independent tasks split by executor, each batch runs as separate CLI instance
**Execution**: All independent tasks launch concurrently via single Claude message with multiple tool calls **Execution**: Independent task batches launch concurrently via single Claude message with multiple tool calls (one tool call per batch)
## Error Handling ## Error Handling

View File

@@ -704,7 +704,7 @@ function WorkflowTaskWidgetComponent({ className }: WorkflowTaskWidgetProps) {
const isLastOdd = currentSession.tasks!.length % 2 === 1 && index === currentSession.tasks!.length - 1; const isLastOdd = currentSession.tasks!.length % 2 === 1 && index === currentSession.tasks!.length - 1;
return ( return (
<div <div
key={task.task_id} key={`${currentSession.session_id}-${task.task_id}`}
className={cn( className={cn(
'flex items-center gap-2 p-2 rounded hover:bg-background/50 transition-colors', 'flex items-center gap-2 p-2 rounded hover:bg-background/50 transition-colors',
isLastOdd && 'col-span-2' isLastOdd && 'col-span-2'

View File

@@ -0,0 +1,396 @@
// ========================================
// Platform Configuration Cards
// ========================================
// Individual configuration cards for each notification platform
import { useState } from 'react';
import { useIntl } from 'react-intl';
import {
MessageCircle,
Send,
Link,
Check,
X,
ChevronDown,
ChevronUp,
TestTube,
Eye,
EyeOff,
} from 'lucide-react';
import { Card } from '@/components/ui/Card';
import { Button } from '@/components/ui/Button';
import { Input } from '@/components/ui/Input';
import { Badge } from '@/components/ui/Badge';
import { cn } from '@/lib/utils';
import type {
RemoteNotificationConfig,
NotificationPlatform,
DiscordConfig,
TelegramConfig,
WebhookConfig,
} from '@/types/remote-notification';
import { PLATFORM_INFO } from '@/types/remote-notification';
interface PlatformConfigCardsProps {
config: RemoteNotificationConfig;
expandedPlatform: NotificationPlatform | null;
testing: NotificationPlatform | null;
onToggleExpand: (platform: NotificationPlatform | null) => void;
onUpdateConfig: (
platform: NotificationPlatform,
updates: Partial<DiscordConfig | TelegramConfig | WebhookConfig>
) => void;
onTest: (
platform: NotificationPlatform,
config: DiscordConfig | TelegramConfig | WebhookConfig
) => void;
onSave: () => void;
saving: boolean;
}
export function PlatformConfigCards({
config,
expandedPlatform,
testing,
onToggleExpand,
onUpdateConfig,
onTest,
onSave,
saving,
}: PlatformConfigCardsProps) {
const { formatMessage } = useIntl();
const platforms: NotificationPlatform[] = ['discord', 'telegram', 'webhook'];
const getPlatformIcon = (platform: NotificationPlatform) => {
switch (platform) {
case 'discord':
return <MessageCircle className="w-4 h-4" />;
case 'telegram':
return <Send className="w-4 h-4" />;
case 'webhook':
return <Link className="w-4 h-4" />;
}
};
const getPlatformConfig = (
platform: NotificationPlatform
): DiscordConfig | TelegramConfig | WebhookConfig => {
switch (platform) {
case 'discord':
return config.platforms.discord || { enabled: false, webhookUrl: '' };
case 'telegram':
return config.platforms.telegram || { enabled: false, botToken: '', chatId: '' };
case 'webhook':
return config.platforms.webhook || { enabled: false, url: '', method: 'POST' };
}
};
const isConfigured = (platform: NotificationPlatform): boolean => {
const platformConfig = getPlatformConfig(platform);
switch (platform) {
case 'discord':
return !!(platformConfig as DiscordConfig).webhookUrl;
case 'telegram':
return !!(platformConfig as TelegramConfig).botToken && !!(platformConfig as TelegramConfig).chatId;
case 'webhook':
return !!(platformConfig as WebhookConfig).url;
}
};
return (
<div className="grid gap-3">
{platforms.map((platform) => {
const info = PLATFORM_INFO[platform];
const platformConfig = getPlatformConfig(platform);
const configured = isConfigured(platform);
const expanded = expandedPlatform === platform;
return (
<Card key={platform} className="overflow-hidden">
{/* Header */}
<div
className="p-4 cursor-pointer hover:bg-muted/50 transition-colors"
onClick={() => onToggleExpand(expanded ? null : platform)}
>
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className={cn(
'p-2 rounded-lg',
platformConfig.enabled && configured
? 'bg-primary/10 text-primary'
: 'bg-muted text-muted-foreground'
)}>
{getPlatformIcon(platform)}
</div>
<div>
<div className="flex items-center gap-2">
<span className="text-sm font-medium">{info.name}</span>
{configured && (
<Badge variant="outline" className="text-xs text-green-600 border-green-500/30">
<Check className="w-3 h-3 mr-1" />
{formatMessage({ id: 'settings.remoteNotifications.configured' })}
</Badge>
)}
</div>
<p className="text-xs text-muted-foreground mt-0.5">{info.description}</p>
</div>
</div>
<div className="flex items-center gap-2">
<Button
variant={platformConfig.enabled ? 'default' : 'outline'}
size="sm"
className="h-7"
onClick={(e) => {
e.stopPropagation();
onUpdateConfig(platform, { enabled: !platformConfig.enabled });
}}
>
{platformConfig.enabled ? (
<Check className="w-3.5 h-3.5" />
) : (
<X className="w-3.5 h-3.5" />
)}
</Button>
{expanded ? (
<ChevronUp className="w-4 h-4 text-muted-foreground" />
) : (
<ChevronDown className="w-4 h-4 text-muted-foreground" />
)}
</div>
</div>
</div>
{/* Expanded Content */}
{expanded && (
<div className="border-t border-border p-4 space-y-4 bg-muted/30">
{platform === 'discord' && (
<DiscordConfigForm
config={platformConfig as DiscordConfig}
onUpdate={(updates) => onUpdateConfig('discord', updates)}
/>
)}
{platform === 'telegram' && (
<TelegramConfigForm
config={platformConfig as TelegramConfig}
onUpdate={(updates) => onUpdateConfig('telegram', updates)}
/>
)}
{platform === 'webhook' && (
<WebhookConfigForm
config={platformConfig as WebhookConfig}
onUpdate={(updates) => onUpdateConfig('webhook', updates)}
/>
)}
{/* Action Buttons */}
<div className="flex items-center gap-2 pt-2">
<Button
variant="outline"
size="sm"
onClick={() => onTest(platform, platformConfig)}
disabled={testing === platform || !configured}
>
<TestTube className={cn('w-3.5 h-3.5 mr-1', testing === platform && 'animate-pulse')} />
{formatMessage({ id: 'settings.remoteNotifications.testConnection' })}
</Button>
<Button
variant="default"
size="sm"
onClick={onSave}
disabled={saving}
>
{formatMessage({ id: 'settings.remoteNotifications.save' })}
</Button>
</div>
</div>
)}
</Card>
);
})}
</div>
);
}
// ========== Discord Config Form ==========
function DiscordConfigForm({
config,
onUpdate,
}: {
config: DiscordConfig;
onUpdate: (updates: Partial<DiscordConfig>) => void;
}) {
const { formatMessage } = useIntl();
const [showUrl, setShowUrl] = useState(false);
return (
<div className="space-y-3">
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.discord.webhookUrl' })}
</label>
<div className="flex gap-2 mt-1">
<Input
type={showUrl ? 'text' : 'password'}
value={config.webhookUrl || ''}
onChange={(e) => onUpdate({ webhookUrl: e.target.value })}
placeholder="https://discord.com/api/webhooks/..."
className="flex-1"
/>
<Button
variant="outline"
size="sm"
className="shrink-0"
onClick={() => setShowUrl(!showUrl)}
>
{showUrl ? <EyeOff className="w-4 h-4" /> : <Eye className="w-4 h-4" />}
</Button>
</div>
<p className="text-xs text-muted-foreground mt-1">
{formatMessage({ id: 'settings.remoteNotifications.discord.webhookUrlHint' })}
</p>
</div>
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.discord.username' })}
</label>
<Input
value={config.username || ''}
onChange={(e) => onUpdate({ username: e.target.value })}
placeholder="CCW Notification"
className="mt-1"
/>
</div>
</div>
);
}
// ========== Telegram Config Form ==========
function TelegramConfigForm({
config,
onUpdate,
}: {
config: TelegramConfig;
onUpdate: (updates: Partial<TelegramConfig>) => void;
}) {
const { formatMessage } = useIntl();
const [showToken, setShowToken] = useState(false);
return (
<div className="space-y-3">
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.telegram.botToken' })}
</label>
<div className="flex gap-2 mt-1">
<Input
type={showToken ? 'text' : 'password'}
value={config.botToken || ''}
onChange={(e) => onUpdate({ botToken: e.target.value })}
placeholder="1234567890:ABCdefGHIjklMNOpqrsTUVwxyz"
className="flex-1"
/>
<Button
variant="outline"
size="sm"
className="shrink-0"
onClick={() => setShowToken(!showToken)}
>
{showToken ? <EyeOff className="w-4 h-4" /> : <Eye className="w-4 h-4" />}
</Button>
</div>
<p className="text-xs text-muted-foreground mt-1">
{formatMessage({ id: 'settings.remoteNotifications.telegram.botTokenHint' })}
</p>
</div>
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.telegram.chatId' })}
</label>
<Input
value={config.chatId || ''}
onChange={(e) => onUpdate({ chatId: e.target.value })}
placeholder="-1001234567890"
className="mt-1"
/>
<p className="text-xs text-muted-foreground mt-1">
{formatMessage({ id: 'settings.remoteNotifications.telegram.chatIdHint' })}
</p>
</div>
</div>
);
}
// ========== Webhook Config Form ==========
function WebhookConfigForm({
config,
onUpdate,
}: {
config: WebhookConfig;
onUpdate: (updates: Partial<WebhookConfig>) => void;
}) {
const { formatMessage } = useIntl();
return (
<div className="space-y-3">
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.webhook.url' })}
</label>
<Input
value={config.url || ''}
onChange={(e) => onUpdate({ url: e.target.value })}
placeholder="https://your-server.com/webhook"
className="mt-1"
/>
</div>
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.webhook.method' })}
</label>
<div className="flex gap-2 mt-1">
<Button
variant={config.method === 'POST' ? 'default' : 'outline'}
size="sm"
onClick={() => onUpdate({ method: 'POST' })}
>
POST
</Button>
<Button
variant={config.method === 'PUT' ? 'default' : 'outline'}
size="sm"
onClick={() => onUpdate({ method: 'PUT' })}
>
PUT
</Button>
</div>
</div>
<div>
<label className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.webhook.headers' })}
</label>
<Input
value={config.headers ? JSON.stringify(config.headers) : ''}
onChange={(e) => {
try {
const headers = e.target.value ? JSON.parse(e.target.value) : undefined;
onUpdate({ headers });
} catch {
// Invalid JSON, ignore
}
}}
placeholder='{"Authorization": "Bearer token"}'
className="mt-1 font-mono text-xs"
/>
<p className="text-xs text-muted-foreground mt-1">
{formatMessage({ id: 'settings.remoteNotifications.webhook.headersHint' })}
</p>
</div>
</div>
);
}
export default PlatformConfigCards;

View File

@@ -0,0 +1,347 @@
// ========================================
// Remote Notification Settings Section
// ========================================
// Configuration UI for remote notification platforms
import { useState, useEffect, useCallback } from 'react';
import { useIntl } from 'react-intl';
import {
Bell,
BellOff,
RefreshCw,
Check,
X,
ChevronDown,
ChevronUp,
TestTube,
Save,
AlertTriangle,
} from 'lucide-react';
import { Card } from '@/components/ui/Card';
import { Button } from '@/components/ui/Button';
import { Input } from '@/components/ui/Input';
import { Badge } from '@/components/ui/Badge';
import { cn } from '@/lib/utils';
import { toast } from 'sonner';
import type {
RemoteNotificationConfig,
NotificationPlatform,
EventConfig,
DiscordConfig,
TelegramConfig,
WebhookConfig,
} from '@/types/remote-notification';
import { PLATFORM_INFO, EVENT_INFO, getDefaultConfig } from '@/types/remote-notification';
import { PlatformConfigCards } from './PlatformConfigCards';
interface RemoteNotificationSectionProps {
className?: string;
}
export function RemoteNotificationSection({ className }: RemoteNotificationSectionProps) {
const { formatMessage } = useIntl();
const [config, setConfig] = useState<RemoteNotificationConfig | null>(null);
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [testing, setTesting] = useState<NotificationPlatform | null>(null);
const [expandedPlatform, setExpandedPlatform] = useState<NotificationPlatform | null>(null);
// Load configuration
const loadConfig = useCallback(async () => {
setLoading(true);
try {
const response = await fetch('/api/notifications/remote/config');
if (response.ok) {
const data = await response.json();
setConfig(data);
} else {
// Use default config if not found
setConfig(getDefaultConfig());
}
} catch (error) {
console.error('Failed to load remote notification config:', error);
setConfig(getDefaultConfig());
} finally {
setLoading(false);
}
}, []);
useEffect(() => {
loadConfig();
}, [loadConfig]);
// Save configuration
const saveConfig = useCallback(async (newConfig: RemoteNotificationConfig) => {
setSaving(true);
try {
const response = await fetch('/api/notifications/remote/config', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(newConfig),
});
if (response.ok) {
const data = await response.json();
setConfig(data.config);
toast.success(formatMessage({ id: 'settings.remoteNotifications.saved' }));
} else {
throw new Error(`HTTP ${response.status}`);
}
} catch (error) {
toast.error(formatMessage({ id: 'settings.remoteNotifications.saveError' }));
} finally {
setSaving(false);
}
}, [formatMessage]);
// Test platform
const testPlatform = useCallback(async (
platform: NotificationPlatform,
platformConfig: DiscordConfig | TelegramConfig | WebhookConfig
) => {
setTesting(platform);
try {
const response = await fetch('/api/notifications/remote/test', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ platform, config: platformConfig }),
});
const result = await response.json();
if (result.success) {
toast.success(
formatMessage({ id: 'settings.remoteNotifications.testSuccess' }),
{ description: `${result.responseTime}ms` }
);
} else {
toast.error(
formatMessage({ id: 'settings.remoteNotifications.testFailed' }),
{ description: result.error }
);
}
} catch (error) {
toast.error(formatMessage({ id: 'settings.remoteNotifications.testError' }));
} finally {
setTesting(null);
}
}, [formatMessage]);
// Toggle master switch
const toggleEnabled = () => {
if (!config) return;
saveConfig({ ...config, enabled: !config.enabled });
};
// Update platform config
const updatePlatformConfig = (
platform: NotificationPlatform,
updates: Partial<DiscordConfig | TelegramConfig | WebhookConfig>
) => {
if (!config) return;
const newConfig = {
...config,
platforms: {
...config.platforms,
[platform]: {
...config.platforms[platform as keyof typeof config.platforms],
...updates,
},
},
};
setConfig(newConfig);
};
// Update event config
const updateEventConfig = (eventIndex: number, updates: Partial<EventConfig>) => {
if (!config) return;
const newEvents = [...config.events];
newEvents[eventIndex] = { ...newEvents[eventIndex], ...updates };
setConfig({ ...config, events: newEvents });
};
// Reset to defaults
const resetConfig = async () => {
if (!confirm(formatMessage({ id: 'settings.remoteNotifications.resetConfirm' }))) {
return;
}
try {
const response = await fetch('/api/notifications/remote/reset', {
method: 'POST',
});
if (response.ok) {
const data = await response.json();
setConfig(data.config);
toast.success(formatMessage({ id: 'settings.remoteNotifications.resetSuccess' }));
}
} catch {
toast.error(formatMessage({ id: 'settings.remoteNotifications.resetError' }));
}
};
if (loading) {
return (
<Card className={cn('p-6', className)}>
<div className="flex items-center justify-center py-8">
<RefreshCw className="w-5 h-5 animate-spin text-muted-foreground" />
</div>
</Card>
);
}
if (!config) {
return null;
}
return (
<Card className={cn('p-6', className)}>
{/* Header */}
<div className="flex items-center justify-between mb-6">
<h2 className="text-lg font-semibold text-foreground flex items-center gap-2">
{config.enabled ? (
<Bell className="w-5 h-5 text-primary" />
) : (
<BellOff className="w-5 h-5 text-muted-foreground" />
)}
{formatMessage({ id: 'settings.remoteNotifications.title' })}
</h2>
<div className="flex items-center gap-2">
<Button
variant="outline"
size="sm"
onClick={() => loadConfig()}
disabled={loading}
>
<RefreshCw className={cn('w-3.5 h-3.5', loading && 'animate-spin')} />
</Button>
<Button
variant={config.enabled ? 'default' : 'outline'}
size="sm"
onClick={toggleEnabled}
>
{config.enabled ? (
<>
<Check className="w-4 h-4 mr-1" />
{formatMessage({ id: 'settings.remoteNotifications.enabled' })}
</>
) : (
<>
<X className="w-4 h-4 mr-1" />
{formatMessage({ id: 'settings.remoteNotifications.disabled' })}
</>
)}
</Button>
</div>
</div>
{/* Description */}
<p className="text-sm text-muted-foreground mb-6">
{formatMessage({ id: 'settings.remoteNotifications.description' })}
</p>
{config.enabled && (
<>
{/* Platform Configuration */}
<div className="space-y-4 mb-6">
<h3 className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.platforms' })}
</h3>
<PlatformConfigCards
config={config}
expandedPlatform={expandedPlatform}
testing={testing}
onToggleExpand={setExpandedPlatform}
onUpdateConfig={updatePlatformConfig}
onTest={testPlatform}
onSave={() => saveConfig(config)}
saving={saving}
/>
</div>
{/* Event Configuration */}
<div className="space-y-4">
<h3 className="text-sm font-medium text-foreground">
{formatMessage({ id: 'settings.remoteNotifications.events' })}
</h3>
<div className="grid gap-3">
{config.events.map((eventConfig, index) => {
const info = EVENT_INFO[eventConfig.event];
return (
<div
key={eventConfig.event}
className="flex items-center justify-between p-3 rounded-lg border border-border bg-muted/30"
>
<div className="flex items-center gap-3">
<div className={cn(
'p-2 rounded-lg',
eventConfig.enabled ? 'bg-primary/10 text-primary' : 'bg-muted text-muted-foreground'
)}>
<span className="text-sm">{info.icon}</span>
</div>
<div>
<p className="text-sm font-medium">{info.name}</p>
<p className="text-xs text-muted-foreground">{info.description}</p>
</div>
</div>
<div className="flex items-center gap-2">
{/* Platform badges */}
<div className="flex gap-1">
{eventConfig.platforms.map((platform) => (
<Badge key={platform} variant="secondary" className="text-xs">
{PLATFORM_INFO[platform].name}
</Badge>
))}
{eventConfig.platforms.length === 0 && (
<Badge variant="outline" className="text-xs text-muted-foreground">
{formatMessage({ id: 'settings.remoteNotifications.noPlatforms' })}
</Badge>
)}
</div>
{/* Toggle */}
<Button
variant={eventConfig.enabled ? 'default' : 'outline'}
size="sm"
className="h-7"
onClick={() => updateEventConfig(index, { enabled: !eventConfig.enabled })}
>
{eventConfig.enabled ? (
<Check className="w-3.5 h-3.5" />
) : (
<X className="w-3.5 h-3.5" />
)}
</Button>
</div>
</div>
);
})}
</div>
</div>
{/* Action Buttons */}
<div className="flex items-center justify-between mt-6 pt-4 border-t border-border">
<Button
variant="outline"
size="sm"
onClick={resetConfig}
>
{formatMessage({ id: 'settings.remoteNotifications.reset' })}
</Button>
<Button
variant="default"
size="sm"
onClick={() => saveConfig(config)}
disabled={saving}
>
<Save className="w-4 h-4 mr-1" />
{saving
? formatMessage({ id: 'settings.remoteNotifications.saving' })
: formatMessage({ id: 'settings.remoteNotifications.save' })}
</Button>
</div>
</>
)}
</Card>
);
}
export default RemoteNotificationSection;

View File

@@ -0,0 +1,270 @@
// ========================================
// CliConfigModal Component
// ========================================
// Config modal for creating a new CLI session in Terminal Dashboard.
import * as React from 'react';
import { useIntl } from 'react-intl';
import { FolderOpen } from 'lucide-react';
import { cn } from '@/lib/utils';
import { Button } from '@/components/ui/Button';
import { Input } from '@/components/ui/Input';
import { Label } from '@/components/ui/Label';
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter,
} from '@/components/ui/Dialog';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/Select';
import { RadioGroup, RadioGroupItem } from '@/components/ui/RadioGroup';
export type CliTool = 'claude' | 'gemini' | 'qwen' | 'codex' | 'opencode';
export type LaunchMode = 'default' | 'yolo';
export type ShellKind = 'bash' | 'pwsh';
export interface CliSessionConfig {
tool: CliTool;
model?: string;
launchMode: LaunchMode;
preferredShell: ShellKind;
workingDir: string;
}
export interface CliConfigModalProps {
isOpen: boolean;
onClose: () => void;
defaultWorkingDir?: string | null;
onCreateSession: (config: CliSessionConfig) => Promise<void>;
}
const CLI_TOOLS: CliTool[] = ['claude', 'gemini', 'qwen', 'codex', 'opencode'];
const MODEL_OPTIONS: Record<CliTool, string[]> = {
claude: ['sonnet', 'haiku'],
gemini: ['gemini-2.5-pro', 'gemini-2.5-flash'],
qwen: ['coder-model'],
codex: ['gpt-5.2'],
opencode: ['opencode/glm-4.7-free'],
};
const AUTO_MODEL_VALUE = '__auto__';
export function CliConfigModal({
isOpen,
onClose,
defaultWorkingDir,
onCreateSession,
}: CliConfigModalProps) {
const { formatMessage } = useIntl();
const [tool, setTool] = React.useState<CliTool>('gemini');
const [model, setModel] = React.useState<string | undefined>(MODEL_OPTIONS.gemini[0]);
const [launchMode, setLaunchMode] = React.useState<LaunchMode>('yolo');
const [preferredShell, setPreferredShell] = React.useState<ShellKind>('bash');
const [workingDir, setWorkingDir] = React.useState<string>(defaultWorkingDir ?? '');
const [isSubmitting, setIsSubmitting] = React.useState(false);
const [error, setError] = React.useState<string | null>(null);
const modelOptions = React.useMemo(() => MODEL_OPTIONS[tool] ?? [], [tool]);
React.useEffect(() => {
if (!isOpen) return;
// Reset to a safe default each time the modal is opened.
const nextWorkingDir = defaultWorkingDir ?? '';
setWorkingDir(nextWorkingDir);
setError(null);
}, [isOpen, defaultWorkingDir]);
const handleToolChange = (nextTool: string) => {
const next = nextTool as CliTool;
setTool(next);
const nextModels = MODEL_OPTIONS[next] ?? [];
if (!model || !nextModels.includes(model)) {
setModel(nextModels[0]);
}
};
const handleBrowse = () => {
// Reserved for future file-picker integration
console.log('[CliConfigModal] browse working directory - not implemented');
};
const handleCreate = async () => {
const dir = workingDir.trim();
if (!dir) {
setError(formatMessage({ id: 'terminalDashboard.cliConfig.errors.workingDirRequired' }));
return;
}
setIsSubmitting(true);
setError(null);
try {
await onCreateSession({
tool,
model,
launchMode,
preferredShell,
workingDir: dir,
});
onClose();
} catch (err) {
console.error('[CliConfigModal] create session failed:', err);
setError(formatMessage({ id: 'terminalDashboard.cliConfig.errors.createFailed' }));
} finally {
setIsSubmitting(false);
}
};
return (
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
<DialogContent className="sm:max-w-[720px] max-h-[90vh] overflow-y-auto">
<DialogHeader>
<DialogTitle>{formatMessage({ id: 'terminalDashboard.cliConfig.title' })}</DialogTitle>
<DialogDescription>
{formatMessage({ id: 'terminalDashboard.cliConfig.description' })}
</DialogDescription>
</DialogHeader>
<div className="space-y-4 py-4">
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
{/* Tool */}
<div className="space-y-2">
<Label htmlFor="cli-config-tool">
{formatMessage({ id: 'terminalDashboard.cliConfig.tool' })}
</Label>
<Select value={tool} onValueChange={handleToolChange} disabled={isSubmitting}>
<SelectTrigger id="cli-config-tool">
<SelectValue />
</SelectTrigger>
<SelectContent>
{CLI_TOOLS.map((t) => (
<SelectItem key={t} value={t}>
{t}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{/* Model */}
<div className="space-y-2">
<Label htmlFor="cli-config-model">
{formatMessage({ id: 'terminalDashboard.cliConfig.model' })}
</Label>
<Select
value={model ?? AUTO_MODEL_VALUE}
onValueChange={(v) => setModel(v === AUTO_MODEL_VALUE ? undefined : v)}
disabled={isSubmitting}
>
<SelectTrigger id="cli-config-model">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value={AUTO_MODEL_VALUE}>
{formatMessage({ id: 'terminalDashboard.cliConfig.modelAuto' })}
</SelectItem>
{modelOptions.map((m) => (
<SelectItem key={m} value={m}>
{m}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
{/* Mode */}
<div className="space-y-2">
<Label>{formatMessage({ id: 'terminalDashboard.cliConfig.mode' })}</Label>
<RadioGroup
value={launchMode}
onValueChange={(v) => setLaunchMode(v as LaunchMode)}
className="flex items-center gap-4"
>
<label className="flex items-center gap-2 text-sm cursor-pointer">
<RadioGroupItem value="default" />
{formatMessage({ id: 'terminalDashboard.cliConfig.modeDefault' })}
</label>
<label className="flex items-center gap-2 text-sm cursor-pointer">
<RadioGroupItem value="yolo" />
{formatMessage({ id: 'terminalDashboard.cliConfig.modeYolo' })}
</label>
</RadioGroup>
</div>
{/* Shell */}
<div className="space-y-2">
<Label htmlFor="cli-config-shell">
{formatMessage({ id: 'terminalDashboard.cliConfig.shell' })}
</Label>
<Select
value={preferredShell}
onValueChange={(v) => setPreferredShell(v as ShellKind)}
disabled={isSubmitting}
>
<SelectTrigger id="cli-config-shell">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="bash">bash</SelectItem>
<SelectItem value="pwsh">pwsh</SelectItem>
</SelectContent>
</Select>
</div>
{/* Working Directory */}
<div className="space-y-2">
<Label htmlFor="cli-config-workingDir">
{formatMessage({ id: 'terminalDashboard.cliConfig.workingDir' })}
</Label>
<div className="flex gap-2">
<Input
id="cli-config-workingDir"
value={workingDir}
onChange={(e) => {
setWorkingDir(e.target.value);
if (error) setError(null);
}}
placeholder={formatMessage({ id: 'terminalDashboard.cliConfig.workingDirPlaceholder' })}
disabled={isSubmitting}
className={cn(error && 'border-destructive')}
/>
<Button
type="button"
variant="outline"
onClick={handleBrowse}
disabled={isSubmitting}
title={formatMessage({ id: 'terminalDashboard.cliConfig.browse' })}
>
<FolderOpen className="w-4 h-4" />
</Button>
</div>
{error && <p className="text-xs text-destructive">{error}</p>}
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={onClose} disabled={isSubmitting}>
{formatMessage({ id: 'common.actions.cancel' })}
</Button>
<Button onClick={handleCreate} disabled={isSubmitting}>
{formatMessage({ id: 'common.actions.create' })}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
);
}
export default CliConfigModal;

View File

@@ -27,6 +27,11 @@ import {
DropdownMenu, DropdownMenu,
DropdownMenuContent, DropdownMenuContent,
DropdownMenuItem, DropdownMenuItem,
DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger, DropdownMenuTrigger,
DropdownMenuSeparator, DropdownMenuSeparator,
} from '@/components/ui/Dropdown'; } from '@/components/ui/Dropdown';
@@ -37,6 +42,8 @@ import {
import { useIssues, useIssueQueue } from '@/hooks/useIssues'; import { useIssues, useIssueQueue } from '@/hooks/useIssues';
import { useTerminalGridStore, selectTerminalGridFocusedPaneId } from '@/stores/terminalGridStore'; import { useTerminalGridStore, selectTerminalGridFocusedPaneId } from '@/stores/terminalGridStore';
import { useWorkflowStore, selectProjectPath } from '@/stores/workflowStore'; import { useWorkflowStore, selectProjectPath } from '@/stores/workflowStore';
import { sendCliSessionText } from '@/lib/api';
import { CliConfigModal, type CliSessionConfig } from './CliConfigModal';
// ========== Types ========== // ========== Types ==========
@@ -56,6 +63,19 @@ const LAYOUT_PRESETS = [
{ id: 'grid-2x2' as const, icon: LayoutGrid, labelId: 'terminalDashboard.toolbar.layoutGrid' }, { id: 'grid-2x2' as const, icon: LayoutGrid, labelId: 'terminalDashboard.toolbar.layoutGrid' },
]; ];
type LaunchMode = 'default' | 'yolo';
const CLI_TOOLS = ['claude', 'gemini', 'qwen', 'codex', 'opencode'] as const;
type CliTool = (typeof CLI_TOOLS)[number];
const LAUNCH_COMMANDS: Record<CliTool, Record<LaunchMode, string>> = {
claude: { default: 'claude', yolo: 'claude --permission-mode bypassPermissions' },
gemini: { default: 'gemini', yolo: 'gemini --approval-mode yolo' },
qwen: { default: 'qwen', yolo: 'qwen --approval-mode yolo' },
codex: { default: 'codex', yolo: 'codex --full-auto' },
opencode: { default: 'opencode', yolo: 'opencode' },
};
// ========== Component ========== // ========== Component ==========
export function DashboardToolbar({ activePanel, onTogglePanel }: DashboardToolbarProps) { export function DashboardToolbar({ activePanel, onTogglePanel }: DashboardToolbarProps) {
@@ -94,117 +114,216 @@ export function DashboardToolbar({ activePanel, onTogglePanel }: DashboardToolba
const focusedPaneId = useTerminalGridStore(selectTerminalGridFocusedPaneId); const focusedPaneId = useTerminalGridStore(selectTerminalGridFocusedPaneId);
const createSessionAndAssign = useTerminalGridStore((s) => s.createSessionAndAssign); const createSessionAndAssign = useTerminalGridStore((s) => s.createSessionAndAssign);
const [isCreating, setIsCreating] = useState(false); const [isCreating, setIsCreating] = useState(false);
const [selectedTool, setSelectedTool] = useState<CliTool>('gemini');
const [launchMode, setLaunchMode] = useState<LaunchMode>('yolo');
const [isConfigOpen, setIsConfigOpen] = useState(false);
const handleQuickCreate = useCallback(async () => { const handleQuickCreate = useCallback(async () => {
if (!focusedPaneId || !projectPath) return; if (!focusedPaneId || !projectPath) return;
setIsCreating(true); setIsCreating(true);
try { try {
await createSessionAndAssign(focusedPaneId, { const created = await createSessionAndAssign(focusedPaneId, {
workingDir: projectPath, workingDir: projectPath,
preferredShell: 'bash', preferredShell: 'bash',
tool: selectedTool,
}, projectPath); }, projectPath);
if (created?.session?.sessionKey) {
const command = LAUNCH_COMMANDS[selectedTool]?.[launchMode] ?? selectedTool;
setTimeout(() => {
sendCliSessionText(
created.session.sessionKey,
{ text: command, appendNewline: true },
projectPath
).catch((err) => console.error('[DashboardToolbar] auto-launch failed:', err));
}, 300);
}
} finally {
setIsCreating(false);
}
}, [focusedPaneId, projectPath, createSessionAndAssign, selectedTool, launchMode]);
const handleConfigure = useCallback(() => {
setIsConfigOpen(true);
}, []);
const handleCreateConfiguredSession = useCallback(async (config: CliSessionConfig) => {
if (!focusedPaneId || !projectPath) throw new Error('No focused pane or project path');
setIsCreating(true);
try {
const created = await createSessionAndAssign(
focusedPaneId,
{
workingDir: config.workingDir || projectPath,
preferredShell: config.preferredShell,
tool: config.tool,
model: config.model,
},
projectPath
);
if (!created?.session?.sessionKey) throw new Error('createSessionAndAssign failed');
const tool = config.tool as CliTool;
const mode = config.launchMode as LaunchMode;
const command = LAUNCH_COMMANDS[tool]?.[mode] ?? tool;
setTimeout(() => {
sendCliSessionText(
created.session.sessionKey,
{ text: command, appendNewline: true },
projectPath
).catch((err) => console.error('[DashboardToolbar] auto-launch failed:', err));
}, 300);
} finally { } finally {
setIsCreating(false); setIsCreating(false);
} }
}, [focusedPaneId, projectPath, createSessionAndAssign]); }, [focusedPaneId, projectPath, createSessionAndAssign]);
const handleConfigure = useCallback(() => {
// TODO: Open configuration modal (future implementation)
console.log('Configure CLI session - modal to be implemented');
}, []);
return ( return (
<div className="flex items-center gap-1 px-2 h-[40px] border-b border-border bg-muted/30 shrink-0"> <>
{/* Launch CLI dropdown */} <div className="flex items-center gap-1 px-2 h-[40px] border-b border-border bg-muted/30 shrink-0">
<DropdownMenu> {/* Launch CLI dropdown */}
<DropdownMenuTrigger asChild> <DropdownMenu>
<DropdownMenuTrigger asChild>
<button
className={cn(
'flex items-center gap-1.5 px-2.5 py-1.5 rounded-md text-xs transition-colors',
'text-muted-foreground hover:text-foreground hover:bg-muted',
isCreating && 'opacity-50 cursor-wait'
)}
disabled={isCreating || !projectPath}
>
{isCreating ? (
<Loader2 className="w-3.5 h-3.5 animate-spin" />
) : (
<Terminal className="w-3.5 h-3.5" />
)}
<span>{formatMessage({ id: 'terminalDashboard.toolbar.launchCli' })}</span>
<ChevronDown className="w-3 h-3" />
</button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" sideOffset={4}>
<DropdownMenuSub>
<DropdownMenuSubTrigger className="gap-2">
<span>{formatMessage({ id: 'terminalDashboard.toolbar.tool' })}</span>
<span className="text-xs text-muted-foreground">({selectedTool})</span>
</DropdownMenuSubTrigger>
<DropdownMenuSubContent>
<DropdownMenuRadioGroup
value={selectedTool}
onValueChange={(v) => setSelectedTool(v as CliTool)}
>
{CLI_TOOLS.map((tool) => (
<DropdownMenuRadioItem key={tool} value={tool}>
{tool}
</DropdownMenuRadioItem>
))}
</DropdownMenuRadioGroup>
</DropdownMenuSubContent>
</DropdownMenuSub>
<DropdownMenuSub>
<DropdownMenuSubTrigger className="gap-2">
<span>{formatMessage({ id: 'terminalDashboard.toolbar.mode' })}</span>
<span className="text-xs text-muted-foreground">
{launchMode === 'default'
? formatMessage({ id: 'terminalDashboard.toolbar.modeDefault' })
: formatMessage({ id: 'terminalDashboard.toolbar.modeYolo' })}
</span>
</DropdownMenuSubTrigger>
<DropdownMenuSubContent>
<DropdownMenuRadioGroup
value={launchMode}
onValueChange={(v) => setLaunchMode(v as LaunchMode)}
>
<DropdownMenuRadioItem value="default">
{formatMessage({ id: 'terminalDashboard.toolbar.modeDefault' })}
</DropdownMenuRadioItem>
<DropdownMenuRadioItem value="yolo">
{formatMessage({ id: 'terminalDashboard.toolbar.modeYolo' })}
</DropdownMenuRadioItem>
</DropdownMenuRadioGroup>
</DropdownMenuSubContent>
</DropdownMenuSub>
<DropdownMenuSeparator />
<DropdownMenuItem
onClick={handleQuickCreate}
disabled={isCreating || !projectPath || !focusedPaneId}
className="gap-2"
>
<Zap className="w-4 h-4" />
<span>{formatMessage({ id: 'terminalDashboard.toolbar.quickCreate' })}</span>
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem
onClick={handleConfigure}
disabled={isCreating || !projectPath || !focusedPaneId}
className="gap-2"
>
<Settings className="w-4 h-4" />
<span>{formatMessage({ id: 'terminalDashboard.toolbar.configure' })}</span>
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
{/* Separator */}
<div className="w-px h-5 bg-border mx-1" />
{/* Panel toggle buttons */}
<ToolbarButton
icon={AlertCircle}
label={formatMessage({ id: 'terminalDashboard.toolbar.issues' })}
isActive={activePanel === 'issues'}
onClick={() => onTogglePanel('issues')}
badge={openCount > 0 ? openCount : undefined}
/>
<ToolbarButton
icon={ListChecks}
label={formatMessage({ id: 'terminalDashboard.toolbar.queue' })}
isActive={activePanel === 'queue'}
onClick={() => onTogglePanel('queue')}
badge={queueCount > 0 ? queueCount : undefined}
/>
<ToolbarButton
icon={Info}
label={formatMessage({ id: 'terminalDashboard.toolbar.inspector' })}
isActive={activePanel === 'inspector'}
onClick={() => onTogglePanel('inspector')}
dot={hasChain}
/>
{/* Separator */}
<div className="w-px h-5 bg-border mx-1" />
{/* Layout presets */}
{LAYOUT_PRESETS.map((preset) => (
<button <button
key={preset.id}
onClick={() => handlePreset(preset.id)}
className={cn( className={cn(
'flex items-center gap-1.5 px-2.5 py-1.5 rounded-md text-xs transition-colors', 'p-1.5 rounded transition-colors',
'text-muted-foreground hover:text-foreground hover:bg-muted', 'text-muted-foreground hover:text-foreground hover:bg-muted'
isCreating && 'opacity-50 cursor-wait'
)} )}
disabled={isCreating || !projectPath} title={formatMessage({ id: preset.labelId })}
> >
{isCreating ? ( <preset.icon className="w-3.5 h-3.5" />
<Loader2 className="w-3.5 h-3.5 animate-spin" />
) : (
<Terminal className="w-3.5 h-3.5" />
)}
<span>{formatMessage({ id: 'terminalDashboard.toolbar.launchCli' })}</span>
<ChevronDown className="w-3 h-3" />
</button> </button>
</DropdownMenuTrigger> ))}
<DropdownMenuContent align="start" sideOffset={4}>
<DropdownMenuItem
onClick={handleQuickCreate}
disabled={isCreating || !projectPath || !focusedPaneId}
className="gap-2"
>
<Zap className="w-4 h-4" />
<span>{formatMessage({ id: 'terminalDashboard.toolbar.quickCreate' })}</span>
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem
onClick={handleConfigure}
disabled={isCreating}
className="gap-2"
>
<Settings className="w-4 h-4" />
<span>{formatMessage({ id: 'terminalDashboard.toolbar.configure' })}</span>
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
{/* Separator */} {/* Right-aligned title */}
<div className="w-px h-5 bg-border mx-1" /> <span className="ml-auto text-xs text-muted-foreground font-medium">
{formatMessage({ id: 'terminalDashboard.page.title' })}
</span>
</div>
{/* Panel toggle buttons */} <CliConfigModal
<ToolbarButton isOpen={isConfigOpen}
icon={AlertCircle} onClose={() => setIsConfigOpen(false)}
label={formatMessage({ id: 'terminalDashboard.toolbar.issues' })} defaultWorkingDir={projectPath}
isActive={activePanel === 'issues'} onCreateSession={handleCreateConfiguredSession}
onClick={() => onTogglePanel('issues')}
badge={openCount > 0 ? openCount : undefined}
/> />
<ToolbarButton </>
icon={ListChecks}
label={formatMessage({ id: 'terminalDashboard.toolbar.queue' })}
isActive={activePanel === 'queue'}
onClick={() => onTogglePanel('queue')}
badge={queueCount > 0 ? queueCount : undefined}
/>
<ToolbarButton
icon={Info}
label={formatMessage({ id: 'terminalDashboard.toolbar.inspector' })}
isActive={activePanel === 'inspector'}
onClick={() => onTogglePanel('inspector')}
dot={hasChain}
/>
{/* Separator */}
<div className="w-px h-5 bg-border mx-1" />
{/* Layout presets */}
{LAYOUT_PRESETS.map((preset) => (
<button
key={preset.id}
onClick={() => handlePreset(preset.id)}
className={cn(
'p-1.5 rounded transition-colors',
'text-muted-foreground hover:text-foreground hover:bg-muted'
)}
title={formatMessage({ id: preset.labelId })}
>
<preset.icon className="w-3.5 h-3.5" />
</button>
))}
{/* Right-aligned title */}
<span className="ml-auto text-xs text-muted-foreground font-medium">
{formatMessage({ id: 'terminalDashboard.page.title' })}
</span>
</div>
); );
} }

View File

@@ -145,6 +145,22 @@ export type {
UseDeleteMemoryReturn, UseDeleteMemoryReturn,
} from './useMemory'; } from './useMemory';
// ========== Unified Memory ==========
export {
useUnifiedSearch,
useUnifiedStats,
useRecommendations,
useReindex,
} from './useUnifiedSearch';
export type {
UseUnifiedSearchOptions,
UseUnifiedSearchReturn,
UseUnifiedStatsReturn,
UseRecommendationsOptions,
UseRecommendationsReturn,
UseReindexReturn,
} from './useUnifiedSearch';
// ========== MCP Servers ========== // ========== MCP Servers ==========
export { export {
useMcpServers, useMcpServers,

View File

@@ -15,6 +15,7 @@ import {
} from '../lib/api'; } from '../lib/api';
import { useWorkflowStore, selectProjectPath } from '@/stores/workflowStore'; import { useWorkflowStore, selectProjectPath } from '@/stores/workflowStore';
import { workspaceQueryKeys } from '@/lib/queryKeys'; import { workspaceQueryKeys } from '@/lib/queryKeys';
import { parseMemoryMetadata } from '@/lib/utils';
// Query key factory // Query key factory
export const memoryKeys = { export const memoryKeys = {
@@ -99,13 +100,8 @@ export function useMemory(options: UseMemoryOptions = {}): UseMemoryReturn {
// Filter by favorite status (from metadata) // Filter by favorite status (from metadata)
if (filter?.favorite === true) { if (filter?.favorite === true) {
memories = memories.filter((m) => { memories = memories.filter((m) => {
if (!m.metadata) return false; const metadata = parseMemoryMetadata(m.metadata);
try { return metadata.favorite === true;
const metadata = typeof m.metadata === 'string' ? JSON.parse(m.metadata) : m.metadata;
return metadata.favorite === true;
} catch {
return false;
}
}); });
} }

View File

@@ -0,0 +1,199 @@
// ========================================
// useUnifiedSearch Hook
// ========================================
// TanStack Query hooks for unified memory search, stats, and recommendations
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query';
import {
fetchUnifiedSearch,
fetchUnifiedStats,
fetchRecommendations,
triggerReindex,
type UnifiedSearchResult,
type UnifiedMemoryStats,
type RecommendationResult,
type ReindexResponse,
} from '../lib/api';
import { useWorkflowStore, selectProjectPath } from '@/stores/workflowStore';
import { workspaceQueryKeys } from '@/lib/queryKeys';
// Default stale time: 1 minute
const STALE_TIME = 60 * 1000;
// ========== Unified Search ==========
export interface UseUnifiedSearchOptions {
query: string;
categories?: string;
topK?: number;
minScore?: number;
enabled?: boolean;
staleTime?: number;
}
export interface UseUnifiedSearchReturn {
results: UnifiedSearchResult[];
total: number;
isLoading: boolean;
isFetching: boolean;
error: Error | null;
refetch: () => Promise<void>;
}
/**
* Hook for unified vector + FTS5 search across all memory categories
*/
export function useUnifiedSearch(options: UseUnifiedSearchOptions): UseUnifiedSearchReturn {
const {
query,
categories,
topK,
minScore,
enabled = true,
staleTime = STALE_TIME,
} = options;
const projectPath = useWorkflowStore(selectProjectPath);
// Only enable query when projectPath exists and query is non-empty
const queryEnabled = enabled && !!projectPath && query.trim().length > 0;
const result = useQuery({
queryKey: workspaceQueryKeys.unifiedSearch(projectPath || '', query, categories),
queryFn: () =>
fetchUnifiedSearch(
query,
{ topK, minScore, category: categories },
projectPath || undefined
),
staleTime,
enabled: queryEnabled,
retry: 1,
});
const refetch = async () => {
await result.refetch();
};
return {
results: result.data?.results ?? [],
total: result.data?.total ?? 0,
isLoading: result.isLoading,
isFetching: result.isFetching,
error: result.error,
refetch,
};
}
// ========== Unified Stats ==========
export interface UseUnifiedStatsReturn {
stats: UnifiedMemoryStats | null;
isLoading: boolean;
isFetching: boolean;
error: Error | null;
refetch: () => Promise<void>;
}
/**
* Hook for fetching unified memory statistics
*/
export function useUnifiedStats(): UseUnifiedStatsReturn {
const projectPath = useWorkflowStore(selectProjectPath);
const queryEnabled = !!projectPath;
const result = useQuery({
queryKey: workspaceQueryKeys.unifiedStats(projectPath || ''),
queryFn: () => fetchUnifiedStats(projectPath || undefined),
staleTime: STALE_TIME,
enabled: queryEnabled,
retry: 2,
});
const refetch = async () => {
await result.refetch();
};
return {
stats: result.data?.stats ?? null,
isLoading: result.isLoading,
isFetching: result.isFetching,
error: result.error,
refetch,
};
}
// ========== Recommendations ==========
export interface UseRecommendationsOptions {
memoryId: string;
limit?: number;
enabled?: boolean;
}
export interface UseRecommendationsReturn {
recommendations: RecommendationResult[];
total: number;
isLoading: boolean;
isFetching: boolean;
error: Error | null;
}
/**
* Hook for KNN-based memory recommendations
*/
export function useRecommendations(options: UseRecommendationsOptions): UseRecommendationsReturn {
const { memoryId, limit = 5, enabled = true } = options;
const projectPath = useWorkflowStore(selectProjectPath);
const queryEnabled = enabled && !!projectPath && !!memoryId;
const result = useQuery({
queryKey: workspaceQueryKeys.unifiedRecommendations(projectPath || '', memoryId),
queryFn: () => fetchRecommendations(memoryId, limit, projectPath || undefined),
staleTime: STALE_TIME,
enabled: queryEnabled,
retry: 1,
});
return {
recommendations: result.data?.recommendations ?? [],
total: result.data?.total ?? 0,
isLoading: result.isLoading,
isFetching: result.isFetching,
error: result.error,
};
}
// ========== Reindex Mutation ==========
export interface UseReindexReturn {
reindex: () => Promise<ReindexResponse>;
isReindexing: boolean;
error: Error | null;
}
/**
* Hook for triggering vector index rebuild
*/
export function useReindex(): UseReindexReturn {
const queryClient = useQueryClient();
const projectPath = useWorkflowStore(selectProjectPath);
const mutation = useMutation({
mutationFn: () => triggerReindex(projectPath || undefined),
onSuccess: () => {
// Invalidate unified memory cache after reindex
if (projectPath) {
queryClient.invalidateQueries({
queryKey: workspaceQueryKeys.unifiedMemory(projectPath),
});
}
},
});
return {
reindex: mutation.mutateAsync,
isReindexing: mutation.isPending,
error: mutation.error,
};
}

View File

@@ -6360,3 +6360,149 @@ export async function fetchCliSessionAudit(
withPath(`/api/audit/cli-sessions${queryString ? `?${queryString}` : ''}`, options?.projectPath) withPath(`/api/audit/cli-sessions${queryString ? `?${queryString}` : ''}`, options?.projectPath)
); );
} }
// ========== Unified Memory API ==========
export interface UnifiedSearchResult {
source_id: string;
source_type: string;
score: number;
content: string;
category: string;
rank_sources: {
vector_rank?: number;
vector_score?: number;
fts_rank?: number;
heat_score?: number;
};
}
export interface UnifiedSearchResponse {
success: boolean;
query: string;
total: number;
results: UnifiedSearchResult[];
}
export interface UnifiedMemoryStats {
core_memories: {
total: number;
archived: number;
};
stage1_outputs: number;
entities: number;
prompts: number;
conversations: number;
vector_index: {
available: boolean;
total_chunks: number;
hnsw_available: boolean;
hnsw_count: number;
dimension: number;
categories?: Record<string, number>;
};
}
export interface RecommendationResult {
source_id: string;
source_type: string;
score: number;
content: string;
category: string;
}
export interface ReindexResponse {
success: boolean;
hnsw_count?: number;
elapsed_time?: number;
error?: string;
}
/**
* Search unified memory using vector + FTS5 fusion (RRF)
* @param query - Search query text
* @param options - Search options (topK, minScore, category)
* @param projectPath - Optional project path for workspace isolation
*/
export async function fetchUnifiedSearch(
query: string,
options?: {
topK?: number;
minScore?: number;
category?: string;
},
projectPath?: string
): Promise<UnifiedSearchResponse> {
const params = new URLSearchParams();
params.set('q', query);
if (options?.topK) params.set('topK', String(options.topK));
if (options?.minScore) params.set('minScore', String(options.minScore));
if (options?.category) params.set('category', options.category);
const data = await fetchApi<UnifiedSearchResponse & { error?: string }>(
withPath(`/api/unified-memory/search?${params.toString()}`, projectPath)
);
if (data.success === false) {
throw new Error(data.error || 'Search failed');
}
return data;
}
/**
* Fetch unified memory statistics (core memories, entities, vectors, etc.)
* @param projectPath - Optional project path for workspace isolation
*/
export async function fetchUnifiedStats(
projectPath?: string
): Promise<{ success: boolean; stats: UnifiedMemoryStats }> {
const data = await fetchApi<{ success: boolean; stats: UnifiedMemoryStats; error?: string }>(
withPath('/api/unified-memory/stats', projectPath)
);
if (data.success === false) {
throw new Error(data.error || 'Failed to load unified stats');
}
return data;
}
/**
* Get KNN-based recommendations for a specific memory
* @param memoryId - Core memory ID (CMEM-*)
* @param limit - Number of recommendations (default: 5)
* @param projectPath - Optional project path for workspace isolation
*/
export async function fetchRecommendations(
memoryId: string,
limit?: number,
projectPath?: string
): Promise<{ success: boolean; memory_id: string; total: number; recommendations: RecommendationResult[] }> {
const params = new URLSearchParams();
if (limit) params.set('limit', String(limit));
const queryString = params.toString();
const data = await fetchApi<{ success: boolean; memory_id: string; total: number; recommendations: RecommendationResult[]; error?: string }>(
withPath(
`/api/unified-memory/recommendations/${encodeURIComponent(memoryId)}${queryString ? `?${queryString}` : ''}`,
projectPath
)
);
if (data.success === false) {
throw new Error(data.error || 'Failed to load recommendations');
}
return data;
}
/**
* Trigger vector index rebuild
* @param projectPath - Optional project path for workspace isolation
*/
export async function triggerReindex(
projectPath?: string
): Promise<ReindexResponse> {
return fetchApi<ReindexResponse>(
'/api/unified-memory/reindex',
{
method: 'POST',
body: JSON.stringify({ path: projectPath }),
}
);
}

View File

@@ -130,6 +130,15 @@ export const workspaceQueryKeys = {
offset?: number; offset?: number;
} }
) => [...workspaceQueryKeys.audit(projectPath), 'cliSessions', options] as const, ) => [...workspaceQueryKeys.audit(projectPath), 'cliSessions', options] as const,
// ========== Unified Memory ==========
unifiedMemory: (projectPath: string) => [...workspaceQueryKeys.all(projectPath), 'unifiedMemory'] as const,
unifiedSearch: (projectPath: string, query: string, categories?: string) =>
[...workspaceQueryKeys.unifiedMemory(projectPath), 'search', query, categories] as const,
unifiedStats: (projectPath: string) =>
[...workspaceQueryKeys.unifiedMemory(projectPath), 'stats'] as const,
unifiedRecommendations: (projectPath: string, memoryId: string) =>
[...workspaceQueryKeys.unifiedMemory(projectPath), 'recommendations', memoryId] as const,
}; };
// ========== API Settings Keys ========== // ========== API Settings Keys ==========

View File

@@ -14,4 +14,20 @@ export function cn(...inputs: ClassValue[]): string {
return twMerge(clsx(inputs)); return twMerge(clsx(inputs));
} }
/**
* Safely parse memory metadata from string, object, or undefined.
* Returns an empty object on parse failure or missing input.
*/
export function parseMemoryMetadata(
metadata: string | Record<string, any> | undefined | null
): Record<string, any> {
if (!metadata) return {};
if (typeof metadata === 'object') return metadata;
try {
return JSON.parse(metadata);
} catch {
return {};
}
}
export type { ClassValue }; export type { ClassValue };

View File

@@ -22,19 +22,31 @@
"tabs": { "tabs": {
"memories": "Memories", "memories": "Memories",
"favorites": "Favorites", "favorites": "Favorites",
"archived": "Archived" "archived": "Archived",
"unifiedSearch": "Unified Search"
}, },
"stats": { "stats": {
"totalSize": "Total Size", "totalSize": "Total Size",
"count": "Count", "count": "Count",
"claudeMdCount": "CLAUDE.md Files", "claudeMdCount": "CLAUDE.md Files",
"totalEntries": "Total Entries" "totalEntries": "Total Entries",
"vectorChunks": "Vector Chunks",
"hnswStatus": "HNSW Index",
"entities": "Entities",
"prompts": "Prompts"
}, },
"filters": { "filters": {
"search": "Search memories...", "search": "Search memories...",
"searchUnified": "Semantic search across all memory types...",
"tags": "Tags", "tags": "Tags",
"clear": "Clear", "clear": "Clear",
"all": "All" "all": "All",
"categoryAll": "All Categories",
"categoryCoreMemory": "Core Memory",
"categoryCliHistory": "CLI History",
"categoryWorkflow": "Workflow",
"categoryEntity": "Entity",
"categoryPattern": "Pattern"
}, },
"card": { "card": {
"id": "ID", "id": "ID",
@@ -82,5 +94,20 @@
"coreMemory": "Core Memory", "coreMemory": "Core Memory",
"workflow": "Workflow", "workflow": "Workflow",
"cliHistory": "CLI History" "cliHistory": "CLI History"
},
"unified": {
"score": "Score",
"noResults": "No results found. Try a different search query.",
"searching": "Searching...",
"resultCount": "{count} results",
"recommendations": "Related",
"noRecommendations": "No recommendations available",
"reindex": "Rebuild Index",
"reindexing": "Rebuilding...",
"reindexSuccess": "Index rebuilt successfully",
"reindexError": "Failed to rebuild index",
"vectorRank": "Vector #{rank}",
"ftsRank": "FTS #{rank}",
"heatScore": "Heat: {score}"
} }
} }

View File

@@ -114,6 +114,45 @@
"on": "On", "on": "On",
"off": "Off" "off": "Off"
}, },
"remoteNotifications": {
"title": "Remote Notifications",
"description": "Send notifications to external platforms like Discord, Telegram, or custom webhooks when events occur.",
"enabled": "Enabled",
"disabled": "Disabled",
"platforms": "Platform Configuration",
"events": "Event Triggers",
"noPlatforms": "No platforms",
"configured": "Configured",
"save": "Save",
"saving": "Saving...",
"saved": "Configuration saved",
"saveError": "Failed to save configuration",
"reset": "Reset to Defaults",
"resetConfirm": "Reset all remote notification settings to defaults?",
"resetSuccess": "Settings reset to defaults",
"resetError": "Failed to reset settings",
"testConnection": "Test Connection",
"testSuccess": "Test notification sent successfully",
"testFailed": "Test notification failed",
"testError": "Failed to send test notification",
"discord": {
"webhookUrl": "Webhook URL",
"webhookUrlHint": "Create a webhook in your Discord channel settings",
"username": "Custom Username (optional)"
},
"telegram": {
"botToken": "Bot Token",
"botTokenHint": "Get from @BotFather on Telegram",
"chatId": "Chat ID",
"chatIdHint": "User or group chat ID (use @userinfobot to find it)"
},
"webhook": {
"url": "Webhook URL",
"method": "HTTP Method",
"headers": "Custom Headers (JSON)",
"headersHint": "Optional JSON object with custom headers"
}
},
"versionCheck": { "versionCheck": {
"title": "Version Update", "title": "Version Update",
"currentVersion": "Current Version", "currentVersion": "Current Version",

View File

@@ -22,19 +22,31 @@
"tabs": { "tabs": {
"memories": "记忆", "memories": "记忆",
"favorites": "收藏", "favorites": "收藏",
"archived": "归档" "archived": "归档",
"unifiedSearch": "统一搜索"
}, },
"stats": { "stats": {
"totalSize": "总大小", "totalSize": "总大小",
"count": "数量", "count": "数量",
"claudeMdCount": "CLAUDE.md 文件", "claudeMdCount": "CLAUDE.md 文件",
"totalEntries": "总条目" "totalEntries": "总条目",
"vectorChunks": "向量块",
"hnswStatus": "HNSW 索引",
"entities": "实体",
"prompts": "提示"
}, },
"filters": { "filters": {
"search": "搜索记忆...", "search": "搜索记忆...",
"searchUnified": "跨所有记忆类型语义搜索...",
"tags": "标签", "tags": "标签",
"clear": "清除", "clear": "清除",
"all": "全部" "all": "全部",
"categoryAll": "所有类别",
"categoryCoreMemory": "核心记忆",
"categoryCliHistory": "CLI 历史",
"categoryWorkflow": "工作流",
"categoryEntity": "实体",
"categoryPattern": "模式"
}, },
"card": { "card": {
"id": "ID", "id": "ID",
@@ -82,5 +94,20 @@
"coreMemory": "核心记忆", "coreMemory": "核心记忆",
"workflow": "工作流", "workflow": "工作流",
"cliHistory": "CLI 历史" "cliHistory": "CLI 历史"
},
"unified": {
"score": "分数",
"noResults": "未找到结果。请尝试不同的搜索查询。",
"searching": "搜索中...",
"resultCount": "{count} 条结果",
"recommendations": "相关",
"noRecommendations": "暂无推荐",
"reindex": "重建索引",
"reindexing": "重建中...",
"reindexSuccess": "索引重建成功",
"reindexError": "索引重建失败",
"vectorRank": "向量 #{rank}",
"ftsRank": "全文 #{rank}",
"heatScore": "热度: {score}"
} }
} }

View File

@@ -114,6 +114,45 @@
"on": "开启", "on": "开启",
"off": "关闭" "off": "关闭"
}, },
"remoteNotifications": {
"title": "远程通知",
"description": "当事件发生时,发送通知到 Discord、Telegram 或自定义 Webhook 等外部平台。",
"enabled": "已启用",
"disabled": "已禁用",
"platforms": "平台配置",
"events": "事件触发器",
"noPlatforms": "无平台",
"configured": "已配置",
"save": "保存",
"saving": "保存中...",
"saved": "配置已保存",
"saveError": "保存配置失败",
"reset": "重置为默认值",
"resetConfirm": "确定要将所有远程通知设置重置为默认值吗?",
"resetSuccess": "设置已重置为默认值",
"resetError": "重置设置失败",
"testConnection": "测试连接",
"testSuccess": "测试通知发送成功",
"testFailed": "测试通知发送失败",
"testError": "发送测试通知失败",
"discord": {
"webhookUrl": "Webhook URL",
"webhookUrlHint": "在 Discord 频道设置中创建 Webhook",
"username": "自定义用户名(可选)"
},
"telegram": {
"botToken": "Bot Token",
"botTokenHint": "从 Telegram 的 @BotFather 获取",
"chatId": "Chat ID",
"chatIdHint": "用户或群组 Chat ID使用 @userinfobot 查找)"
},
"webhook": {
"url": "Webhook URL",
"method": "HTTP 方法",
"headers": "自定义请求头JSON",
"headersHint": "可选的 JSON 对象,包含自定义请求头"
}
},
"versionCheck": { "versionCheck": {
"title": "版本更新", "title": "版本更新",
"currentVersion": "当前版本", "currentVersion": "当前版本",

View File

@@ -2,6 +2,7 @@
// Memory Page // Memory Page
// ======================================== // ========================================
// View and manage core memory and context with CRUD operations // View and manage core memory and context with CRUD operations
// Includes unified vector search across all memory categories
import { useState, useEffect } from 'react'; import { useState, useEffect } from 'react';
import { useIntl } from 'react-intl'; import { useIntl } from 'react-intl';
@@ -22,6 +23,11 @@ import {
Archive, Archive,
ArchiveRestore, ArchiveRestore,
AlertCircle, AlertCircle,
Layers,
Zap,
Terminal,
GitBranch,
Hash,
} from 'lucide-react'; } from 'lucide-react';
import { Card } from '@/components/ui/Card'; import { Card } from '@/components/ui/Card';
import { Button } from '@/components/ui/Button'; import { Button } from '@/components/ui/Button';
@@ -30,9 +36,39 @@ import { Badge } from '@/components/ui/Badge';
import { TabsNavigation } from '@/components/ui/TabsNavigation'; import { TabsNavigation } from '@/components/ui/TabsNavigation';
import { Dialog, DialogContent, DialogHeader, DialogTitle } from '@/components/ui/Dialog'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '@/components/ui/Dialog';
import { Checkbox } from '@/components/ui/Checkbox'; import { Checkbox } from '@/components/ui/Checkbox';
import { useMemory, useMemoryMutations } from '@/hooks'; import { useMemory, useMemoryMutations, useUnifiedSearch, useUnifiedStats, useRecommendations, useReindex } from '@/hooks';
import type { CoreMemory } from '@/lib/api'; import type { CoreMemory, UnifiedSearchResult } from '@/lib/api';
import { cn } from '@/lib/utils'; import { cn, parseMemoryMetadata } from '@/lib/utils';
// ========== Source Type Helpers ==========
const SOURCE_TYPE_COLORS: Record<string, string> = {
core_memory: 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300',
cli_history: 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300',
workflow: 'bg-purple-100 text-purple-800 dark:bg-purple-900/30 dark:text-purple-300',
entity: 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300',
pattern: 'bg-pink-100 text-pink-800 dark:bg-pink-900/30 dark:text-pink-300',
};
const SOURCE_TYPE_ICONS: Record<string, React.ReactNode> = {
core_memory: <Brain className="w-3 h-3" />,
cli_history: <Terminal className="w-3 h-3" />,
workflow: <GitBranch className="w-3 h-3" />,
entity: <Hash className="w-3 h-3" />,
pattern: <Layers className="w-3 h-3" />,
};
function SourceTypeBadge({ sourceType }: { sourceType: string }) {
const colorClass = SOURCE_TYPE_COLORS[sourceType] || 'bg-gray-100 text-gray-800 dark:bg-gray-900/30 dark:text-gray-300';
const icon = SOURCE_TYPE_ICONS[sourceType] || <Database className="w-3 h-3" />;
return (
<span className={cn('inline-flex items-center gap-1 px-2 py-0.5 rounded-full text-xs font-medium', colorClass)}>
{icon}
{sourceType}
</span>
);
}
// ========== Memory Card Component ========== // ========== Memory Card Component ==========
@@ -51,7 +87,7 @@ function MemoryCard({ memory, onView, onEdit, onDelete, onCopy, onToggleFavorite
const formattedDate = new Date(memory.createdAt).toLocaleDateString(); const formattedDate = new Date(memory.createdAt).toLocaleDateString();
// Parse metadata from memory // Parse metadata from memory
const metadata = memory.metadata ? (typeof memory.metadata === 'string' ? JSON.parse(memory.metadata) : memory.metadata) : {}; const metadata = parseMemoryMetadata(memory.metadata);
const isFavorite = metadata.favorite === true; const isFavorite = metadata.favorite === true;
const priority = metadata.priority || 'medium'; const priority = metadata.priority || 'medium';
const isArchived = memory.archived || false; const isArchived = memory.archived || false;
@@ -197,6 +233,138 @@ function MemoryCard({ memory, onView, onEdit, onDelete, onCopy, onToggleFavorite
); );
} }
// ========== Unified Search Result Card ==========
interface UnifiedResultCardProps {
result: UnifiedSearchResult;
onCopy: (content: string) => void;
}
function UnifiedResultCard({ result, onCopy }: UnifiedResultCardProps) {
const { formatMessage } = useIntl();
const scorePercent = (result.score * 100).toFixed(1);
return (
<Card className="overflow-hidden">
<div className="p-4">
<div className="flex items-start justify-between gap-2">
<div className="flex items-start gap-3 min-w-0 flex-1">
<SourceTypeBadge sourceType={result.source_type} />
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2 flex-wrap">
<span className="text-sm font-medium text-foreground truncate">
{result.source_id}
</span>
<Badge variant="outline" className="text-xs shrink-0">
{formatMessage({ id: 'memory.unified.score' })}: {scorePercent}%
</Badge>
</div>
{/* Rank sources */}
<div className="flex items-center gap-2 mt-1">
{result.rank_sources.vector_rank != null && (
<span className="text-xs text-muted-foreground">
{formatMessage({ id: 'memory.unified.vectorRank' }, { rank: result.rank_sources.vector_rank })}
</span>
)}
{result.rank_sources.fts_rank != null && (
<span className="text-xs text-muted-foreground">
{formatMessage({ id: 'memory.unified.ftsRank' }, { rank: result.rank_sources.fts_rank })}
</span>
)}
{result.rank_sources.heat_score != null && (
<span className="text-xs text-muted-foreground">
{formatMessage({ id: 'memory.unified.heatScore' }, { score: result.rank_sources.heat_score.toFixed(2) })}
</span>
)}
</div>
</div>
</div>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0 shrink-0"
onClick={() => onCopy(result.content)}
>
<Copy className="w-4 h-4" />
</Button>
</div>
{/* Content preview */}
<p className="text-sm text-muted-foreground mt-2 line-clamp-3">
{result.content}
</p>
</div>
</Card>
);
}
// ========== Recommendations Panel ==========
interface RecommendationsPanelProps {
memoryId: string;
onCopy: (content: string) => void;
}
function RecommendationsPanel({ memoryId, onCopy }: RecommendationsPanelProps) {
const { formatMessage } = useIntl();
const { recommendations, isLoading } = useRecommendations({
memoryId,
limit: 5,
enabled: !!memoryId,
});
if (isLoading) {
return (
<div className="flex items-center gap-2 text-muted-foreground py-2">
<Loader2 className="w-4 h-4 animate-spin" />
<span className="text-sm">{formatMessage({ id: 'memory.unified.searching' })}</span>
</div>
);
}
if (recommendations.length === 0) {
return (
<p className="text-sm text-muted-foreground py-2">
{formatMessage({ id: 'memory.unified.noRecommendations' })}
</p>
);
}
return (
<div className="space-y-2">
{recommendations.map((rec) => (
<div
key={rec.source_id}
className="flex items-start gap-2 p-2 rounded-md bg-muted/30 hover:bg-muted/50 transition-colors"
>
<SourceTypeBadge sourceType={rec.source_type} />
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<span className="text-xs font-medium text-foreground truncate">
{rec.source_id}
</span>
<span className="text-xs text-muted-foreground shrink-0">
{(rec.score * 100).toFixed(0)}%
</span>
</div>
<p className="text-xs text-muted-foreground line-clamp-2 mt-0.5">
{rec.content}
</p>
</div>
<Button
variant="ghost"
size="sm"
className="h-6 w-6 p-0 shrink-0"
onClick={() => onCopy(rec.content)}
>
<Copy className="w-3 h-3" />
</Button>
</div>
))}
</div>
);
}
// ========== View Memory Dialog ========== // ========== View Memory Dialog ==========
interface ViewMemoryDialogProps { interface ViewMemoryDialogProps {
@@ -211,7 +379,7 @@ function ViewMemoryDialog({ memory, open, onOpenChange, onEdit, onCopy }: ViewMe
const { formatMessage } = useIntl(); const { formatMessage } = useIntl();
if (!memory) return null; if (!memory) return null;
const metadata = memory.metadata ? (typeof memory.metadata === 'string' ? JSON.parse(memory.metadata) : memory.metadata) : {}; const metadata = parseMemoryMetadata(memory.metadata);
const priority = metadata.priority || 'medium'; const priority = metadata.priority || 'medium';
const formattedDate = new Date(memory.createdAt).toLocaleDateString(); const formattedDate = new Date(memory.createdAt).toLocaleDateString();
const formattedSize = memory.size const formattedSize = memory.size
@@ -264,6 +432,15 @@ function ViewMemoryDialog({ memory, open, onOpenChange, onEdit, onCopy }: ViewMe
</pre> </pre>
</div> </div>
{/* Recommendations */}
<div className="pt-2 border-t border-border">
<h4 className="text-sm font-medium text-foreground flex items-center gap-1.5 mb-2">
<Zap className="w-4 h-4 text-primary" />
{formatMessage({ id: 'memory.unified.recommendations' })}
</h4>
<RecommendationsPanel memoryId={memory.id} onCopy={onCopy} />
</div>
{/* Actions */} {/* Actions */}
<div className="flex justify-end gap-2 pt-2 border-t border-border"> <div className="flex justify-end gap-2 pt-2 border-t border-border">
<Button variant="outline" size="sm" onClick={() => onCopy(memory.content)}> <Button variant="outline" size="sm" onClick={() => onCopy(memory.content)}>
@@ -311,21 +488,9 @@ function NewMemoryDialog({
setTagsInput(editingMemory.tags?.join(', ') || ''); setTagsInput(editingMemory.tags?.join(', ') || '');
// Sync metadata // Sync metadata
if (editingMemory.metadata) { const metadata = parseMemoryMetadata(editingMemory.metadata);
try { setIsFavorite(metadata.favorite === true);
const metadata = typeof editingMemory.metadata === 'string' setPriority(metadata.priority || 'medium');
? JSON.parse(editingMemory.metadata)
: editingMemory.metadata;
setIsFavorite(metadata.favorite === true);
setPriority(metadata.priority || 'medium');
} catch {
setIsFavorite(false);
setPriority('medium');
}
} else {
setIsFavorite(false);
setPriority('medium');
}
} else { } else {
// New mode: reset all state // New mode: reset all state
setContent(''); setContent('');
@@ -436,6 +601,17 @@ function NewMemoryDialog({
); );
} }
// ========== Category Filter ==========
const CATEGORY_OPTIONS = [
{ value: '', labelId: 'memory.filters.categoryAll' },
{ value: 'core_memory', labelId: 'memory.filters.categoryCoreMemory' },
{ value: 'cli_history', labelId: 'memory.filters.categoryCliHistory' },
{ value: 'workflow', labelId: 'memory.filters.categoryWorkflow' },
{ value: 'entity', labelId: 'memory.filters.categoryEntity' },
{ value: 'pattern', labelId: 'memory.filters.categoryPattern' },
];
// ========== Main Page Component ========== // ========== Main Page Component ==========
export function MemoryPage() { export function MemoryPage() {
@@ -445,9 +621,13 @@ export function MemoryPage() {
const [isNewMemoryOpen, setIsNewMemoryOpen] = useState(false); const [isNewMemoryOpen, setIsNewMemoryOpen] = useState(false);
const [editingMemory, setEditingMemory] = useState<CoreMemory | null>(null); const [editingMemory, setEditingMemory] = useState<CoreMemory | null>(null);
const [viewingMemory, setViewingMemory] = useState<CoreMemory | null>(null); const [viewingMemory, setViewingMemory] = useState<CoreMemory | null>(null);
const [currentTab, setCurrentTab] = useState<'memories' | 'favorites' | 'archived'>('memories'); const [currentTab, setCurrentTab] = useState<'memories' | 'favorites' | 'archived' | 'unifiedSearch'>('memories');
const [unifiedQuery, setUnifiedQuery] = useState('');
const [selectedCategory, setSelectedCategory] = useState('');
// Build filter based on current tab const isUnifiedTab = currentTab === 'unifiedSearch';
// Build filter based on current tab (for non-unified tabs)
const favoriteFilter = currentTab === 'favorites' ? { favorite: true } : undefined; const favoriteFilter = currentTab === 'favorites' ? { favorite: true } : undefined;
const archivedFilter = currentTab === 'archived' ? { archived: true } : { archived: false }; const archivedFilter = currentTab === 'archived' ? { archived: true } : { archived: false };
@@ -467,8 +647,34 @@ export function MemoryPage() {
...favoriteFilter, ...favoriteFilter,
...archivedFilter, ...archivedFilter,
}, },
enabled: !isUnifiedTab,
}); });
// Unified search
const {
results: unifiedResults,
total: unifiedTotal,
isLoading: unifiedLoading,
isFetching: unifiedFetching,
error: unifiedError,
refetch: refetchUnified,
} = useUnifiedSearch({
query: unifiedQuery,
categories: selectedCategory || undefined,
topK: 20,
enabled: isUnifiedTab && unifiedQuery.trim().length > 0,
});
// Unified stats
const {
stats: unifiedStats,
isLoading: statsLoading,
refetch: refetchStats,
} = useUnifiedStats();
// Reindex mutation
const { reindex, isReindexing } = useReindex();
const { createMemory, updateMemory, deleteMemory, archiveMemory, unarchiveMemory, isCreating, isUpdating } = const { createMemory, updateMemory, deleteMemory, archiveMemory, unarchiveMemory, isCreating, isUpdating } =
useMemoryMutations(); useMemoryMutations();
@@ -495,9 +701,7 @@ export function MemoryPage() {
const handleToggleFavorite = async (memory: CoreMemory) => { const handleToggleFavorite = async (memory: CoreMemory) => {
try { try {
const currentMetadata = memory.metadata const currentMetadata = parseMemoryMetadata(memory.metadata);
? (typeof memory.metadata === 'string' ? JSON.parse(memory.metadata) : memory.metadata)
: {};
const newFavorite = !(currentMetadata.favorite === true); const newFavorite = !(currentMetadata.favorite === true);
await updateMemory(memory.id, { await updateMemory(memory.id, {
content: memory.content, content: memory.content,
@@ -544,6 +748,17 @@ export function MemoryPage() {
} }
}; };
const handleReindex = async () => {
try {
await reindex();
toast.success(formatMessage({ id: 'memory.unified.reindexSuccess' }));
refetchStats();
} catch (err) {
console.error('Failed to reindex:', err);
toast.error(formatMessage({ id: 'memory.unified.reindexError' }));
}
};
const toggleTag = (tag: string) => { const toggleTag = (tag: string) => {
setSelectedTags((prev) => setSelectedTags((prev) =>
prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag] prev.includes(tag) ? prev.filter((t) => t !== tag) : [...prev, tag]
@@ -556,6 +771,18 @@ export function MemoryPage() {
? `${(totalSize / 1024).toFixed(1)} KB` ? `${(totalSize / 1024).toFixed(1)} KB`
: `${(totalSize / (1024 * 1024)).toFixed(1)} MB`; : `${(totalSize / (1024 * 1024)).toFixed(1)} MB`;
const handleRefresh = () => {
if (isUnifiedTab) {
refetchUnified();
refetchStats();
} else {
refetch();
}
};
const isRefreshing = isUnifiedTab ? unifiedFetching : isFetching;
const activeError = isUnifiedTab ? unifiedError : error;
return ( return (
<div className="space-y-6"> <div className="space-y-6">
{/* Page Header */} {/* Page Header */}
@@ -570,21 +797,37 @@ export function MemoryPage() {
</p> </p>
</div> </div>
<div className="flex gap-2"> <div className="flex gap-2">
<Button variant="outline" onClick={() => refetch()} disabled={isFetching}> {isUnifiedTab && (
<RefreshCw className={cn('w-4 h-4 mr-2', isFetching && 'animate-spin')} /> <Button
variant="outline"
onClick={handleReindex}
disabled={isReindexing}
>
{isReindexing ? (
<Loader2 className="w-4 h-4 mr-2 animate-spin" />
) : (
<Zap className="w-4 h-4 mr-2" />
)}
{formatMessage({ id: isReindexing ? 'memory.unified.reindexing' : 'memory.unified.reindex' })}
</Button>
)}
<Button variant="outline" onClick={handleRefresh} disabled={isRefreshing}>
<RefreshCw className={cn('w-4 h-4 mr-2', isRefreshing && 'animate-spin')} />
{formatMessage({ id: 'common.actions.refresh' })} {formatMessage({ id: 'common.actions.refresh' })}
</Button> </Button>
<Button onClick={() => { setEditingMemory(null); setIsNewMemoryOpen(true); }}> {!isUnifiedTab && (
<Plus className="w-4 h-4 mr-2" /> <Button onClick={() => { setEditingMemory(null); setIsNewMemoryOpen(true); }}>
{formatMessage({ id: 'memory.actions.add' })} <Plus className="w-4 h-4 mr-2" />
</Button> {formatMessage({ id: 'memory.actions.add' })}
</Button>
)}
</div> </div>
</div> </div>
{/* Tab Navigation - styled like LiteTasksPage */} {/* Tab Navigation */}
<TabsNavigation <TabsNavigation
value={currentTab} value={currentTab}
onValueChange={(v) => setCurrentTab(v as 'memories' | 'favorites' | 'archived')} onValueChange={(v) => setCurrentTab(v as typeof currentTab)}
tabs={[ tabs={[
{ {
value: 'memories', value: 'memories',
@@ -601,141 +844,285 @@ export function MemoryPage() {
label: formatMessage({ id: 'memory.tabs.archived' }), label: formatMessage({ id: 'memory.tabs.archived' }),
icon: <Archive className="h-4 w-4" />, icon: <Archive className="h-4 w-4" />,
}, },
{
value: 'unifiedSearch',
label: formatMessage({ id: 'memory.tabs.unifiedSearch' }),
icon: <Search className="h-4 w-4" />,
},
]} ]}
/> />
{/* Error alert */} {/* Error alert */}
{error && ( {activeError && (
<div className="flex items-center gap-2 p-4 rounded-lg bg-destructive/10 border border-destructive/30 text-destructive"> <div className="flex items-center gap-2 p-4 rounded-lg bg-destructive/10 border border-destructive/30 text-destructive">
<AlertCircle className="h-5 w-5 flex-shrink-0" /> <AlertCircle className="h-5 w-5 flex-shrink-0" />
<div className="flex-1"> <div className="flex-1">
<p className="text-sm font-medium">{formatMessage({ id: 'common.errors.loadFailed' })}</p> <p className="text-sm font-medium">{formatMessage({ id: 'common.errors.loadFailed' })}</p>
<p className="text-xs mt-0.5">{error.message}</p> <p className="text-xs mt-0.5">{activeError.message}</p>
</div> </div>
<Button variant="outline" size="sm" onClick={() => refetch()}> <Button variant="outline" size="sm" onClick={handleRefresh}>
{formatMessage({ id: 'home.errors.retry' })} {formatMessage({ id: 'home.errors.retry' })}
</Button> </Button>
</div> </div>
)} )}
{/* Stats Cards */} {/* Stats Cards */}
<div className="grid grid-cols-1 md:grid-cols-3 gap-4"> {isUnifiedTab ? (
<Card className="p-4"> /* Unified Stats Cards */
<div className="flex items-center gap-3"> <div className="grid grid-cols-1 md:grid-cols-4 gap-4">
<div className="p-2 rounded-lg bg-primary/10"> <Card className="p-4">
<Database className="w-5 h-5 text-primary" /> <div className="flex items-center gap-3">
<div className="p-2 rounded-lg bg-primary/10">
<Database className="w-5 h-5 text-primary" />
</div>
<div>
<div className="text-2xl font-bold text-foreground">
{statsLoading ? '-' : (unifiedStats?.core_memories.total ?? 0)}
</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.count' })}</p>
</div>
</div> </div>
<div> </Card>
<div className="text-2xl font-bold text-foreground">{memories.length}</div> <Card className="p-4">
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.count' })}</p> <div className="flex items-center gap-3">
<div className="p-2 rounded-lg bg-orange-500/10">
<Hash className="w-5 h-5 text-orange-500" />
</div>
<div>
<div className="text-2xl font-bold text-foreground">
{statsLoading ? '-' : (unifiedStats?.entities ?? 0)}
</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.entities' })}</p>
</div>
</div> </div>
</div> </Card>
</Card> <Card className="p-4">
<Card className="p-4"> <div className="flex items-center gap-3">
<div className="flex items-center gap-3"> <div className="p-2 rounded-lg bg-blue-500/10">
<div className="p-2 rounded-lg bg-info/10"> <Layers className="w-5 h-5 text-blue-500" />
<FileText className="w-5 h-5 text-info" /> </div>
<div>
<div className="text-2xl font-bold text-foreground">
{statsLoading ? '-' : (unifiedStats?.vector_index.total_chunks ?? 0)}
</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.vectorChunks' })}</p>
</div>
</div> </div>
<div> </Card>
<div className="text-2xl font-bold text-foreground">{claudeMdCount}</div> <Card className="p-4">
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.claudeMdCount' })}</p> <div className="flex items-center gap-3">
<div className={cn(
"p-2 rounded-lg",
unifiedStats?.vector_index.hnsw_available ? "bg-green-500/10" : "bg-muted"
)}>
<Zap className={cn(
"w-5 h-5",
unifiedStats?.vector_index.hnsw_available ? "text-green-500" : "text-muted-foreground"
)} />
</div>
<div>
<div className="text-2xl font-bold text-foreground">
{statsLoading ? '-' : (unifiedStats?.vector_index.hnsw_available ? unifiedStats.vector_index.hnsw_count : 'N/A')}
</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.hnswStatus' })}</p>
</div>
</div> </div>
</div> </Card>
</Card> </div>
<Card className="p-4"> ) : (
<div className="flex items-center gap-3"> /* Standard Stats Cards */
<div className="p-2 rounded-lg bg-success/10"> <div className="grid grid-cols-1 md:grid-cols-3 gap-4">
<Brain className="w-5 h-5 text-success" /> <Card className="p-4">
<div className="flex items-center gap-3">
<div className="p-2 rounded-lg bg-primary/10">
<Database className="w-5 h-5 text-primary" />
</div>
<div>
<div className="text-2xl font-bold text-foreground">{memories.length}</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.count' })}</p>
</div>
</div> </div>
<div> </Card>
<div className="text-2xl font-bold text-foreground">{formattedTotalSize}</div> <Card className="p-4">
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.totalSize' })}</p> <div className="flex items-center gap-3">
<div className="p-2 rounded-lg bg-info/10">
<FileText className="w-5 h-5 text-info" />
</div>
<div>
<div className="text-2xl font-bold text-foreground">{claudeMdCount}</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.claudeMdCount' })}</p>
</div>
</div> </div>
</div> </Card>
</Card> <Card className="p-4">
</div> <div className="flex items-center gap-3">
<div className="p-2 rounded-lg bg-success/10">
<Brain className="w-5 h-5 text-success" />
</div>
<div>
<div className="text-2xl font-bold text-foreground">{formattedTotalSize}</div>
<p className="text-sm text-muted-foreground">{formatMessage({ id: 'memory.stats.totalSize' })}</p>
</div>
</div>
</Card>
</div>
)}
{/* Search and Filters */} {/* Search and Filters */}
<div className="space-y-3"> {isUnifiedTab ? (
<div className="relative"> /* Unified Search Input + Category Filter */
<Search className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-muted-foreground" /> <div className="space-y-3">
<Input <div className="flex gap-3">
placeholder={formatMessage({ id: 'memory.filters.search' })} <div className="relative flex-1">
value={searchQuery} <Search className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-muted-foreground" />
onChange={(e) => setSearchQuery(e.target.value)} <Input
className="pl-9" placeholder={formatMessage({ id: 'memory.filters.searchUnified' })}
/> value={unifiedQuery}
</div> onChange={(e) => setUnifiedQuery(e.target.value)}
className="pl-9"
{/* Tags Filter */} />
{allTags.length > 0 && ( </div>
<div className="flex flex-wrap gap-2"> <select
<span className="text-sm text-muted-foreground py-1">{formatMessage({ id: 'memory.card.tags' })}:</span> value={selectedCategory}
{allTags.map((tag) => ( onChange={(e) => setSelectedCategory(e.target.value)}
<Button className="px-3 py-2 bg-background border border-input rounded-md text-sm min-w-[160px]"
key={tag} >
variant={selectedTags.includes(tag) ? 'default' : 'outline'} {CATEGORY_OPTIONS.map((opt) => (
size="sm" <option key={opt.value} value={opt.value}>
className="h-7" {formatMessage({ id: opt.labelId })}
onClick={() => toggleTag(tag)} </option>
> ))}
<Tag className="w-3 h-3 mr-1" /> </select>
{tag}
</Button>
))}
{selectedTags.length > 0 && (
<Button
variant="ghost"
size="sm"
className="h-7"
onClick={() => setSelectedTags([])}
>
{formatMessage({ id: 'memory.filters.clear' })}
</Button>
)}
</div> </div>
)} {unifiedQuery.trim().length > 0 && !unifiedLoading && (
</div> <p className="text-sm text-muted-foreground">
{formatMessage({ id: 'memory.unified.resultCount' }, { count: unifiedTotal })}
{/* Memory List */} </p>
{isLoading ? ( )}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{[1, 2, 3, 4, 5, 6].map((i) => (
<div key={i} className="h-64 bg-muted animate-pulse rounded-lg" />
))}
</div> </div>
) : memories.length === 0 ? (
<Card className="p-8 text-center">
<Brain className="w-12 h-12 mx-auto text-muted-foreground/50" />
<h3 className="mt-4 text-lg font-medium text-foreground">
{formatMessage({ id: 'memory.emptyState.title' })}
</h3>
<p className="mt-2 text-muted-foreground">
{formatMessage({ id: 'memory.emptyState.message' })}
</p>
<Button className="mt-4" onClick={() => { setEditingMemory(null); setIsNewMemoryOpen(true); }}>
<Plus className="w-4 h-4 mr-2" />
{formatMessage({ id: 'memory.emptyState.createFirst' })}
</Button>
</Card>
) : ( ) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4"> /* Standard Search + Tag Filters */
{memories.map((memory) => ( <div className="space-y-3">
<MemoryCard <div className="relative">
key={memory.id} <Search className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-muted-foreground" />
memory={memory} <Input
onView={setViewingMemory} placeholder={formatMessage({ id: 'memory.filters.search' })}
onEdit={handleEdit} value={searchQuery}
onDelete={handleDelete} onChange={(e) => setSearchQuery(e.target.value)}
onCopy={copyToClipboard} className="pl-9"
onToggleFavorite={handleToggleFavorite}
onArchive={handleArchive}
onUnarchive={handleUnarchive}
/> />
))} </div>
{/* Tags Filter */}
{allTags.length > 0 && (
<div className="flex flex-wrap gap-2">
<span className="text-sm text-muted-foreground py-1">{formatMessage({ id: 'memory.card.tags' })}:</span>
{allTags.map((tag) => (
<Button
key={tag}
variant={selectedTags.includes(tag) ? 'default' : 'outline'}
size="sm"
className="h-7"
onClick={() => toggleTag(tag)}
>
<Tag className="w-3 h-3 mr-1" />
{tag}
</Button>
))}
{selectedTags.length > 0 && (
<Button
variant="ghost"
size="sm"
className="h-7"
onClick={() => setSelectedTags([])}
>
{formatMessage({ id: 'memory.filters.clear' })}
</Button>
)}
</div>
)}
</div> </div>
)} )}
{/* Content Area */}
{isUnifiedTab ? (
/* Unified Search Results */
unifiedLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="w-6 h-6 animate-spin text-primary mr-2" />
<span className="text-muted-foreground">
{formatMessage({ id: 'memory.unified.searching' })}
</span>
</div>
) : unifiedQuery.trim().length === 0 ? (
<Card className="p-8 text-center">
<Search className="w-12 h-12 mx-auto text-muted-foreground/50" />
<h3 className="mt-4 text-lg font-medium text-foreground">
{formatMessage({ id: 'memory.tabs.unifiedSearch' })}
</h3>
<p className="mt-2 text-muted-foreground">
{formatMessage({ id: 'memory.filters.searchUnified' })}
</p>
</Card>
) : unifiedResults.length === 0 ? (
<Card className="p-8 text-center">
<Search className="w-12 h-12 mx-auto text-muted-foreground/50" />
<h3 className="mt-4 text-lg font-medium text-foreground">
{formatMessage({ id: 'memory.unified.noResults' })}
</h3>
</Card>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{unifiedResults.map((result) => (
<UnifiedResultCard
key={`${result.source_type}-${result.source_id}`}
result={result}
onCopy={copyToClipboard}
/>
))}
</div>
)
) : (
/* Standard Memory List */
isLoading ? (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{[1, 2, 3, 4, 5, 6].map((i) => (
<div key={i} className="h-64 bg-muted animate-pulse rounded-lg" />
))}
</div>
) : memories.length === 0 ? (
<Card className="p-8 text-center">
<Brain className="w-12 h-12 mx-auto text-muted-foreground/50" />
<h3 className="mt-4 text-lg font-medium text-foreground">
{formatMessage({ id: 'memory.emptyState.title' })}
</h3>
<p className="mt-2 text-muted-foreground">
{formatMessage({ id: 'memory.emptyState.message' })}
</p>
<Button className="mt-4" onClick={() => { setEditingMemory(null); setIsNewMemoryOpen(true); }}>
<Plus className="w-4 h-4 mr-2" />
{formatMessage({ id: 'memory.emptyState.createFirst' })}
</Button>
</Card>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{memories.map((memory) => (
<MemoryCard
key={memory.id}
memory={memory}
onView={setViewingMemory}
onEdit={handleEdit}
onDelete={handleDelete}
onCopy={copyToClipboard}
onToggleFavorite={handleToggleFavorite}
onArchive={handleArchive}
onUnarchive={handleUnarchive}
/>
))}
</div>
)
)}
{/* View Memory Dialog */} {/* View Memory Dialog */}
<ViewMemoryDialog <ViewMemoryDialog
memory={viewingMemory} memory={viewingMemory}

View File

@@ -55,6 +55,7 @@ import {
useCcwInstallations, useCcwInstallations,
useUpgradeCcwInstallation, useUpgradeCcwInstallation,
} from '@/hooks/useSystemSettings'; } from '@/hooks/useSystemSettings';
import { RemoteNotificationSection } from '@/components/settings/RemoteNotificationSection';
// ========== File Path Input with Native File Picker ========== // ========== File Path Input with Native File Picker ==========
@@ -1299,6 +1300,9 @@ export function SettingsPage() {
</div> </div>
</Card> </Card>
{/* Remote Notifications */}
<RemoteNotificationSection />
{/* Reset Settings */} {/* Reset Settings */}
<Card className="p-6 border-destructive/50"> <Card className="p-6 border-destructive/50">
<h2 className="text-lg font-semibold text-foreground flex items-center gap-2 mb-4"> <h2 className="text-lg font-semibold text-foreground flex items-center gap-2 mb-4">

View File

@@ -0,0 +1,193 @@
// ========================================
// Remote Notification Types (Frontend)
// ========================================
// Type definitions for remote notification system UI
// Mirrors backend types with UI-specific additions
/**
* Supported notification platforms
*/
export type NotificationPlatform = 'discord' | 'telegram' | 'webhook';
/**
* Event types that can trigger notifications
*/
export type NotificationEventType =
| 'ask-user-question'
| 'session-start'
| 'session-end'
| 'task-completed'
| 'task-failed';
/**
* Discord platform configuration
*/
export interface DiscordConfig {
enabled: boolean;
webhookUrl: string;
username?: string;
avatarUrl?: string;
}
/**
* Telegram platform configuration
*/
export interface TelegramConfig {
enabled: boolean;
botToken: string;
chatId: string;
parseMode?: 'HTML' | 'Markdown' | 'MarkdownV2';
}
/**
* Generic Webhook platform configuration
*/
export interface WebhookConfig {
enabled: boolean;
url: string;
method: 'POST' | 'PUT';
headers?: Record<string, string>;
timeout?: number;
}
/**
* Event configuration
*/
export interface EventConfig {
event: NotificationEventType;
platforms: NotificationPlatform[];
enabled: boolean;
}
/**
* Full remote notification configuration
*/
export interface RemoteNotificationConfig {
enabled: boolean;
platforms: {
discord?: DiscordConfig;
telegram?: TelegramConfig;
webhook?: WebhookConfig;
};
events: EventConfig[];
timeout: number;
}
/**
* Test notification request
*/
export interface TestNotificationRequest {
platform: NotificationPlatform;
config: DiscordConfig | TelegramConfig | WebhookConfig;
}
/**
* Test notification result
*/
export interface TestNotificationResult {
success: boolean;
error?: string;
responseTime?: number;
}
/**
* Platform display info
*/
export interface PlatformInfo {
id: NotificationPlatform;
name: string;
icon: string;
description: string;
requiredFields: string[];
}
/**
* Event display info
*/
export interface EventInfo {
id: NotificationEventType;
name: string;
description: string;
icon: string;
}
/**
* Predefined platform information
*/
export const PLATFORM_INFO: Record<NotificationPlatform, PlatformInfo> = {
discord: {
id: 'discord',
name: 'Discord',
icon: 'message-circle',
description: 'Send notifications to Discord channels via webhook',
requiredFields: ['webhookUrl'],
},
telegram: {
id: 'telegram',
name: 'Telegram',
icon: 'send',
description: 'Send notifications to Telegram chats via bot',
requiredFields: ['botToken', 'chatId'],
},
webhook: {
id: 'webhook',
name: 'Custom Webhook',
icon: 'link',
description: 'Send notifications to a custom HTTP endpoint',
requiredFields: ['url'],
},
};
/**
* Predefined event information
*/
export const EVENT_INFO: Record<NotificationEventType, EventInfo> = {
'ask-user-question': {
id: 'ask-user-question',
name: 'Ask User Question',
description: 'Notification when Claude asks a question via AskUserQuestion',
icon: 'help-circle',
},
'session-start': {
id: 'session-start',
name: 'Session Start',
description: 'Notification when a CLI session starts',
icon: 'play',
},
'session-end': {
id: 'session-end',
name: 'Session End',
description: 'Notification when a CLI session ends',
icon: 'square',
},
'task-completed': {
id: 'task-completed',
name: 'Task Completed',
description: 'Notification when a task completes successfully',
icon: 'check-circle',
},
'task-failed': {
id: 'task-failed',
name: 'Task Failed',
description: 'Notification when a task fails',
icon: 'alert-circle',
},
};
/**
* Default configuration for UI initialization
*/
export function getDefaultConfig(): RemoteNotificationConfig {
return {
enabled: false,
platforms: {},
events: [
{ event: 'ask-user-question', platforms: ['discord', 'telegram'], enabled: true },
{ event: 'session-start', platforms: [], enabled: false },
{ event: 'session-end', platforms: [], enabled: false },
{ event: 'task-completed', platforms: [], enabled: false },
{ event: 'task-failed', platforms: ['discord', 'telegram'], enabled: true },
],
timeout: 10000,
};
}

View File

@@ -0,0 +1,473 @@
#!/usr/bin/env python3
"""
Unified Memory Embedder - Bridge CCW to CodexLens VectorStore (HNSW)
Uses CodexLens VectorStore for HNSW-indexed vector storage and search,
replacing full-table-scan cosine similarity with sub-10ms approximate
nearest neighbor lookups.
Protocol: JSON via stdin/stdout
Operations: embed, search, search_by_vector, status, reindex
Usage:
echo '{"operation":"embed","store_path":"...","chunks":[...]}' | python unified_memory_embedder.py
echo '{"operation":"search","store_path":"...","query":"..."}' | python unified_memory_embedder.py
echo '{"operation":"status","store_path":"..."}' | python unified_memory_embedder.py
echo '{"operation":"reindex","store_path":"..."}' | python unified_memory_embedder.py
"""
import json
import sys
import time
from pathlib import Path
from typing import List, Dict, Any, Optional
try:
import numpy as np
except ImportError:
print(json.dumps({
"success": False,
"error": "numpy is required. Install with: pip install numpy"
}))
sys.exit(1)
try:
from codexlens.semantic.factory import get_embedder, clear_embedder_cache
from codexlens.semantic.vector_store import VectorStore
from codexlens.entities import SemanticChunk
except ImportError:
print(json.dumps({
"success": False,
"error": "CodexLens not found. Install with: pip install codex-lens[semantic]"
}))
sys.exit(1)
# Valid category values for filtering
VALID_CATEGORIES = {"core_memory", "cli_history", "workflow", "entity", "pattern"}
class UnifiedMemoryEmbedder:
"""Unified embedder backed by CodexLens VectorStore (HNSW)."""
def __init__(self, store_path: str):
"""
Initialize with path to VectorStore database directory.
Args:
store_path: Directory containing vectors.db and vectors.hnsw
"""
self.store_path = Path(store_path)
self.store_path.mkdir(parents=True, exist_ok=True)
db_path = str(self.store_path / "vectors.db")
self.store = VectorStore(db_path)
# Lazy-load embedder to avoid ~0.8s model loading for status command
self._embedder = None
@property
def embedder(self):
"""Lazy-load the embedder on first access."""
if self._embedder is None:
self._embedder = get_embedder(
backend="fastembed",
profile="code",
use_gpu=True
)
return self._embedder
def embed(self, chunks: List[Dict[str, Any]], batch_size: int = 8) -> Dict[str, Any]:
"""
Embed chunks and insert into VectorStore.
Each chunk dict must contain:
- content: str
- source_id: str
- source_type: str (e.g. "core_memory", "workflow", "cli_history")
- category: str (e.g. "core_memory", "cli_history", "workflow", "entity", "pattern")
Optional fields:
- chunk_index: int (default 0)
- metadata: dict (additional metadata)
Args:
chunks: List of chunk dicts to embed
batch_size: Number of chunks to embed per batch
Returns:
Result dict with success, chunks_processed, chunks_failed, elapsed_time
"""
start_time = time.time()
chunks_processed = 0
chunks_failed = 0
if not chunks:
return {
"success": True,
"chunks_processed": 0,
"chunks_failed": 0,
"elapsed_time": 0.0
}
# Process in batches
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [c["content"] for c in batch]
try:
# Batch embed
embeddings = self.embedder.embed_to_numpy(texts)
# Build SemanticChunks and insert
semantic_chunks = []
for j, chunk_data in enumerate(batch):
category = chunk_data.get("category", chunk_data.get("source_type", "core_memory"))
source_id = chunk_data.get("source_id", "")
chunk_index = chunk_data.get("chunk_index", 0)
extra_meta = chunk_data.get("metadata", {})
# Build metadata dict for VectorStore
metadata = {
"source_id": source_id,
"source_type": chunk_data.get("source_type", ""),
"chunk_index": chunk_index,
**extra_meta
}
sc = SemanticChunk(
content=chunk_data["content"],
embedding=embeddings[j].tolist(),
metadata=metadata
)
semantic_chunks.append((sc, source_id, category))
# Insert into VectorStore
for sc, file_path, category in semantic_chunks:
try:
self.store.add_chunk(sc, file_path=file_path, category=category)
chunks_processed += 1
except Exception as e:
print(f"Error inserting chunk: {e}", file=sys.stderr)
chunks_failed += 1
except Exception as e:
print(f"Error embedding batch starting at {i}: {e}", file=sys.stderr)
chunks_failed += len(batch)
elapsed_time = time.time() - start_time
return {
"success": chunks_failed == 0,
"chunks_processed": chunks_processed,
"chunks_failed": chunks_failed,
"elapsed_time": round(elapsed_time, 3)
}
def search(
self,
query: str,
top_k: int = 10,
min_score: float = 0.3,
category: Optional[str] = None
) -> Dict[str, Any]:
"""
Search VectorStore using HNSW index.
Args:
query: Search query text
top_k: Number of results
min_score: Minimum similarity threshold
category: Optional category filter
Returns:
Result dict with success and matches list
"""
try:
start_time = time.time()
# Generate query embedding (embed_to_numpy accepts single string)
query_emb = self.embedder.embed_to_numpy(query)[0].tolist()
# Search via VectorStore HNSW
results = self.store.search_similar(
query_emb,
top_k=top_k,
min_score=min_score,
category=category
)
elapsed_time = time.time() - start_time
matches = []
for result in results:
meta = result.metadata if result.metadata else {}
if isinstance(meta, str):
try:
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
matches.append({
"content": result.content or result.excerpt or "",
"score": round(float(result.score), 4),
"source_id": meta.get("source_id", result.path or ""),
"source_type": meta.get("source_type", ""),
"chunk_index": meta.get("chunk_index", 0),
"category": meta.get("category", ""),
"metadata": meta
})
return {
"success": True,
"matches": matches,
"elapsed_time": round(elapsed_time, 3),
"total_searched": len(results)
}
except Exception as e:
return {
"success": False,
"matches": [],
"error": str(e)
}
def search_by_vector(
self,
vector: List[float],
top_k: int = 10,
min_score: float = 0.3,
category: Optional[str] = None
) -> Dict[str, Any]:
"""
Search VectorStore using a pre-computed embedding vector (no re-embedding).
Args:
vector: Pre-computed embedding vector (list of floats)
top_k: Number of results
min_score: Minimum similarity threshold
category: Optional category filter
Returns:
Result dict with success and matches list
"""
try:
start_time = time.time()
# Search via VectorStore HNSW directly with provided vector
results = self.store.search_similar(
vector,
top_k=top_k,
min_score=min_score,
category=category
)
elapsed_time = time.time() - start_time
matches = []
for result in results:
meta = result.metadata if result.metadata else {}
if isinstance(meta, str):
try:
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
matches.append({
"content": result.content or result.excerpt or "",
"score": round(float(result.score), 4),
"source_id": meta.get("source_id", result.path or ""),
"source_type": meta.get("source_type", ""),
"chunk_index": meta.get("chunk_index", 0),
"category": meta.get("category", ""),
"metadata": meta
})
return {
"success": True,
"matches": matches,
"elapsed_time": round(elapsed_time, 3),
"total_searched": len(results)
}
except Exception as e:
return {
"success": False,
"matches": [],
"error": str(e)
}
def status(self) -> Dict[str, Any]:
"""
Get VectorStore index status.
Returns:
Status dict with total_chunks, hnsw_available, dimension, etc.
"""
try:
total_chunks = self.store.count_chunks()
hnsw_available = self.store.ann_available
hnsw_count = self.store.ann_count
dimension = self.store.dimension or 768
# Count per category from SQLite
categories = {}
try:
import sqlite3
db_path = str(self.store_path / "vectors.db")
with sqlite3.connect(db_path) as conn:
rows = conn.execute(
"SELECT category, COUNT(*) FROM semantic_chunks GROUP BY category"
).fetchall()
for row in rows:
categories[row[0] or "unknown"] = row[1]
except Exception:
pass
return {
"success": True,
"total_chunks": total_chunks,
"hnsw_available": hnsw_available,
"hnsw_count": hnsw_count,
"dimension": dimension,
"categories": categories,
"model_config": {
"backend": "fastembed",
"profile": "code",
"dimension": 768,
"max_tokens": 8192
}
}
except Exception as e:
return {
"success": False,
"total_chunks": 0,
"hnsw_available": False,
"hnsw_count": 0,
"dimension": 0,
"error": str(e)
}
def reindex(self) -> Dict[str, Any]:
"""
Rebuild HNSW index from scratch.
Returns:
Result dict with success and timing
"""
try:
start_time = time.time()
self.store.rebuild_ann_index()
elapsed_time = time.time() - start_time
return {
"success": True,
"hnsw_count": self.store.ann_count,
"elapsed_time": round(elapsed_time, 3)
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
def main():
"""Main entry point. Reads JSON from stdin, writes JSON to stdout."""
try:
raw_input = sys.stdin.read()
if not raw_input.strip():
print(json.dumps({
"success": False,
"error": "No input provided. Send JSON via stdin."
}))
sys.exit(1)
request = json.loads(raw_input)
except json.JSONDecodeError as e:
print(json.dumps({
"success": False,
"error": f"Invalid JSON input: {e}"
}))
sys.exit(1)
operation = request.get("operation")
store_path = request.get("store_path")
if not operation:
print(json.dumps({
"success": False,
"error": "Missing required field: operation"
}))
sys.exit(1)
if not store_path:
print(json.dumps({
"success": False,
"error": "Missing required field: store_path"
}))
sys.exit(1)
try:
embedder = UnifiedMemoryEmbedder(store_path)
if operation == "embed":
chunks = request.get("chunks", [])
batch_size = request.get("batch_size", 8)
result = embedder.embed(chunks, batch_size=batch_size)
elif operation == "search":
query = request.get("query", "")
if not query:
result = {"success": False, "error": "Missing required field: query", "matches": []}
else:
top_k = request.get("top_k", 10)
min_score = request.get("min_score", 0.3)
category = request.get("category")
result = embedder.search(query, top_k=top_k, min_score=min_score, category=category)
elif operation == "search_by_vector":
vector = request.get("vector", [])
if not vector:
result = {"success": False, "error": "Missing required field: vector", "matches": []}
else:
top_k = request.get("top_k", 10)
min_score = request.get("min_score", 0.3)
category = request.get("category")
result = embedder.search_by_vector(vector, top_k=top_k, min_score=min_score, category=category)
elif operation == "status":
result = embedder.status()
elif operation == "reindex":
result = embedder.reindex()
else:
result = {
"success": False,
"error": f"Unknown operation: {operation}. Valid: embed, search, search_by_vector, status, reindex"
}
print(json.dumps(result))
# Clean up ONNX resources to ensure process can exit cleanly
clear_embedder_cache()
except Exception as e:
try:
clear_embedder_cache()
except Exception:
pass
print(json.dumps({
"success": False,
"error": str(e)
}))
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -263,6 +263,10 @@ export function run(argv: string[]): void {
.option('--output <file>', 'Output file path for export') .option('--output <file>', 'Output file path for export')
.option('--overwrite', 'Overwrite existing memories when importing') .option('--overwrite', 'Overwrite existing memories when importing')
.option('--prefix <prefix>', 'Add prefix to imported memory IDs') .option('--prefix <prefix>', 'Add prefix to imported memory IDs')
.option('--unified', 'Use unified vector+FTS search (for search subcommand)')
.option('--topK <n>', 'Max results for unified search', '20')
.option('--minScore <n>', 'Min relevance score for unified search', '0')
.option('--category <cat>', 'Filter by category for unified search')
.action((subcommand, args, options) => coreMemoryCommand(subcommand, args, options)); .action((subcommand, args, options) => coreMemoryCommand(subcommand, args, options));
// Hook command - CLI endpoint for Claude Code hooks // Hook command - CLI endpoint for Claude Code hooks

View File

@@ -35,6 +35,10 @@ interface CommandOptions {
delete?: boolean; delete?: boolean;
merge?: string; merge?: string;
dedup?: boolean; dedup?: boolean;
unified?: boolean;
topK?: string;
minScore?: string;
category?: string;
} }
/** /**
@@ -844,6 +848,114 @@ async function jobsAction(options: CommandOptions): Promise<void> {
} }
} }
/**
* Unified vector+FTS search across all memory stores
*/
async function unifiedSearchAction(keyword: string, options: CommandOptions): Promise<void> {
if (!keyword || keyword.trim() === '') {
console.error(chalk.red('Error: Query is required'));
console.error(chalk.gray('Usage: ccw core-memory search --unified <query> [--topK 20] [--minScore 0] [--category <cat>]'));
process.exit(1);
}
try {
const { UnifiedMemoryService } = await import('../core/unified-memory-service.js');
const service = new UnifiedMemoryService(getProjectPath());
const topK = parseInt(options.topK || '20', 10);
const minScore = parseFloat(options.minScore || '0');
const category = options.category || undefined;
console.log(chalk.cyan(`\n Unified search: "${keyword}" (topK=${topK}, minScore=${minScore})\n`));
const results = await service.search(keyword, {
limit: topK,
minScore,
category: category as any,
});
if (results.length === 0) {
console.log(chalk.yellow(' No results found.\n'));
return;
}
if (options.json) {
console.log(JSON.stringify({ query: keyword, total: results.length, results }, null, 2));
return;
}
console.log(chalk.gray(' -----------------------------------------------------------------------'));
for (const result of results) {
const sources: string[] = [];
if (result.rank_sources.vector_rank) sources.push(`vec:#${result.rank_sources.vector_rank}`);
if (result.rank_sources.fts_rank) sources.push(`fts:#${result.rank_sources.fts_rank}`);
if (result.rank_sources.heat_score) sources.push(`heat:${result.rank_sources.heat_score.toFixed(1)}`);
const snippet = result.content.substring(0, 120).replace(/\n/g, ' ');
console.log(
chalk.cyan(` ${result.source_id}`) +
chalk.gray(` [${result.source_type}/${result.category}]`) +
chalk.white(` score=${result.score.toFixed(4)}`)
);
console.log(chalk.gray(` Sources: ${sources.join(' | ')}`));
console.log(chalk.white(` ${snippet}${result.content.length > 120 ? '...' : ''}`));
console.log(chalk.gray(' -----------------------------------------------------------------------'));
}
console.log(chalk.gray(`\n Total: ${results.length}\n`));
} catch (error) {
console.error(chalk.red(`Error: ${(error as Error).message}`));
process.exit(1);
}
}
/**
* Rebuild the unified HNSW vector index from scratch
*/
async function reindexAction(options: CommandOptions): Promise<void> {
try {
const { UnifiedVectorIndex, isUnifiedEmbedderAvailable } = await import('../core/unified-vector-index.js');
if (!isUnifiedEmbedderAvailable()) {
console.error(chalk.red('Error: Unified embedder is not available.'));
console.error(chalk.gray('Ensure Python venv and embedder script are set up.'));
process.exit(1);
}
const index = new UnifiedVectorIndex(getProjectPath());
console.log(chalk.cyan('\n Rebuilding unified vector index...\n'));
const result = await index.reindexAll();
if (!result.success) {
console.error(chalk.red(` Reindex failed: ${result.error}\n`));
process.exit(1);
}
if (options.json) {
console.log(JSON.stringify(result, null, 2));
return;
}
console.log(chalk.green(' Reindex complete.'));
if (result.hnsw_count !== undefined) {
console.log(chalk.white(` HNSW vectors: ${result.hnsw_count}`));
}
if (result.elapsed_time !== undefined) {
console.log(chalk.white(` Elapsed: ${result.elapsed_time.toFixed(2)}s`));
}
console.log();
} catch (error) {
console.error(chalk.red(`Error: ${(error as Error).message}`));
process.exit(1);
}
}
/** /**
* Core Memory command entry point * Core Memory command entry point
*/ */
@@ -889,7 +1001,11 @@ export async function coreMemoryCommand(
break; break;
case 'search': case 'search':
await searchAction(textArg, options); if (options.unified) {
await unifiedSearchAction(textArg, options);
} else {
await searchAction(textArg, options);
}
break; break;
case 'projects': case 'projects':
@@ -921,6 +1037,10 @@ export async function coreMemoryCommand(
await jobsAction(options); await jobsAction(options);
break; break;
case 'reindex':
await reindexAction(options);
break;
default: default:
console.log(chalk.bold.cyan('\n CCW Core Memory\n')); console.log(chalk.bold.cyan('\n CCW Core Memory\n'));
console.log(' Manage core memory entries and session clusters.\n'); console.log(' Manage core memory entries and session clusters.\n');
@@ -945,12 +1065,14 @@ export async function coreMemoryCommand(
console.log(chalk.white(' context ') + chalk.gray('Get progressive index')); console.log(chalk.white(' context ') + chalk.gray('Get progressive index'));
console.log(chalk.white(' load-cluster <id> ') + chalk.gray('Load cluster context')); console.log(chalk.white(' load-cluster <id> ') + chalk.gray('Load cluster context'));
console.log(chalk.white(' search <keyword> ') + chalk.gray('Search sessions')); console.log(chalk.white(' search <keyword> ') + chalk.gray('Search sessions'));
console.log(chalk.white(' search --unified <query> ') + chalk.gray('Unified vector+FTS search'));
console.log(); console.log();
console.log(chalk.bold(' Memory V2 Pipeline:')); console.log(chalk.bold(' Memory V2 Pipeline:'));
console.log(chalk.white(' extract ') + chalk.gray('Run batch memory extraction')); console.log(chalk.white(' extract ') + chalk.gray('Run batch memory extraction'));
console.log(chalk.white(' extract-status ') + chalk.gray('Show extraction pipeline status')); console.log(chalk.white(' extract-status ') + chalk.gray('Show extraction pipeline status'));
console.log(chalk.white(' consolidate ') + chalk.gray('Run memory consolidation')); console.log(chalk.white(' consolidate ') + chalk.gray('Run memory consolidation'));
console.log(chalk.white(' jobs ') + chalk.gray('List all pipeline jobs')); console.log(chalk.white(' jobs ') + chalk.gray('List all pipeline jobs'));
console.log(chalk.white(' reindex ') + chalk.gray('Rebuild unified vector index'));
console.log(); console.log();
console.log(chalk.bold(' Options:')); console.log(chalk.bold(' Options:'));
console.log(chalk.gray(' --id <id> Memory ID (for export/summary)')); console.log(chalk.gray(' --id <id> Memory ID (for export/summary)'));

View File

@@ -12,7 +12,7 @@ interface HookOptions {
stdin?: boolean; stdin?: boolean;
sessionId?: string; sessionId?: string;
prompt?: string; prompt?: string;
type?: 'session-start' | 'context'; type?: 'session-start' | 'context' | 'session-end';
path?: string; path?: string;
} }
@@ -95,10 +95,32 @@ function getProjectPath(hookCwd?: string): string {
return hookCwd || process.cwd(); return hookCwd || process.cwd();
} }
/**
* Check if UnifiedContextBuilder is available (embedder dependencies present).
* Returns the builder instance or null if not available.
*/
async function tryCreateContextBuilder(projectPath: string): Promise<any | null> {
try {
const { isUnifiedEmbedderAvailable } = await import('../core/unified-vector-index.js');
if (!isUnifiedEmbedderAvailable()) {
return null;
}
const { UnifiedContextBuilder } = await import('../core/unified-context-builder.js');
return new UnifiedContextBuilder(projectPath);
} catch {
return null;
}
}
/** /**
* Session context action - provides progressive context loading * Session context action - provides progressive context loading
* First prompt: returns session overview with clusters *
* Subsequent prompts: returns intent-matched sessions * Uses UnifiedContextBuilder when available (embedder present):
* - session-start: MEMORY.md summary + clusters + hot entities + patterns
* - per-prompt: vector search across all memory categories
*
* Falls back to SessionClusteringService.getProgressiveIndex() when
* the embedder is unavailable, preserving backward compatibility.
*/ */
async function sessionContextAction(options: HookOptions): Promise<void> { async function sessionContextAction(options: HookOptions): Promise<void> {
let { stdin, sessionId, prompt } = options; let { stdin, sessionId, prompt } = options;
@@ -154,29 +176,43 @@ async function sessionContextAction(options: HookOptions): Promise<void> {
let contextType: 'session-start' | 'context'; let contextType: 'session-start' | 'context';
let content = ''; let content = '';
// Dynamic import to avoid circular dependencies // Try UnifiedContextBuilder first; fall back to getProgressiveIndex
const { SessionClusteringService } = await import('../core/session-clustering-service.js'); const contextBuilder = await tryCreateContextBuilder(projectPath);
const clusteringService = new SessionClusteringService(projectPath);
if (isFirstPrompt) { if (contextBuilder) {
// First prompt: return session overview with clusters // Use UnifiedContextBuilder
contextType = 'session-start'; if (isFirstPrompt) {
content = await clusteringService.getProgressiveIndex({ contextType = 'session-start';
type: 'session-start', content = await contextBuilder.buildSessionStartContext();
sessionId } else if (prompt && prompt.trim().length > 0) {
}); contextType = 'context';
} else if (prompt && prompt.trim().length > 0) { content = await contextBuilder.buildPromptContext(prompt);
// Subsequent prompts with content: return intent-matched sessions } else {
contextType = 'context'; contextType = 'context';
content = await clusteringService.getProgressiveIndex({ content = '';
type: 'context', }
sessionId,
prompt
});
} else { } else {
// Subsequent prompts without content: return minimal context // Fallback: use legacy SessionClusteringService.getProgressiveIndex()
contextType = 'context'; const { SessionClusteringService } = await import('../core/session-clustering-service.js');
content = ''; // No context needed for empty prompts const clusteringService = new SessionClusteringService(projectPath);
if (isFirstPrompt) {
contextType = 'session-start';
content = await clusteringService.getProgressiveIndex({
type: 'session-start',
sessionId
});
} else if (prompt && prompt.trim().length > 0) {
contextType = 'context';
content = await clusteringService.getProgressiveIndex({
type: 'context',
sessionId,
prompt
});
} else {
contextType = 'context';
content = '';
}
} }
if (stdin) { if (stdin) {
@@ -194,6 +230,7 @@ async function sessionContextAction(options: HookOptions): Promise<void> {
console.log(chalk.cyan('Type:'), contextType); console.log(chalk.cyan('Type:'), contextType);
console.log(chalk.cyan('First Prompt:'), isFirstPrompt ? 'Yes' : 'No'); console.log(chalk.cyan('First Prompt:'), isFirstPrompt ? 'Yes' : 'No');
console.log(chalk.cyan('Load Count:'), newState.loadCount); console.log(chalk.cyan('Load Count:'), newState.loadCount);
console.log(chalk.cyan('Builder:'), contextBuilder ? 'UnifiedContextBuilder' : 'Legacy (getProgressiveIndex)');
console.log(chalk.gray('─'.repeat(40))); console.log(chalk.gray('─'.repeat(40)));
if (content) { if (content) {
console.log(content); console.log(content);
@@ -210,6 +247,81 @@ async function sessionContextAction(options: HookOptions): Promise<void> {
} }
} }
/**
* Session end action - triggers async background tasks for memory maintenance.
*
* Tasks executed:
* 1. Incremental vector embedding (index new/updated content)
* 2. Incremental clustering (cluster unclustered sessions)
* 3. Heat score updates (recalculate entity heat scores)
*
* All tasks run best-effort; failures are logged but do not affect exit code.
*/
async function sessionEndAction(options: HookOptions): Promise<void> {
let { stdin, sessionId } = options;
let hookCwd: string | undefined;
if (stdin) {
try {
const stdinData = await readStdin();
if (stdinData) {
const hookData = JSON.parse(stdinData) as HookData;
sessionId = hookData.session_id || sessionId;
hookCwd = hookData.cwd;
}
} catch {
// Silently continue if stdin parsing fails
}
}
if (!sessionId) {
if (!stdin) {
console.error(chalk.red('Error: --session-id is required'));
}
process.exit(stdin ? 0 : 1);
}
try {
const projectPath = getProjectPath(hookCwd);
const contextBuilder = await tryCreateContextBuilder(projectPath);
if (!contextBuilder) {
// UnifiedContextBuilder not available - skip session-end tasks
if (!stdin) {
console.log(chalk.gray('(UnifiedContextBuilder not available, skipping session-end tasks)'));
}
process.exit(0);
}
const tasks: Array<{ name: string; execute: () => Promise<void> }> = contextBuilder.buildSessionEndTasks(sessionId);
if (!stdin) {
console.log(chalk.green(`Session End: executing ${tasks.length} background tasks...`));
}
// Execute all tasks concurrently (best-effort)
const results = await Promise.allSettled(
tasks.map((task: { name: string; execute: () => Promise<void> }) => task.execute())
);
if (!stdin) {
for (let i = 0; i < tasks.length; i++) {
const status = results[i].status === 'fulfilled' ? 'OK' : 'FAIL';
const color = status === 'OK' ? chalk.green : chalk.yellow;
console.log(color(` [${status}] ${tasks[i].name}`));
}
}
process.exit(0);
} catch (error) {
if (stdin) {
process.exit(0);
}
console.error(chalk.red(`Error: ${(error as Error).message}`));
process.exit(1);
}
}
/** /**
* Parse CCW status.json and output formatted status * Parse CCW status.json and output formatted status
*/ */
@@ -311,6 +423,7 @@ ${chalk.bold('USAGE')}
${chalk.bold('SUBCOMMANDS')} ${chalk.bold('SUBCOMMANDS')}
parse-status Parse CCW status.json and display current/next command parse-status Parse CCW status.json and display current/next command
session-context Progressive session context loading (replaces curl/bash hook) session-context Progressive session context loading (replaces curl/bash hook)
session-end Trigger background memory maintenance tasks
notify Send notification to ccw view dashboard notify Send notification to ccw view dashboard
${chalk.bold('OPTIONS')} ${chalk.bold('OPTIONS')}
@@ -363,6 +476,9 @@ export async function hookCommand(
case 'context': case 'context':
await sessionContextAction(options); await sessionContextAction(options);
break; break;
case 'session-end':
await sessionEndAction(options);
break;
case 'notify': case 'notify':
await notifyAction(options); await notifyAction(options);
break; break;

View File

@@ -0,0 +1,154 @@
// ========================================
// Remote Notification Configuration Manager
// ========================================
// Manages persistent storage of remote notification settings
// Storage: ~/.ccw/config/remote-notification.json
import { existsSync, mkdirSync, readFileSync, writeFileSync } from 'fs';
import { join } from 'path';
import { getCCWHome, ensureStorageDir } from './storage-paths.js';
import type {
RemoteNotificationConfig,
DEFAULT_REMOTE_NOTIFICATION_CONFIG,
} from '../types/remote-notification.js';
import { DeepPartial, deepMerge } from '../types/util.js';
/**
* Configuration file path
*/
function getConfigFilePath(): string {
return join(getCCWHome(), 'config', 'remote-notification.json');
}
/**
* Ensure configuration directory exists
*/
function ensureConfigDir(): void {
const configDir = join(getCCWHome(), 'config');
ensureStorageDir(configDir);
}
/**
* Default configuration factory
*/
export function getDefaultConfig(): RemoteNotificationConfig {
return {
enabled: false,
platforms: {},
events: [
{ event: 'ask-user-question', platforms: ['discord', 'telegram'], enabled: true },
{ event: 'session-start', platforms: [], enabled: false },
{ event: 'session-end', platforms: [], enabled: false },
{ event: 'task-completed', platforms: [], enabled: false },
{ event: 'task-failed', platforms: ['discord', 'telegram'], enabled: true },
],
timeout: 10000,
};
}
/**
* Load remote notification configuration
* Returns default config if file doesn't exist
*/
export function loadConfig(): RemoteNotificationConfig {
const configPath = getConfigFilePath();
if (!existsSync(configPath)) {
return getDefaultConfig();
}
try {
const data = readFileSync(configPath, 'utf-8');
const parsed = JSON.parse(data);
// Merge with defaults to ensure all fields exist
return deepMerge(getDefaultConfig(), parsed);
} catch (error) {
console.error('[RemoteNotificationConfig] Failed to load config:', error);
return getDefaultConfig();
}
}
/**
* Save remote notification configuration
*/
export function saveConfig(config: RemoteNotificationConfig): void {
ensureConfigDir();
const configPath = getConfigFilePath();
try {
writeFileSync(configPath, JSON.stringify(config, null, 2), 'utf-8');
} catch (error) {
console.error('[RemoteNotificationConfig] Failed to save config:', error);
throw error;
}
}
/**
* Update configuration with partial changes
*/
export function updateConfig(
updates: DeepPartial<RemoteNotificationConfig>
): RemoteNotificationConfig {
const current = loadConfig();
const updated = deepMerge(current, updates);
saveConfig(updated);
return updated;
}
/**
* Reset configuration to defaults
*/
export function resetConfig(): RemoteNotificationConfig {
const defaultConfig = getDefaultConfig();
saveConfig(defaultConfig);
return defaultConfig;
}
/**
* Check if any platform is configured and enabled
*/
export function hasEnabledPlatform(config: RemoteNotificationConfig): boolean {
if (!config.enabled) return false;
const { discord, telegram, webhook } = config.platforms;
return (
(discord?.enabled && !!discord.webhookUrl) ||
(telegram?.enabled && !!telegram.botToken && !!telegram.chatId) ||
(webhook?.enabled && !!webhook.url)
);
}
/**
* Get enabled platforms for a specific event
*/
export function getEnabledPlatformsForEvent(
config: RemoteNotificationConfig,
eventType: string
): string[] {
if (!config.enabled) return [];
const eventConfig = config.events.find((e) => e.event === eventType);
if (!eventConfig || !eventConfig.enabled) return [];
return eventConfig.platforms.filter((platform) => {
const platformConfig = config.platforms[platform as keyof typeof config.platforms];
if (!platformConfig) return false;
switch (platform) {
case 'discord':
return (platformConfig as { enabled: boolean; webhookUrl?: string }).enabled &&
!!(platformConfig as { webhookUrl?: string }).webhookUrl;
case 'telegram':
return (platformConfig as { enabled: boolean; botToken?: string; chatId?: string }).enabled &&
!!(platformConfig as { botToken?: string }).botToken &&
!!(platformConfig as { chatId?: string }).chatId;
case 'webhook':
return (platformConfig as { enabled: boolean; url?: string }).enabled &&
!!(platformConfig as { url?: string }).url;
default:
return false;
}
});
}

View File

@@ -388,6 +388,15 @@ export interface ProjectPaths {
/** Skills directory */ /** Skills directory */
skills: string; skills: string;
}; };
/** Unified vector index paths (HNSW-backed) */
unifiedVectors: {
/** Root: <projectRoot>/unified-vectors/ */
root: string;
/** SQLite database for vector metadata */
vectorsDb: string;
/** HNSW index file */
hnswIndex: string;
};
} }
/** /**
@@ -454,6 +463,11 @@ export function getProjectPaths(projectPath: string): ProjectPaths {
memoryMd: join(projectDir, 'core-memory', 'v2', 'MEMORY.md'), memoryMd: join(projectDir, 'core-memory', 'v2', 'MEMORY.md'),
skills: join(projectDir, 'core-memory', 'v2', 'skills'), skills: join(projectDir, 'core-memory', 'v2', 'skills'),
}, },
unifiedVectors: {
root: join(projectDir, 'unified-vectors'),
vectorsDb: join(projectDir, 'unified-vectors', 'vectors.db'),
hnswIndex: join(projectDir, 'unified-vectors', 'vectors.hnsw'),
},
}; };
} }
@@ -483,6 +497,11 @@ export function getProjectPathsById(projectId: string): ProjectPaths {
memoryMd: join(projectDir, 'core-memory', 'v2', 'MEMORY.md'), memoryMd: join(projectDir, 'core-memory', 'v2', 'MEMORY.md'),
skills: join(projectDir, 'core-memory', 'v2', 'skills'), skills: join(projectDir, 'core-memory', 'v2', 'skills'),
}, },
unifiedVectors: {
root: join(projectDir, 'unified-vectors'),
vectorsDb: join(projectDir, 'unified-vectors', 'vectors.db'),
hnswIndex: join(projectDir, 'unified-vectors', 'vectors.hnsw'),
},
}; };
} }

View File

@@ -7,6 +7,8 @@ import Database from 'better-sqlite3';
import { existsSync, mkdirSync } from 'fs'; import { existsSync, mkdirSync } from 'fs';
import { join } from 'path'; import { join } from 'path';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js'; import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
import { UnifiedVectorIndex, isUnifiedEmbedderAvailable } from './unified-vector-index.js';
import type { ChunkMetadata } from './unified-vector-index.js';
// Types // Types
export interface CoreMemory { export interface CoreMemory {
@@ -101,6 +103,7 @@ export class CoreMemoryStore {
private db: Database.Database; private db: Database.Database;
private dbPath: string; private dbPath: string;
private projectPath: string; private projectPath: string;
private vectorIndex: UnifiedVectorIndex | null = null;
constructor(projectPath: string) { constructor(projectPath: string) {
this.projectPath = projectPath; this.projectPath = projectPath;
@@ -328,6 +331,38 @@ export class CoreMemoryStore {
return this.db; return this.db;
} }
/**
* Get or create the UnifiedVectorIndex instance (lazy initialization).
* Returns null if the embedder is not available.
*/
private getVectorIndex(): UnifiedVectorIndex | null {
if (this.vectorIndex) return this.vectorIndex;
if (!isUnifiedEmbedderAvailable()) return null;
this.vectorIndex = new UnifiedVectorIndex(this.projectPath);
return this.vectorIndex;
}
/**
* Fire-and-forget: sync content to the vector index.
* Logs errors but never throws, to avoid disrupting the synchronous write path.
*/
private syncToVectorIndex(content: string, sourceId: string): void {
const idx = this.getVectorIndex();
if (!idx) return;
const metadata: ChunkMetadata = {
source_id: sourceId,
source_type: 'core_memory',
category: 'core_memory',
};
idx.indexContent(content, metadata).catch((err) => {
if (process.env.DEBUG) {
console.error(`[CoreMemoryStore] Vector index sync failed for ${sourceId}:`, (err as Error).message);
}
});
}
/** /**
* Generate timestamp-based ID for core memory * Generate timestamp-based ID for core memory
*/ */
@@ -387,6 +422,9 @@ export class CoreMemoryStore {
id id
); );
// Sync updated content to vector index
this.syncToVectorIndex(memory.content, id);
return this.getMemory(id)!; return this.getMemory(id)!;
} else { } else {
// Insert new memory // Insert new memory
@@ -406,6 +444,9 @@ export class CoreMemoryStore {
memory.metadata || null memory.metadata || null
); );
// Sync new content to vector index
this.syncToVectorIndex(memory.content, id);
return this.getMemory(id)!; return this.getMemory(id)!;
} }
} }

View File

@@ -13,6 +13,10 @@ import type { ConversationRecord } from '../tools/cli-history-store.js';
import { getHistoryStore } from '../tools/cli-history-store.js'; import { getHistoryStore } from '../tools/cli-history-store.js';
import { getCoreMemoryStore, type Stage1Output } from './core-memory-store.js'; import { getCoreMemoryStore, type Stage1Output } from './core-memory-store.js';
import { MemoryJobScheduler } from './memory-job-scheduler.js'; import { MemoryJobScheduler } from './memory-job-scheduler.js';
import { UnifiedVectorIndex, isUnifiedEmbedderAvailable } from './unified-vector-index.js';
import type { ChunkMetadata } from './unified-vector-index.js';
import { SessionClusteringService } from './session-clustering-service.js';
import { PatternDetector } from './pattern-detector.js';
import { import {
MAX_SESSION_AGE_DAYS, MAX_SESSION_AGE_DAYS,
MIN_IDLE_HOURS, MIN_IDLE_HOURS,
@@ -384,9 +388,38 @@ export class MemoryExtractionPipeline {
const store = getCoreMemoryStore(this.projectPath); const store = getCoreMemoryStore(this.projectPath);
store.upsertStage1Output(output); store.upsertStage1Output(output);
// Sync extracted content to vector index (fire-and-forget)
this.syncExtractionToVectorIndex(output);
return output; return output;
} }
/**
* Sync extraction output to the vector index.
* Indexes both raw_memory and rollout_summary with category='cli_history'.
* Fire-and-forget: errors are logged but never thrown.
*/
private syncExtractionToVectorIndex(output: Stage1Output): void {
if (!isUnifiedEmbedderAvailable()) return;
const vectorIndex = new UnifiedVectorIndex(this.projectPath);
const combinedContent = `${output.raw_memory}\n\n---\n\n${output.rollout_summary}`;
const metadata: ChunkMetadata = {
source_id: output.thread_id,
source_type: 'cli_history',
category: 'cli_history',
};
vectorIndex.indexContent(combinedContent, metadata).catch((err) => {
if (process.env.DEBUG) {
console.error(
`[MemoryExtractionPipeline] Vector index sync failed for ${output.thread_id}:`,
(err as Error).message
);
}
});
}
// ======================================================================== // ========================================================================
// Batch orchestration // Batch orchestration
// ======================================================================== // ========================================================================
@@ -461,6 +494,76 @@ export class MemoryExtractionPipeline {
await Promise.all(promises); await Promise.all(promises);
} }
// Post-extraction: trigger incremental clustering and pattern detection
// These are fire-and-forget to avoid blocking the main extraction flow.
if (result.succeeded > 0) {
this.triggerPostExtractionHooks(
eligibleSessions.filter((_, i) => i < result.processed).map(s => s.id)
);
}
return result; return result;
} }
/**
* Fire-and-forget: trigger incremental clustering and pattern detection
* after Phase 1 extraction completes.
*
* - incrementalCluster: processes each newly extracted session
* - detectPatterns: runs pattern detection across all chunks
*
* Errors are logged but never thrown, to avoid disrupting the caller.
*/
private triggerPostExtractionHooks(extractedSessionIds: string[]): void {
const clusteringService = new SessionClusteringService(this.projectPath);
const patternDetector = new PatternDetector(this.projectPath);
// Incremental clustering for each extracted session (fire-and-forget)
(async () => {
try {
// Check frequency control before running clustering
const shouldCluster = await clusteringService.shouldRunClustering();
if (!shouldCluster) {
if (process.env.DEBUG) {
console.log('[PostExtraction] Clustering skipped: frequency control not met');
}
return;
}
for (const sessionId of extractedSessionIds) {
try {
await clusteringService.incrementalCluster(sessionId);
} catch (err) {
if (process.env.DEBUG) {
console.warn(
`[PostExtraction] Incremental clustering failed for ${sessionId}:`,
(err as Error).message
);
}
}
}
} catch (err) {
if (process.env.DEBUG) {
console.warn('[PostExtraction] Clustering hook failed:', (err as Error).message);
}
}
})();
// Pattern detection (fire-and-forget)
(async () => {
try {
const result = await patternDetector.detectPatterns();
if (result.patterns.length > 0) {
console.log(
`[PostExtraction] Pattern detection: ${result.patterns.length} patterns found, ` +
`${result.solidified.length} solidified (${result.elapsedMs}ms)`
);
}
} catch (err) {
if (process.env.DEBUG) {
console.warn('[PostExtraction] Pattern detection failed:', (err as Error).message);
}
}
})();
}
} }

View File

@@ -0,0 +1,485 @@
/**
* Pattern Detector - Detects recurring content patterns across sessions
*
* Uses vector clustering (cosine similarity > 0.85) to group semantically similar
* chunks into patterns. Patterns appearing in N>=3 distinct sessions are flagged
* as candidates. High-confidence patterns (>=0.8) are solidified into CoreMemory
* and skills/*.md files.
*/
import { CoreMemoryStore, getCoreMemoryStore } from './core-memory-store.js';
import { UnifiedVectorIndex, isUnifiedEmbedderAvailable } from './unified-vector-index.js';
import type { VectorSearchMatch } from './unified-vector-index.js';
import { existsSync, mkdirSync, writeFileSync } from 'fs';
import { join } from 'path';
// -- Constants --
/** Minimum cosine similarity to group chunks into the same pattern */
const PATTERN_SIMILARITY_THRESHOLD = 0.85;
/** Minimum number of distinct sessions a pattern must appear in */
const MIN_SESSION_FREQUENCY = 3;
/** Confidence threshold for auto-solidification */
const SOLIDIFY_CONFIDENCE_THRESHOLD = 0.8;
/** Maximum number of chunks to analyze per detection run */
const MAX_CHUNKS_TO_ANALYZE = 200;
/** Top-K neighbors to search per chunk during clustering */
const NEIGHBOR_TOP_K = 15;
// -- Types --
export interface DetectedPattern {
/** Unique pattern identifier */
id: string;
/** Human-readable pattern name derived from content */
name: string;
/** Representative content snippet */
representative: string;
/** Source IDs (sessions) where this pattern appears */
sourceIds: string[];
/** Number of distinct sessions */
sessionCount: number;
/** Average similarity score within the pattern group */
avgSimilarity: number;
/** Confidence score (0-1), based on frequency and similarity */
confidence: number;
/** Category of the chunks in this pattern */
category: string;
}
export interface PatternDetectionResult {
/** All detected patterns */
patterns: DetectedPattern[];
/** Number of chunks analyzed */
chunksAnalyzed: number;
/** Patterns that were solidified (written to CoreMemory + skills) */
solidified: string[];
/** Elapsed time in ms */
elapsedMs: number;
}
export interface SolidifyResult {
memoryId: string;
skillPath: string | null;
}
// -- PatternDetector --
export class PatternDetector {
private projectPath: string;
private coreMemoryStore: CoreMemoryStore;
private vectorIndex: UnifiedVectorIndex | null = null;
constructor(projectPath: string) {
this.projectPath = projectPath;
this.coreMemoryStore = getCoreMemoryStore(projectPath);
if (isUnifiedEmbedderAvailable()) {
this.vectorIndex = new UnifiedVectorIndex(projectPath);
}
}
/**
* Detect recurring patterns across sessions by vector clustering.
*
* Algorithm:
* 1. Get representative chunks from VectorStore (via search with broad queries)
* 2. For each chunk, search HNSW for nearest neighbors (cosine > PATTERN_SIMILARITY_THRESHOLD)
* 3. Group chunks with high mutual similarity into pattern clusters
* 4. Count distinct source_ids per cluster (session frequency)
* 5. Patterns with sessionCount >= MIN_SESSION_FREQUENCY become candidates
*
* @returns Detection result with candidate patterns
*/
async detectPatterns(): Promise<PatternDetectionResult> {
const startTime = Date.now();
const result: PatternDetectionResult = {
patterns: [],
chunksAnalyzed: 0,
solidified: [],
elapsedMs: 0,
};
if (!this.vectorIndex) {
result.elapsedMs = Date.now() - startTime;
return result;
}
// Step 1: Gather chunks from the vector store via broad category searches
const allChunks = await this.gatherChunksForAnalysis();
result.chunksAnalyzed = allChunks.length;
if (allChunks.length < MIN_SESSION_FREQUENCY) {
result.elapsedMs = Date.now() - startTime;
return result;
}
// Step 2: Cluster chunks by vector similarity
const patternGroups = await this.clusterChunksByVector(allChunks);
// Step 3: Filter by session frequency and build DetectedPattern objects
for (const group of patternGroups) {
const uniqueSources = new Set(group.map(c => c.source_id));
if (uniqueSources.size < MIN_SESSION_FREQUENCY) continue;
const avgSim = group.reduce((sum, c) => sum + c.score, 0) / group.length;
// Confidence: combines frequency (normalized) and avg similarity
const frequencyScore = Math.min(uniqueSources.size / 10, 1.0);
const confidence = avgSim * 0.6 + frequencyScore * 0.4;
const representative = group[0]; // Highest scoring chunk
const patternName = this.derivePatternName(group);
const patternId = `PAT-${Date.now()}-${Math.random().toString(36).substring(2, 6)}`;
result.patterns.push({
id: patternId,
name: patternName,
representative: representative.content.substring(0, 500),
sourceIds: Array.from(uniqueSources),
sessionCount: uniqueSources.size,
avgSimilarity: Math.round(avgSim * 1000) / 1000,
confidence: Math.round(confidence * 1000) / 1000,
category: representative.category || 'unknown',
});
}
// Sort by confidence descending
result.patterns.sort((a, b) => b.confidence - a.confidence);
// Step 4: Auto-solidify high-confidence patterns (fire-and-forget)
for (const pattern of result.patterns) {
if (pattern.confidence >= SOLIDIFY_CONFIDENCE_THRESHOLD) {
try {
await this.solidifyPattern(pattern);
result.solidified.push(pattern.id);
} catch (err) {
console.warn(
`[PatternDetector] Failed to solidify pattern ${pattern.id}:`,
(err as Error).message
);
}
}
}
result.elapsedMs = Date.now() - startTime;
return result;
}
/**
* Gather a representative set of chunks for pattern analysis.
* Uses broad search queries across categories to collect diverse chunks.
*/
private async gatherChunksForAnalysis(): Promise<VectorSearchMatch[]> {
if (!this.vectorIndex) return [];
const allChunks: VectorSearchMatch[] = [];
const seenContent = new Set<string>();
// Search across common categories with broad queries
const broadQueries = [
'implementation pattern',
'configuration setup',
'error handling',
'testing approach',
'workflow process',
];
const categories = ['core_memory', 'cli_history', 'workflow'] as const;
for (const category of categories) {
for (const query of broadQueries) {
if (allChunks.length >= MAX_CHUNKS_TO_ANALYZE) break;
try {
const result = await this.vectorIndex.search(query, {
topK: Math.ceil(MAX_CHUNKS_TO_ANALYZE / (broadQueries.length * categories.length)),
minScore: 0.1,
category,
});
if (result.success) {
for (const match of result.matches) {
// Deduplicate by content hash (first 100 chars)
const contentKey = match.content.substring(0, 100);
if (!seenContent.has(contentKey)) {
seenContent.add(contentKey);
allChunks.push(match);
}
if (allChunks.length >= MAX_CHUNKS_TO_ANALYZE) break;
}
}
} catch {
// Search failed for this query/category, continue
}
}
}
return allChunks;
}
/**
* Cluster chunks by vector similarity using HNSW neighbor search.
*
* For each unprocessed chunk, search for its nearest neighbors.
* Chunks with cosine similarity > PATTERN_SIMILARITY_THRESHOLD are grouped together.
* Uses a union-find-like approach via visited tracking.
*/
private async clusterChunksByVector(
chunks: VectorSearchMatch[]
): Promise<VectorSearchMatch[][]> {
if (!this.vectorIndex) return [];
const groups: VectorSearchMatch[][] = [];
const processed = new Set<number>();
for (let i = 0; i < chunks.length; i++) {
if (processed.has(i)) continue;
const seedChunk = chunks[i];
const group: VectorSearchMatch[] = [seedChunk];
processed.add(i);
// Search for neighbors of this chunk's content
try {
const neighbors = await this.vectorIndex.search(seedChunk.content, {
topK: NEIGHBOR_TOP_K,
minScore: PATTERN_SIMILARITY_THRESHOLD,
});
if (neighbors.success) {
for (const neighbor of neighbors.matches) {
// Skip self-matches
if (neighbor.content === seedChunk.content) continue;
// Find this neighbor in our chunk list
for (let j = 0; j < chunks.length; j++) {
if (processed.has(j)) continue;
if (
chunks[j].source_id === neighbor.source_id &&
chunks[j].chunk_index === neighbor.chunk_index
) {
group.push({ ...chunks[j], score: neighbor.score });
processed.add(j);
break;
}
}
// Also include neighbors not in our original list
if (neighbor.source_id && neighbor.source_id !== seedChunk.source_id) {
// Check if already in group by source_id
const alreadyInGroup = group.some(
g => g.source_id === neighbor.source_id && g.chunk_index === neighbor.chunk_index
);
if (!alreadyInGroup) {
group.push(neighbor);
}
}
}
}
} catch {
// HNSW search failed, skip this chunk's neighborhood
}
// Only keep groups with chunks from multiple sources
const uniqueSources = new Set(group.map(c => c.source_id));
if (uniqueSources.size >= 2) {
groups.push(group);
}
}
return groups;
}
/**
* Derive a human-readable pattern name from a group of similar chunks.
* Extracts common keywords/phrases from the representative content.
*/
private derivePatternName(group: VectorSearchMatch[]): string {
// Extended stopwords including generic tech terms
const stopwords = new Set([
'the', 'and', 'for', 'that', 'this', 'with', 'from', 'have', 'will',
'are', 'was', 'were', 'been', 'what', 'when', 'where', 'which',
'there', 'their', 'they', 'them', 'then', 'than', 'into', 'some',
'code', 'file', 'function', 'class', 'import', 'export', 'const',
'async', 'await', 'return', 'type', 'interface', 'string', 'number',
'true', 'false', 'null', 'undefined', 'object', 'array', 'value',
'data', 'result', 'error', 'name', 'path', 'index', 'item', 'list',
'should', 'would', 'could', 'does', 'make', 'like', 'just', 'also',
'used', 'using', 'each', 'other', 'more', 'only', 'need', 'very',
]);
const isSignificant = (w: string) => w.length >= 4 && !stopwords.has(w);
// Count word and bigram frequency across all chunks
const wordFreq = new Map<string, number>();
const bigramFreq = new Map<string, number>();
for (const chunk of group) {
const words = chunk.content.toLowerCase().split(/[\s\W]+/).filter(isSignificant);
const uniqueWords = new Set(words);
for (const word of uniqueWords) {
wordFreq.set(word, (wordFreq.get(word) || 0) + 1);
}
// Extract bigrams from consecutive significant words
for (let i = 0; i < words.length - 1; i++) {
const bigram = `${words[i]}-${words[i + 1]}`;
bigramFreq.set(bigram, (bigramFreq.get(bigram) || 0) + 1);
}
}
// Prefer bigrams that appear in multiple chunks
const topBigrams = Array.from(bigramFreq.entries())
.filter(([, count]) => count >= 2)
.sort((a, b) => b[1] - a[1]);
if (topBigrams.length > 0) {
// Use top bigram, optionally append a distinguishing single word
const name = topBigrams[0][0];
const bigramWords = new Set(name.split('-'));
const extra = Array.from(wordFreq.entries())
.filter(([w, count]) => count >= 2 && !bigramWords.has(w))
.sort((a, b) => b[1] - a[1]);
if (extra.length > 0) {
const candidate = `${name}-${extra[0][0]}`;
return candidate.length <= 50 ? candidate : name;
}
return name;
}
// Fallback to top single words
const topWords = Array.from(wordFreq.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 3)
.map(([w]) => w);
if (topWords.length >= 2) {
const name = topWords.join('-');
return name.length <= 50 ? name : topWords.slice(0, 2).join('-');
} else if (topWords.length === 1) {
return topWords[0];
}
return 'unnamed-pattern';
}
/**
* Solidify a detected pattern by writing it to CoreMemory and skills/*.md.
*
* Creates:
* 1. A CoreMemory entry with the pattern content and metadata
* 2. A skills/{pattern_slug}.md file with the pattern documentation
*
* This method is fire-and-forget - errors are logged but not propagated.
*
* @param pattern - The detected pattern to solidify
* @returns Result with memory ID and skill file path
*/
async solidifyPattern(pattern: DetectedPattern): Promise<SolidifyResult> {
// 1. Create CoreMemory entry
const memoryContent = this.buildPatternMemoryContent(pattern);
const memory = this.coreMemoryStore.upsertMemory({
content: memoryContent,
summary: `Detected pattern: ${pattern.name} (${pattern.sessionCount} sessions, confidence: ${pattern.confidence})`,
metadata: JSON.stringify({
type: 'detected_pattern',
pattern_id: pattern.id,
pattern_name: pattern.name,
session_count: pattern.sessionCount,
confidence: pattern.confidence,
source_ids: pattern.sourceIds,
detected_at: new Date().toISOString(),
}),
});
// 2. Write skills file
let skillPath: string | null = null;
try {
const slug = pattern.name
.toLowerCase()
.replace(/[^a-z0-9]+/g, '-')
.replace(/^-|-$/g, '')
.substring(0, 50);
const skillsDir = join(this.projectPath, '.claude', 'skills');
if (!existsSync(skillsDir)) {
mkdirSync(skillsDir, { recursive: true });
}
skillPath = join(skillsDir, `${slug}.md`);
const skillContent = this.buildSkillContent(pattern);
writeFileSync(skillPath, skillContent, 'utf-8');
} catch (err) {
console.warn(
`[PatternDetector] Failed to write skill file for ${pattern.name}:`,
(err as Error).message
);
skillPath = null;
}
console.log(
`[PatternDetector] Solidified pattern '${pattern.name}' -> memory=${memory.id}, skill=${skillPath || 'none'}`
);
return { memoryId: memory.id, skillPath };
}
/**
* Build CoreMemory content for a detected pattern.
*/
private buildPatternMemoryContent(pattern: DetectedPattern): string {
const lines: string[] = [
`# Detected Pattern: ${pattern.name}`,
'',
`**Confidence**: ${pattern.confidence}`,
`**Sessions**: ${pattern.sessionCount} (${pattern.sourceIds.join(', ')})`,
`**Category**: ${pattern.category}`,
`**Avg Similarity**: ${pattern.avgSimilarity}`,
'',
'## Representative Content',
'',
pattern.representative,
'',
'## Usage',
'',
'This pattern was automatically detected across multiple sessions.',
'It represents a recurring approach or concept in this project.',
];
return lines.join('\n');
}
/**
* Build skill file content for a detected pattern.
*/
private buildSkillContent(pattern: DetectedPattern): string {
const lines: string[] = [
`# ${pattern.name}`,
'',
`> Auto-detected pattern (confidence: ${pattern.confidence}, sessions: ${pattern.sessionCount})`,
'',
'## Description',
'',
pattern.representative,
'',
'## Context',
'',
`This pattern was detected across ${pattern.sessionCount} sessions:`,
...pattern.sourceIds.map(id => `- ${id}`),
'',
'## When to Apply',
'',
'Apply this pattern when working on similar tasks or encountering related concepts.',
'',
`---`,
`*Auto-generated by PatternDetector on ${new Date().toISOString()}*`,
];
return lines.join('\n');
}
}

View File

@@ -0,0 +1,357 @@
// ========================================
// Remote Notification Routes
// ========================================
// API endpoints for remote notification configuration
import type { IncomingMessage, ServerResponse } from 'http';
import { URL } from 'url';
import {
loadConfig,
saveConfig,
resetConfig,
} from '../../config/remote-notification-config.js';
import {
remoteNotificationService,
} from '../../services/remote-notification-service.js';
import {
maskSensitiveConfig,
type RemoteNotificationConfig,
type TestNotificationRequest,
type NotificationPlatform,
type DiscordConfig,
type TelegramConfig,
type WebhookConfig,
} from '../../types/remote-notification.js';
import { deepMerge } from '../../types/util.js';
// ========== Input Validation ==========
/**
* Validate URL format (must be http or https)
*/
function isValidUrl(url: string): boolean {
try {
const parsed = new URL(url);
return ['http:', 'https:'].includes(parsed.protocol);
} catch {
return false;
}
}
/**
* Validate Discord webhook URL format
*/
function isValidDiscordWebhookUrl(url: string): boolean {
if (!isValidUrl(url)) return false;
try {
const parsed = new URL(url);
// Discord webhooks are typically: discord.com/api/webhooks/{id}/{token}
return (
(parsed.hostname === 'discord.com' || parsed.hostname === 'discordapp.com') &&
parsed.pathname.startsWith('/api/webhooks/')
);
} catch {
return false;
}
}
/**
* Validate Telegram bot token format (typically: 123456789:ABCdef...)
*/
function isValidTelegramBotToken(token: string): boolean {
// Telegram bot tokens are in format: {bot_id}:{token}
// Bot ID is a number, token is alphanumeric with underscores and hyphens
return /^\d{8,15}:[A-Za-z0-9_-]{30,50}$/.test(token);
}
/**
* Validate Telegram chat ID format
*/
function isValidTelegramChatId(chatId: string): boolean {
// Chat IDs are numeric, optionally negative (for groups)
return /^-?\d{1,20}$/.test(chatId);
}
/**
* Validate webhook headers (must be valid JSON object)
*/
function isValidHeaders(headers: unknown): { valid: boolean; error?: string } {
if (headers === undefined || headers === null) {
return { valid: true }; // Optional field
}
if (typeof headers !== 'object' || Array.isArray(headers)) {
return { valid: false, error: 'Headers must be an object' };
}
const headerObj = headers as Record<string, unknown>;
// Check for reasonable size limit (10KB)
const serialized = JSON.stringify(headers);
if (serialized.length > 10240) {
return { valid: false, error: 'Headers too large (max 10KB)' };
}
// Validate each header key and value
for (const [key, value] of Object.entries(headerObj)) {
if (typeof key !== 'string' || key.length === 0) {
return { valid: false, error: 'Header keys must be non-empty strings' };
}
if (typeof value !== 'string') {
return { valid: false, error: `Header '${key}' value must be a string` };
}
// Block potentially dangerous headers
const lowerKey = key.toLowerCase();
if (['host', 'content-length', 'connection'].includes(lowerKey)) {
return { valid: false, error: `Header '${key}' is not allowed` };
}
}
return { valid: true };
}
/**
* Validate configuration updates
*/
function validateConfigUpdates(updates: Partial<RemoteNotificationConfig>): { valid: boolean; error?: string } {
// Validate platforms if present
if (updates.platforms) {
const { discord, telegram, webhook } = updates.platforms;
// Validate Discord config
if (discord) {
if (discord.webhookUrl !== undefined && discord.webhookUrl !== '') {
if (!isValidUrl(discord.webhookUrl)) {
return { valid: false, error: 'Invalid Discord webhook URL format' };
}
// Warning: we allow non-Discord URLs for flexibility, but log it
if (!isValidDiscordWebhookUrl(discord.webhookUrl)) {
console.warn('[RemoteNotification] Webhook URL does not match Discord format');
}
}
if (discord.username !== undefined && discord.username.length > 80) {
return { valid: false, error: 'Discord username too long (max 80 chars)' };
}
}
// Validate Telegram config
if (telegram) {
if (telegram.botToken !== undefined && telegram.botToken !== '') {
if (!isValidTelegramBotToken(telegram.botToken)) {
return { valid: false, error: 'Invalid Telegram bot token format' };
}
}
if (telegram.chatId !== undefined && telegram.chatId !== '') {
if (!isValidTelegramChatId(telegram.chatId)) {
return { valid: false, error: 'Invalid Telegram chat ID format' };
}
}
}
// Validate Webhook config
if (webhook) {
if (webhook.url !== undefined && webhook.url !== '') {
if (!isValidUrl(webhook.url)) {
return { valid: false, error: 'Invalid webhook URL format' };
}
}
if (webhook.headers !== undefined) {
const headerValidation = isValidHeaders(webhook.headers);
if (!headerValidation.valid) {
return { valid: false, error: headerValidation.error };
}
}
if (webhook.timeout !== undefined && (webhook.timeout < 1000 || webhook.timeout > 60000)) {
return { valid: false, error: 'Webhook timeout must be between 1000ms and 60000ms' };
}
}
}
// Validate timeout
if (updates.timeout !== undefined && (updates.timeout < 1000 || updates.timeout > 60000)) {
return { valid: false, error: 'Timeout must be between 1000ms and 60000ms' };
}
return { valid: true };
}
/**
* Validate test notification request
*/
function validateTestRequest(request: TestNotificationRequest): { valid: boolean; error?: string } {
if (!request.platform) {
return { valid: false, error: 'Missing platform' };
}
const validPlatforms: NotificationPlatform[] = ['discord', 'telegram', 'webhook'];
if (!validPlatforms.includes(request.platform as NotificationPlatform)) {
return { valid: false, error: `Invalid platform: ${request.platform}` };
}
if (!request.config) {
return { valid: false, error: 'Missing config' };
}
// Platform-specific validation
switch (request.platform) {
case 'discord': {
const config = request.config as Partial<DiscordConfig>;
if (!config.webhookUrl) {
return { valid: false, error: 'Discord webhook URL is required' };
}
if (!isValidUrl(config.webhookUrl)) {
return { valid: false, error: 'Invalid Discord webhook URL format' };
}
break;
}
case 'telegram': {
const config = request.config as Partial<TelegramConfig>;
if (!config.botToken) {
return { valid: false, error: 'Telegram bot token is required' };
}
if (!config.chatId) {
return { valid: false, error: 'Telegram chat ID is required' };
}
if (!isValidTelegramBotToken(config.botToken)) {
return { valid: false, error: 'Invalid Telegram bot token format' };
}
if (!isValidTelegramChatId(config.chatId)) {
return { valid: false, error: 'Invalid Telegram chat ID format' };
}
break;
}
case 'webhook': {
const config = request.config as Partial<WebhookConfig>;
if (!config.url) {
return { valid: false, error: 'Webhook URL is required' };
}
if (!isValidUrl(config.url)) {
return { valid: false, error: 'Invalid webhook URL format' };
}
if (config.headers) {
const headerValidation = isValidHeaders(config.headers);
if (!headerValidation.valid) {
return { valid: false, error: headerValidation.error };
}
}
break;
}
}
return { valid: true };
}
/**
* Handle remote notification routes
* GET /api/notifications/remote/config - Get current config
* POST /api/notifications/remote/config - Update config
* POST /api/notifications/remote/test - Test notification
* POST /api/notifications/remote/reset - Reset to defaults
*/
export async function handleNotificationRoutes(
req: IncomingMessage,
res: ServerResponse,
pathname: string
): Promise<boolean> {
// GET /api/notifications/remote/config
if (pathname === '/api/notifications/remote/config' && req.method === 'GET') {
const config = loadConfig();
const masked = maskSensitiveConfig(config);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(masked));
return true;
}
// POST /api/notifications/remote/config
if (pathname === '/api/notifications/remote/config' && req.method === 'POST') {
const body = await readBody(req);
try {
const updates = JSON.parse(body) as Partial<RemoteNotificationConfig>;
// Validate input
const validation = validateConfigUpdates(updates);
if (!validation.valid) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: validation.error }));
return true;
}
const current = loadConfig();
const updated = deepMerge(current, updates);
saveConfig(updated);
// Reload service config
remoteNotificationService.reloadConfig();
const masked = maskSensitiveConfig(updated);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ success: true, config: masked }));
} catch (error) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
error: error instanceof Error ? error.message : 'Invalid configuration',
}));
}
return true;
}
// POST /api/notifications/remote/test
if (pathname === '/api/notifications/remote/test' && req.method === 'POST') {
const body = await readBody(req);
try {
const request = JSON.parse(body) as TestNotificationRequest;
// Validate input
const validation = validateTestRequest(request);
if (!validation.valid) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ success: false, error: validation.error }));
return true;
}
const result = await remoteNotificationService.testPlatform(
request.platform as NotificationPlatform,
request.config
);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(result));
} catch (error) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
success: false,
error: error instanceof Error ? error.message : 'Invalid request',
}));
}
return true;
}
// POST /api/notifications/remote/reset
if (pathname === '/api/notifications/remote/reset' && req.method === 'POST') {
const config = resetConfig();
remoteNotificationService.reloadConfig();
const masked = maskSensitiveConfig(config);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ success: true, config: masked }));
return true;
}
return false;
}
/**
* Read request body as string
*/
async function readBody(req: IncomingMessage): Promise<string> {
return new Promise((resolve, reject) => {
let body = '';
req.on('data', (chunk) => { body += chunk; });
req.on('end', () => resolve(body));
req.on('error', reject);
});
}

View File

@@ -0,0 +1,151 @@
/**
* Unified Memory API Routes
*
* Provides HTTP endpoints for the unified memory system:
* - GET /api/unified-memory/search - RRF fusion search (vector + FTS5)
* - GET /api/unified-memory/stats - Aggregated statistics
* - POST /api/unified-memory/reindex - Rebuild HNSW vector index
* - GET /api/unified-memory/recommendations/:id - KNN recommendations
*/
import type { RouteContext } from './types.js';
/**
* Handle Unified Memory API routes.
* @returns true if route was handled, false otherwise
*/
export async function handleUnifiedMemoryRoutes(ctx: RouteContext): Promise<boolean> {
const { pathname, url, req, res, initialPath, handlePostRequest } = ctx;
// =========================================================================
// GET /api/unified-memory/search
// Query params: q (required), categories, topK, minScore
// =========================================================================
if (pathname === '/api/unified-memory/search' && req.method === 'GET') {
const query = url.searchParams.get('q');
const projectPath = url.searchParams.get('path') || initialPath;
const topK = parseInt(url.searchParams.get('topK') || '20', 10);
const minScore = parseFloat(url.searchParams.get('minScore') || '0');
const category = url.searchParams.get('category') || undefined;
if (!query) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'Query parameter q is required' }));
return true;
}
try {
const { UnifiedMemoryService } = await import('../unified-memory-service.js');
const service = new UnifiedMemoryService(projectPath);
const results = await service.search(query, {
limit: topK,
minScore,
category: category as any,
});
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
success: true,
query,
total: results.length,
results,
}));
} catch (error: unknown) {
res.writeHead(500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: (error as Error).message }));
}
return true;
}
// =========================================================================
// GET /api/unified-memory/stats
// =========================================================================
if (pathname === '/api/unified-memory/stats' && req.method === 'GET') {
const projectPath = url.searchParams.get('path') || initialPath;
try {
const { UnifiedMemoryService } = await import('../unified-memory-service.js');
const service = new UnifiedMemoryService(projectPath);
const stats = await service.getStats();
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ success: true, stats }));
} catch (error: unknown) {
res.writeHead(500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: (error as Error).message }));
}
return true;
}
// =========================================================================
// POST /api/unified-memory/reindex
// Body (optional): { path: string }
// =========================================================================
if (pathname === '/api/unified-memory/reindex' && req.method === 'POST') {
handlePostRequest(req, res, async (body: any) => {
const { path: projectPath } = body || {};
const basePath = projectPath || initialPath;
try {
const { UnifiedVectorIndex, isUnifiedEmbedderAvailable } = await import('../unified-vector-index.js');
if (!isUnifiedEmbedderAvailable()) {
return {
error: 'Unified embedder is not available. Ensure Python venv and embedder script are set up.',
status: 503,
};
}
const index = new UnifiedVectorIndex(basePath);
const result = await index.reindexAll();
return {
success: result.success,
hnsw_count: result.hnsw_count,
elapsed_time: result.elapsed_time,
error: result.error,
};
} catch (error: unknown) {
return { error: (error as Error).message, status: 500 };
}
});
return true;
}
// =========================================================================
// GET /api/unified-memory/recommendations/:id
// Query params: limit (optional, default 5)
// =========================================================================
if (pathname.startsWith('/api/unified-memory/recommendations/') && req.method === 'GET') {
const memoryId = pathname.replace('/api/unified-memory/recommendations/', '');
const projectPath = url.searchParams.get('path') || initialPath;
const limit = parseInt(url.searchParams.get('limit') || '5', 10);
if (!memoryId) {
res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'Memory ID is required' }));
return true;
}
try {
const { UnifiedMemoryService } = await import('../unified-memory-service.js');
const service = new UnifiedMemoryService(projectPath);
const recommendations = await service.getRecommendations(memoryId, limit);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
success: true,
memory_id: memoryId,
total: recommendations.length,
recommendations,
}));
} catch (error: unknown) {
res.writeHead(500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: (error as Error).message }));
}
return true;
}
return false;
}

View File

@@ -9,6 +9,7 @@ import { handleAuditRoutes } from './routes/audit-routes.js';
import { handleProviderRoutes } from './routes/provider-routes.js'; import { handleProviderRoutes } from './routes/provider-routes.js';
import { handleMemoryRoutes } from './routes/memory-routes.js'; import { handleMemoryRoutes } from './routes/memory-routes.js';
import { handleCoreMemoryRoutes } from './routes/core-memory-routes.js'; import { handleCoreMemoryRoutes } from './routes/core-memory-routes.js';
import { handleUnifiedMemoryRoutes } from './routes/unified-memory-routes.js';
import { handleMcpRoutes } from './routes/mcp-routes.js'; import { handleMcpRoutes } from './routes/mcp-routes.js';
import { handleHooksRoutes } from './routes/hooks-routes.js'; import { handleHooksRoutes } from './routes/hooks-routes.js';
import { handleUnsplashRoutes, handleBackgroundRoutes } from './routes/unsplash-routes.js'; import { handleUnsplashRoutes, handleBackgroundRoutes } from './routes/unsplash-routes.js';
@@ -37,6 +38,7 @@ import { handleDashboardRoutes } from './routes/dashboard-routes.js';
import { handleOrchestratorRoutes } from './routes/orchestrator-routes.js'; import { handleOrchestratorRoutes } from './routes/orchestrator-routes.js';
import { handleConfigRoutes } from './routes/config-routes.js'; import { handleConfigRoutes } from './routes/config-routes.js';
import { handleTeamRoutes } from './routes/team-routes.js'; import { handleTeamRoutes } from './routes/team-routes.js';
import { handleNotificationRoutes } from './routes/notification-routes.js';
// Import WebSocket handling // Import WebSocket handling
import { handleWebSocketUpgrade, broadcastToClients, extractSessionIdFromPath } from './websocket.js'; import { handleWebSocketUpgrade, broadcastToClients, extractSessionIdFromPath } from './websocket.js';
@@ -462,6 +464,11 @@ export async function startServer(options: ServerOptions = {}): Promise<http.Ser
if (await handleCoreMemoryRoutes(routeContext)) return; if (await handleCoreMemoryRoutes(routeContext)) return;
} }
// Unified Memory routes (/api/unified-memory/*)
if (pathname.startsWith('/api/unified-memory/')) {
if (await handleUnifiedMemoryRoutes(routeContext)) return;
}
// MCP routes (/api/mcp*, /api/codex-mcp*) // MCP routes (/api/mcp*, /api/codex-mcp*)
if (pathname.startsWith('/api/mcp') || pathname.startsWith('/api/codex-mcp')) { if (pathname.startsWith('/api/mcp') || pathname.startsWith('/api/codex-mcp')) {
@@ -533,6 +540,11 @@ export async function startServer(options: ServerOptions = {}): Promise<http.Ser
if (await handleTeamRoutes(routeContext)) return; if (await handleTeamRoutes(routeContext)) return;
} }
// Remote notification routes (/api/notifications/remote/*)
if (pathname.startsWith('/api/notifications/remote')) {
if (await handleNotificationRoutes(req, res, pathname)) return;
}
// Task routes (/api/tasks) // Task routes (/api/tasks)
if (pathname.startsWith('/api/tasks')) { if (pathname.startsWith('/api/tasks')) {
if (await handleTaskRoutes(routeContext)) return; if (await handleTaskRoutes(routeContext)) return;

View File

@@ -0,0 +1,592 @@
// ========================================
// Remote Notification Service
// ========================================
// Core service for dispatching notifications to external platforms
// Non-blocking, best-effort delivery with parallel dispatch
import http from 'http';
import https from 'https';
import { URL } from 'url';
import type {
RemoteNotificationConfig,
NotificationContext,
NotificationDispatchResult,
PlatformNotificationResult,
NotificationPlatform,
DiscordConfig,
TelegramConfig,
WebhookConfig,
} from '../../types/remote-notification.js';
import {
loadConfig,
getEnabledPlatformsForEvent,
hasEnabledPlatform,
} from '../../config/remote-notification-config.js';
/**
* Remote Notification Service
* Handles dispatching notifications to configured platforms
*/
class RemoteNotificationService {
private config: RemoteNotificationConfig | null = null;
private configLoadedAt: number = 0;
private readonly CONFIG_TTL = 30000; // Reload config every 30 seconds
/**
* Get current config (with auto-reload)
*/
private getConfig(): RemoteNotificationConfig {
const now = Date.now();
if (!this.config || now - this.configLoadedAt > this.CONFIG_TTL) {
this.config = loadConfig();
this.configLoadedAt = now;
}
return this.config;
}
/**
* Force reload configuration
*/
reloadConfig(): void {
this.config = loadConfig();
this.configLoadedAt = Date.now();
}
/**
* Check if notifications are enabled for a given event
*/
shouldNotify(eventType: string): boolean {
const config = this.getConfig();
if (!config.enabled) return false;
const enabledPlatforms = getEnabledPlatformsForEvent(config, eventType);
return enabledPlatforms.length > 0;
}
/**
* Send notification to all configured platforms for an event
* Non-blocking: returns immediately, actual dispatch is async
*/
sendNotification(
eventType: string,
context: Omit<NotificationContext, 'eventType' | 'timestamp'>
): void {
const config = this.getConfig();
// Quick check before async dispatch
if (!config.enabled) return;
const enabledPlatforms = getEnabledPlatformsForEvent(config, eventType);
if (enabledPlatforms.length === 0) return;
const fullContext: NotificationContext = {
...context,
eventType: eventType as NotificationContext['eventType'],
timestamp: new Date().toISOString(),
};
// Fire-and-forget dispatch
this.dispatchToPlatforms(enabledPlatforms, fullContext, config).catch((error) => {
// Silent failure - log only
console.error('[RemoteNotification] Dispatch failed:', error);
});
}
/**
* Send notification and wait for results (for testing)
*/
async sendNotificationAsync(
eventType: string,
context: Omit<NotificationContext, 'eventType' | 'timestamp'>
): Promise<NotificationDispatchResult> {
const config = this.getConfig();
const startTime = Date.now();
if (!config.enabled) {
return { success: false, results: [], totalTime: 0 };
}
const enabledPlatforms = getEnabledPlatformsForEvent(config, eventType);
if (enabledPlatforms.length === 0) {
return { success: false, results: [], totalTime: Date.now() - startTime };
}
const fullContext: NotificationContext = {
...context,
eventType: eventType as NotificationContext['eventType'],
timestamp: new Date().toISOString(),
};
const results = await this.dispatchToPlatforms(enabledPlatforms, fullContext, config);
return {
success: results.some((r) => r.success),
results,
totalTime: Date.now() - startTime,
};
}
/**
* Dispatch to multiple platforms in parallel
*/
private async dispatchToPlatforms(
platforms: string[],
context: NotificationContext,
config: RemoteNotificationConfig
): Promise<PlatformNotificationResult[]> {
const promises = platforms.map((platform) =>
this.dispatchToPlatform(platform as NotificationPlatform, context, config)
);
const results = await Promise.allSettled(promises);
return results.map((result, index) => {
if (result.status === 'fulfilled') {
return result.value;
}
return {
platform: platforms[index] as NotificationPlatform,
success: false,
error: result.reason?.message || 'Unknown error',
};
});
}
/**
* Dispatch to a single platform
*/
private async dispatchToPlatform(
platform: NotificationPlatform,
context: NotificationContext,
config: RemoteNotificationConfig
): Promise<PlatformNotificationResult> {
const startTime = Date.now();
try {
switch (platform) {
case 'discord':
return await this.sendDiscord(context, config.platforms.discord!, config.timeout);
case 'telegram':
return await this.sendTelegram(context, config.platforms.telegram!, config.timeout);
case 'webhook':
return await this.sendWebhook(context, config.platforms.webhook!, config.timeout);
default:
return {
platform,
success: false,
error: `Unknown platform: ${platform}`,
};
}
} catch (error) {
return {
platform,
success: false,
error: error instanceof Error ? error.message : String(error),
responseTime: Date.now() - startTime,
};
}
}
/**
* Send Discord notification via webhook
*/
private async sendDiscord(
context: NotificationContext,
config: DiscordConfig,
timeout: number
): Promise<PlatformNotificationResult> {
const startTime = Date.now();
if (!config.webhookUrl) {
return { platform: 'discord', success: false, error: 'Webhook URL not configured' };
}
const embed = this.buildDiscordEmbed(context);
const body = {
username: config.username || 'CCW Notification',
avatar_url: config.avatarUrl,
embeds: [embed],
};
try {
await this.httpRequest(config.webhookUrl, body, timeout);
return {
platform: 'discord',
success: true,
responseTime: Date.now() - startTime,
};
} catch (error) {
return {
platform: 'discord',
success: false,
error: error instanceof Error ? error.message : String(error),
responseTime: Date.now() - startTime,
};
}
}
/**
* Build Discord embed from context
*/
private buildDiscordEmbed(context: NotificationContext): Record<string, unknown> {
const eventEmoji: Record<string, string> = {
'ask-user-question': '❓',
'session-start': '▶️',
'session-end': '⏹️',
'task-completed': '✅',
'task-failed': '❌',
};
const eventColors: Record<string, number> = {
'ask-user-question': 0x3498db, // Blue
'session-start': 0x2ecc71, // Green
'session-end': 0x95a5a6, // Gray
'task-completed': 0x27ae60, // Dark Green
'task-failed': 0xe74c3c, // Red
};
const fields: Array<{ name: string; value: string; inline?: boolean }> = [];
if (context.sessionId) {
fields.push({ name: 'Session', value: context.sessionId.slice(0, 16) + '...', inline: true });
}
if (context.questionText) {
const truncated = context.questionText.length > 200
? context.questionText.slice(0, 200) + '...'
: context.questionText;
fields.push({ name: 'Question', value: truncated, inline: false });
}
if (context.taskDescription) {
const truncated = context.taskDescription.length > 200
? context.taskDescription.slice(0, 200) + '...'
: context.taskDescription;
fields.push({ name: 'Task', value: truncated, inline: false });
}
if (context.errorMessage) {
const truncated = context.errorMessage.length > 200
? context.errorMessage.slice(0, 200) + '...'
: context.errorMessage;
fields.push({ name: 'Error', value: truncated, inline: false });
}
return {
title: `${eventEmoji[context.eventType] || '📢'} ${this.formatEventName(context.eventType)}`,
color: eventColors[context.eventType] || 0x9b59b6,
fields,
timestamp: context.timestamp,
footer: { text: 'CCW Remote Notification' },
};
}
/**
* Send Telegram notification via Bot API
*/
private async sendTelegram(
context: NotificationContext,
config: TelegramConfig,
timeout: number
): Promise<PlatformNotificationResult> {
const startTime = Date.now();
if (!config.botToken || !config.chatId) {
return { platform: 'telegram', success: false, error: 'Bot token or chat ID not configured' };
}
const text = this.buildTelegramMessage(context);
const url = `https://api.telegram.org/bot${config.botToken}/sendMessage`;
const body = {
chat_id: config.chatId,
text,
parse_mode: config.parseMode || 'HTML',
};
try {
await this.httpRequest(url, body, timeout);
return {
platform: 'telegram',
success: true,
responseTime: Date.now() - startTime,
};
} catch (error) {
return {
platform: 'telegram',
success: false,
error: error instanceof Error ? error.message : String(error),
responseTime: Date.now() - startTime,
};
}
}
/**
* Build Telegram message from context
*/
private buildTelegramMessage(context: NotificationContext): string {
const eventEmoji: Record<string, string> = {
'ask-user-question': '❓',
'session-start': '▶️',
'session-end': '⏹️',
'task-completed': '✅',
'task-failed': '❌',
};
const lines: string[] = [];
lines.push(`<b>${eventEmoji[context.eventType] || '📢'} ${this.formatEventName(context.eventType)}</b>`);
lines.push('');
if (context.sessionId) {
lines.push(`<b>Session:</b> <code>${context.sessionId.slice(0, 16)}...</code>`);
}
if (context.questionText) {
const truncated = context.questionText.length > 300
? context.questionText.slice(0, 300) + '...'
: context.questionText;
lines.push(`<b>Question:</b> ${this.escapeHtml(truncated)}`);
}
if (context.taskDescription) {
const truncated = context.taskDescription.length > 300
? context.taskDescription.slice(0, 300) + '...'
: context.taskDescription;
lines.push(`<b>Task:</b> ${this.escapeHtml(truncated)}`);
}
if (context.errorMessage) {
const truncated = context.errorMessage.length > 300
? context.errorMessage.slice(0, 300) + '...'
: context.errorMessage;
lines.push(`<b>Error:</b> <code>${this.escapeHtml(truncated)}</code>`);
}
lines.push('');
lines.push(`<i>📅 ${new Date(context.timestamp).toLocaleString()}</i>`);
return lines.join('\n');
}
/**
* Send generic webhook notification
*/
private async sendWebhook(
context: NotificationContext,
config: WebhookConfig,
timeout: number
): Promise<PlatformNotificationResult> {
const startTime = Date.now();
if (!config.url) {
return { platform: 'webhook', success: false, error: 'Webhook URL not configured' };
}
const body = {
event: context.eventType,
timestamp: context.timestamp,
sessionId: context.sessionId,
questionText: context.questionText,
taskDescription: context.taskDescription,
errorMessage: context.errorMessage,
metadata: context.metadata,
};
try {
await this.httpRequest(config.url, body, config.timeout || timeout, config.method, config.headers);
return {
platform: 'webhook',
success: true,
responseTime: Date.now() - startTime,
};
} catch (error) {
return {
platform: 'webhook',
success: false,
error: error instanceof Error ? error.message : String(error),
responseTime: Date.now() - startTime,
};
}
}
/**
* Check if a URL is safe from SSRF attacks
* Blocks private IP ranges, loopback, and link-local addresses
*/
private isUrlSafe(urlString: string): { safe: boolean; error?: string } {
try {
const parsedUrl = new URL(urlString);
// Only allow http and https protocols
if (!['http:', 'https:'].includes(parsedUrl.protocol)) {
return { safe: false, error: 'Only http and https protocols are allowed' };
}
const hostname = parsedUrl.hostname.toLowerCase();
// Block localhost variants
if (hostname === 'localhost' || hostname === 'localhost.localdomain' || hostname === '0.0.0.0') {
return { safe: false, error: 'Localhost addresses are not allowed' };
}
// Block IPv4 loopback (127.0.0.0/8)
if (/^127\.\d{1,3}\.\d{1,3}\.\d{1,3}$/.test(hostname)) {
return { safe: false, error: 'Loopback addresses are not allowed' };
}
// Block IPv4 private ranges
// 10.0.0.0/8
if (/^10\.\d{1,3}\.\d{1,3}\.\d{1,3}$/.test(hostname)) {
return { safe: false, error: 'Private IP addresses are not allowed' };
}
// 172.16.0.0/12
if (/^172\.(1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}$/.test(hostname)) {
return { safe: false, error: 'Private IP addresses are not allowed' };
}
// 192.168.0.0/16
if (/^192\.168\.\d{1,3}\.\d{1,3}$/.test(hostname)) {
return { safe: false, error: 'Private IP addresses are not allowed' };
}
// Block link-local addresses (169.254.0.0/16)
if (/^169\.254\.\d{1,3}\.\d{1,3}$/.test(hostname)) {
return { safe: false, error: 'Link-local addresses are not allowed' };
}
// Block IPv6 loopback and private
if (hostname === '::1' || hostname.startsWith('fc') || hostname.startsWith('fd') || hostname === '::') {
return { safe: false, error: 'IPv6 private/loopback addresses are not allowed' };
}
// Block hostnames that look like IP addresses in various formats
// (e.g., 0x7f.0.0.1, 2130706433, etc.)
if (/^0x[0-9a-f]+/i.test(hostname) || /^\d{8,}$/.test(hostname)) {
return { safe: false, error: 'Suspicious hostname format' };
}
// Block cloud metadata endpoints
if (hostname === '169.254.169.254' || hostname === 'metadata.google.internal' || hostname === 'metadata.azure.internal') {
return { safe: false, error: 'Cloud metadata endpoints are not allowed' };
}
return { safe: true };
} catch (error) {
return { safe: false, error: 'Invalid URL format' };
}
}
/**
* Generic HTTP request helper
*/
private httpRequest(
url: string,
body: unknown,
timeout: number,
method: 'POST' | 'PUT' = 'POST',
headers: Record<string, string> = {}
): Promise<void> {
return new Promise((resolve, reject) => {
// SSRF protection: validate URL before making request
const urlSafety = this.isUrlSafe(url);
if (!urlSafety.safe) {
reject(new Error(`URL validation failed: ${urlSafety.error}`));
return;
}
const parsedUrl = new URL(url);
const isHttps = parsedUrl.protocol === 'https:';
const client = isHttps ? https : http;
const requestOptions: http.RequestOptions = {
hostname: parsedUrl.hostname,
port: parsedUrl.port || (isHttps ? 443 : 80),
path: parsedUrl.pathname + parsedUrl.search,
method,
headers: {
'Content-Type': 'application/json',
...headers,
},
timeout,
};
const req = client.request(requestOptions, (res) => {
let data = '';
res.on('data', (chunk) => { data += chunk; });
res.on('end', () => {
if (res.statusCode && res.statusCode >= 200 && res.statusCode < 300) {
resolve();
} else {
reject(new Error(`HTTP ${res.statusCode}: ${data.slice(0, 200)}`));
}
});
});
req.on('error', reject);
req.on('timeout', () => {
req.destroy();
reject(new Error('Request timeout'));
});
req.write(JSON.stringify(body));
req.end();
});
}
/**
* Format event name for display
*/
private formatEventName(eventType: string): string {
return eventType
.split('-')
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
}
/**
* Escape HTML for Telegram messages
*/
private escapeHtml(text: string): string {
return text
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;');
}
/**
* Test a platform configuration
*/
async testPlatform(
platform: NotificationPlatform,
config: DiscordConfig | TelegramConfig | WebhookConfig
): Promise<{ success: boolean; error?: string; responseTime?: number }> {
const testContext: NotificationContext = {
eventType: 'task-completed',
sessionId: 'test-session',
taskDescription: 'This is a test notification from CCW',
timestamp: new Date().toISOString(),
};
const startTime = Date.now();
try {
switch (platform) {
case 'discord':
return await this.sendDiscord(testContext, config as DiscordConfig, 10000);
case 'telegram':
return await this.sendTelegram(testContext, config as TelegramConfig, 10000);
case 'webhook':
return await this.sendWebhook(testContext, config as WebhookConfig, 10000);
default:
return { success: false, error: `Unknown platform: ${platform}` };
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : String(error),
responseTime: Date.now() - startTime,
};
}
}
}
// Singleton instance
export const remoteNotificationService = new RemoteNotificationService();

View File

@@ -5,6 +5,7 @@
import { CoreMemoryStore, SessionCluster, ClusterMember, SessionMetadataCache } from './core-memory-store.js'; import { CoreMemoryStore, SessionCluster, ClusterMember, SessionMetadataCache } from './core-memory-store.js';
import { CliHistoryStore } from '../tools/cli-history-store.js'; import { CliHistoryStore } from '../tools/cli-history-store.js';
import { UnifiedVectorIndex, isUnifiedEmbedderAvailable } from './unified-vector-index.js';
import { StoragePaths } from '../config/storage-paths.js'; import { StoragePaths } from '../config/storage-paths.js';
import { readdirSync, readFileSync, statSync, existsSync } from 'fs'; import { readdirSync, readFileSync, statSync, existsSync } from 'fs';
import { join } from 'path'; import { join } from 'path';
@@ -21,6 +22,10 @@ const WEIGHTS = {
// Clustering threshold (0.4 = moderate similarity required) // Clustering threshold (0.4 = moderate similarity required)
const CLUSTER_THRESHOLD = 0.4; const CLUSTER_THRESHOLD = 0.4;
// Incremental clustering frequency control
const MIN_CLUSTER_INTERVAL_HOURS = 6;
const MIN_NEW_SESSIONS_FOR_CLUSTER = 5;
export interface ClusteringOptions { export interface ClusteringOptions {
scope?: 'all' | 'recent' | 'unclustered'; scope?: 'all' | 'recent' | 'unclustered';
timeRange?: { start: string; end: string }; timeRange?: { start: string; end: string };
@@ -33,15 +38,29 @@ export interface ClusteringResult {
sessionsClustered: number; sessionsClustered: number;
} }
export interface IncrementalClusterResult {
sessionId: string;
clusterId: string | null;
action: 'joined_existing' | 'created_new' | 'skipped';
}
export class SessionClusteringService { export class SessionClusteringService {
private coreMemoryStore: CoreMemoryStore; private coreMemoryStore: CoreMemoryStore;
private cliHistoryStore: CliHistoryStore; private cliHistoryStore: CliHistoryStore;
private projectPath: string; private projectPath: string;
private vectorIndex: UnifiedVectorIndex | null = null;
/** Cache: sessionId -> list of nearby session source_ids from HNSW search */
private vectorNeighborCache: Map<string, Map<string, number>> = new Map();
constructor(projectPath: string) { constructor(projectPath: string) {
this.projectPath = projectPath; this.projectPath = projectPath;
this.coreMemoryStore = new CoreMemoryStore(projectPath); this.coreMemoryStore = new CoreMemoryStore(projectPath);
this.cliHistoryStore = new CliHistoryStore(projectPath); this.cliHistoryStore = new CliHistoryStore(projectPath);
// Initialize vector index if available
if (isUnifiedEmbedderAvailable()) {
this.vectorIndex = new UnifiedVectorIndex(projectPath);
}
} }
/** /**
@@ -331,14 +350,36 @@ export class SessionClusteringService {
} }
/** /**
* Calculate vector similarity using pre-computed embeddings from memory_chunks * Calculate vector similarity using HNSW index when available.
* Returns average cosine similarity of chunk embeddings * Falls back to direct cosine similarity on pre-computed embeddings from memory_chunks.
*
* HNSW path: Uses cached neighbor lookup from vectorNeighborCache (populated by
* preloadVectorNeighbors). This replaces the O(N) full-table scan with O(1) cache lookup.
*
* Fallback path: Averages chunk embeddings from SQLite and computes cosine similarity directly.
*/ */
private calculateVectorSimilarity(s1: SessionMetadataCache, s2: SessionMetadataCache): number { private calculateVectorSimilarity(s1: SessionMetadataCache, s2: SessionMetadataCache): number {
// HNSW path: check if we have pre-loaded neighbor scores
const neighbors1 = this.vectorNeighborCache.get(s1.session_id);
if (neighbors1) {
const score = neighbors1.get(s2.session_id);
if (score !== undefined) return score;
// s2 is not a neighbor of s1 via HNSW - low similarity
return 0;
}
// Also check reverse direction
const neighbors2 = this.vectorNeighborCache.get(s2.session_id);
if (neighbors2) {
const score = neighbors2.get(s1.session_id);
if (score !== undefined) return score;
return 0;
}
// Fallback: direct cosine similarity on chunk embeddings
const embedding1 = this.getSessionEmbedding(s1.session_id); const embedding1 = this.getSessionEmbedding(s1.session_id);
const embedding2 = this.getSessionEmbedding(s2.session_id); const embedding2 = this.getSessionEmbedding(s2.session_id);
// Graceful fallback if no embeddings available
if (!embedding1 || !embedding2) { if (!embedding1 || !embedding2) {
return 0; return 0;
} }
@@ -346,6 +387,55 @@ export class SessionClusteringService {
return this.cosineSimilarity(embedding1, embedding2); return this.cosineSimilarity(embedding1, embedding2);
} }
/**
* Preload vector neighbors for a set of sessions using HNSW search.
* For each session, gets its average embedding and searches for nearby chunks,
* then aggregates scores by source_id to get session-level similarity scores.
*
* This replaces the O(N^2) full-table scan with O(N * topK) HNSW lookups.
*/
async preloadVectorNeighbors(sessionIds: string[], topK: number = 20): Promise<void> {
if (!this.vectorIndex) return;
this.vectorNeighborCache.clear();
for (const sessionId of sessionIds) {
const avgEmbedding = this.getSessionEmbedding(sessionId);
if (!avgEmbedding) continue;
try {
const result = await this.vectorIndex.searchByVector(avgEmbedding, {
topK,
minScore: 0.1,
});
if (!result.success || !result.matches.length) continue;
// Aggregate scores by source_id (session-level similarity)
const neighborScores = new Map<string, number[]>();
for (const match of result.matches) {
const sourceId = match.source_id;
if (sourceId === sessionId) continue; // skip self
if (!neighborScores.has(sourceId)) {
neighborScores.set(sourceId, []);
}
neighborScores.get(sourceId)!.push(match.score);
}
// Average scores per neighbor session
const avgScores = new Map<string, number>();
for (const [neighborId, scores] of neighborScores) {
const avg = scores.reduce((sum, s) => sum + s, 0) / scores.length;
avgScores.set(neighborId, avg);
}
this.vectorNeighborCache.set(sessionId, avgScores);
} catch {
// HNSW search failed for this session, skip
}
}
}
/** /**
* Get session embedding by averaging all chunk embeddings * Get session embedding by averaging all chunk embeddings
*/ */
@@ -494,11 +584,16 @@ export class SessionClusteringService {
this.coreMemoryStore.upsertSessionMetadata(session); this.coreMemoryStore.upsertSessionMetadata(session);
} }
// 4. Calculate relevance matrix // 4. Preload HNSW vector neighbors for efficient similarity calculation
const n = sessions.length; if (this.vectorIndex) {
const relevanceMatrix: number[][] = Array(n).fill(0).map(() => Array(n).fill(0)); const sessionIds = sessions.map(s => s.session_id);
await this.preloadVectorNeighbors(sessionIds);
console.log(`[Clustering] Preloaded HNSW vector neighbors for ${sessionIds.length} sessions`);
}
let maxScore = 0; // 5. Calculate relevance matrix
const n = sessions.length;
const relevanceMatrix: number[][] = Array(n).fill(0).map(() => Array(n).fill(0)); let maxScore = 0;
let avgScore = 0; let avgScore = 0;
let pairCount = 0; let pairCount = 0;
@@ -519,7 +614,7 @@ export class SessionClusteringService {
console.log(`[Clustering] Relevance stats: max=${maxScore.toFixed(3)}, avg=${avgScore.toFixed(3)}, pairs=${pairCount}, threshold=${CLUSTER_THRESHOLD}`); console.log(`[Clustering] Relevance stats: max=${maxScore.toFixed(3)}, avg=${avgScore.toFixed(3)}, pairs=${pairCount}, threshold=${CLUSTER_THRESHOLD}`);
} }
// 5. Agglomerative clustering // 6. Agglomerative clustering
const minClusterSize = options?.minClusterSize || 2; const minClusterSize = options?.minClusterSize || 2;
// Early return if not enough sessions // Early return if not enough sessions
@@ -531,7 +626,7 @@ export class SessionClusteringService {
const newPotentialClusters = this.agglomerativeClustering(sessions, relevanceMatrix, CLUSTER_THRESHOLD); const newPotentialClusters = this.agglomerativeClustering(sessions, relevanceMatrix, CLUSTER_THRESHOLD);
console.log(`[Clustering] Generated ${newPotentialClusters.length} potential clusters`); console.log(`[Clustering] Generated ${newPotentialClusters.length} potential clusters`);
// 6. Process clusters: create new or merge with existing // 7. Process clusters: create new or merge with existing
let clustersCreated = 0; let clustersCreated = 0;
let clustersMerged = 0; let clustersMerged = 0;
let sessionsClustered = 0; let sessionsClustered = 0;
@@ -716,6 +811,145 @@ export class SessionClusteringService {
return { merged, deleted, remaining }; return { merged, deleted, remaining };
} }
/**
* Check whether clustering should run based on frequency control.
* Conditions: last clustering > MIN_CLUSTER_INTERVAL_HOURS ago AND
* new unclustered sessions >= MIN_NEW_SESSIONS_FOR_CLUSTER.
*
* Stores last_cluster_time in session_clusters metadata.
*/
async shouldRunClustering(): Promise<boolean> {
// Check last cluster time from cluster metadata
const clusters = this.coreMemoryStore.listClusters('active');
let lastClusterTime = 0;
for (const cluster of clusters) {
const createdMs = new Date(cluster.created_at).getTime();
if (createdMs > lastClusterTime) {
lastClusterTime = createdMs;
}
const updatedMs = new Date(cluster.updated_at).getTime();
if (updatedMs > lastClusterTime) {
lastClusterTime = updatedMs;
}
}
// Check time interval
const now = Date.now();
const hoursSinceLastCluster = (now - lastClusterTime) / (1000 * 60 * 60);
if (lastClusterTime > 0 && hoursSinceLastCluster < MIN_CLUSTER_INTERVAL_HOURS) {
return false;
}
// Check number of unclustered sessions
const allSessions = await this.collectSessions({ scope: 'recent' });
const unclusteredCount = allSessions.filter(s => {
const sessionClusters = this.coreMemoryStore.getSessionClusters(s.session_id);
return sessionClusters.length === 0;
}).length;
return unclusteredCount >= MIN_NEW_SESSIONS_FOR_CLUSTER;
}
/**
* Incremental clustering: process only a single new session.
*
* Computes the new session's similarity against existing cluster centroids
* using HNSW search. If similarity >= CLUSTER_THRESHOLD, joins the best
* matching cluster. Otherwise, remains unclustered until enough sessions
* accumulate for a new cluster.
*
* @param sessionId - The session to incrementally cluster
* @returns Result indicating what action was taken
*/
async incrementalCluster(sessionId: string): Promise<IncrementalClusterResult> {
// Get or create session metadata
let sessionMeta = this.coreMemoryStore.getSessionMetadata(sessionId);
if (!sessionMeta) {
// Try to build metadata from available sources
const allSessions = await this.collectSessions({ scope: 'all' });
sessionMeta = allSessions.find(s => s.session_id === sessionId) || null;
if (!sessionMeta) {
return { sessionId, clusterId: null, action: 'skipped' };
}
this.coreMemoryStore.upsertSessionMetadata(sessionMeta);
}
// Check if already clustered
const existingClusters = this.coreMemoryStore.getSessionClusters(sessionId);
if (existingClusters.length > 0) {
return { sessionId, clusterId: existingClusters[0].id, action: 'skipped' };
}
// Get all active clusters and their representative sessions
const activeClusters = this.coreMemoryStore.listClusters('active');
if (activeClusters.length === 0) {
return { sessionId, clusterId: null, action: 'skipped' };
}
// Use HNSW to find nearest neighbors for the new session
if (this.vectorIndex) {
await this.preloadVectorNeighbors([sessionId]);
}
// Calculate similarity against each cluster's member sessions
let bestCluster: SessionCluster | null = null;
let bestScore = 0;
for (const cluster of activeClusters) {
const members = this.coreMemoryStore.getClusterMembers(cluster.id);
if (members.length === 0) continue;
// Calculate average relevance against cluster members (sample up to 5)
const sampleMembers = members.slice(0, 5);
let totalScore = 0;
let validCount = 0;
for (const member of sampleMembers) {
const memberMeta = this.coreMemoryStore.getSessionMetadata(member.session_id);
if (!memberMeta) continue;
const score = this.calculateRelevance(sessionMeta, memberMeta);
totalScore += score;
validCount++;
}
if (validCount === 0) continue;
const avgScore = totalScore / validCount;
if (avgScore > bestScore) {
bestScore = avgScore;
bestCluster = cluster;
}
}
// Join best cluster if above threshold
if (bestCluster && bestScore >= CLUSTER_THRESHOLD) {
const existingMembers = this.coreMemoryStore.getClusterMembers(bestCluster.id);
this.coreMemoryStore.addClusterMember({
cluster_id: bestCluster.id,
session_id: sessionId,
session_type: sessionMeta.session_type as 'core_memory' | 'workflow' | 'cli_history' | 'native',
sequence_order: existingMembers.length + 1,
relevance_score: bestScore,
});
// Update cluster description
this.coreMemoryStore.updateCluster(bestCluster.id, {
description: `Auto-generated cluster with ${existingMembers.length + 1} sessions`
});
console.log(`[Clustering] Session ${sessionId} joined cluster '${bestCluster.name}' (score: ${bestScore.toFixed(3)})`);
return { sessionId, clusterId: bestCluster.id, action: 'joined_existing' };
}
// Not similar enough to any existing cluster
return { sessionId, clusterId: null, action: 'skipped' };
}
/** /**
* Agglomerative clustering algorithm * Agglomerative clustering algorithm
* Returns array of clusters (each cluster is array of sessions) * Returns array of clusters (each cluster is array of sessions)

View File

@@ -0,0 +1,410 @@
/**
* UnifiedContextBuilder - Assembles context for Claude Code hooks
*
* Provides componentized context assembly for:
* - session-start: MEMORY.md summary + cluster overview + hot entities + solidified patterns
* - per-prompt: vector search + intent matching across all categories
* - session-end: incremental embedding + clustering + heat score update tasks
*
* Character limits:
* - session-start: <= 1000 chars
* - per-prompt: <= 500 chars
*/
import { existsSync, readdirSync } from 'fs';
import { join, basename } from 'path';
import { getProjectPaths } from '../config/storage-paths.js';
import { getMemoryMdContent } from './memory-consolidation-pipeline.js';
import { getMemoryStore } from './memory-store.js';
import type { HotEntity } from './memory-store.js';
import {
UnifiedVectorIndex,
isUnifiedEmbedderAvailable,
} from './unified-vector-index.js';
import type { VectorSearchMatch } from './unified-vector-index.js';
import { SessionClusteringService } from './session-clustering-service.js';
// =============================================================================
// Constants
// =============================================================================
/** Maximum character count for session-start context */
const SESSION_START_LIMIT = 1000;
/** Maximum character count for per-prompt context */
const PER_PROMPT_LIMIT = 500;
/** Maximum characters for the MEMORY.md summary component */
const MEMORY_SUMMARY_LIMIT = 500;
/** Number of top clusters to show in overview */
const TOP_CLUSTERS = 3;
/** Number of top hot entities to show */
const TOP_HOT_ENTITIES = 5;
/** Days to look back for hot entities */
const HOT_ENTITY_DAYS = 7;
/** Number of vector search results for per-prompt */
const VECTOR_TOP_K = 8;
/** Minimum vector similarity score */
const VECTOR_MIN_SCORE = 0.3;
// =============================================================================
// Types
// =============================================================================
/** A task to be executed asynchronously at session-end */
export interface SessionEndTask {
/** Descriptive name of the task */
name: string;
/** Async function to execute */
execute: () => Promise<void>;
}
// =============================================================================
// UnifiedContextBuilder
// =============================================================================
export class UnifiedContextBuilder {
private projectPath: string;
private paths: ReturnType<typeof getProjectPaths>;
constructor(projectPath: string) {
this.projectPath = projectPath;
this.paths = getProjectPaths(projectPath);
}
// ---------------------------------------------------------------------------
// Public: session-start context
// ---------------------------------------------------------------------------
/**
* Build context for session-start hook injection.
*
* Components (assembled in order, truncated to <= 1000 chars total):
* 1. MEMORY.md summary (up to 500 chars)
* 2. Cluster overview (top 3 active clusters)
* 3. Hot entities (top 5 within last 7 days)
* 4. Solidified patterns (skills/*.md file list)
*/
async buildSessionStartContext(): Promise<string> {
const sections: string[] = [];
// Component 1: MEMORY.md summary
const memorySummary = this.buildMemorySummary();
if (memorySummary) {
sections.push(memorySummary);
}
// Component 2: Cluster overview
const clusterOverview = await this.buildClusterOverview();
if (clusterOverview) {
sections.push(clusterOverview);
}
// Component 3: Hot entities
const hotEntities = this.buildHotEntities();
if (hotEntities) {
sections.push(hotEntities);
}
// Component 4: Solidified patterns
const patterns = this.buildSolidifiedPatterns();
if (patterns) {
sections.push(patterns);
}
if (sections.length === 0) {
return '';
}
// Assemble and truncate
let content = '<ccw-memory-context>\n' + sections.join('\n') + '\n</ccw-memory-context>';
if (content.length > SESSION_START_LIMIT) {
content = content.substring(0, SESSION_START_LIMIT - 20) + '\n</ccw-memory-context>';
}
return content;
}
// ---------------------------------------------------------------------------
// Public: per-prompt context
// ---------------------------------------------------------------------------
/**
* Build context for per-prompt hook injection.
*
* Uses vector search across all categories to find relevant memories
* matching the current prompt. Results are ranked by similarity score.
*
* @param prompt - Current user prompt text
* @returns Context string (<= 500 chars) or empty string
*/
async buildPromptContext(prompt: string): Promise<string> {
if (!prompt || !prompt.trim()) {
return '';
}
if (!isUnifiedEmbedderAvailable()) {
return '';
}
try {
const vectorIndex = new UnifiedVectorIndex(this.projectPath);
const result = await vectorIndex.search(prompt, {
topK: VECTOR_TOP_K,
minScore: VECTOR_MIN_SCORE,
});
if (!result.success || result.matches.length === 0) {
return '';
}
return this.formatPromptMatches(result.matches);
} catch {
return '';
}
}
// ---------------------------------------------------------------------------
// Public: session-end tasks
// ---------------------------------------------------------------------------
/**
* Build a list of async tasks to run at session-end.
*
* Tasks:
* 1. Incremental vector embedding (index new/updated content)
* 2. Incremental clustering (cluster unclustered sessions)
* 3. Heat score updates (recalculate entity heat scores)
*
* @param sessionId - Current session ID for context
* @returns Array of tasks with name and execute function
*/
buildSessionEndTasks(sessionId: string): SessionEndTask[] {
const tasks: SessionEndTask[] = [];
// Task 1: Incremental vector embedding
if (isUnifiedEmbedderAvailable()) {
tasks.push({
name: 'incremental-embedding',
execute: async () => {
try {
const vectorIndex = new UnifiedVectorIndex(this.projectPath);
// Re-index the MEMORY.md content if available
const memoryContent = getMemoryMdContent(this.projectPath);
if (memoryContent) {
await vectorIndex.indexContent(memoryContent, {
source_id: 'MEMORY_MD',
source_type: 'core_memory',
category: 'core_memory',
});
}
} catch (err) {
// Log but don't throw - session-end tasks are best-effort
if (process.env.DEBUG) {
console.error('[UnifiedContextBuilder] Embedding task failed:', (err as Error).message);
}
}
},
});
}
// Task 2: Incremental clustering
tasks.push({
name: 'incremental-clustering',
execute: async () => {
try {
const clusteringService = new SessionClusteringService(this.projectPath);
await clusteringService.autocluster({ scope: 'unclustered' });
} catch (err) {
if (process.env.DEBUG) {
console.error('[UnifiedContextBuilder] Clustering task failed:', (err as Error).message);
}
}
},
});
// Task 3: Heat score updates
tasks.push({
name: 'heat-score-update',
execute: async () => {
try {
const memoryStore = getMemoryStore(this.projectPath);
const hotEntities = memoryStore.getHotEntities(50);
for (const entity of hotEntities) {
if (entity.id != null) {
memoryStore.calculateHeatScore(entity.id);
}
}
} catch (err) {
if (process.env.DEBUG) {
console.error('[UnifiedContextBuilder] Heat score update failed:', (err as Error).message);
}
}
},
});
return tasks;
}
// ---------------------------------------------------------------------------
// Private: Component builders
// ---------------------------------------------------------------------------
/**
* Build MEMORY.md summary component.
* Reads MEMORY.md and returns first MEMORY_SUMMARY_LIMIT characters.
*/
private buildMemorySummary(): string {
const content = getMemoryMdContent(this.projectPath);
if (!content) {
return '';
}
let summary = content.trim();
if (summary.length > MEMORY_SUMMARY_LIMIT) {
// Truncate at a newline boundary if possible
const truncated = summary.substring(0, MEMORY_SUMMARY_LIMIT);
const lastNewline = truncated.lastIndexOf('\n');
summary = lastNewline > MEMORY_SUMMARY_LIMIT * 0.6
? truncated.substring(0, lastNewline) + '...'
: truncated + '...';
}
return `## Memory Summary\n${summary}\n`;
}
/**
* Build cluster overview component.
* Shows top N active clusters from the clustering service.
*/
private async buildClusterOverview(): Promise<string> {
try {
const { CoreMemoryStore } = await import('./core-memory-store.js');
const store = new CoreMemoryStore(this.projectPath);
const clusters = store.listClusters('active');
if (clusters.length === 0) {
return '';
}
// Sort by most recent activity
const sorted = clusters
.map(c => {
const members = store.getClusterMembers(c.id);
return { cluster: c, memberCount: members.length };
})
.sort((a, b) => b.memberCount - a.memberCount)
.slice(0, TOP_CLUSTERS);
let output = '## Active Clusters\n';
for (const { cluster, memberCount } of sorted) {
const intent = cluster.intent ? ` - ${cluster.intent}` : '';
output += `- **${cluster.name}** (${memberCount})${intent}\n`;
}
return output;
} catch {
return '';
}
}
/**
* Build hot entities component.
* Shows top N entities by heat_score that were active within last 7 days.
*/
private buildHotEntities(): string {
try {
const memoryStore = getMemoryStore(this.projectPath);
const allHot = memoryStore.getHotEntities(TOP_HOT_ENTITIES * 3);
if (allHot.length === 0) {
return '';
}
// Filter to entities seen within the last HOT_ENTITY_DAYS days
const cutoff = new Date();
cutoff.setDate(cutoff.getDate() - HOT_ENTITY_DAYS);
const cutoffStr = cutoff.toISOString();
const recentHot = allHot
.filter(e => (e.last_seen_at || '') >= cutoffStr)
.slice(0, TOP_HOT_ENTITIES);
if (recentHot.length === 0) {
return '';
}
let output = '## Hot Entities (7d)\n';
for (const entity of recentHot) {
const heat = Math.round(entity.stats.heat_score);
output += `- ${entity.type}:${entity.value} (heat:${heat})\n`;
}
return output;
} catch {
return '';
}
}
/**
* Build solidified patterns component.
* Scans skills/*.md files and lists their names.
*/
private buildSolidifiedPatterns(): string {
try {
const skillsDir = this.paths.memoryV2.skills;
if (!existsSync(skillsDir)) {
return '';
}
const files = readdirSync(skillsDir).filter(f => f.endsWith('.md'));
if (files.length === 0) {
return '';
}
let output = '## Patterns\n';
for (const file of files.slice(0, 5)) {
const name = basename(file, '.md');
output += `- ${name}\n`;
}
return output;
} catch {
return '';
}
}
// ---------------------------------------------------------------------------
// Private: Formatting helpers
// ---------------------------------------------------------------------------
/**
* Format vector search matches for per-prompt context.
* Builds a compact Markdown snippet within PER_PROMPT_LIMIT chars.
*/
private formatPromptMatches(matches: VectorSearchMatch[]): string {
let output = '<ccw-related-memory>\n';
for (const match of matches) {
const score = Math.round(match.score * 100);
const snippet = match.content.substring(0, 80).replace(/\n/g, ' ').trim();
const line = `- [${match.category}] ${snippet} (${score}%)\n`;
// Check if adding this line would exceed limit
if (output.length + line.length + 25 > PER_PROMPT_LIMIT) {
break;
}
output += line;
}
output += '</ccw-related-memory>';
return output;
}
}

View File

@@ -0,0 +1,488 @@
/**
* Unified Memory Service - Cross-store search with RRF fusion
*
* Provides a single search() interface that combines:
* - Vector search (HNSW via UnifiedVectorIndex)
* - Full-text search (FTS5 via MemoryStore.searchPrompts)
* - Heat-based scoring (entity heat from MemoryStore)
*
* Fusion: Reciprocal Rank Fusion (RRF)
* score = sum(1 / (k + rank_i) * weight_i)
* k = 60, weights = { vector: 0.6, fts: 0.3, heat: 0.1 }
*/
import { UnifiedVectorIndex, isUnifiedEmbedderAvailable } from './unified-vector-index.js';
import type {
VectorCategory,
VectorSearchMatch,
VectorIndexStatus,
} from './unified-vector-index.js';
import { CoreMemoryStore, getCoreMemoryStore } from './core-memory-store.js';
import type { CoreMemory } from './core-memory-store.js';
import { MemoryStore, getMemoryStore } from './memory-store.js';
import type { PromptHistory, HotEntity } from './memory-store.js';
// =============================================================================
// Types
// =============================================================================
/** Options for unified search */
export interface UnifiedSearchOptions {
/** Maximum number of results to return (default: 20) */
limit?: number;
/** Minimum relevance score threshold (default: 0.0) */
minScore?: number;
/** Filter by category */
category?: VectorCategory;
/** Vector search top-k (default: 30, fetched internally for fusion) */
vectorTopK?: number;
/** FTS search limit (default: 30, fetched internally for fusion) */
ftsLimit?: number;
}
/** A unified search result item */
export interface UnifiedSearchResult {
/** Unique identifier for the source item */
source_id: string;
/** Source type: core_memory, cli_history, workflow, entity, pattern */
source_type: string;
/** Fused relevance score (0..1 range, higher is better) */
score: number;
/** Text content (snippet or full) */
content: string;
/** Category of the result */
category: string;
/** Which ranking sources contributed to this result */
rank_sources: {
vector_rank?: number;
vector_score?: number;
fts_rank?: number;
heat_score?: number;
};
}
/** Aggregated statistics from all stores + vector index */
export interface UnifiedMemoryStats {
core_memories: {
total: number;
archived: number;
};
stage1_outputs: number;
entities: number;
prompts: number;
conversations: number;
vector_index: {
available: boolean;
total_chunks: number;
hnsw_available: boolean;
hnsw_count: number;
dimension: number;
categories?: Record<string, number>;
};
}
/** KNN recommendation result */
export interface RecommendationResult {
source_id: string;
source_type: string;
score: number;
content: string;
category: string;
}
// =============================================================================
// RRF Constants
// =============================================================================
/** RRF smoothing constant (standard value from the original RRF paper) */
const RRF_K = 60;
/** Fusion weights */
const WEIGHT_VECTOR = 0.6;
const WEIGHT_FTS = 0.3;
const WEIGHT_HEAT = 0.1;
// =============================================================================
// UnifiedMemoryService
// =============================================================================
/**
* Unified Memory Service providing cross-store search and recommendations.
*
* Combines vector similarity, full-text search, and entity heat scores
* using Reciprocal Rank Fusion (RRF) for result ranking.
*/
export class UnifiedMemoryService {
private projectPath: string;
private vectorIndex: UnifiedVectorIndex | null = null;
private coreMemoryStore: CoreMemoryStore;
private memoryStore: MemoryStore;
constructor(projectPath: string) {
this.projectPath = projectPath;
this.coreMemoryStore = getCoreMemoryStore(projectPath);
this.memoryStore = getMemoryStore(projectPath);
if (isUnifiedEmbedderAvailable()) {
this.vectorIndex = new UnifiedVectorIndex(projectPath);
}
}
// ==========================================================================
// Search
// ==========================================================================
/**
* Unified search across all memory stores.
*
* Pipeline:
* 1. Vector search via UnifiedVectorIndex (semantic similarity)
* 2. FTS5 search via MemoryStore.searchPrompts (keyword matching)
* 3. Heat boost via entity heat scores
* 4. RRF fusion to combine ranked lists
*
* @param query - Natural language search query
* @param options - Search options
* @returns Fused search results sorted by relevance
*/
async search(
query: string,
options: UnifiedSearchOptions = {}
): Promise<UnifiedSearchResult[]> {
const {
limit = 20,
minScore = 0.0,
category,
vectorTopK = 30,
ftsLimit = 30,
} = options;
// Run vector search and FTS search in parallel
const [vectorResults, ftsResults, hotEntities] = await Promise.all([
this.runVectorSearch(query, vectorTopK, category),
this.runFtsSearch(query, ftsLimit),
this.getHeatScores(),
]);
// Build heat score lookup
const heatMap = new Map<string, number>();
for (const entity of hotEntities) {
// Use normalized_value as key for heat lookup
heatMap.set(entity.normalized_value, entity.stats.heat_score);
}
// Collect all unique source_ids from both result sets
const allSourceIds = new Set<string>();
const vectorRankMap = new Map<string, { rank: number; score: number; match: VectorSearchMatch }>();
const ftsRankMap = new Map<string, { rank: number; item: PromptHistory }>();
// Build vector rank map
for (let i = 0; i < vectorResults.length; i++) {
const match = vectorResults[i];
const id = match.source_id;
allSourceIds.add(id);
vectorRankMap.set(id, { rank: i + 1, score: match.score, match });
}
// Build FTS rank map
for (let i = 0; i < ftsResults.length; i++) {
const item = ftsResults[i];
const id = item.session_id;
allSourceIds.add(id);
ftsRankMap.set(id, { rank: i + 1, item });
}
// Calculate RRF score for each unique source_id
const results: UnifiedSearchResult[] = [];
for (const sourceId of allSourceIds) {
const vectorEntry = vectorRankMap.get(sourceId);
const ftsEntry = ftsRankMap.get(sourceId);
// RRF: score = sum(weight_i / (k + rank_i))
let rrfScore = 0;
const rankSources: UnifiedSearchResult['rank_sources'] = {};
// Vector component
if (vectorEntry) {
rrfScore += WEIGHT_VECTOR / (RRF_K + vectorEntry.rank);
rankSources.vector_rank = vectorEntry.rank;
rankSources.vector_score = vectorEntry.score;
}
// FTS component
if (ftsEntry) {
rrfScore += WEIGHT_FTS / (RRF_K + ftsEntry.rank);
rankSources.fts_rank = ftsEntry.rank;
}
// Heat component (boost based on entity heat)
const heatScore = this.lookupHeatScore(sourceId, heatMap);
if (heatScore > 0) {
// Normalize heat score to a rank-like value (1 = hottest)
// Use inverse: higher heat = lower rank number = higher contribution
const heatRank = Math.max(1, Math.ceil(100 / (1 + heatScore)));
rrfScore += WEIGHT_HEAT / (RRF_K + heatRank);
rankSources.heat_score = heatScore;
}
if (rrfScore < minScore) continue;
// Build result entry
let content = '';
let sourceType = '';
let resultCategory = '';
if (vectorEntry) {
content = vectorEntry.match.content;
sourceType = vectorEntry.match.source_type;
resultCategory = vectorEntry.match.category;
} else if (ftsEntry) {
content = ftsEntry.item.prompt_text || ftsEntry.item.context_summary || '';
sourceType = 'cli_history';
resultCategory = 'cli_history';
}
results.push({
source_id: sourceId,
source_type: sourceType,
score: rrfScore,
content,
category: resultCategory,
rank_sources: rankSources,
});
}
// Sort by RRF score descending, take top `limit`
results.sort((a, b) => b.score - a.score);
return results.slice(0, limit);
}
// ==========================================================================
// Recommendations
// ==========================================================================
/**
* Get recommendations based on a memory's vector neighbors (KNN).
*
* Fetches the content of the given memory, then runs a vector search
* to find similar content across all stores.
*
* @param memoryId - Core memory ID (CMEM-*)
* @param limit - Number of recommendations (default: 5)
* @returns Recommended items sorted by similarity
*/
async getRecommendations(
memoryId: string,
limit: number = 5
): Promise<RecommendationResult[]> {
// Get the memory content
const memory = this.coreMemoryStore.getMemory(memoryId);
if (!memory) {
return [];
}
if (!this.vectorIndex) {
return [];
}
// Use memory content as query for KNN search
// Request extra results so we can filter out self
const searchResult = await this.vectorIndex.search(memory.content, {
topK: limit + 5,
minScore: 0.3,
});
if (!searchResult.success) {
return [];
}
// Filter out self and map to recommendations
const recommendations: RecommendationResult[] = [];
for (const match of searchResult.matches) {
// Skip the source memory itself
if (match.source_id === memoryId) continue;
recommendations.push({
source_id: match.source_id,
source_type: match.source_type,
score: match.score,
content: match.content,
category: match.category,
});
if (recommendations.length >= limit) break;
}
return recommendations;
}
// ==========================================================================
// Statistics
// ==========================================================================
/**
* Get aggregated statistics from all stores and the vector index.
*
* @returns Unified stats across core memories, V2 outputs, entities, prompts, and vectors
*/
async getStats(): Promise<UnifiedMemoryStats> {
// Get core memory stats
const allMemories = this.coreMemoryStore.getMemories({ limit: 100000 });
const archivedMemories = allMemories.filter(m => m.archived);
const stage1Count = this.coreMemoryStore.countStage1Outputs();
// Get memory store stats (entities, prompts, conversations)
const db = (this.memoryStore as any).db;
let entityCount = 0;
let promptCount = 0;
let conversationCount = 0;
try {
entityCount = (db.prepare('SELECT COUNT(*) as count FROM entities').get() as { count: number }).count;
} catch { /* table may not exist */ }
try {
promptCount = (db.prepare('SELECT COUNT(*) as count FROM prompt_history').get() as { count: number }).count;
} catch { /* table may not exist */ }
try {
conversationCount = (db.prepare('SELECT COUNT(*) as count FROM conversations').get() as { count: number }).count;
} catch { /* table may not exist */ }
// Get vector index status
let vectorStatus: VectorIndexStatus = {
success: false,
total_chunks: 0,
hnsw_available: false,
hnsw_count: 0,
dimension: 0,
};
if (this.vectorIndex) {
try {
vectorStatus = await this.vectorIndex.getStatus();
} catch {
// Vector index not available
}
}
return {
core_memories: {
total: allMemories.length,
archived: archivedMemories.length,
},
stage1_outputs: stage1Count,
entities: entityCount,
prompts: promptCount,
conversations: conversationCount,
vector_index: {
available: vectorStatus.success,
total_chunks: vectorStatus.total_chunks,
hnsw_available: vectorStatus.hnsw_available,
hnsw_count: vectorStatus.hnsw_count,
dimension: vectorStatus.dimension,
categories: vectorStatus.categories,
},
};
}
// ==========================================================================
// Internal helpers
// ==========================================================================
/**
* Run vector search via UnifiedVectorIndex.
* Returns empty array if vector index is not available.
*/
private async runVectorSearch(
query: string,
topK: number,
category?: VectorCategory
): Promise<VectorSearchMatch[]> {
if (!this.vectorIndex) {
return [];
}
try {
const result = await this.vectorIndex.search(query, {
topK,
minScore: 0.1,
category,
});
if (!result.success) {
return [];
}
return result.matches;
} catch {
return [];
}
}
/**
* Run FTS5 full-text search via MemoryStore.searchPrompts.
* Returns empty array on error.
*/
private async runFtsSearch(
query: string,
limit: number
): Promise<PromptHistory[]> {
try {
// FTS5 requires sanitized query (no special characters)
const sanitized = this.sanitizeFtsQuery(query);
if (!sanitized) return [];
return this.memoryStore.searchPrompts(sanitized, limit);
} catch {
return [];
}
}
/**
* Get hot entities for heat-based scoring.
*/
private async getHeatScores(): Promise<HotEntity[]> {
try {
return this.memoryStore.getHotEntities(50);
} catch {
return [];
}
}
/**
* Look up heat score for a source ID.
* Checks if any entity's normalized_value matches the source_id.
*/
private lookupHeatScore(
sourceId: string,
heatMap: Map<string, number>
): number {
// Direct match
if (heatMap.has(sourceId)) {
return heatMap.get(sourceId)!;
}
// Check if source_id is a substring of any entity value (file paths)
for (const [key, score] of heatMap) {
if (sourceId.includes(key) || key.includes(sourceId)) {
return score;
}
}
return 0;
}
/**
* Sanitize a query string for FTS5 MATCH syntax.
* Removes special characters that would cause FTS5 parse errors.
*/
private sanitizeFtsQuery(query: string): string {
// Remove FTS5 special operators and punctuation
return query
.replace(/[*":(){}[\]^~\\/<>!@#$%&=+|;,.'`]/g, ' ')
.replace(/\s+/g, ' ')
.trim();
}
}

View File

@@ -0,0 +1,474 @@
/**
* Unified Vector Index - TypeScript bridge to unified_memory_embedder.py
*
* Provides HNSW-backed vector indexing and search for all memory content
* (core_memory, cli_history, workflow, entity, pattern) via CodexLens VectorStore.
*
* Features:
* - JSON stdin/stdout protocol to Python embedder
* - Content chunking (paragraph -> sentence splitting, CHUNK_SIZE=1500, OVERLAP=200)
* - Batch embedding via CodexLens EmbedderFactory
* - HNSW approximate nearest neighbor search (sub-10ms for 1000 chunks)
* - Category-based filtering
*/
import { spawn } from 'child_process';
import { join, dirname } from 'path';
import { existsSync } from 'fs';
import { fileURLToPath } from 'url';
import { getCodexLensPython } from '../utils/codexlens-path.js';
import { StoragePaths, ensureStorageDir } from '../config/storage-paths.js';
// Get directory of this module
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Venv python path (reuse CodexLens venv)
const VENV_PYTHON = getCodexLensPython();
// Script path
const EMBEDDER_SCRIPT = join(__dirname, '..', '..', 'scripts', 'unified_memory_embedder.py');
// Chunking constants (match existing core-memory-store.ts)
const CHUNK_SIZE = 1500;
const OVERLAP = 200;
// =============================================================================
// Types
// =============================================================================
/** Valid source types for vector content */
export type SourceType = 'core_memory' | 'workflow' | 'cli_history';
/** Valid category values for vector filtering */
export type VectorCategory = 'core_memory' | 'cli_history' | 'workflow' | 'entity' | 'pattern';
/** Metadata attached to each chunk in the vector store */
export interface ChunkMetadata {
/** Source identifier (e.g., memory ID, session ID) */
source_id: string;
/** Source type */
source_type: SourceType;
/** Category for filtering */
category: VectorCategory;
/** Chunk index within the source */
chunk_index?: number;
/** Additional metadata */
[key: string]: unknown;
}
/** A chunk to be embedded and indexed */
export interface VectorChunk {
/** Text content */
content: string;
/** Source identifier */
source_id: string;
/** Source type */
source_type: SourceType;
/** Category for filtering */
category: VectorCategory;
/** Chunk index */
chunk_index: number;
/** Additional metadata */
metadata?: Record<string, unknown>;
}
/** Result of an embed operation */
export interface EmbedResult {
success: boolean;
chunks_processed: number;
chunks_failed: number;
elapsed_time: number;
error?: string;
}
/** A single search match */
export interface VectorSearchMatch {
content: string;
score: number;
source_id: string;
source_type: string;
chunk_index: number;
category: string;
metadata: Record<string, unknown>;
}
/** Result of a search operation */
export interface VectorSearchResult {
success: boolean;
matches: VectorSearchMatch[];
elapsed_time?: number;
total_searched?: number;
error?: string;
}
/** Search options */
export interface VectorSearchOptions {
topK?: number;
minScore?: number;
category?: VectorCategory;
}
/** Index status information */
export interface VectorIndexStatus {
success: boolean;
total_chunks: number;
hnsw_available: boolean;
hnsw_count: number;
dimension: number;
categories?: Record<string, number>;
model_config?: {
backend: string;
profile: string;
dimension: number;
max_tokens: number;
};
error?: string;
}
/** Reindex result */
export interface ReindexResult {
success: boolean;
hnsw_count?: number;
elapsed_time?: number;
error?: string;
}
// =============================================================================
// Python Bridge
// =============================================================================
/**
* Check if the unified embedder is available (venv and script exist)
*/
export function isUnifiedEmbedderAvailable(): boolean {
if (!existsSync(VENV_PYTHON)) {
return false;
}
if (!existsSync(EMBEDDER_SCRIPT)) {
return false;
}
return true;
}
/**
* Run Python script with JSON stdin/stdout protocol.
*
* @param request - JSON request object to send via stdin
* @param timeout - Timeout in milliseconds (default: 5 minutes)
* @returns Parsed JSON response
*/
function runPython<T>(request: Record<string, unknown>, timeout: number = 300000): Promise<T> {
return new Promise((resolve, reject) => {
if (!isUnifiedEmbedderAvailable()) {
reject(
new Error(
'Unified embedder not available. Ensure CodexLens venv exists at ~/.codexlens/venv'
)
);
return;
}
const child = spawn(VENV_PYTHON, [EMBEDDER_SCRIPT], {
stdio: ['pipe', 'pipe', 'pipe'],
timeout,
});
let stdout = '';
let stderr = '';
child.stdout.on('data', (data) => {
stdout += data.toString();
});
child.stderr.on('data', (data) => {
stderr += data.toString();
});
child.on('close', (code) => {
if (code === 0 && stdout.trim()) {
try {
resolve(JSON.parse(stdout.trim()) as T);
} catch {
reject(new Error(`Failed to parse Python output: ${stdout.substring(0, 500)}`));
}
} else {
reject(new Error(`Python script failed (exit code ${code}): ${stderr || stdout}`));
}
});
child.on('error', (err) => {
if ((err as NodeJS.ErrnoException).code === 'ETIMEDOUT') {
reject(new Error('Python script timed out'));
} else {
reject(new Error(`Failed to spawn Python: ${err.message}`));
}
});
// Write JSON request to stdin and close
const jsonInput = JSON.stringify(request);
child.stdin.write(jsonInput);
child.stdin.end();
});
}
// =============================================================================
// Content Chunking
// =============================================================================
/**
* Chunk content into smaller pieces for embedding.
* Uses paragraph-first, sentence-fallback strategy with overlap.
*
* Matches the chunking logic in core-memory-store.ts:
* - CHUNK_SIZE = 1500 characters
* - OVERLAP = 200 characters
* - Split by paragraph boundaries (\n\n) first
* - Fall back to sentence boundaries (. ) for oversized paragraphs
*
* @param content - Text content to chunk
* @returns Array of chunk strings
*/
export function chunkContent(content: string): string[] {
const chunks: string[] = [];
// Split by paragraph boundaries first
const paragraphs = content.split(/\n\n+/);
let currentChunk = '';
for (const paragraph of paragraphs) {
// If adding this paragraph would exceed chunk size
if (currentChunk.length + paragraph.length > CHUNK_SIZE && currentChunk.length > 0) {
chunks.push(currentChunk.trim());
// Start new chunk with overlap
const overlapText = currentChunk.slice(-OVERLAP);
currentChunk = overlapText + '\n\n' + paragraph;
} else {
currentChunk += (currentChunk ? '\n\n' : '') + paragraph;
}
}
// Add remaining chunk
if (currentChunk.trim()) {
chunks.push(currentChunk.trim());
}
// If chunks are still too large, split by sentences
const finalChunks: string[] = [];
for (const chunk of chunks) {
if (chunk.length <= CHUNK_SIZE) {
finalChunks.push(chunk);
} else {
// Split by sentence boundaries
const sentences = chunk.split(/\. +/);
let sentenceChunk = '';
for (const sentence of sentences) {
const sentenceWithPeriod = sentence + '. ';
if (
sentenceChunk.length + sentenceWithPeriod.length > CHUNK_SIZE &&
sentenceChunk.length > 0
) {
finalChunks.push(sentenceChunk.trim());
const overlapText = sentenceChunk.slice(-OVERLAP);
sentenceChunk = overlapText + sentenceWithPeriod;
} else {
sentenceChunk += sentenceWithPeriod;
}
}
if (sentenceChunk.trim()) {
finalChunks.push(sentenceChunk.trim());
}
}
}
return finalChunks.length > 0 ? finalChunks : [content];
}
// =============================================================================
// UnifiedVectorIndex Class
// =============================================================================
/**
* Unified vector index backed by CodexLens VectorStore (HNSW).
*
* Provides content chunking, embedding, storage, and search for all
* memory content types through a single interface.
*/
export class UnifiedVectorIndex {
private storePath: string;
/**
* Create a UnifiedVectorIndex for a project.
*
* @param projectPath - Project root path (used to resolve storage location)
*/
constructor(projectPath: string) {
const paths = StoragePaths.project(projectPath);
this.storePath = paths.unifiedVectors.root;
ensureStorageDir(this.storePath);
}
/**
* Index content by chunking, embedding, and storing in VectorStore.
*
* @param content - Text content to index
* @param metadata - Metadata for all chunks (source_id, source_type, category)
* @returns Embed result
*/
async indexContent(
content: string,
metadata: ChunkMetadata
): Promise<EmbedResult> {
if (!content.trim()) {
return {
success: true,
chunks_processed: 0,
chunks_failed: 0,
elapsed_time: 0,
};
}
// Chunk content
const textChunks = chunkContent(content);
// Build chunk objects for Python
const chunks: VectorChunk[] = textChunks.map((text, index) => ({
content: text,
source_id: metadata.source_id,
source_type: metadata.source_type,
category: metadata.category,
chunk_index: metadata.chunk_index != null ? metadata.chunk_index + index : index,
metadata: { ...metadata },
}));
try {
const result = await runPython<EmbedResult>({
operation: 'embed',
store_path: this.storePath,
chunks,
batch_size: 8,
});
return result;
} catch (err) {
return {
success: false,
chunks_processed: 0,
chunks_failed: textChunks.length,
elapsed_time: 0,
error: (err as Error).message,
};
}
}
/**
* Search the vector index using semantic similarity.
*
* @param query - Natural language search query
* @param options - Search options (topK, minScore, category)
* @returns Search results sorted by relevance
*/
async search(
query: string,
options: VectorSearchOptions = {}
): Promise<VectorSearchResult> {
const { topK = 10, minScore = 0.3, category } = options;
try {
const result = await runPython<VectorSearchResult>({
operation: 'search',
store_path: this.storePath,
query,
top_k: topK,
min_score: minScore,
category: category || null,
});
return result;
} catch (err) {
return {
success: false,
matches: [],
error: (err as Error).message,
};
}
}
/**
* Search the vector index using a pre-computed embedding vector.
* Bypasses text embedding, directly querying HNSW with a raw vector.
*
* @param vector - Pre-computed embedding vector (array of floats)
* @param options - Search options (topK, minScore, category)
* @returns Search results sorted by relevance
*/
async searchByVector(
vector: number[],
options: VectorSearchOptions = {}
): Promise<VectorSearchResult> {
const { topK = 10, minScore = 0.3, category } = options;
try {
const result = await runPython<VectorSearchResult>({
operation: 'search_by_vector',
store_path: this.storePath,
vector,
top_k: topK,
min_score: minScore,
category: category || null,
});
return result;
} catch (err) {
return {
success: false,
matches: [],
error: (err as Error).message,
};
}
}
/**
* Rebuild the HNSW index from scratch.
*
* @returns Reindex result
*/
async reindexAll(): Promise<ReindexResult> {
try {
const result = await runPython<ReindexResult>({
operation: 'reindex',
store_path: this.storePath,
});
return result;
} catch (err) {
return {
success: false,
error: (err as Error).message,
};
}
}
/**
* Get the current status of the vector index.
*
* @returns Index status including chunk counts, HNSW availability, dimension
*/
async getStatus(): Promise<VectorIndexStatus> {
try {
const result = await runPython<VectorIndexStatus>({
operation: 'status',
store_path: this.storePath,
});
return result;
} catch (err) {
return {
success: false,
total_chunks: 0,
hnsw_available: false,
hnsw_count: 0,
dimension: 0,
error: (err as Error).message,
};
}
}
}

View File

@@ -17,6 +17,7 @@ import type {
} from '../core/a2ui/A2UITypes.js'; } from '../core/a2ui/A2UITypes.js';
import http from 'http'; import http from 'http';
import { a2uiWebSocketHandler } from '../core/a2ui/A2UIWebSocketHandler.js'; import { a2uiWebSocketHandler } from '../core/a2ui/A2UIWebSocketHandler.js';
import { remoteNotificationService } from '../core/services/remote-notification-service.js';
const DASHBOARD_PORT = Number(process.env.CCW_PORT || 3456); const DASHBOARD_PORT = Number(process.env.CCW_PORT || 3456);
const POLL_INTERVAL_MS = 1000; const POLL_INTERVAL_MS = 1000;
@@ -466,6 +467,14 @@ export async function execute(params: AskQuestionParams): Promise<ToolResult<Ask
const a2uiSurface = generateQuestionSurface(question, surfaceId); const a2uiSurface = generateQuestionSurface(question, surfaceId);
const sentCount = a2uiWebSocketHandler.sendSurface(a2uiSurface.surfaceUpdate); const sentCount = a2uiWebSocketHandler.sendSurface(a2uiSurface.surfaceUpdate);
// Trigger remote notification for ask-user-question event (if enabled)
if (remoteNotificationService.shouldNotify('ask-user-question')) {
remoteNotificationService.sendNotification('ask-user-question', {
sessionId: surfaceId,
questionText: question.title,
});
}
// If no local WS clients, start HTTP polling for answer from Dashboard // If no local WS clients, start HTTP polling for answer from Dashboard
if (sentCount === 0) { if (sentCount === 0) {
startAnswerPolling(question.id); startAnswerPolling(question.id);
@@ -1064,6 +1073,15 @@ async function executeSimpleFormat(
// Send the surface // Send the surface
const sentCount = a2uiWebSocketHandler.sendSurface(surfaceUpdate); const sentCount = a2uiWebSocketHandler.sendSurface(surfaceUpdate);
// Trigger remote notification for ask-user-question event (if enabled)
if (remoteNotificationService.shouldNotify('ask-user-question')) {
const questionTexts = questions.map(q => q.question).join('\n');
remoteNotificationService.sendNotification('ask-user-question', {
sessionId: compositeId,
questionText: questionTexts,
});
}
// If no local WS clients, start HTTP polling for answer from Dashboard // If no local WS clients, start HTTP polling for answer from Dashboard
if (sentCount === 0) { if (sentCount === 0) {
startAnswerPolling(compositeId, true); startAnswerPolling(compositeId, true);

View File

@@ -0,0 +1,227 @@
// ========================================
// Remote Notification Types
// ========================================
// Type definitions for remote notification system
// Supports Discord, Telegram, and Generic Webhook platforms
/**
* Supported notification platforms
*/
export type NotificationPlatform = 'discord' | 'telegram' | 'webhook';
/**
* Event types that can trigger notifications
*/
export type NotificationEventType =
| 'ask-user-question' // AskUserQuestion triggered
| 'session-start' // CLI session started
| 'session-end' // CLI session ended
| 'task-completed' // Task completed successfully
| 'task-failed'; // Task failed
/**
* Discord platform configuration
*/
export interface DiscordConfig {
/** Whether Discord notifications are enabled */
enabled: boolean;
/** Discord webhook URL */
webhookUrl: string;
/** Optional custom username for the webhook */
username?: string;
/** Optional avatar URL for the webhook */
avatarUrl?: string;
}
/**
* Telegram platform configuration
*/
export interface TelegramConfig {
/** Whether Telegram notifications are enabled */
enabled: boolean;
/** Telegram bot token */
botToken: string;
/** Telegram chat ID (user or group) */
chatId: string;
/** Optional parse mode (HTML, Markdown, MarkdownV2) */
parseMode?: 'HTML' | 'Markdown' | 'MarkdownV2';
}
/**
* Generic Webhook platform configuration
*/
export interface WebhookConfig {
/** Whether webhook notifications are enabled */
enabled: boolean;
/** Webhook URL */
url: string;
/** HTTP method (POST or PUT) */
method: 'POST' | 'PUT';
/** Custom headers */
headers?: Record<string, string>;
/** Request timeout in milliseconds */
timeout?: number;
}
/**
* Event configuration - maps events to platforms
*/
export interface EventConfig {
/** Event type */
event: NotificationEventType;
/** Platforms to notify for this event */
platforms: NotificationPlatform[];
/** Whether this event's notifications are enabled */
enabled: boolean;
}
/**
* Full remote notification configuration
*/
export interface RemoteNotificationConfig {
/** Master switch for all remote notifications */
enabled: boolean;
/** Platform-specific configurations */
platforms: {
discord?: DiscordConfig;
telegram?: TelegramConfig;
webhook?: WebhookConfig;
};
/** Event-to-platform mappings */
events: EventConfig[];
/** Global timeout for all notification requests (ms) */
timeout: number;
}
/**
* Context passed when sending a notification
*/
export interface NotificationContext {
/** Event type that triggered the notification */
eventType: NotificationEventType;
/** Session ID if applicable */
sessionId?: string;
/** Question text for ask-user-question events */
questionText?: string;
/** Task description for task events */
taskDescription?: string;
/** Error message for task-failed events */
errorMessage?: string;
/** Timestamp of the event */
timestamp: string;
/** Additional metadata */
metadata?: Record<string, unknown>;
}
/**
* Result of a single platform notification attempt
*/
export interface PlatformNotificationResult {
/** Platform that was notified */
platform: NotificationPlatform;
/** Whether the notification succeeded */
success: boolean;
/** Error message if failed */
error?: string;
/** Response time in milliseconds */
responseTime?: number;
}
/**
* Result of sending notifications to all configured platforms
*/
export interface NotificationDispatchResult {
/** Whether at least one notification succeeded */
success: boolean;
/** Results for each platform */
results: PlatformNotificationResult[];
/** Total dispatch time in milliseconds */
totalTime: number;
}
/**
* Test notification request
*/
export interface TestNotificationRequest {
/** Platform to test */
platform: NotificationPlatform;
/** Platform configuration to test (temporary, not saved) */
config: DiscordConfig | TelegramConfig | WebhookConfig;
}
/**
* Test notification result
*/
export interface TestNotificationResult {
/** Whether the test succeeded */
success: boolean;
/** Error message if failed */
error?: string;
/** Response time in milliseconds */
responseTime?: number;
}
/**
* Default configuration values
*/
export const DEFAULT_REMOTE_NOTIFICATION_CONFIG: RemoteNotificationConfig = {
enabled: false,
platforms: {},
events: [
{ event: 'ask-user-question', platforms: ['discord', 'telegram'], enabled: true },
{ event: 'session-start', platforms: [], enabled: false },
{ event: 'session-end', platforms: [], enabled: false },
{ event: 'task-completed', platforms: [], enabled: false },
{ event: 'task-failed', platforms: ['discord', 'telegram'], enabled: true },
],
timeout: 10000, // 10 seconds
};
/**
* Mask sensitive fields in config for API responses
*/
export function maskSensitiveConfig(config: RemoteNotificationConfig): RemoteNotificationConfig {
return {
...config,
platforms: {
discord: config.platforms.discord ? {
...config.platforms.discord,
webhookUrl: maskWebhookUrl(config.platforms.discord.webhookUrl),
} : undefined,
telegram: config.platforms.telegram ? {
...config.platforms.telegram,
botToken: maskToken(config.platforms.telegram.botToken),
} : undefined,
webhook: config.platforms.webhook ? {
...config.platforms.webhook,
// Don't mask webhook URL as it's needed for display
} : undefined,
},
};
}
/**
* Mask webhook URL for display (show only domain and last part)
*/
function maskWebhookUrl(url: string): string {
if (!url) return '';
try {
const parsed = new URL(url);
const pathParts = parsed.pathname.split('/');
const lastPart = pathParts[pathParts.length - 1];
if (lastPart && lastPart.length > 8) {
return `${parsed.origin}/.../${lastPart.slice(0, 4)}****`;
}
return `${parsed.origin}/****`;
} catch {
return '****';
}
}
/**
* Mask bot token for display
*/
function maskToken(token: string): string {
if (!token || token.length < 10) return '****';
return `${token.slice(0, 6)}****${token.slice(-4)}`;
}

75
ccw/src/types/util.ts Normal file
View File

@@ -0,0 +1,75 @@
// ========================================
// Utility Types
// ========================================
// Common utility type definitions
/**
* Deep partial type - makes all nested properties optional
*/
export type DeepPartial<T> = T extends object
? {
[P in keyof T]?: DeepPartial<T[P]>;
}
: T;
/**
* Make specific keys optional
*/
export type PartialBy<T, K extends keyof T> = Omit<T, K> & Partial<Pick<T, K>>;
/**
* Make specific keys required
*/
export type RequiredBy<T, K extends keyof T> = Omit<T, K> & Required<Pick<T, K>>;
/**
* Extract function parameter types
*/
export type Parameters<T> = T extends (...args: infer P) => unknown ? P : never;
/**
* Extract function return type
*/
export type ReturnType<T> = T extends (...args: unknown[]) => infer R ? R : never;
// ========================================
// Utility Functions
// ========================================
/**
* Deep merge utility for configuration updates
* Recursively merges source into target, preserving nested objects
*/
export function deepMerge<T extends Record<string, unknown>>(
target: T,
source: DeepPartial<T>
): T {
const result = { ...target } as T;
for (const key in source) {
if (Object.prototype.hasOwnProperty.call(source, key)) {
const sourceValue = source[key];
const targetValue = target[key];
if (
sourceValue !== undefined &&
sourceValue !== null &&
typeof sourceValue === 'object' &&
!Array.isArray(sourceValue) &&
targetValue !== undefined &&
targetValue !== null &&
typeof targetValue === 'object' &&
!Array.isArray(targetValue)
) {
(result as Record<string, unknown>)[key] = deepMerge(
targetValue as Record<string, unknown>,
sourceValue as DeepPartial<Record<string, unknown>>
);
} else if (sourceValue !== undefined) {
(result as Record<string, unknown>)[key] = sourceValue;
}
}
}
return result;
}

View File

@@ -22,6 +22,9 @@ dependencies = [
"tree-sitter-typescript>=0.23", "tree-sitter-typescript>=0.23",
"pathspec>=0.11", "pathspec>=0.11",
"watchdog>=3.0", "watchdog>=3.0",
# ast-grep for pattern-based AST matching (PyO3 bindings)
# Note: May have compatibility issues with Python 3.13
"ast-grep-py>=0.3.0; python_version < '3.13'",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -189,6 +189,9 @@ class Config:
api_batch_size_max: int = 2048 # Absolute upper limit for batch size api_batch_size_max: int = 2048 # Absolute upper limit for batch size
chars_per_token_estimate: int = 4 # Characters per token estimation ratio chars_per_token_estimate: int = 4 # Characters per token estimation ratio
# Parser configuration
use_astgrep: bool = False # Use ast-grep for Python relationship extraction (tree-sitter is default)
def __post_init__(self) -> None: def __post_init__(self) -> None:
try: try:
self.data_dir = self.data_dir.expanduser().resolve() self.data_dir = self.data_dir.expanduser().resolve()

View File

@@ -3,6 +3,12 @@
from __future__ import annotations from __future__ import annotations
from .factory import ParserFactory from .factory import ParserFactory
from .astgrep_binding import AstGrepBinding, is_astgrep_available, get_supported_languages
__all__ = ["ParserFactory"] __all__ = [
"ParserFactory",
"AstGrepBinding",
"is_astgrep_available",
"get_supported_languages",
]

View File

@@ -0,0 +1,320 @@
"""ast-grep based parser binding for CodexLens.
Provides AST-level pattern matching via ast-grep-py (PyO3 bindings).
Note: This module wraps the official ast-grep Python bindings for pattern-based
code analysis. If ast-grep-py is unavailable, the parser returns None gracefully.
Callers should use tree-sitter or regex-based fallbacks.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# Import patterns from centralized definition (avoid duplication)
from codexlens.parsers.patterns.python import get_pattern, PYTHON_PATTERNS
# Graceful import pattern following treesitter_parser.py convention
try:
from ast_grep_py import SgNode, SgRoot
ASTGREP_AVAILABLE = True
except ImportError:
SgNode = None # type: ignore[assignment,misc]
SgRoot = None # type: ignore[assignment,misc]
ASTGREP_AVAILABLE = False
log = logging.getLogger(__name__)
class AstGrepBinding:
"""Wrapper for ast-grep-py bindings with CodexLens integration.
Provides pattern-based AST matching for code relationship extraction.
Uses declarative patterns with metavariables ($A, $$ARGS) for matching.
"""
# Language ID mapping to ast-grep language names
LANGUAGE_MAP = {
"python": "python",
"javascript": "javascript",
"typescript": "typescript",
"tsx": "tsx",
}
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize ast-grep binding for a language.
Args:
language_id: Language identifier (python, javascript, typescript, tsx)
path: Optional file path for language variant detection
"""
self.language_id = language_id
self.path = path
self._language: Optional[str] = None
self._root: Optional[SgRoot] = None # type: ignore[valid-type]
if ASTGREP_AVAILABLE:
self._initialize_language()
def _initialize_language(self) -> None:
"""Initialize ast-grep language setting."""
# Detect TSX from file extension
if self.language_id == "typescript" and self.path is not None:
if self.path.suffix.lower() == ".tsx":
self._language = "tsx"
return
self._language = self.LANGUAGE_MAP.get(self.language_id)
def is_available(self) -> bool:
"""Check if ast-grep binding is available and ready.
Returns:
True if ast-grep-py is installed and language is supported
"""
return ASTGREP_AVAILABLE and self._language is not None
def parse(self, source_code: str) -> bool:
"""Parse source code into ast-grep syntax tree.
Args:
source_code: Source code text to parse
Returns:
True if parsing succeeds, False otherwise
"""
if not self.is_available() or SgRoot is None:
return False
try:
self._root = SgRoot(source_code, self._language) # type: ignore[misc]
return True
except (ValueError, TypeError, RuntimeError) as e:
log.debug(f"ast-grep parse error: {e}")
self._root = None
return False
def find_all(self, pattern: str) -> List[SgNode]: # type: ignore[valid-type]
"""Find all matches for a pattern in the parsed source.
Args:
pattern: ast-grep pattern string (e.g., "class $NAME($$$BASES) $$$BODY")
Returns:
List of matching SgNode objects, empty if no matches or not parsed
"""
if not self.is_available() or self._root is None:
return []
try:
root_node = self._root.root()
# ast-grep-py 0.40+ requires dict config format
config = {"rule": {"pattern": pattern}}
return list(root_node.find_all(config))
except (ValueError, TypeError, AttributeError) as e:
log.debug(f"ast-grep find_all error: {e}")
return []
def find_inheritance(self) -> List[Dict[str, str]]:
"""Find all class inheritance declarations.
Returns:
List of dicts with 'class_name' and 'bases' keys
"""
if self.language_id != "python":
return []
matches = self.find_all(get_pattern("class_with_bases"))
results: List[Dict[str, str]] = []
for node in matches:
class_name = self._get_match(node, "NAME")
if class_name:
results.append({
"class_name": class_name,
"bases": self._get_match(node, "BASES"), # Base classes text
})
return results
def find_calls(self) -> List[Dict[str, str]]:
"""Find all function/method calls.
Returns:
List of dicts with 'function' and 'line' keys
"""
if self.language_id != "python":
return []
matches = self.find_all(get_pattern("call"))
results: List[Dict[str, str]] = []
for node in matches:
func_name = self._get_match(node, "FUNC")
if func_name:
# Skip self. and cls. prefixed calls
base = func_name.split(".", 1)[0]
if base not in {"self", "cls"}:
results.append({
"function": func_name,
"line": str(self._get_line_number(node)),
})
return results
def find_imports(self) -> List[Dict[str, str]]:
"""Find all import statements.
Returns:
List of dicts with 'module' and 'type' keys
"""
if self.language_id != "python":
return []
results: List[Dict[str, str]] = []
# Find 'import X' statements
import_matches = self.find_all(get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
if module:
results.append({
"module": module,
"type": "import",
"line": str(self._get_line_number(node)),
})
# Find 'from X import Y' statements
from_matches = self.find_all(get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
if module:
results.append({
"module": module,
"names": names or "",
"type": "from_import",
"line": str(self._get_line_number(node)),
})
return results
def _get_match(self, node: SgNode, metavar: str) -> str: # type: ignore[valid-type]
"""Extract matched metavariable value from node.
Args:
node: SgNode with match
metavar: Metavariable name (without $ prefix)
Returns:
Matched text or empty string
"""
if node is None:
return ""
try:
match = node.get_match(metavar)
if match is not None:
return match.text()
except (ValueError, AttributeError, KeyError) as e:
log.debug(f"ast-grep get_match error for {metavar}: {e}")
return ""
def _get_node_text(self, node: SgNode) -> str: # type: ignore[valid-type]
"""Get full text of a node.
Args:
node: SgNode to extract text from
Returns:
Node's text content
"""
if node is None:
return ""
try:
return node.text()
except (ValueError, AttributeError) as e:
log.debug(f"ast-grep get_node_text error: {e}")
return ""
def _get_line_number(self, node: SgNode) -> int: # type: ignore[valid-type]
"""Get starting line number of a node.
Args:
node: SgNode to get line number for
Returns:
1-based line number
"""
if node is None:
return 0
try:
range_info = node.range()
# ast-grep-py 0.40+ returns Range object with .start.line attribute
if hasattr(range_info, 'start') and hasattr(range_info.start, 'line'):
return range_info.start.line + 1 # Convert to 1-based
# Fallback for string format "(0,0)-(1,8)"
if isinstance(range_info, str) and range_info:
start_part = range_info.split('-')[0].strip('()')
start_line = int(start_part.split(',')[0])
return start_line + 1
except (ValueError, AttributeError, TypeError, IndexError) as e:
log.debug(f"ast-grep get_line_number error: {e}")
return 0
def _get_line_range(self, node: SgNode) -> Tuple[int, int]: # type: ignore[valid-type]
"""Get line range (start, end) of a node.
Args:
node: SgNode to get line range for
Returns:
Tuple of (start_line, end_line), both 1-based inclusive
"""
if node is None:
return (0, 0)
try:
range_info = node.range()
# ast-grep-py 0.40+ returns Range object with .start.line and .end.line
if hasattr(range_info, 'start') and hasattr(range_info, 'end'):
start_line = getattr(range_info.start, 'line', 0)
end_line = getattr(range_info.end, 'line', 0)
return (start_line + 1, end_line + 1) # Convert to 1-based
# Fallback for string format "(0,0)-(1,8)"
if isinstance(range_info, str) and range_info:
parts = range_info.split('-')
start_part = parts[0].strip('()')
end_part = parts[1].strip('()')
start_line = int(start_part.split(',')[0])
end_line = int(end_part.split(',')[0])
return (start_line + 1, end_line + 1)
except (ValueError, AttributeError, TypeError, IndexError) as e:
log.debug(f"ast-grep get_line_range error: {e}")
return (0, 0)
def get_language(self) -> Optional[str]:
"""Get the configured ast-grep language.
Returns:
Language string or None if not configured
"""
return self._language
def is_astgrep_available() -> bool:
"""Check if ast-grep-py is installed and available.
Returns:
True if ast-grep bindings can be imported
"""
return ASTGREP_AVAILABLE
def get_supported_languages() -> List[str]:
"""Get list of supported languages for ast-grep.
Returns:
List of language identifiers
"""
return list(AstGrepBinding.LANGUAGE_MAP.keys())

View File

@@ -0,0 +1,931 @@
"""Ast-grep based processor for Python relationship extraction.
Provides pattern-based AST matching for extracting code relationships
(inheritance, calls, imports) from Python source code.
This processor wraps the ast-grep-py bindings and provides a higher-level
interface for relationship extraction, similar to TreeSitterSymbolParser.
Design Pattern:
- Follows TreeSitterSymbolParser class structure for consistency
- Uses declarative patterns defined in patterns/python/__init__.py
- Provides scope-aware relationship extraction with alias resolution
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
# Import patterns module
from codexlens.parsers.patterns.python import (
PYTHON_PATTERNS,
get_pattern,
get_metavar,
)
# Graceful import pattern following existing convention
try:
from ast_grep_py import SgNode, SgRoot
from codexlens.parsers.astgrep_binding import AstGrepBinding, ASTGREP_AVAILABLE
except ImportError:
SgNode = None # type: ignore[assignment,misc]
SgRoot = None # type: ignore[assignment,misc]
AstGrepBinding = None # type: ignore[assignment,misc]
ASTGREP_AVAILABLE = False
class BaseAstGrepProcessor(ABC):
"""Abstract base class for ast-grep based processors.
Provides common infrastructure for pattern-based AST processing.
Subclasses implement language-specific pattern processing logic.
"""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize processor for a language.
Args:
language_id: Language identifier (python, javascript, typescript)
path: Optional file path for language variant detection
"""
self.language_id = language_id
self.path = path
self._binding: Optional[AstGrepBinding] = None
if ASTGREP_AVAILABLE and AstGrepBinding is not None:
self._binding = AstGrepBinding(language_id, path)
def is_available(self) -> bool:
"""Check if ast-grep processor is available.
Returns:
True if ast-grep binding is ready
"""
return self._binding is not None and self._binding.is_available()
def run_ast_grep(self, source_code: str, pattern: str) -> List[SgNode]: # type: ignore[valid-type]
"""Execute ast-grep pattern matching on source code.
Args:
source_code: Source code text to analyze
pattern: ast-grep pattern string
Returns:
List of matching SgNode objects, empty if no matches or unavailable
"""
if not self.is_available() or self._binding is None:
return []
if not self._binding.parse(source_code):
return []
return self._binding.find_all(pattern)
@abstractmethod
def process_matches(
self,
matches: List[SgNode], # type: ignore[valid-type]
source_code: str,
path: Path,
) -> List[CodeRelationship]:
"""Process ast-grep matches into code relationships.
Args:
matches: List of matched SgNode objects
source_code: Original source code
path: File path being processed
Returns:
List of extracted code relationships
"""
pass
@abstractmethod
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract relationships.
Args:
text: Source code text
path: File path
Returns:
IndexedFile with symbols and relationships, None if unavailable
"""
pass
class AstGrepPythonProcessor(BaseAstGrepProcessor):
"""Python-specific ast-grep processor for relationship extraction.
Extracts INHERITS, CALLS, and IMPORTS relationships from Python code
using declarative ast-grep patterns with scope-aware processing.
"""
def __init__(self, path: Optional[Path] = None) -> None:
"""Initialize Python processor.
Args:
path: Optional file path (for consistency with base class)
"""
super().__init__("python", path)
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse Python source code and extract relationships.
Args:
text: Python source code text
path: File path
Returns:
IndexedFile with symbols and relationships, None if unavailable
"""
if not self.is_available():
return None
try:
symbols = self._extract_symbols(text)
relationships = self._extract_relationships(text, path)
return IndexedFile(
path=str(path.resolve()),
language="python",
symbols=symbols,
chunks=[],
relationships=relationships,
)
except (ValueError, TypeError, AttributeError) as e:
# Log specific parsing errors for debugging
import logging
logging.getLogger(__name__).debug(f"ast-grep parsing error: {e}")
return None
def _extract_symbols(self, source_code: str) -> List[Symbol]:
"""Extract Python symbols (classes, functions, methods).
Args:
source_code: Python source code
Returns:
List of Symbol objects
"""
symbols: List[Symbol] = []
# Collect all scope definitions with line ranges for proper method detection
# Format: (start_line, end_line, kind, name)
scope_defs: List[Tuple[int, int, str, str]] = []
# Track async function positions to avoid duplicates
async_positions: set = set()
# Extract class definitions
class_matches = self.run_ast_grep(source_code, get_pattern("class_def"))
for node in class_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
scope_defs.append((start_line, end_line, "class", name))
# Extract async function definitions FIRST (before regular functions)
async_matches = self.run_ast_grep(source_code, get_pattern("async_func_def"))
for node in async_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
scope_defs.append((start_line, end_line, "function", name))
async_positions.add(start_line) # Mark this position as async
# Extract function definitions (skip those already captured as async)
func_matches = self.run_ast_grep(source_code, get_pattern("func_def"))
for node in func_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
# Skip if already captured as async function (same position)
if start_line not in async_positions:
scope_defs.append((start_line, end_line, "function", name))
# Sort by start line for scope-aware processing
scope_defs.sort(key=lambda x: x[0])
# Process with scope tracking to determine method vs function
scope_stack: List[Tuple[str, int, str]] = [] # (name, end_line, kind)
for start_line, end_line, kind, name in scope_defs:
# Pop scopes that have ended
while scope_stack and scope_stack[-1][1] < start_line:
scope_stack.pop()
if kind == "class":
symbols.append(Symbol(
name=name,
kind="class",
range=(start_line, end_line),
))
scope_stack.append((name, end_line, "class"))
else: # function
# Determine if it's a method (inside a class) or function
is_method = bool(scope_stack) and scope_stack[-1][2] == "class"
symbols.append(Symbol(
name=name,
kind="method" if is_method else "function",
range=(start_line, end_line),
))
scope_stack.append((name, end_line, "function"))
return symbols
def _extract_relationships(self, source_code: str, path: Path) -> List[CodeRelationship]:
"""Extract code relationships with scope and alias resolution.
Args:
source_code: Python source code
path: File path
Returns:
List of CodeRelationship objects
"""
if not self.is_available() or self._binding is None:
return []
source_file = str(path.resolve())
# Collect all matches with line numbers and end lines for scope processing
# Format: (start_line, end_line, match_type, symbol, node)
all_matches: List[Tuple[int, int, str, str, Any]] = []
# Get class definitions (with and without bases) for scope tracking
class_with_bases = self.run_ast_grep(source_code, get_pattern("class_with_bases"))
for node in class_with_bases:
class_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if class_name:
# Record class scope and inheritance
all_matches.append((start_line, end_line, "class_def", class_name, node))
# Extract bases from node text (ast-grep-py 0.40+ doesn't capture $$$)
node_text = self._binding._get_node_text(node) if self._binding else ""
bases_text = self._extract_bases_from_class_text(node_text)
if bases_text:
# Also record inheritance relationship
all_matches.append((start_line, end_line, "inherits", bases_text, node))
# Get classes without bases for scope tracking
class_no_bases = self.run_ast_grep(source_code, get_pattern("class_def"))
for node in class_no_bases:
class_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if class_name:
# Check if not already recorded (avoid duplicates from class_with_bases)
existing = [m for m in all_matches if m[2] == "class_def" and m[3] == class_name and m[0] == start_line]
if not existing:
all_matches.append((start_line, end_line, "class_def", class_name, node))
# Get function definitions for scope tracking
func_matches = self.run_ast_grep(source_code, get_pattern("func_def"))
for node in func_matches:
func_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if func_name:
all_matches.append((start_line, end_line, "func_def", func_name, node))
# Get async function definitions for scope tracking
async_func_matches = self.run_ast_grep(source_code, get_pattern("async_func_def"))
for node in async_func_matches:
func_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if func_name:
all_matches.append((start_line, end_line, "func_def", func_name, node))
# Get import matches
import_matches = self.run_ast_grep(source_code, get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
start_line, end_line = self._get_line_range(node)
if module:
all_matches.append((start_line, end_line, "import", module, node))
from_matches = self.run_ast_grep(source_code, get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
start_line, end_line = self._get_line_range(node)
if module:
all_matches.append((start_line, end_line, "from_import", f"{module}:{names}", node))
# Get call matches
call_matches = self.run_ast_grep(source_code, get_pattern("call"))
for node in call_matches:
func = self._get_match(node, "FUNC")
start_line, end_line = self._get_line_range(node)
if func:
# Skip self. and cls. prefixed calls
base = func.split(".", 1)[0]
if base not in {"self", "cls"}:
all_matches.append((start_line, end_line, "call", func, node))
# Sort by start line number for scope processing
all_matches.sort(key=lambda x: (x[0], x[2] == "call")) # Process scope defs before calls on same line
# Process with scope tracking
relationships = self._process_scope_and_aliases(all_matches, source_file)
return relationships
def _process_scope_and_aliases(
self,
matches: List[Tuple[int, int, str, str, Any]],
source_file: str,
) -> List[CodeRelationship]:
"""Process matches with scope and alias resolution.
Implements proper scope tracking similar to treesitter_parser.py:
- Maintains scope_stack for tracking current scope (class/function names)
- Maintains alias_stack with per-scope alias mappings (inherited from parent)
- Pops scopes when current line passes their end line
- Resolves call targets using current scope's alias map
Args:
matches: Sorted list of (start_line, end_line, type, symbol, node) tuples
source_file: Source file path
Returns:
List of resolved CodeRelationship objects
"""
relationships: List[CodeRelationship] = []
# Scope stack: list of (name, end_line) tuples
scope_stack: List[Tuple[str, int]] = [("<module>", float("inf"))]
# Alias stack: list of alias dicts, one per scope level
# Each new scope inherits parent's aliases (copy on write)
alias_stack: List[Dict[str, str]] = [{}]
def get_current_scope() -> str:
"""Get the name of the current (innermost) scope."""
return scope_stack[-1][0]
def pop_scopes_before(line: int) -> None:
"""Pop all scopes that have ended before the given line."""
while len(scope_stack) > 1 and scope_stack[-1][1] < line:
scope_stack.pop()
alias_stack.pop()
def push_scope(name: str, end_line: int) -> None:
"""Push a new scope onto the stack."""
scope_stack.append((name, end_line))
# Copy parent scope's aliases for inheritance
alias_stack.append(dict(alias_stack[-1]))
def update_aliases(updates: Dict[str, str]) -> None:
"""Update current scope's alias map."""
alias_stack[-1].update(updates)
def resolve_alias(symbol: str) -> str:
"""Resolve a symbol using current scope's alias map."""
if "." not in symbol:
# Simple name - check if it's an alias
return alias_stack[-1].get(symbol, symbol)
# Dotted name - resolve the base
parts = symbol.split(".", 1)
base = parts[0]
rest = parts[1]
if base in alias_stack[-1]:
return f"{alias_stack[-1][base]}.{rest}"
return symbol
for start_line, end_line, match_type, symbol, node in matches:
# Pop any scopes that have ended
pop_scopes_before(start_line)
if match_type == "class_def":
# Push class scope
push_scope(symbol, end_line)
elif match_type == "func_def":
# Push function scope
push_scope(symbol, end_line)
elif match_type == "inherits":
# Record inheritance relationship
# Parse base classes from the bases text
base_classes = self._parse_base_classes(symbol)
for base_class in base_classes:
base_class = base_class.strip()
if base_class:
# Resolve alias for base class
resolved_base = resolve_alias(base_class)
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=resolved_base,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
elif match_type == "import":
# Process import statement
module = symbol
# Simple import: add base name to alias map
base_name = module.split(".", 1)[0]
update_aliases({base_name: module})
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
elif match_type == "from_import":
# Process from-import statement
parts = symbol.split(":", 1)
module = parts[0]
names = parts[1] if len(parts) > 1 else ""
# Record the import relationship
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
# Add aliases for imported names
if names and names != "*":
for name in names.split(","):
name = name.strip()
# Handle "name as alias" syntax
if " as " in name:
as_parts = name.split(" as ")
original = as_parts[0].strip()
alias = as_parts[1].strip()
if alias:
update_aliases({alias: f"{module}.{original}"})
elif name:
update_aliases({name: f"{module}.{name}"})
elif match_type == "call":
# Resolve alias for call target
resolved = resolve_alias(symbol)
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=resolved,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=start_line,
))
return relationships
def process_matches(
self,
matches: List[SgNode], # type: ignore[valid-type]
source_code: str,
path: Path,
) -> List[CodeRelationship]:
"""Process ast-grep matches into code relationships.
This is a simplified interface for direct match processing.
For full relationship extraction with scope tracking, use parse().
Args:
matches: List of matched SgNode objects
source_code: Original source code
path: File path being processed
Returns:
List of extracted code relationships
"""
if not self.is_available() or self._binding is None:
return []
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
for node in matches:
# Default to call relationship for generic matches
func = self._get_match(node, "FUNC")
line = self._get_line_number(node)
if func:
base = func.split(".", 1)[0]
if base not in {"self", "cls"}:
relationships.append(CodeRelationship(
source_symbol="<module>",
target_symbol=func,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def _get_match(self, node: SgNode, metavar: str) -> str: # type: ignore[valid-type]
"""Extract matched metavariable value from node.
Args:
node: SgNode with match
metavar: Metavariable name (without $ prefix)
Returns:
Matched text or empty string
"""
if self._binding is None or node is None:
return ""
return self._binding._get_match(node, metavar)
def _get_line_number(self, node: SgNode) -> int: # type: ignore[valid-type]
"""Get starting line number of a node.
Args:
node: SgNode to get line number for
Returns:
1-based line number
"""
if self._binding is None or node is None:
return 0
return self._binding._get_line_number(node)
def _get_line_range(self, node: SgNode) -> Tuple[int, int]: # type: ignore[valid-type]
"""Get line range for a node.
Args:
node: SgNode to get range for
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
if self._binding is None or node is None:
return (0, 0)
return self._binding._get_line_range(node)
# =========================================================================
# Dedicated extraction methods for INHERITS, CALL, IMPORTS relationships
# =========================================================================
def extract_inherits(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
) -> List[CodeRelationship]:
"""Extract INHERITS relationships from Python code.
Identifies class inheritance patterns including:
- Single inheritance: class Child(Parent):
- Multiple inheritance: class Child(A, B, C):
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
Returns:
List of CodeRelationship objects with INHERITS type
"""
if not self.is_available():
return []
relationships: List[CodeRelationship] = []
# Use class_with_bases pattern to find classes with inheritance
matches = self.run_ast_grep(source_code, get_pattern("class_with_bases"))
for node in matches:
class_name = self._get_match(node, "NAME")
line = self._get_line_number(node)
if class_name:
# Extract bases from the node text (first line: "class ClassName(Base1, Base2):")
# ast-grep-py 0.40+ doesn't capture $$$ multi-matches, so parse from text
node_text = self._binding._get_node_text(node) if self._binding else ""
bases_text = self._extract_bases_from_class_text(node_text)
if bases_text:
# Parse individual base classes from the bases text
base_classes = self._parse_base_classes(bases_text)
for base_class in base_classes:
base_class = base_class.strip()
if base_class:
relationships.append(CodeRelationship(
source_symbol=class_name,
target_symbol=base_class,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def _extract_bases_from_class_text(self, class_text: str) -> str:
"""Extract base classes text from class definition.
Args:
class_text: Full text of class definition (e.g., "class Dog(Animal):\\n pass")
Returns:
Text inside parentheses (e.g., "Animal") or empty string
"""
import re
# Match "class Name(BASES):" - extract BASES
match = re.search(r'class\s+\w+\s*\(([^)]*)\)\s*:', class_text)
if match:
return match.group(1).strip()
return ""
def extract_calls(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
alias_map: Optional[Dict[str, str]] = None,
) -> List[CodeRelationship]:
"""Extract CALL relationships from Python code.
Identifies function and method call patterns including:
- Simple calls: func()
- Calls with arguments: func(arg1, arg2)
- Method calls: obj.method()
- Chained calls: obj.method1().method2()
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
alias_map: Optional alias map for resolving imported names
Returns:
List of CodeRelationship objects with CALL type
"""
if not self.is_available():
return []
relationships: List[CodeRelationship] = []
alias_map = alias_map or {}
# Use the generic call pattern
matches = self.run_ast_grep(source_code, get_pattern("call"))
for node in matches:
func = self._get_match(node, "FUNC")
line = self._get_line_number(node)
if func:
# Skip self. and cls. prefixed calls (internal method calls)
base = func.split(".", 1)[0]
if base in {"self", "cls", "super"}:
continue
# Resolve alias if available
resolved = self._resolve_call_alias(func, alias_map)
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=resolved,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def extract_imports(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
) -> Tuple[List[CodeRelationship], Dict[str, str]]:
"""Extract IMPORTS relationships from Python code.
Identifies import patterns including:
- Simple import: import os
- Import with alias: import numpy as np
- From import: from typing import List
- From import with alias: from collections import defaultdict as dd
- Relative import: from .module import func
- Star import: from module import *
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
Returns:
Tuple of:
- List of CodeRelationship objects with IMPORTS type
- Dict mapping local names to fully qualified module names (alias map)
"""
if not self.is_available():
return [], {}
relationships: List[CodeRelationship] = []
alias_map: Dict[str, str] = {}
# Process simple imports: import X
import_matches = self.run_ast_grep(source_code, get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
line = self._get_line_number(node)
if module:
# Add to alias map: first part of module
base_name = module.split(".", 1)[0]
alias_map[base_name] = module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process import with alias: import X as Y
alias_matches = self.run_ast_grep(source_code, get_pattern("import_with_alias"))
for node in alias_matches:
module = self._get_match(node, "MODULE")
alias = self._get_match(node, "ALIAS")
line = self._get_line_number(node)
if module and alias:
alias_map[alias] = module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process from imports: from X import Y
from_matches = self.run_ast_grep(source_code, get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
line = self._get_line_number(node)
if module:
# Add relationship for the module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Add aliases for imported names
if names and names != "*":
for name in names.split(","):
name = name.strip()
# Handle "name as alias" syntax
if " as " in name:
parts = name.split(" as ")
original = parts[0].strip()
alias = parts[1].strip()
alias_map[alias] = f"{module}.{original}"
elif name:
alias_map[name] = f"{module}.{name}"
# Process star imports: from X import *
star_matches = self.run_ast_grep(source_code, get_pattern("from_import_star"))
for node in star_matches:
module = self._get_match(node, "MODULE")
line = self._get_line_number(node)
if module:
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=f"{module}.*",
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process relative imports: from .X import Y
relative_matches = self.run_ast_grep(source_code, get_pattern("relative_import"))
for node in relative_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
line = self._get_line_number(node)
# Prepend dot for relative module path
rel_module = f".{module}" if module else "."
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=rel_module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships, alias_map
# =========================================================================
# Helper methods for pattern processing
# =========================================================================
def _parse_base_classes(self, bases_text: str) -> List[str]:
"""Parse base class names from inheritance text.
Handles single and multiple inheritance with proper comma splitting.
Accounts for nested parentheses and complex type annotations.
Args:
bases_text: Text inside the parentheses of class definition
Returns:
List of base class names
"""
if not bases_text:
return []
# Simple comma split (may not handle all edge cases)
bases = []
depth = 0
current = []
for char in bases_text:
if char == "(":
depth += 1
current.append(char)
elif char == ")":
depth -= 1
current.append(char)
elif char == "," and depth == 0:
base = "".join(current).strip()
if base:
bases.append(base)
current = []
else:
current.append(char)
# Add the last base class
if current:
base = "".join(current).strip()
if base:
bases.append(base)
return bases
def _resolve_call_alias(self, func_name: str, alias_map: Dict[str, str]) -> str:
"""Resolve a function call name using import aliases.
Args:
func_name: The function/method name as it appears in code
alias_map: Mapping of local names to fully qualified names
Returns:
Resolved function name (fully qualified if possible)
"""
if "." not in func_name:
# Simple function call - check if it's an alias
return alias_map.get(func_name, func_name)
# Method call or qualified name - resolve the base
parts = func_name.split(".", 1)
base = parts[0]
rest = parts[1]
if base in alias_map:
return f"{alias_map[base]}.{rest}"
return func_name
def is_astgrep_processor_available() -> bool:
"""Check if ast-grep processor is available.
Returns:
True if ast-grep-py is installed and processor can be used
"""
return ASTGREP_AVAILABLE
__all__ = [
"BaseAstGrepProcessor",
"AstGrepPythonProcessor",
"is_astgrep_processor_available",
]

View File

@@ -0,0 +1,5 @@
"""ast-grep pattern definitions for various languages.
This package contains language-specific pattern definitions for
extracting code relationships using ast-grep declarative patterns.
"""

View File

@@ -0,0 +1,204 @@
"""Python ast-grep patterns for relationship extraction.
This module defines declarative patterns for extracting code relationships
(inheritance, calls, imports) from Python source code using ast-grep.
Pattern Syntax (ast-grep-py 0.40+):
$VAR - Single metavariable (matches one AST node)
$$$VAR - Multiple metavariable (matches zero or more nodes)
Example:
"class $CLASS_NAME($$$BASES) $$$BODY" matches:
class MyClass(BaseClass):
pass
with $CLASS_NAME = "MyClass", $$$BASES = "BaseClass", $$$BODY = "pass"
YAML Pattern Files:
inherits.yaml - INHERITS relationship patterns (single/multiple inheritance)
imports.yaml - IMPORTS relationship patterns (import, from...import, as)
call.yaml - CALL relationship patterns (function/method calls)
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
# Directory containing YAML pattern files
PATTERNS_DIR = Path(__file__).parent
# Python ast-grep patterns organized by relationship type
# Note: ast-grep-py 0.40+ uses $$$ for zero-or-more multi-match
PYTHON_PATTERNS: Dict[str, str] = {
# Class definitions with inheritance
"class_def": "class $NAME $$$BODY",
"class_with_bases": "class $NAME($$$BASES) $$$BODY",
# Single inheritance: class Child(Parent):
"single_inheritance": "class $CLASS_NAME($BASE) $$$BODY",
# Multiple inheritance: class Child(A, B, C):
"multiple_inheritance": "class $CLASS_NAME($BASE, $$$MORE_BASES) $$$BODY",
# Function definitions (use $$$ for zero-or-more params)
"func_def": "def $NAME($$$PARAMS): $$$BODY",
"async_func_def": "async def $NAME($$$PARAMS): $$$BODY",
# Import statements - basic forms
"import_stmt": "import $MODULE",
"import_from": "from $MODULE import $NAMES",
# Import statements - extended forms
"import_with_alias": "import $MODULE as $ALIAS",
"import_multiple": "import $FIRST, $$$REST",
"from_import_single": "from $MODULE import $NAME",
"from_import_with_alias": "from $MODULE import $NAME as $ALIAS",
"from_import_multiple": "from $MODULE import $FIRST, $$$REST",
"from_import_star": "from $MODULE import *",
"relative_import": "from .$$$MODULE import $NAMES",
# Function/method calls - basic form (use $$$ for zero-or-more args)
"call": "$FUNC($$$ARGS)",
"method_call": "$OBJ.$METHOD($$$ARGS)",
# Function/method calls - specific forms
"simple_call": "$FUNC()",
"call_with_args": "$FUNC($$$ARGS)",
"chained_call": "$OBJ.$METHOD($$$ARGS).$$$CHAIN",
"constructor_call": "$CLASS($$$ARGS)",
}
# Metavariable names for extracting match data
METAVARS = {
# Class patterns
"class_name": "NAME",
"class_bases": "BASES",
"class_body": "BODY",
"inherit_class": "CLASS_NAME",
"inherit_base": "BASE",
"inherit_more_bases": "MORE_BASES",
# Function patterns
"func_name": "NAME",
"func_params": "PARAMS",
"func_body": "BODY",
# Import patterns
"import_module": "MODULE",
"import_names": "NAMES",
"import_alias": "ALIAS",
"import_first": "FIRST",
"import_rest": "REST",
# Call patterns
"call_func": "FUNC",
"call_obj": "OBJ",
"call_method": "METHOD",
"call_args": "ARGS",
"call_class": "CLASS",
"call_chain": "CHAIN",
}
# Relationship pattern mapping - expanded for new patterns
RELATIONSHIP_PATTERNS: Dict[str, List[str]] = {
"inheritance": ["class_with_bases", "single_inheritance", "multiple_inheritance"],
"imports": [
"import_stmt", "import_from",
"import_with_alias", "import_multiple",
"from_import_single", "from_import_with_alias",
"from_import_multiple", "from_import_star",
"relative_import",
],
"calls": ["call", "method_call", "simple_call", "call_with_args", "constructor_call"],
}
# YAML pattern file mapping
YAML_PATTERN_FILES = {
"inheritance": "inherits.yaml",
"imports": "imports.yaml",
"calls": "call.yaml",
}
def get_pattern(pattern_name: str) -> str:
"""Get an ast-grep pattern by name.
Args:
pattern_name: Key from PYTHON_PATTERNS dict
Returns:
Pattern string
Raises:
KeyError: If pattern name not found
"""
if pattern_name not in PYTHON_PATTERNS:
raise KeyError(f"Unknown pattern: {pattern_name}. Available: {list(PYTHON_PATTERNS.keys())}")
return PYTHON_PATTERNS[pattern_name]
def get_patterns_for_relationship(rel_type: str) -> List[str]:
"""Get all patterns that can extract a given relationship type.
Args:
rel_type: Relationship type (inheritance, imports, calls)
Returns:
List of pattern names
"""
return RELATIONSHIP_PATTERNS.get(rel_type, [])
def get_metavar(name: str) -> str:
"""Get metavariable name without $ prefix.
Args:
name: Key from METAVARS dict
Returns:
Metavariable name (e.g., "NAME" not "$NAME")
"""
return METAVARS.get(name, name.upper())
def get_yaml_pattern_path(rel_type: str) -> Optional[Path]:
"""Get the path to a YAML pattern file for a relationship type.
Args:
rel_type: Relationship type (inheritance, imports, calls)
Returns:
Path to YAML file or None if not found
"""
filename = YAML_PATTERN_FILES.get(rel_type)
if filename:
return PATTERNS_DIR / filename
return None
def list_yaml_pattern_files() -> Dict[str, Path]:
"""List all available YAML pattern files.
Returns:
Dict mapping relationship type to YAML file path
"""
result = {}
for rel_type, filename in YAML_PATTERN_FILES.items():
path = PATTERNS_DIR / filename
if path.exists():
result[rel_type] = path
return result
__all__ = [
"PYTHON_PATTERNS",
"METAVARS",
"RELATIONSHIP_PATTERNS",
"YAML_PATTERN_FILES",
"PATTERNS_DIR",
"get_pattern",
"get_patterns_for_relationship",
"get_metavar",
"get_yaml_pattern_path",
"list_yaml_pattern_files",
]

View File

@@ -0,0 +1,87 @@
# Python CALL patterns for ast-grep
# Extracts function and method call expressions
# Pattern metadata
id: python-call
language: python
description: Extract function and method calls from Python code
patterns:
# Simple function call
# Matches: func()
- id: simple_call
pattern: "$FUNC()"
message: "Found simple function call"
severity: hint
# Function call with arguments
# Matches: func(arg1, arg2)
- id: call_with_args
pattern: "$FUNC($$$ARGS)"
message: "Found function call with arguments"
severity: hint
# Method call
# Matches: obj.method()
- id: method_call
pattern: "$OBJ.$METHOD($$$ARGS)"
message: "Found method call"
severity: hint
# Chained method call
# Matches: obj.method1().method2()
- id: chained_call
pattern: "$OBJ.$METHOD($$$ARGS).$$$CHAIN"
message: "Found chained method call"
severity: hint
# Call with keyword arguments
# Matches: func(arg=value)
- id: call_with_kwargs
pattern: "$FUNC($$$ARGS, $KWARG=$VALUE$$$MORE)"
message: "Found call with keyword argument"
severity: hint
# Constructor call
# Matches: ClassName()
- id: constructor_call
pattern: "$CLASS($$$ARGS)"
message: "Found constructor call"
severity: hint
# Subscript call (not a real call, but often confused)
# This pattern helps exclude indexing from calls
- id: subscript_access
pattern: "$OBJ[$INDEX]"
message: "Found subscript access"
severity: hint
# Metavariables used:
# $FUNC - Function name being called
# $OBJ - Object receiving the method call
# $METHOD - Method name being called
# $ARGS - Positional arguments
# $KWARG - Keyword argument name
# $VALUE - Keyword argument value
# $CLASS - Class name for constructor calls
# $INDEX - Index for subscript access
# $$$MORE - Additional arguments
# $$$CHAIN - Additional method chains
# Note: The generic call pattern "$FUNC($$$ARGS)" will match all function calls
# including method calls and constructor calls. More specific patterns help
# categorize the type of call.
# Examples matched:
# print("hello") -> call_with_args
# len(items) -> call_with_args
# obj.process() -> method_call
# obj.get().save() -> chained_call
# func(name=value) -> call_with_kwargs
# MyClass() -> constructor_call
# items[0] -> subscript_access (not a call)
# Filtering notes:
# - self.method() calls are typically filtered during processing
# - cls.method() calls are typically filtered during processing
# - super().method() calls may be handled specially

View File

@@ -0,0 +1,82 @@
# Python IMPORTS patterns for ast-grep
# Extracts import statements (import, from...import, as aliases)
# Pattern metadata
id: python-imports
language: python
description: Extract import statements from Python code
patterns:
# Simple import
# Matches: import os
- id: simple_import
pattern: "import $MODULE"
message: "Found simple import"
severity: hint
# Import with alias
# Matches: import numpy as np
- id: import_with_alias
pattern: "import $MODULE as $ALIAS"
message: "Found import with alias"
severity: hint
# Multiple imports
# Matches: import os, sys
- id: multiple_imports
pattern: "import $FIRST, $$$REST"
message: "Found multiple imports"
severity: hint
# From import (single name)
# Matches: from os import path
- id: from_import_single
pattern: "from $MODULE import $NAME"
message: "Found from-import single"
severity: hint
# From import with alias
# Matches: from collections import defaultdict as dd
- id: from_import_with_alias
pattern: "from $MODULE import $NAME as $ALIAS"
message: "Found from-import with alias"
severity: hint
# From import multiple names
# Matches: from typing import List, Dict, Optional
- id: from_import_multiple
pattern: "from $MODULE import $FIRST, $$$REST"
message: "Found from-import multiple"
severity: hint
# From import star
# Matches: from module import *
- id: from_import_star
pattern: "from $MODULE import *"
message: "Found star import"
severity: warning
# Relative import
# Matches: from .module import func
- id: relative_import
pattern: "from .$$$MODULE import $NAMES"
message: "Found relative import"
severity: hint
# Metavariables used:
# $MODULE - The module being imported
# $ALIAS - The alias for the import
# $NAME - The specific name being imported
# $FIRST - First item in a multi-item import
# $$$REST - Remaining items in a multi-item import
# $NAMES - Names being imported in from-import
# Examples matched:
# import os -> simple_import
# import numpy as np -> import_with_alias
# import os, sys, pathlib -> multiple_imports
# from os import path -> from_import_single
# from typing import List, Dict, Set -> from_import_multiple
# from collections import defaultdict -> from_import_single
# from .helpers import utils -> relative_import
# from module import * -> from_import_star

View File

@@ -0,0 +1,42 @@
# Python INHERITS patterns for ast-grep
# Extracts class inheritance relationships (single and multiple inheritance)
# Pattern metadata
id: python-inherits
language: python
description: Extract class inheritance relationships from Python code
# Single inheritance pattern
# Matches: class Child(Parent):
patterns:
- id: single_inheritance
pattern: "class $CLASS_NAME($BASE) $$$BODY"
message: "Found single inheritance"
severity: hint
# Multiple inheritance pattern
# Matches: class Child(Parent1, Parent2, Parent3):
- id: multiple_inheritance
pattern: "class $CLASS_NAME($BASE, $$$MORE_BASES) $$$BODY"
message: "Found multiple inheritance"
severity: hint
# Generic inheritance with any number of bases
# Matches: class Child(...): with any number of parent classes
- id: class_with_bases
pattern: "class $NAME($$$BASES) $$$BODY"
message: "Found class with base classes"
severity: hint
# Metavariables used:
# $CLASS_NAME - The name of the child class
# $BASE - First base class (for single inheritance)
# $BASES - All base classes combined
# $MORE_BASES - Additional base classes after the first (for multiple inheritance)
# $$$BODY - Class body (statements, can be multiple)
# Examples matched:
# class Dog(Animal): -> single_inheritance
# class C(A, B): -> multiple_inheritance
# class D(BaseMixin, logging.Log) -> class_with_bases
# class E(A, B, C, D): -> multiple_inheritance

View File

@@ -11,7 +11,7 @@ return `None`; callers should use a regex-based fallback such as
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, TYPE_CHECKING
try: try:
from tree_sitter import Language as TreeSitterLanguage from tree_sitter import Language as TreeSitterLanguage
@@ -27,26 +27,45 @@ except ImportError:
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer from codexlens.parsers.tokenizer import get_default_tokenizer
if TYPE_CHECKING:
from codexlens.config import Config
class TreeSitterSymbolParser: class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction.""" """Parser using tree-sitter for AST-level symbol extraction.
def __init__(self, language_id: str, path: Optional[Path] = None) -> None: Supports optional ast-grep integration for Python relationship extraction
when config.use_astgrep is True and ast-grep-py is available.
"""
def __init__(
self,
language_id: str,
path: Optional[Path] = None,
config: Optional["Config"] = None,
) -> None:
"""Initialize tree-sitter parser for a language. """Initialize tree-sitter parser for a language.
Args: Args:
language_id: Language identifier (python, javascript, typescript, etc.) language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx) path: Optional file path for language variant detection (e.g., .tsx)
config: Optional Config instance for parser feature toggles
""" """
self.language_id = language_id self.language_id = language_id
self.path = path self.path = path
self._config = config
self._parser: Optional[object] = None self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer() self._tokenizer = get_default_tokenizer()
self._astgrep_processor = None
if TREE_SITTER_AVAILABLE: if TREE_SITTER_AVAILABLE:
self._initialize_parser() self._initialize_parser()
# Initialize ast-grep processor for Python if config enables it
if self._should_use_astgrep():
self._initialize_astgrep_processor()
def _initialize_parser(self) -> None: def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language.""" """Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None: if TreeSitterParser is None or TreeSitterLanguage is None:
@@ -82,6 +101,31 @@ class TreeSitterSymbolParser:
self._parser = None self._parser = None
self._language = None self._language = None
def _should_use_astgrep(self) -> bool:
"""Check if ast-grep should be used for relationship extraction.
Returns:
True if config.use_astgrep is True and language is Python
"""
if self._config is None:
return False
if not getattr(self._config, "use_astgrep", False):
return False
return self.language_id == "python"
def _initialize_astgrep_processor(self) -> None:
"""Initialize ast-grep processor for Python relationship extraction."""
try:
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
is_astgrep_processor_available,
)
if is_astgrep_processor_available():
self._astgrep_processor = AstGrepPythonProcessor(self.path)
except ImportError:
self._astgrep_processor = None
def is_available(self) -> bool: def is_available(self) -> bool:
"""Check if tree-sitter parser is available. """Check if tree-sitter parser is available.
@@ -138,7 +182,10 @@ class TreeSitterSymbolParser:
source_bytes, root = parsed source_bytes, root = parsed
try: try:
symbols = self._extract_symbols(source_bytes, root) symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path) # Pass source_code for ast-grep integration
relationships = self._extract_relationships(
source_bytes, root, path, source_code=text
)
return IndexedFile( return IndexedFile(
path=str(path.resolve()), path=str(path.resolve()),
@@ -173,13 +220,68 @@ class TreeSitterSymbolParser:
source_bytes: bytes, source_bytes: bytes,
root: TreeSitterNode, root: TreeSitterNode,
path: Path, path: Path,
source_code: Optional[str] = None,
) -> List[CodeRelationship]: ) -> List[CodeRelationship]:
"""Extract relationships, optionally using ast-grep for Python.
When config.use_astgrep is True and ast-grep is available for Python,
uses ast-grep for relationship extraction. Otherwise, uses tree-sitter.
Args:
source_bytes: Source code as bytes
root: Root AST node from tree-sitter
path: File path
source_code: Optional source code string (required for ast-grep)
Returns:
List of extracted relationships
"""
if self.language_id == "python": if self.language_id == "python":
# Try ast-grep first if configured and available
if self._astgrep_processor is not None and source_code is not None:
try:
astgrep_rels = self._extract_python_relationships_astgrep(
source_code, path
)
if astgrep_rels is not None:
return astgrep_rels
except Exception:
# Fall back to tree-sitter on ast-grep failure
pass
return self._extract_python_relationships(source_bytes, root, path) return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}: if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path) return self._extract_js_ts_relationships(source_bytes, root, path)
return [] return []
def _extract_python_relationships_astgrep(
self,
source_code: str,
path: Path,
) -> Optional[List[CodeRelationship]]:
"""Extract Python relationships using ast-grep processor.
Args:
source_code: Python source code text
path: File path
Returns:
List of relationships, or None if ast-grep unavailable
"""
if self._astgrep_processor is None:
return None
if not self._astgrep_processor.is_available():
return None
try:
indexed = self._astgrep_processor.parse(source_code, path)
if indexed is not None:
return indexed.relationships
except Exception:
pass
return None
def _extract_python_relationships( def _extract_python_relationships(
self, self,
source_bytes: bytes, source_bytes: bytes,

View File

@@ -0,0 +1 @@
"""Tests for codexlens.parsers modules."""

View File

@@ -0,0 +1,444 @@
"""Tests for dedicated extraction methods: extract_inherits, extract_calls, extract_imports.
Tests pattern-based relationship extraction from Python source code
using ast-grep-py bindings for INHERITS, CALL, and IMPORTS relationships.
"""
from pathlib import Path
import pytest
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
is_astgrep_processor_available,
)
from codexlens.entities import RelationshipType
# Check if ast-grep is available for conditional test skipping
ASTGREP_AVAILABLE = is_astgrep_processor_available()
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractInherits:
"""Tests for extract_inherits method - INHERITS relationship extraction."""
def test_single_inheritance(self):
"""Test extraction of single inheritance relationship."""
processor = AstGrepPythonProcessor()
code = """
class Animal:
pass
class Dog(Animal):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "Dog"
assert rel.target_symbol == "Animal"
assert rel.relationship_type == RelationshipType.INHERITS
def test_multiple_inheritance(self):
"""Test extraction of multiple inheritance relationships."""
processor = AstGrepPythonProcessor()
code = """
class A:
pass
class B:
pass
class C(A, B):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
# Should have 2 relationships: C->A and C->B
assert len(relationships) == 2
targets = {r.target_symbol for r in relationships}
assert "A" in targets
assert "B" in targets
for rel in relationships:
assert rel.source_symbol == "C"
def test_no_inheritance(self):
"""Test that classes without inheritance return empty list."""
processor = AstGrepPythonProcessor()
code = """
class Standalone:
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 0
def test_nested_class_inheritance(self):
"""Test extraction of inheritance in nested classes."""
processor = AstGrepPythonProcessor()
code = """
class Outer:
class Inner(Base):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 1
assert relationships[0].source_symbol == "Inner"
assert relationships[0].target_symbol == "Base"
def test_inheritance_with_complex_bases(self):
"""Test extraction with generic or complex base classes."""
processor = AstGrepPythonProcessor()
code = """
class Service(BaseService, mixins.Loggable):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 2
targets = {r.target_symbol for r in relationships}
assert "BaseService" in targets
assert "mixins.Loggable" in targets
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractCalls:
"""Tests for extract_calls method - CALL relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple function calls."""
processor = AstGrepPythonProcessor()
code = """
def main():
print("hello")
len([1, 2, 3])
"""
relationships = processor.extract_calls(code, "test.py", "main")
targets = {r.target_symbol for r in relationships}
assert "print" in targets
assert "len" in targets
def test_method_call(self):
"""Test extraction of method calls."""
processor = AstGrepPythonProcessor()
code = """
def process():
obj.method()
items.append(1)
"""
relationships = processor.extract_calls(code, "test.py", "process")
targets = {r.target_symbol for r in relationships}
assert "obj.method" in targets
assert "items.append" in targets
def test_skips_self_calls(self):
"""Test that self.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Service:
def process(self):
self.internal()
external_func()
"""
relationships = processor.extract_calls(code, "test.py", "Service")
targets = {r.target_symbol for r in relationships}
# self.internal should be filtered
assert "self.internal" not in targets
assert "internal" not in targets
assert "external_func" in targets
def test_skips_cls_calls(self):
"""Test that cls.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Factory:
@classmethod
def create(cls):
cls.helper()
other_func()
"""
relationships = processor.extract_calls(code, "test.py", "Factory")
targets = {r.target_symbol for r in relationships}
assert "cls.helper" not in targets
assert "other_func" in targets
def test_alias_resolution(self):
"""Test call alias resolution using import map."""
processor = AstGrepPythonProcessor()
code = """
def main():
np.array([1, 2, 3])
"""
alias_map = {"np": "numpy"}
relationships = processor.extract_calls(code, "test.py", "main", alias_map)
assert len(relationships) >= 1
# Should resolve np.array to numpy.array
assert any("numpy.array" in r.target_symbol for r in relationships)
def test_no_calls(self):
"""Test that code without calls returns empty list."""
processor = AstGrepPythonProcessor()
code = """
x = 1
y = x + 2
"""
relationships = processor.extract_calls(code, "test.py")
assert len(relationships) == 0
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractImports:
"""Tests for extract_imports method - IMPORTS relationship extraction."""
def test_simple_import(self):
"""Test extraction of simple import statements."""
processor = AstGrepPythonProcessor()
code = "import os"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "os"
assert relationships[0].relationship_type == RelationshipType.IMPORTS
assert alias_map.get("os") == "os"
def test_import_with_alias(self):
"""Test extraction of import with alias."""
processor = AstGrepPythonProcessor()
code = "import numpy as np"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "numpy"
assert alias_map.get("np") == "numpy"
def test_from_import(self):
"""Test extraction of from-import statements."""
processor = AstGrepPythonProcessor()
code = "from typing import List, Dict"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "typing"
assert alias_map.get("List") == "typing.List"
assert alias_map.get("Dict") == "typing.Dict"
def test_from_import_with_alias(self):
"""Test extraction of from-import with alias."""
processor = AstGrepPythonProcessor()
code = "from collections import defaultdict as dd"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
# The alias map should map dd to collections.defaultcount
assert "dd" in alias_map
assert "defaultdict" in alias_map.get("dd", "")
def test_star_import(self):
"""Test extraction of star imports."""
processor = AstGrepPythonProcessor()
code = "from module import *"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) >= 1
# Star import should be recorded
star_imports = [r for r in relationships if "*" in r.target_symbol]
assert len(star_imports) >= 1
def test_relative_import(self):
"""Test extraction of relative imports."""
processor = AstGrepPythonProcessor()
code = "from .utils import helper"
relationships, alias_map = processor.extract_imports(code, "test.py")
# Should capture the relative import
assert len(relationships) >= 1
rel_imports = [r for r in relationships if r.target_symbol.startswith(".")]
assert len(rel_imports) >= 1
def test_multiple_imports(self):
"""Test extraction of multiple import types."""
processor = AstGrepPythonProcessor()
code = """
import os
import sys
from typing import List
from collections import defaultdict as dd
"""
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) >= 4
targets = {r.target_symbol for r in relationships}
assert "os" in targets
assert "sys" in targets
assert "typing" in targets
assert "collections" in targets
def test_no_imports(self):
"""Test that code without imports returns empty list."""
processor = AstGrepPythonProcessor()
code = """
x = 1
def foo():
pass
"""
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 0
assert len(alias_map) == 0
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractMethodsIntegration:
"""Integration tests combining multiple extraction methods."""
def test_full_file_extraction(self):
"""Test extracting all relationships from a complete file."""
processor = AstGrepPythonProcessor()
code = """
import os
from typing import List, Optional
class Base:
pass
class Service(Base):
def __init__(self):
self.data = []
def process(self):
result = os.path.join("a", "b")
items = List([1, 2, 3])
return result
def main():
svc = Service()
svc.process()
"""
source_file = "test.py"
# Extract all relationship types
imports, alias_map = processor.extract_imports(code, source_file)
inherits = processor.extract_inherits(code, source_file)
calls = processor.extract_calls(code, source_file, alias_map=alias_map)
# Verify we got all expected relationships
assert len(imports) >= 2 # os and typing
assert len(inherits) == 1 # Service -> Base
assert len(calls) >= 2 # os.path.join and others
# Verify inheritance
assert any(r.source_symbol == "Service" and r.target_symbol == "Base"
for r in inherits)
def test_alias_propagation(self):
"""Test that import aliases propagate to call resolution."""
processor = AstGrepPythonProcessor()
code = """
import numpy as np
def compute():
arr = np.array([1, 2, 3])
return np.sum(arr)
"""
source_file = "test.py"
imports, alias_map = processor.extract_imports(code, source_file)
calls = processor.extract_calls(code, source_file, alias_map=alias_map)
# Alias map should have np -> numpy
assert alias_map.get("np") == "numpy"
# Calls should resolve np.array and np.sum
resolved_targets = {r.target_symbol for r in calls}
# At minimum, np.array and np.sum should be captured
np_calls = [t for t in resolved_targets if "np" in t or "numpy" in t]
assert len(np_calls) >= 2
class TestExtractMethodFallback:
"""Tests for fallback behavior when ast-grep unavailable."""
def test_extract_inherits_empty_when_unavailable(self):
"""Test extract_inherits returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "class Dog(Animal): pass"
relationships = processor.extract_inherits(code, "test.py")
assert relationships == []
def test_extract_calls_empty_when_unavailable(self):
"""Test extract_calls returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "print('hello')"
relationships = processor.extract_calls(code, "test.py")
assert relationships == []
def test_extract_imports_empty_when_unavailable(self):
"""Test extract_imports returns empty tuple when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "import os"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert relationships == []
assert alias_map == {}
class TestHelperMethods:
"""Tests for internal helper methods."""
def test_parse_base_classes_single(self):
"""Test _parse_base_classes with single base."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("BaseClass")
assert result == ["BaseClass"]
def test_parse_base_classes_multiple(self):
"""Test _parse_base_classes with multiple bases."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("A, B, C")
assert result == ["A", "B", "C"]
def test_parse_base_classes_with_generics(self):
"""Test _parse_base_classes with generic types."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("Generic[T], Mixin")
assert "Generic[T]" in result
assert "Mixin" in result
def test_resolve_call_alias_simple(self):
"""Test _resolve_call_alias with simple name."""
processor = AstGrepPythonProcessor()
alias_map = {"np": "numpy"}
result = processor._resolve_call_alias("np", alias_map)
assert result == "numpy"
def test_resolve_call_alias_qualified(self):
"""Test _resolve_call_alias with qualified name."""
processor = AstGrepPythonProcessor()
alias_map = {"np": "numpy"}
result = processor._resolve_call_alias("np.array", alias_map)
assert result == "numpy.array"
def test_resolve_call_alias_no_match(self):
"""Test _resolve_call_alias when no alias exists."""
processor = AstGrepPythonProcessor()
alias_map = {}
result = processor._resolve_call_alias("myfunc", alias_map)
assert result == "myfunc"

View File

@@ -0,0 +1,402 @@
"""Tests for AstGrepPythonProcessor.
Tests pattern-based relationship extraction from Python source code
using ast-grep-py bindings.
"""
from pathlib import Path
import pytest
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
BaseAstGrepProcessor,
is_astgrep_processor_available,
)
from codexlens.parsers.patterns.python import (
PYTHON_PATTERNS,
METAVARS,
RELATIONSHIP_PATTERNS,
get_pattern,
get_patterns_for_relationship,
get_metavar,
)
# Check if ast-grep is available for conditional test skipping
ASTGREP_AVAILABLE = is_astgrep_processor_available()
class TestPatternDefinitions:
"""Tests for Python pattern definitions."""
def test_python_patterns_exist(self):
"""Verify all expected patterns are defined."""
expected_patterns = [
"class_def",
"class_with_bases",
"func_def",
"async_func_def",
"import_stmt",
"import_from",
"call",
"method_call",
]
for pattern_name in expected_patterns:
assert pattern_name in PYTHON_PATTERNS, f"Missing pattern: {pattern_name}"
def test_get_pattern_returns_correct_pattern(self):
"""Test get_pattern returns expected pattern strings."""
# Note: ast-grep-py 0.40+ uses $$$ for zero-or-more multi-match
assert get_pattern("class_def") == "class $NAME $$$BODY"
assert get_pattern("func_def") == "def $NAME($$$PARAMS): $$$BODY"
assert get_pattern("import_stmt") == "import $MODULE"
def test_get_pattern_raises_for_unknown(self):
"""Test get_pattern raises KeyError for unknown patterns."""
with pytest.raises(KeyError):
get_pattern("nonexistent_pattern")
def test_metavars_defined(self):
"""Verify metavariable mappings are defined."""
expected_metavars = [
"class_name",
"func_name",
"import_module",
"call_func",
]
for var in expected_metavars:
assert var in METAVARS, f"Missing metavar: {var}"
def test_get_metavar(self):
"""Test get_metavar returns correct values."""
assert get_metavar("class_name") == "NAME"
assert get_metavar("func_name") == "NAME"
assert get_metavar("import_module") == "MODULE"
def test_relationship_patterns_mapping(self):
"""Test relationship type to pattern mapping."""
assert "class_with_bases" in get_patterns_for_relationship("inheritance")
assert "import_stmt" in get_patterns_for_relationship("imports")
assert "import_from" in get_patterns_for_relationship("imports")
assert "call" in get_patterns_for_relationship("calls")
class TestAstGrepPythonProcessorAvailability:
"""Tests for processor availability."""
def test_is_available_returns_bool(self):
"""Test is_available returns a boolean."""
processor = AstGrepPythonProcessor()
assert isinstance(processor.is_available(), bool)
def test_is_available_matches_global_check(self):
"""Test is_available matches is_astgrep_processor_available."""
processor = AstGrepPythonProcessor()
assert processor.is_available() == is_astgrep_processor_available()
def test_module_level_check(self):
"""Test module-level availability function."""
assert isinstance(is_astgrep_processor_available(), bool)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorParsing:
"""Tests for Python parsing with ast-grep."""
def test_parse_simple_function(self):
"""Test parsing a simple function definition."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert result.language == "python"
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
assert result.symbols[0].kind == "function"
def test_parse_class(self):
"""Test parsing a class definition."""
processor = AstGrepPythonProcessor()
code = "class MyClass:\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "MyClass"
assert result.symbols[0].kind == "class"
def test_parse_async_function(self):
"""Test parsing an async function definition."""
processor = AstGrepPythonProcessor()
code = "async def fetch_data():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "fetch_data"
def test_parse_class_with_inheritance(self):
"""Test parsing class with inheritance."""
processor = AstGrepPythonProcessor()
code = """
class Base:
pass
class Child(Base):
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
names = [s.name for s in result.symbols]
assert "Base" in names
assert "Child" in names
# Check inheritance relationship
inherits = [
r for r in result.relationships
if r.relationship_type.value == "inherits"
]
assert any(r.source_symbol == "Child" for r in inherits)
def test_parse_imports(self):
"""Test parsing import statements."""
processor = AstGrepPythonProcessor()
code = """
import os
from sys import path
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
imports = [
r for r in result.relationships
if r.relationship_type.value == "imports"
]
assert len(imports) >= 1
targets = {r.target_symbol for r in imports}
assert "os" in targets
def test_parse_function_calls(self):
"""Test parsing function calls."""
processor = AstGrepPythonProcessor()
code = """
def main():
print("hello")
len([1, 2, 3])
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
assert "print" in targets
assert "len" in targets
def test_parse_empty_file(self):
"""Test parsing an empty file."""
processor = AstGrepPythonProcessor()
result = processor.parse("", Path("test.py"))
assert result is not None
assert len(result.symbols) == 0
def test_parse_returns_indexed_file(self):
"""Test that parse returns proper IndexedFile structure."""
processor = AstGrepPythonProcessor()
code = "def test():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert result.path.endswith("test.py")
assert result.language == "python"
assert isinstance(result.symbols, list)
assert isinstance(result.chunks, list)
assert isinstance(result.relationships, list)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorRelationships:
"""Tests for relationship extraction."""
def test_inheritance_extraction(self):
"""Test extraction of inheritance relationships."""
processor = AstGrepPythonProcessor()
code = """
class Animal:
pass
class Dog(Animal):
pass
class Cat(Animal):
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
inherits = [
r for r in result.relationships
if r.relationship_type.value == "inherits"
]
# Should have 2 inheritance relationships
assert len(inherits) >= 2
sources = {r.source_symbol for r in inherits}
assert "Dog" in sources
assert "Cat" in sources
def test_call_extraction_skips_self(self):
"""Test that self.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Service:
def process(self):
self.internal()
external_call()
def external_call():
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
# self.internal should be filtered
assert "self.internal" not in targets
assert "external_call" in targets
def test_import_with_alias_resolution(self):
"""Test import alias resolution in calls."""
processor = AstGrepPythonProcessor()
code = """
import os.path as osp
def main():
osp.join("a", "b")
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
# Should resolve osp to os.path
assert any("os.path" in t for t in targets)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorRunAstGrep:
"""Tests for run_ast_grep method."""
def test_run_ast_grep_returns_list(self):
"""Test run_ast_grep returns a list."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
processor._binding.parse(code) if processor._binding else None
matches = processor.run_ast_grep(code, "def $NAME($$$PARAMS) $$$BODY")
assert isinstance(matches, list)
def test_run_ast_grep_finds_matches(self):
"""Test run_ast_grep finds expected matches."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
matches = processor.run_ast_grep(code, "def $NAME($$$PARAMS) $$$BODY")
assert len(matches) >= 1
def test_run_ast_grep_empty_code(self):
"""Test run_ast_grep with empty code."""
processor = AstGrepPythonProcessor()
matches = processor.run_ast_grep("", "def $NAME($$$PARAMS) $$$BODY")
assert matches == []
def test_run_ast_grep_no_matches(self):
"""Test run_ast_grep when pattern doesn't match."""
processor = AstGrepPythonProcessor()
code = "x = 1"
matches = processor.run_ast_grep(code, "class $NAME $$$BODY")
assert matches == []
class TestAstGrepPythonProcessorFallback:
"""Tests for fallback behavior when ast-grep unavailable."""
def test_parse_returns_none_when_unavailable(self):
"""Test parse returns None when ast-grep unavailable."""
# This test runs regardless of availability
# When unavailable, should gracefully return None
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "def test():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is None
def test_run_ast_grep_empty_when_unavailable(self):
"""Test run_ast_grep returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
matches = processor.run_ast_grep("code", "pattern")
assert matches == []
class TestBaseAstGrepProcessor:
"""Tests for abstract base class."""
def test_cannot_instantiate_base_class(self):
"""Test that BaseAstGrepProcessor cannot be instantiated directly."""
with pytest.raises(TypeError):
BaseAstGrepProcessor("python") # type: ignore[abstract]
def test_subclass_implements_abstract_methods(self):
"""Test that AstGrepPythonProcessor implements all abstract methods."""
processor = AstGrepPythonProcessor()
# Should have process_matches method
assert hasattr(processor, "process_matches")
# Should have parse method
assert hasattr(processor, "parse")
# Check methods are callable
assert callable(processor.process_matches)
assert callable(processor.parse)
class TestPatternIntegration:
"""Tests for pattern module integration with processor."""
def test_processor_uses_pattern_module(self):
"""Verify processor uses patterns from pattern module."""
# The processor should import and use patterns from patterns/python/
from codexlens.parsers.astgrep_processor import get_pattern
# Verify pattern access works
assert get_pattern("class_def") is not None
assert get_pattern("func_def") is not None
def test_pattern_consistency(self):
"""Test pattern definitions are consistent."""
# Patterns used by processor should exist in pattern module
patterns_needed = [
"class_def",
"class_with_bases",
"func_def",
"async_func_def",
"import_stmt",
"import_from",
"call",
]
for pattern_name in patterns_needed:
# Should not raise KeyError
pattern = get_pattern(pattern_name)
assert pattern is not None
assert len(pattern) > 0

View File

@@ -0,0 +1,526 @@
"""Comparison tests for tree-sitter vs ast-grep Python relationship extraction.
Validates that both parsers produce consistent output for Python relationship
extraction (INHERITS, CALL, IMPORTS).
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Set, Tuple
import pytest
from codexlens.config import Config
from codexlens.entities import CodeRelationship, RelationshipType
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
# Sample Python code for testing relationship extraction
SAMPLE_PYTHON_CODE = '''
"""Module docstring."""
import os
import sys
from typing import List, Dict, Optional
from collections import defaultdict as dd
from pathlib import Path as PPath
class BaseClass:
"""Base class."""
def base_method(self):
pass
def another_method(self):
return self.base_method()
class Mixin:
"""Mixin class."""
def mixin_func(self):
return "mixin"
class ChildClass(BaseClass, Mixin):
"""Child class with multiple inheritance."""
def __init__(self):
super().__init__()
self.data = dd(list)
def process(self, items: List[str]) -> Dict[str, int]:
result = {}
for item in items:
result[item] = len(item)
return result
def call_external(self, path: str) -> Optional[str]:
p = PPath(path)
if p.exists():
return str(p.read_text())
return None
def standalone_function():
"""Standalone function."""
data = [1, 2, 3]
return sum(data)
async def async_function():
"""Async function."""
import asyncio
await asyncio.sleep(1)
'''
def relationship_to_tuple(rel: CodeRelationship) -> Tuple[str, str, str, int]:
"""Convert relationship to a comparable tuple.
Returns:
(source_symbol, target_symbol, relationship_type, source_line)
"""
return (
rel.source_symbol,
rel.target_symbol,
rel.relationship_type.value,
rel.source_line,
)
def extract_relationship_tuples(
relationships: List[CodeRelationship],
) -> Set[Tuple[str, str, str]]:
"""Extract relationship tuples without line numbers for comparison.
Returns:
Set of (source_symbol, target_symbol, relationship_type) tuples
"""
return {
(rel.source_symbol, rel.target_symbol, rel.relationship_type.value)
for rel in relationships
}
def filter_by_type(
relationships: List[CodeRelationship],
rel_type: RelationshipType,
) -> List[CodeRelationship]:
"""Filter relationships by type."""
return [r for r in relationships if r.relationship_type == rel_type]
class TestTreeSitterVsAstGrepComparison:
"""Compare tree-sitter and ast-grep Python relationship extraction."""
@pytest.fixture
def sample_path(self, tmp_path: Path) -> Path:
"""Create a temporary Python file with sample code."""
py_file = tmp_path / "sample.py"
py_file.write_text(SAMPLE_PYTHON_CODE)
return py_file
@pytest.fixture
def ts_parser_default(self) -> TreeSitterSymbolParser:
"""Create tree-sitter parser with default config (use_astgrep=False)."""
config = Config()
assert config.use_astgrep is False
return TreeSitterSymbolParser("python", config=config)
@pytest.fixture
def ts_parser_astgrep(self) -> TreeSitterSymbolParser:
"""Create tree-sitter parser with ast-grep enabled."""
config = Config()
config.use_astgrep = True
return TreeSitterSymbolParser("python", config=config)
def test_parser_availability(self, ts_parser_default: TreeSitterSymbolParser) -> None:
"""Test that tree-sitter parser is available."""
assert ts_parser_default.is_available()
def test_astgrep_processor_initialization(
self, ts_parser_astgrep: TreeSitterSymbolParser
) -> None:
"""Test that ast-grep processor is initialized when config enables it."""
# The processor should be initialized (may be None if ast-grep-py not installed)
# This test just verifies the initialization path works
assert ts_parser_astgrep._config is not None
assert ts_parser_astgrep._config.use_astgrep is True
def _skip_if_astgrep_unavailable(
self, ts_parser_astgrep: TreeSitterSymbolParser
) -> None:
"""Skip test if ast-grep is not available."""
if ts_parser_astgrep._astgrep_processor is None:
pytest.skip("ast-grep-py not installed")
def test_parse_returns_valid_result(
self,
ts_parser_default: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that parsing returns a valid IndexedFile."""
source_code = sample_path.read_text()
result = ts_parser_default.parse(source_code, sample_path)
assert result is not None
assert result.language == "python"
assert len(result.symbols) > 0
assert len(result.relationships) > 0
def test_extracted_symbols_match(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that both parsers extract similar symbols."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Compare symbol names
ts_symbols = {s.name for s in result_ts.symbols}
astgrep_symbols = {s.name for s in result_astgrep.symbols}
# Should have the same symbols (classes, functions, methods)
assert ts_symbols == astgrep_symbols
def test_inheritance_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test INHERITS relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract inheritance relationships
ts_inherits = filter_by_type(result_ts.relationships, RelationshipType.INHERITS)
astgrep_inherits = filter_by_type(
result_astgrep.relationships, RelationshipType.INHERITS
)
ts_tuples = extract_relationship_tuples(ts_inherits)
astgrep_tuples = extract_relationship_tuples(astgrep_inherits)
# Both should detect ChildClass(BaseClass, Mixin)
assert ts_tuples == astgrep_tuples
# Verify specific inheritance relationships
expected_inherits = {
("ChildClass", "BaseClass", "inherits"),
("ChildClass", "Mixin", "inherits"),
}
assert ts_tuples == expected_inherits
def test_import_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test IMPORTS relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract import relationships
ts_imports = filter_by_type(result_ts.relationships, RelationshipType.IMPORTS)
astgrep_imports = filter_by_type(
result_astgrep.relationships, RelationshipType.IMPORTS
)
ts_tuples = extract_relationship_tuples(ts_imports)
astgrep_tuples = extract_relationship_tuples(astgrep_imports)
# Compare - should be similar (may differ in exact module representation)
# At minimum, both should detect the top-level imports
ts_modules = {t[1].split(".")[0] for t in ts_tuples}
astgrep_modules = {t[1].split(".")[0] for t in astgrep_tuples}
# Should have imports from: os, sys, typing, collections, pathlib
expected_modules = {"os", "sys", "typing", "collections", "pathlib", "asyncio"}
assert ts_modules >= expected_modules or astgrep_modules >= expected_modules
def test_call_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test CALL relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract call relationships
ts_calls = filter_by_type(result_ts.relationships, RelationshipType.CALL)
astgrep_calls = filter_by_type(
result_astgrep.relationships, RelationshipType.CALL
)
# Calls may differ due to scope tracking differences
# Just verify both parsers find call relationships
assert len(ts_calls) > 0
assert len(astgrep_calls) > 0
# Verify specific calls that should be detected
ts_call_targets = {r.target_symbol for r in ts_calls}
astgrep_call_targets = {r.target_symbol for r in astgrep_calls}
# Both should detect at least some common calls
# (exact match not required due to scope tracking differences)
common_targets = ts_call_targets & astgrep_call_targets
assert len(common_targets) > 0
def test_relationship_count_similarity(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that relationship counts are similar (>95% consistency)."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
ts_count = len(result_ts.relationships)
astgrep_count = len(result_astgrep.relationships)
# Calculate consistency percentage
if max(ts_count, astgrep_count) == 0:
consistency = 100.0
else:
consistency = (
min(ts_count, astgrep_count) / max(ts_count, astgrep_count) * 100
)
# Require >95% consistency
assert consistency >= 95.0, (
f"Relationship consistency {consistency:.1f}% below 95% threshold "
f"(tree-sitter: {ts_count}, ast-grep: {astgrep_count})"
)
def test_config_switch_affects_parser(
self, sample_path: Path
) -> None:
"""Test that config.use_astgrep affects which parser is used."""
config_default = Config()
config_astgrep = Config()
config_astgrep.use_astgrep = True
parser_default = TreeSitterSymbolParser("python", config=config_default)
parser_astgrep = TreeSitterSymbolParser("python", config=config_astgrep)
# Default parser should not have ast-grep processor
assert parser_default._astgrep_processor is None
# Ast-grep parser may have processor if ast-grep-py is installed
# (could be None if not installed, which is fine)
if parser_astgrep._astgrep_processor is not None:
# If available, verify it's the right type
from codexlens.parsers.astgrep_processor import AstGrepPythonProcessor
assert isinstance(
parser_astgrep._astgrep_processor, AstGrepPythonProcessor
)
def test_fallback_to_treesitter_on_astgrep_failure(
self,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that parser falls back to tree-sitter if ast-grep fails."""
source_code = sample_path.read_text()
# Even with use_astgrep=True, should get valid results
result = ts_parser_astgrep.parse(source_code, sample_path)
# Should always return a valid result (either from ast-grep or tree-sitter fallback)
assert result is not None
assert result.language == "python"
assert len(result.relationships) > 0
class TestSimpleCodeSamples:
"""Test with simple code samples for precise comparison."""
def test_simple_inheritance(self) -> None:
"""Test simple single inheritance."""
code = """
class Parent:
pass
class Child(Parent):
pass
"""
self._compare_parsers(code, expected_inherits={("Child", "Parent")})
def test_multiple_inheritance(self) -> None:
"""Test multiple inheritance."""
code = """
class A:
pass
class B:
pass
class C(A, B):
pass
"""
self._compare_parsers(
code, expected_inherits={("C", "A"), ("C", "B")}
)
def test_simple_imports(self) -> None:
"""Test simple import statements."""
code = """
import os
import sys
"""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
result_ag = parser_ag.parse(code, tmp_path)
assert result_ts is not None
# ast-grep result may be None if not installed
if result_ag is not None:
ts_imports = {
r.target_symbol
for r in result_ts.relationships
if r.relationship_type == RelationshipType.IMPORTS
}
ag_imports = {
r.target_symbol
for r in result_ag.relationships
if r.relationship_type == RelationshipType.IMPORTS
}
assert ts_imports == ag_imports
def test_imports_inside_function(self) -> None:
"""Test simple import inside a function scope is recorded.
Note: tree-sitter parser requires a scope to record imports.
Module-level imports without any function/class are not recorded
because scope_stack is empty at module level.
"""
code = """
def my_function():
import collections
return collections
"""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
result_ag = parser_ag.parse(code, tmp_path)
assert result_ts is not None
# Get import relationship targets
ts_imports = [
r.target_symbol
for r in result_ts.relationships
if r.relationship_type == RelationshipType.IMPORTS
]
# Should have collections
ts_has_collections = any("collections" in t for t in ts_imports)
assert ts_has_collections, f"Expected collections import, got: {ts_imports}"
# If ast-grep is available, verify it also finds the imports
if result_ag is not None:
ag_imports = [
r.target_symbol
for r in result_ag.relationships
if r.relationship_type == RelationshipType.IMPORTS
]
ag_has_collections = any("collections" in t for t in ag_imports)
assert ag_has_collections, f"Expected collections import in ast-grep, got: {ag_imports}"
def _compare_parsers(
self,
code: str,
expected_inherits: Set[Tuple[str, str]],
) -> None:
"""Helper to compare parser outputs for inheritance."""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
assert result_ts is not None
# Verify tree-sitter finds expected inheritance
ts_inherits = {
(r.source_symbol, r.target_symbol)
for r in result_ts.relationships
if r.relationship_type == RelationshipType.INHERITS
}
assert ts_inherits == expected_inherits
# If ast-grep is available, verify it matches
result_ag = parser_ag.parse(code, tmp_path)
if result_ag is not None:
ag_inherits = {
(r.source_symbol, r.target_symbol)
for r in result_ag.relationships
if r.relationship_type == RelationshipType.INHERITS
}
assert ag_inherits == expected_inherits
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,191 @@
"""Tests for ast-grep binding module.
Verifies basic import and functionality of AstGrepBinding.
Run with: python -m pytest tests/test_astgrep_binding.py -v
"""
from __future__ import annotations
import pytest
from pathlib import Path
class TestAstGrepBindingAvailability:
"""Test availability checks."""
def test_is_astgrep_available_function(self):
"""Test is_astgrep_available function returns boolean."""
from codexlens.parsers.astgrep_binding import is_astgrep_available
result = is_astgrep_available()
assert isinstance(result, bool)
def test_get_supported_languages(self):
"""Test get_supported_languages returns expected languages."""
from codexlens.parsers.astgrep_binding import get_supported_languages
languages = get_supported_languages()
assert isinstance(languages, list)
assert "python" in languages
assert "javascript" in languages
assert "typescript" in languages
class TestAstGrepBindingInit:
"""Test AstGrepBinding initialization."""
def test_init_python(self):
"""Test initialization with Python language."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
assert binding.language_id == "python"
def test_init_typescript_with_tsx(self):
"""Test TSX detection from file extension."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("typescript", Path("component.tsx"))
assert binding.language_id == "typescript"
def test_is_available_returns_boolean(self):
"""Test is_available returns boolean."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
result = binding.is_available()
assert isinstance(result, bool)
def _is_astgrep_installed():
"""Check if ast-grep-py is installed."""
try:
import ast_grep_py # noqa: F401
return True
except ImportError:
return False
@pytest.mark.skipif(
not _is_astgrep_installed(),
reason="ast-grep-py not installed"
)
class TestAstGrepBindingWithAstGrep:
"""Tests that require ast-grep-py to be installed."""
def test_parse_simple_python(self):
"""Test parsing simple Python code."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = "x = 1"
result = binding.parse(source)
assert result is True
def test_find_inheritance(self):
"""Test finding class inheritance."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
class MyClass(BaseClass):
pass
"""
binding.parse(source)
results = binding.find_inheritance()
assert len(results) >= 0 # May or may not find depending on pattern match
def test_find_calls(self):
"""Test finding function calls."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
def foo():
bar()
baz.qux()
"""
binding.parse(source)
results = binding.find_calls()
assert isinstance(results, list)
def test_find_imports(self):
"""Test finding import statements."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
import os
from typing import List
"""
binding.parse(source)
results = binding.find_imports()
assert isinstance(results, list)
def test_basic_import():
"""Test that the module can be imported."""
try:
from codexlens.parsers.astgrep_binding import (
AstGrepBinding,
is_astgrep_available,
get_supported_languages,
ASTGREP_AVAILABLE,
)
assert True
except ImportError as e:
pytest.fail(f"Failed to import astgrep_binding: {e}")
def test_availability_flag():
"""Test ASTGREP_AVAILABLE flag is defined."""
from codexlens.parsers.astgrep_binding import ASTGREP_AVAILABLE
assert isinstance(ASTGREP_AVAILABLE, bool)
if __name__ == "__main__":
# Run basic verification
print("Testing astgrep_binding module...")
from codexlens.parsers.astgrep_binding import (
AstGrepBinding,
is_astgrep_available,
get_supported_languages,
)
print(f"ast-grep available: {is_astgrep_available()}")
print(f"Supported languages: {get_supported_languages()}")
binding = AstGrepBinding("python")
print(f"Python binding available: {binding.is_available()}")
if binding.is_available():
test_code = """
import os
from typing import List
class MyClass(BaseClass):
def method(self):
self.helper()
external_func()
def helper():
pass
"""
binding.parse(test_code)
print(f"Inheritance found: {binding.find_inheritance()}")
print(f"Calls found: {binding.find_calls()}")
print(f"Imports found: {binding.find_imports()}")
else:
print("Note: ast-grep-py not installed. To install:")
print(" pip install ast-grep-py")
print(" Note: May have compatibility issues with Python 3.13")
print("Basic verification complete!")