diff --git a/.gitignore b/.gitignore index fcfac327..d9803c3a 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ logs/ .DS_Store .codebuddy .vscode +uv.lock node_modules runtime_config.json .idea @@ -51,4 +52,3 @@ src/stock_datasource/modules/auth/auth.db # Frontend build info (regenerated) frontend/tsconfig.tsbuildinfo -.deploy-backup/ diff --git a/data/email.txt b/data/email.txt index ae65c081..6e27c77c 100644 --- a/data/email.txt +++ b/data/email.txt @@ -1 +1,3 @@ 1232@qq.com +feng.yangle@163.com +752282360@qq.com diff --git a/docker-compose.weknora.yml b/docker-compose.weknora.yml index 2ef2cf6a..0fa91ebc 100644 --- a/docker-compose.weknora.yml +++ b/docker-compose.weknora.yml @@ -33,6 +33,7 @@ services: # Redis REDIS_HOST: weknora-redis REDIS_PORT: 6379 + REDIS_ADDR: weknora-redis:6379 REDIS_PASSWORD: ${WEKNORA_REDIS_PASSWORD:-} STREAM_MANAGER_TYPE: redis # Vector store @@ -72,8 +73,9 @@ services: ports: - "${WEKNORA_FRONTEND_PORT:-18881}:80" environment: - BACKEND_HOST: weknora-app - BACKEND_PORT: 8080 + APP_HOST: weknora-app + APP_PORT: 8080 + APP_SCHEME: http networks: - weknora-network depends_on: diff --git a/frontend/src/api/portfolio.ts b/frontend/src/api/portfolio.ts index ea716b19..71acd9e8 100644 --- a/frontend/src/api/portfolio.ts +++ b/frontend/src/api/portfolio.ts @@ -1,11 +1,15 @@ import request from '@/utils/request' -import type { - Position, - PortfolioSummary, +import type { + Position, + PortfolioSummary, CreatePositionRequest, UpdatePositionRequest, AnalysisReport, - AlertCreateRequest + AlertCreateRequest, + Transaction, + CreateTransactionRequest, + TransactionSignal, + KlinePattern } from '@/types/portfolio' export const portfolioApi = { @@ -68,5 +72,28 @@ export const portfolioApi = { // Batch operations batchUpdatePrices() { return request.post('/api/portfolio/batch/update-prices') + }, + + // Transactions + buyTransaction(data: CreateTransactionRequest) { + return request.post('/api/portfolio/transactions/buy', data) + }, + + sellTransaction(data: CreateTransactionRequest) { + return request.post('/api/portfolio/transactions/sell', data) + }, + + getTransactions(params?: { ts_code?: string; start_date?: string; end_date?: string; profile_id?: string }) { + return request.get('/api/portfolio/transactions', { params }) + }, + + // Transaction signals for K-line markers + getTransactionSignals(params: { ts_code: string; start_date?: string; end_date?: string }) { + return request.get('/api/portfolio/transactions/signals', { params }) + }, + + // K-line candlestick patterns + getKlinePatterns(tsCode: string, days: number = 60) { + return request.get(`/api/portfolio/kline-patterns/${tsCode}`, { params: { days } }) } } \ No newline at end of file diff --git a/frontend/src/components/StockDetailDialog.vue b/frontend/src/components/StockDetailDialog.vue index fe21cfcc..69b050bb 100644 --- a/frontend/src/components/StockDetailDialog.vue +++ b/frontend/src/components/StockDetailDialog.vue @@ -5,12 +5,14 @@ import { marketApi } from '@/api/market' import type { KLineResponse } from '@/api/market' import { usePortfolioStore } from '@/stores/portfolio' import { useScreenerStore } from '@/stores/screener' +import { portfolioApi } from '@/api/portfolio' import KLineChart from '@/components/charts/KLineChart.vue' import ChipDistributionChart from '@/components/charts/ChipDistributionChart.vue' import IndicatorPanel from '@/views/market/components/IndicatorPanel.vue' import TrendAnalysis from '@/views/market/components/TrendAnalysis.vue' import type { KLineData, TechnicalSignal, ChipData, ChipStats } from '@/types/common' import type { StockProfile } from '@/api/screener' +import type { KlinePattern } from '@/types/portfolio' interface Props { visible: boolean @@ -32,6 +34,7 @@ const klineData = ref([]) const indicators = ref>({}) const indicatorDates = ref([]) const signals = ref([]) +const transactionSignals = ref([]) const trendAnalysis = ref(null) const loading = ref(false) const chartLoading = ref(false) @@ -61,11 +64,16 @@ const adjustType = ref<'qfq' | 'hfq' | 'none'>('qfq') const selectedIndicators = ref(['MA', 'MACD', 'RSI', 'KDJ']) const chartMode = ref<'daily' | 'minute'>('daily') // 日K / 分钟K切换 -// Add to watchlist form -const addToWatchlistForm = ref({ +// K-line pattern recognition +const klinePatterns = ref([]) +const klinePatternsLoading = ref(false) + +// Transaction form +const transactionType = ref<'buy' | 'sell'>('buy') +const transactionForm = ref({ quantity: 100, - cost_price: 0, - buy_date: new Date().toISOString().split('T')[0], + price: 0, + transaction_date: new Date().toISOString().split('T')[0], notes: '' }) @@ -176,12 +184,18 @@ const fetchStockData = async () => { await triggerBackfill() } - // Set default cost price to latest close price - addToWatchlistForm.value.cost_price = latestPrice.value + // Set default transaction price to latest close price + transactionForm.value.price = latestPrice.value // Fetch indicators await fetchIndicators() + // Fetch transaction signals for B/S markers + fetchTransactionSignals() + + // Fetch K-line patterns + fetchKlinePatterns() + } catch (error) { console.error('Failed to fetch stock data:', error) MessagePlugin.error('获取股票数据失败') @@ -208,6 +222,58 @@ const fetchIndicators = async () => { } } +// Fetch transaction signals (user buy/sell + strategy signals) for B/S markers +const fetchTransactionSignals = async () => { + if (!props.stockCode) return + + try { + const response = await portfolioApi.getTransactionSignals({ + ts_code: props.stockCode + }) + transactionSignals.value = Array.isArray(response) ? response : [] + } catch (error) { + transactionSignals.value = [] + } +} + +// Fetch K-line candlestick patterns +const fetchKlinePatterns = async () => { + if (!props.stockCode) return + + klinePatternsLoading.value = true + try { + const response = await portfolioApi.getKlinePatterns(props.stockCode, period.value) + klinePatterns.value = Array.isArray(response) ? response : [] + } catch (error) { + console.error('Failed to fetch kline patterns:', error) + klinePatterns.value = [] + } finally { + klinePatternsLoading.value = false + } +} + +// Group patterns by date for display +const patternsByDate = computed(() => { + const grouped: Record = {} + for (const p of klinePatterns.value) { + if (!grouped[p.date]) grouped[p.date] = [] + grouped[p.date].push(p) + } + // Sort dates descending (most recent first) + return Object.entries(grouped) + .sort(([a], [b]) => b.localeCompare(a)) + .map(([date, patterns]) => ({ date, patterns })) +}) + +// Combined signals for K-line chart: merge technical + transaction signals +const combinedSignals = computed(() => { + const technical = signals.value.map(s => ({ + ...s, + source: 'strategy' as const + })) + return [...technical, ...transactionSignals.value] +}) + // Fetch realtime minute K-line data const fetchMinuteKline = async () => { if (!props.stockCode) return @@ -387,23 +453,29 @@ const handleAIAnalyze = () => { } } -const handleAddToWatchlist = async () => { +const handleTransaction = async () => { if (!stockInfo.value) return try { - await portfolioStore.addPosition({ + const data = { ts_code: stockInfo.value.code, - quantity: addToWatchlistForm.value.quantity, - cost_price: addToWatchlistForm.value.cost_price, - buy_date: addToWatchlistForm.value.buy_date, - notes: addToWatchlistForm.value.notes || `从智能选股添加 - ${stockInfo.value.name}` - }) + quantity: transactionForm.value.quantity, + price: transactionForm.value.price, + transaction_date: transactionForm.value.transaction_date, + notes: transactionForm.value.notes || `${transactionType.value === 'buy' ? '买入' : '卖出'} - ${stockInfo.value.name}` + } - MessagePlugin.success(`已将 ${stockInfo.value.name} 添加到自选股`) - emit('close') - } catch (error) { - console.error('Failed to add to watchlist:', error) - MessagePlugin.error('添加自选股失败') + if (transactionType.value === 'buy') { + await portfolioStore.buyTransaction(data) + } else { + await portfolioStore.sellTransaction(data) + } + + MessagePlugin.success(`${transactionType.value === 'buy' ? '买入' : '卖出'} ${stockInfo.value.name} 记录成功`) + } catch (error: any) { + console.error('Failed to record transaction:', error) + const msg = error?.response?.data?.detail || '交易记录失败' + MessagePlugin.error(msg) } } @@ -496,6 +568,8 @@ watch(() => props.stockCode, (newCode) => { indicators.value = {} indicatorDates.value = [] signals.value = [] + transactionSignals.value = [] + klinePatterns.value = [] trendAnalysis.value = null chipData.value = [] chipStats.value = null @@ -701,27 +775,38 @@ watch(() => props.visible, (visible) => { :data="chartMode === 'minute' ? minuteKlineData : klineData" :indicators="chartMode === 'minute' ? {} : indicators" :indicator-dates="chartMode === 'minute' ? [] : indicatorDates" - :signals="chartMode === 'minute' ? [] : signals" + :signals="chartMode === 'minute' ? [] : combinedSignals" :loading="chartMode === 'minute' ? minuteKlineLoading : chartLoading" height="100%" /> - - - + + + + + + + 买入 + + + 卖出 + + + + - + props.visible, (visible) => { /> - + @@ -746,12 +831,12 @@ watch(() => props.visible, (visible) => { - - 加入自选 + + {{ transactionType === 'buy' ? '买入' : '卖出' }} @@ -781,6 +866,38 @@ watch(() => props.visible, (visible) => { class="trend-analysis-compact" /> + + + + + + +
+
+
{{ group.date }}
+
+ + {{ pattern.name }} + +
+
+
+ +
@@ -1124,6 +1241,58 @@ watch(() => props.visible, (visible) => { min-height: 0; overflow-y: auto; } + +/* Pattern Card */ +.pattern-card { + background: #fafafa; + flex-shrink: 0; + max-height: 240px; + display: flex; + flex-direction: column; +} + +.pattern-card :deep(.t-card__body) { + flex: 1; + overflow-y: auto; + padding: 8px 12px; +} + +.pattern-card-header { + display: flex; + align-items: center; + justify-content: space-between; + width: 100%; +} + +.pattern-list { + display: flex; + flex-direction: column; + gap: 8px; +} + +.pattern-group { + display: flex; + align-items: flex-start; + gap: 8px; +} + +.pattern-date { + font-size: 11px; + color: #999; + white-space: nowrap; + min-width: 70px; + padding-top: 2px; +} + +.pattern-items { + display: flex; + flex-wrap: wrap; + gap: 4px; +} + +.pattern-tag { + font-size: 11px; +} diff --git a/frontend/src/components/charts/KLineChart.vue b/frontend/src/components/charts/KLineChart.vue index 6e6dd72b..45ec6c10 100644 --- a/frontend/src/components/charts/KLineChart.vue +++ b/frontend/src/components/charts/KLineChart.vue @@ -118,22 +118,47 @@ const buildSignalMarkPoint = (signals: any[] | undefined, dates: string[]) => { const sellPoints: any[] = [] for (const signal of signals) { - const dateStr = signal.date || signal.trade_date - const dateIndex = dates.indexOf(dateStr) + const dateStr = signal.date || signal.trade_date || signal.signal_date + // Try exact match, then YYYY-MM-DD format match + let dateIndex = dates.indexOf(dateStr) + if (dateIndex < 0 && dateStr && dateStr.length >= 10) { + dateIndex = dates.indexOf(dateStr.substring(0, 10)) + } if (dateIndex < 0) continue const signalType = (signal.signal_type || signal.type || '').toLowerCase() + const source = signal.source || 'strategy' + const isUser = source === 'user' + if (signalType.includes('buy') || signalType === 'b' || signalType === 'golden_cross' || signalType === 'oversold') { buyPoints.push({ - coord: [dateStr, signal.price || signal.close], - value: 'B', - itemStyle: { color: '#e74c3c' } + coord: [dates[dateIndex], signal.price || signal.close], + value: isUser ? '买' : 'B', + itemStyle: { color: isUser ? '#f5222d' : '#e74c3c' }, + symbol: isUser ? 'triangle' : 'pin', + symbolSize: isUser ? 20 : 30, + label: { + show: true, + fontSize: isUser ? 9 : 10, + fontWeight: 'bold', + color: '#fff', + formatter: (params: any) => params.value + } }) } else if (signalType.includes('sell') || signalType === 's' || signalType === 'death_cross' || signalType === 'overbought') { sellPoints.push({ - coord: [dateStr, signal.price || signal.close], - value: 'S', - itemStyle: { color: '#2ecc71' } + coord: [dates[dateIndex], signal.price || signal.close], + value: isUser ? '卖' : 'S', + itemStyle: { color: isUser ? '#52c41a' : '#2ecc71' }, + symbol: isUser ? 'triangle' : 'pin', + symbolSize: isUser ? 20 : 30, + label: { + show: true, + fontSize: isUser ? 9 : 10, + fontWeight: 'bold', + color: '#fff', + formatter: (params: any) => params.value + } }) } } diff --git a/frontend/src/stores/portfolio.ts b/frontend/src/stores/portfolio.ts index 22515b1c..9995e250 100644 --- a/frontend/src/stores/portfolio.ts +++ b/frontend/src/stores/portfolio.ts @@ -2,13 +2,14 @@ import { defineStore } from 'pinia' import { ref } from 'vue' import { portfolioApi } from '@/api/portfolio' import { profileApi } from '@/api/profile' -import type { Position, PortfolioSummary, AnalysisReport, CreatePositionRequest } from '@/types/portfolio' +import type { Position, PortfolioSummary, AnalysisReport, CreatePositionRequest, Transaction, CreateTransactionRequest } from '@/types/portfolio' import type { BrokerProfile } from '@/api/profile' export const usePortfolioStore = defineStore('portfolio', () => { const positions = ref([]) const summary = ref(null) const analysis = ref(null) + const transactions = ref([]) const loading = ref(false) // Profile state @@ -134,10 +135,53 @@ export const usePortfolioStore = defineStore('portfolio', () => { await fetchProfiles() } + // ---- Transaction actions ---- + const buyTransaction = async (data: CreateTransactionRequest) => { + loading.value = true + try { + const payload = { ...data } + if (activeProfileId.value && !payload.profile_id) { + payload.profile_id = activeProfileId.value + } + await portfolioApi.buyTransaction(payload) + await Promise.all([fetchPositions(), fetchSummary(), fetchTransactions()]) + } finally { + loading.value = false + } + } + + const sellTransaction = async (data: CreateTransactionRequest) => { + loading.value = true + try { + const payload = { ...data } + if (activeProfileId.value && !payload.profile_id) { + payload.profile_id = activeProfileId.value + } + await portfolioApi.sellTransaction(payload) + await Promise.all([fetchPositions(), fetchSummary(), fetchTransactions()]) + } finally { + loading.value = false + } + } + + const fetchTransactions = async (params?: { ts_code?: string; start_date?: string; end_date?: string; profile_id?: string }) => { + try { + const queryParams: Record = { ...params } as Record + if (activeProfileId.value) { + queryParams.profile_id = activeProfileId.value + } + const response = await portfolioApi.getTransactions(queryParams) + transactions.value = Array.isArray(response) ? response : [] + } catch (e) { + transactions.value = [] + } + } + return { positions, summary, analysis, + transactions, loading, profiles, activeProfileId, @@ -151,6 +195,9 @@ export const usePortfolioStore = defineStore('portfolio', () => { setActiveProfile, createProfile, updateProfile, - deleteProfile + deleteProfile, + buyTransaction, + sellTransaction, + fetchTransactions } }) diff --git a/frontend/src/types/portfolio.ts b/frontend/src/types/portfolio.ts index 8358e686..2214a02a 100644 --- a/frontend/src/types/portfolio.ts +++ b/frontend/src/types/portfolio.ts @@ -96,4 +96,49 @@ export interface ProfitHistoryItem { total_value: number total_cost: number total_profit: number +} + +export interface Transaction { + id: string + user_id: string + ts_code: string + stock_name: string + transaction_type: 'buy' | 'sell' + quantity: number + price: number + transaction_date: string + position_id: string + realized_pl: number | null + notes: string + profile_id: string + created_at: string | null +} + +export interface CreateTransactionRequest { + ts_code: string + quantity: number + price: number + transaction_date: string + notes?: string + profile_id?: string +} + +export interface TransactionSignal { + id: string + ts_code: string + signal_type: 'buy' | 'sell' + source: 'user' | 'strategy' + signal_date: string + price: number + quantity?: number + strategy_name?: string + notes?: string +} + +export interface KlinePattern { + name: string + name_en: string + date: string + type: 'bullish' | 'bearish' | 'neutral' + category: 'single' | 'dual' | 'triple' } \ No newline at end of file diff --git a/frontend/src/views/portfolio/PortfolioView.vue b/frontend/src/views/portfolio/PortfolioView.vue index f86a9851..af5ae4c9 100644 --- a/frontend/src/views/portfolio/PortfolioView.vue +++ b/frontend/src/views/portfolio/PortfolioView.vue @@ -7,6 +7,7 @@ import request from '@/utils/request' import DataEmptyGuide from '@/components/DataEmptyGuide.vue' import StockDetailDialog from '@/components/StockDetailDialog.vue' import type { BrokerProfile } from '@/api/profile' +import type { Transaction } from '@/types/portfolio' import { isTradingTime } from '@/composables/useRealtimePolling' const portfolioStore = usePortfolioStore() @@ -287,6 +288,17 @@ const positionColumns = [ { colKey: 'operation', title: '操作', width: 80 } ] +const transactionColumns = [ + { colKey: 'transaction_date', title: '日期', width: 100 }, + { colKey: 'ts_code', title: '代码', width: 100 }, + { colKey: 'stock_name', title: '名称', width: 90 }, + { colKey: 'transaction_type', title: '类型', width: 70 }, + { colKey: 'quantity', title: '数量', width: 70 }, + { colKey: 'price', title: '价格', width: 85 }, + { colKey: 'realized_pl', title: '已实现盈亏', width: 100 }, + { colKey: 'notes', title: '备注', width: 150 } +] + const watchlistColumns = [ { colKey: 'ts_code', title: '代码', width: 100 }, { colKey: 'stock_name', title: '名称', width: 100 }, @@ -354,6 +366,7 @@ const refreshData = () => { portfolioStore.fetchPositions() portfolioStore.fetchSummary() portfolioStore.fetchAnalysis() + portfolioStore.fetchTransactions() } onMounted(() => { @@ -561,7 +574,34 @@ onUnmounted(() => { - + + + + + + + + + + + diff --git a/src/stock_datasource/cli/server_manager.py b/src/stock_datasource/cli/server_manager.py index 9bacc359..511a88cb 100644 --- a/src/stock_datasource/cli/server_manager.py +++ b/src/stock_datasource/cli/server_manager.py @@ -14,8 +14,10 @@ import os import signal import subprocess +import sys import time from pathlib import Path +from typing import Dict, List, Optional import click @@ -67,13 +69,12 @@ # Helpers # --------------------------------------------------------------------------- - def _ensure_dirs(): _LOG_DIR.mkdir(parents=True, exist_ok=True) _PID_DIR.mkdir(parents=True, exist_ok=True) -def _read_pid(name: str) -> int | None: +def _read_pid(name: str) -> Optional[int]: pid_file = _PID_DIR / _SERVICES[name]["pid"] if not pid_file.exists(): return None @@ -137,7 +138,6 @@ def _wait_for_health(service_name: str, pid: int) -> bool: if health_url: import requests - for i in range(timeout): time.sleep(1) if not _is_running(pid): @@ -152,7 +152,6 @@ def _wait_for_health(service_name: str, pid: int) -> bool: else: # For frontend, check if port is listening import socket - port = svc["port"] for i in range(timeout): time.sleep(1) @@ -171,7 +170,7 @@ def _wait_for_health(service_name: str, pid: int) -> bool: return False -def _get_service_status(name: str) -> dict: +def _get_service_status(name: str) -> Dict: """Get status info for a single service.""" pid = _read_pid(name) svc = _SERVICES[name] @@ -191,7 +190,6 @@ def _get_service_status(name: str) -> dict: # Docker compose helpers # --------------------------------------------------------------------------- - def _has_docker() -> bool: """Check if docker and docker compose are available.""" try: @@ -206,18 +204,16 @@ def _has_docker() -> bool: return False -def _docker_compose_cmd() -> list[str]: +def _docker_compose_cmd() -> List[str]: """Return the appropriate docker compose command.""" try: - subprocess.run( - ["docker", "compose", "version"], capture_output=True, timeout=5, check=True - ) + subprocess.run(["docker", "compose", "version"], capture_output=True, timeout=5, check=True) return ["docker", "compose"] except Exception: return ["docker-compose"] -def _docker_compose_files(include_infra: bool = False) -> list[str]: +def _docker_compose_files(include_infra: bool = False) -> List[str]: """Build the list of -f flags for docker compose.""" files = ["-f", str(_PROJECT_ROOT / "docker-compose.yml")] if include_infra: @@ -231,7 +227,6 @@ def _docker_compose_files(include_infra: bool = False) -> list[str]: # Click commands # --------------------------------------------------------------------------- - @click.group("server") def server(): """Manage application services (start/stop/restart/status). @@ -244,27 +239,12 @@ def server(): @server.command("start") -@click.option( - "--service", - "-s", - multiple=True, - type=click.Choice(ALL_SERVICE_NAMES + ["all"]), - default=["all"], - help="Service(s) to start (default: all)", -) -@click.option( - "--docker", - "use_docker", - is_flag=True, - default=False, - help="Use Docker Compose instead of local processes", -) -@click.option( - "--with-infra", - is_flag=True, - default=False, - help="(Docker mode) Also start infrastructure (ClickHouse, Redis)", -) +@click.option("--service", "-s", multiple=True, type=click.Choice(ALL_SERVICE_NAMES + ["all"]), + default=["all"], help="Service(s) to start (default: all)") +@click.option("--docker", "use_docker", is_flag=True, default=False, + help="Use Docker Compose instead of local processes") +@click.option("--with-infra", is_flag=True, default=False, + help="(Docker mode) Also start infrastructure (ClickHouse, Redis)") def start(service, use_docker, with_infra): """Start application services. @@ -328,16 +308,10 @@ def start(service, use_docker, with_infra): # Wait for health if _wait_for_health(name, proc.pid): actual_port = svc.get("_actual_port", svc["port"]) - click.secho( - f" ✓ {svc['label']} started (http://localhost:{actual_port})", - fg="green", - ) + click.secho(f" ✓ {svc['label']} started (http://localhost:{actual_port})", fg="green") else: if _is_running(proc.pid): - click.secho( - f" ⚠ {svc['label']} started but health check timed out", - fg="yellow", - ) + click.secho(f" ⚠ {svc['label']} started but health check timed out", fg="yellow") click.echo(f" Check logs: tail -f {log_file}") else: click.secho(f" ✗ {svc['label']} process exited unexpectedly", fg="red") @@ -349,21 +323,10 @@ def start(service, use_docker, with_infra): @server.command("stop") -@click.option( - "--service", - "-s", - multiple=True, - type=click.Choice(ALL_SERVICE_NAMES + ["all"]), - default=["all"], - help="Service(s) to stop (default: all)", -) -@click.option( - "--docker", - "use_docker", - is_flag=True, - default=False, - help="Use Docker Compose instead of local processes", -) +@click.option("--service", "-s", multiple=True, type=click.Choice(ALL_SERVICE_NAMES + ["all"]), + default=["all"], help="Service(s) to stop (default: all)") +@click.option("--docker", "use_docker", is_flag=True, default=False, + help="Use Docker Compose instead of local processes") def stop(service, use_docker): """Stop application services. @@ -406,27 +369,12 @@ def stop(service, use_docker): @server.command("restart") -@click.option( - "--service", - "-s", - multiple=True, - type=click.Choice(ALL_SERVICE_NAMES + ["all"]), - default=["all"], - help="Service(s) to restart (default: all)", -) -@click.option( - "--docker", - "use_docker", - is_flag=True, - default=False, - help="Use Docker Compose instead of local processes", -) -@click.option( - "--with-infra", - is_flag=True, - default=False, - help="(Docker mode) Also restart infrastructure", -) +@click.option("--service", "-s", multiple=True, type=click.Choice(ALL_SERVICE_NAMES + ["all"]), + default=["all"], help="Service(s) to restart (default: all)") +@click.option("--docker", "use_docker", is_flag=True, default=False, + help="Use Docker Compose instead of local processes") +@click.option("--with-infra", is_flag=True, default=False, + help="(Docker mode) Also restart infrastructure") def restart(service, use_docker, with_infra): """Restart application services (stop then start). @@ -446,13 +394,8 @@ def restart(service, use_docker, with_infra): @server.command("status") -@click.option( - "--docker", - "use_docker", - is_flag=True, - default=False, - help="Show Docker Compose status instead", -) +@click.option("--docker", "use_docker", is_flag=True, default=False, + help="Show Docker Compose status instead") def status(use_docker): """Show service status. @@ -467,13 +410,9 @@ def status(use_docker): return click.echo("") - click.secho( - "╔══════════════════════════════════════════════════╗", fg="bright_blue" - ) + click.secho("╔══════════════════════════════════════════════════╗", fg="bright_blue") click.secho("║ Service Status ║", fg="bright_blue") - click.secho( - "╚══════════════════════════════════════════════════╝", fg="bright_blue" - ) + click.secho("╚══════════════════════════════════════════════════╝", fg="bright_blue") click.echo("") _print_status_summary() @@ -485,9 +424,7 @@ def status(use_docker): try: result = subprocess.run( ["docker", "inspect", "-f", "{{.State.Status}}", container], - capture_output=True, - text=True, - timeout=5, + capture_output=True, text=True, timeout=5, ) state = result.stdout.strip() if result.returncode == 0 else "not found" if state == "running": @@ -522,18 +459,11 @@ def _print_status_summary(): # Docker helpers # --------------------------------------------------------------------------- - def _docker_start(with_infra: bool): click.echo("") - click.secho( - " Starting services via Docker Compose...", fg="bright_blue", bold=True - ) + click.secho(" Starting services via Docker Compose...", fg="bright_blue", bold=True) click.echo("") - cmd = ( - _docker_compose_cmd() - + _docker_compose_files(include_infra=with_infra) - + ["up", "-d"] - ) + cmd = _docker_compose_cmd() + _docker_compose_files(include_infra=with_infra) + ["up", "-d"] click.echo(f" $ {' '.join(cmd)}") click.echo("") result = subprocess.run(cmd, cwd=str(_PROJECT_ROOT)) @@ -548,11 +478,7 @@ def _docker_stop(): click.echo("") click.secho(" Stopping Docker Compose services...", fg="bright_blue", bold=True) click.echo("") - cmd = _docker_compose_cmd() + [ - "-f", - str(_PROJECT_ROOT / "docker-compose.yml"), - "down", - ] + cmd = _docker_compose_cmd() + ["-f", str(_PROJECT_ROOT / "docker-compose.yml"), "down"] click.echo(f" $ {' '.join(cmd)}") result = subprocess.run(cmd, cwd=str(_PROJECT_ROOT)) if result.returncode == 0: @@ -566,10 +492,6 @@ def _docker_status(): click.echo("") click.secho(" Docker Compose Status:", fg="bright_blue", bold=True) click.echo("") - cmd = _docker_compose_cmd() + [ - "-f", - str(_PROJECT_ROOT / "docker-compose.yml"), - "ps", - ] + cmd = _docker_compose_cmd() + ["-f", str(_PROJECT_ROOT / "docker-compose.yml"), "ps"] subprocess.run(cmd, cwd=str(_PROJECT_ROOT)) click.echo("") diff --git a/src/stock_datasource/models/database.py b/src/stock_datasource/models/database.py index d36fae9b..b315a1b5 100644 --- a/src/stock_datasource/models/database.py +++ b/src/stock_datasource/models/database.py @@ -1,12 +1,11 @@ """Database connection and operations for ClickHouse.""" -import io import logging -import re import threading -from datetime import date, datetime -from typing import Any - +import io +import re +from typing import List, Optional, Dict, Any +from datetime import datetime, date import pandas as pd from clickhouse_driver import Client from tenacity import retry, stop_after_attempt, wait_exponential @@ -45,18 +44,11 @@ def _to_clickhouse_literal(value: Any) -> str: class ClickHouseHttpClient: """ClickHouse client using HTTP interface as fallback when TCP fails.""" - - def __init__( - self, - host: str, - port: int = 8123, - user: str = "default", - password: str = "", - database: str = "default", - name: str = "http", - ): + + def __init__(self, host: str, port: int = 8123, user: str = "default", + password: str = "", database: str = "default", name: str = "http"): """Initialize HTTP client. - + Args: host: ClickHouse host port: HTTP port (default: 8123) @@ -75,18 +67,17 @@ def __init__( self._base_url = f"http://{host}:{port}/" self._auth = (user, password) if password else None # Registry of table schemas for auto-create on UNKNOWN_TABLE - self._table_schemas: dict[str, dict[str, Any]] = {} + self._table_schemas: Dict[str, Dict[str, Any]] = {} # Use a Session with trust_env=False to completely bypass OS-level proxy # env vars (HTTP_PROXY/HTTPS_PROXY) that may be set by plugin discovery. import requests as _requests - self._session = _requests.Session() self._session.trust_env = False if self._auth: self._session.auth = self._auth logger.info(f"Initialized ClickHouse HTTP client [{name}]: {host}:{port}") - - def _request(self, query: str, params: dict | None = None, data: str = None) -> str: + + def _request(self, query: str, params: Optional[Dict] = None, data: str = None) -> str: """Execute HTTP request to ClickHouse.""" # Safely render bound parameters into ClickHouse SQL literals for the # HTTP fallback client. The native TCP client keeps true parameterization. @@ -97,27 +88,18 @@ def _request(self, query: str, params: dict | None = None, data: str = None) -> unreplaced = re.findall(r"%\(([^)]+)\)s", query) if unreplaced: - raise ValueError( - f"Unbound ClickHouse parameters in HTTP query: {unreplaced}" - ) - + raise ValueError(f"Unbound ClickHouse parameters in HTTP query: {unreplaced}") + req_params = {"database": self.database} - + if data: req_params["query"] = query - resp = self._session.post( - self._base_url, params=req_params, data=data, timeout=60 - ) + resp = self._session.post(self._base_url, params=req_params, data=data, timeout=60) else: # Use POST with query in body to avoid URL length limits (e.g., long PIVOT queries with CJK) - resp = self._session.post( - self._base_url, - params=req_params, - data=query.encode("utf-8"), - timeout=60, - headers={"Content-Type": "text/plain; charset=utf-8"}, - ) - + resp = self._session.post(self._base_url, params=req_params, data=query.encode('utf-8'), timeout=60, + headers={'Content-Type': 'text/plain; charset=utf-8'}) + if resp.status_code != 200: logger.error( f"ClickHouse HTTP error [{self.name}]: status={resp.status_code}, " @@ -125,10 +107,10 @@ def _request(self, query: str, params: dict | None = None, data: str = None) -> ) resp.raise_for_status() return resp.text.strip() - - def execute(self, query: str, params: dict | None = None) -> list[tuple]: + + def execute(self, query: str, params: Optional[Dict] = None) -> List[tuple]: """Execute a query and return results as list of tuples. - + On UNKNOWN_TABLE errors, automatically creates the table from registered schema and retries once. """ @@ -136,21 +118,19 @@ def execute(self, query: str, params: dict | None = None) -> list[tuple]: try: # Add FORMAT for SELECT queries (including CTE WITH ... SELECT) query_upper = query.strip().upper() - if ( - query_upper.startswith("SELECT") or query_upper.startswith("WITH") - ) and "FORMAT" not in query_upper: + if (query_upper.startswith("SELECT") or query_upper.startswith("WITH")) and "FORMAT" not in query_upper: query = query.rstrip(";") + " FORMAT TabSeparatedWithNames" - + result = self._request(query, params) - + if not result: return [] - + # Parse TabSeparated result lines = result.split("\n") if len(lines) <= 1: return [] - + # Skip header line, parse data rows = [] for line in lines[1:]: @@ -164,10 +144,7 @@ def execute(self, query: str, params: dict | None = None) -> list[tuple]: # Retry the original query try: query_upper = query.strip().upper() - if ( - query_upper.startswith("SELECT") - or query_upper.startswith("WITH") - ) and "FORMAT" not in query_upper: + if (query_upper.startswith("SELECT") or query_upper.startswith("WITH")) and "FORMAT" not in query_upper: query = query.rstrip(";") + " FORMAT TabSeparatedWithNames" result = self._request(query, params) if not result: @@ -184,58 +161,55 @@ def execute(self, query: str, params: dict | None = None) -> list[tuple]: pass # Retry failed, fall through to original error logger.error(f"HTTP query execution failed [{self.name}]: {e}") raise - - def _extract_unknown_table(exc: Exception) -> str | None: + + def _extract_unknown_table(exc: Exception) -> Optional[str]: """Extract table name from a ClickHouse UNKNOWN_TABLE error. - + The error body is in exc.response.text for HTTPError, or in str(exc) for other exceptions. """ error_text = "" # For requests.HTTPError, the body is in response.text - if hasattr(exc, "response") and exc.response is not None: - error_text = getattr(exc.response, "text", "") or "" + if hasattr(exc, 'response') and exc.response is not None: + error_text = getattr(exc.response, 'text', '') or '' if not error_text: error_text = str(exc) - - if "UNKNOWN_TABLE" not in error_text: + + if 'UNKNOWN_TABLE' not in error_text: return None - + # Pattern: "Unknown table expression identifier 'table_name'" match = re.search(r"Unknown table expression identifier\s+'(\w+)'", error_text) if match: return match.group(1) - + # Fallback: try to extract from FROM clause context match = re.search(r"from\s+`?(\w+)`?", error_text, re.IGNORECASE) if match: return match.group(1) - + return None - - def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: + + def execute_query(self, query: str, params: Optional[Dict] = None) -> pd.DataFrame: """Execute query and return results as DataFrame.""" with self._lock: try: query_upper = query.strip().upper() - if ( - query_upper.startswith("SELECT") or query_upper.startswith("WITH") - ) and "FORMAT" not in query_upper: + if (query_upper.startswith("SELECT") or query_upper.startswith("WITH")) and "FORMAT" not in query_upper: query = query.rstrip(";") + " FORMAT TabSeparatedWithNames" - + result = self._request(query, params) - + if not result: return pd.DataFrame() - + return pd.read_csv(io.StringIO(result), sep="\t", na_values=["\\N"]) except Exception as e: logger.error(f"HTTP query execution failed [{self.name}]: {e}") raise - - def insert_dataframe( - self, table_name: str, df: pd.DataFrame, settings: dict | None = None - ) -> None: + + def insert_dataframe(self, table_name: str, df: pd.DataFrame, + settings: Optional[Dict] = None) -> None: """Insert DataFrame into table via HTTP.""" with self._lock: try: @@ -244,136 +218,113 @@ def insert_dataframe( df = df.copy() for col in df.columns: if pd.api.types.is_datetime64_any_dtype(df[col]): - df[col] = df[col].dt.floor("s") + df[col] = df[col].dt.floor('s') elif df[col].dtype == object and len(df) > 0: # Check if column contains Python datetime objects - sample = ( - df[col].dropna().iloc[0] - if not df[col].dropna().empty - else None - ) + sample = df[col].dropna().iloc[0] if not df[col].dropna().empty else None if isinstance(sample, datetime): df[col] = df[col].apply( - lambda x: ( - x.strftime("%Y-%m-%d %H:%M:%S") - if isinstance(x, datetime) - else x - ) + lambda x: x.strftime('%Y-%m-%d %H:%M:%S') if isinstance(x, datetime) else x ) - + # Sanitize string columns: ClickHouse TabSeparated format does not # support embedded newlines or tabs in field values (no quoting). for col in df.columns: if df[col].dtype == object: df[col] = df[col].apply( - lambda v: ( - v.replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\t", "\\t") - if isinstance(v, str) - else v - ) + lambda v: v.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") + if isinstance(v, str) else v ) # Convert DataFrame to TabSeparated format # Use na_rep='\\N' so NaN/None becomes \N (ClickHouse NULL in TSV) - data = df.to_csv(sep="\t", index=False, header=False, na_rep="\\N") + data = df.to_csv(sep="\t", index=False, header=False, na_rep='\\N') columns = ", ".join(df.columns) query = f"INSERT INTO {table_name} ({columns}) FORMAT TabSeparated" self._request(query, data=data) logger.info(f"Inserted {len(df)} rows into {table_name} [{self.name}]") except Exception as e: - logger.error( - f"Failed to insert data into {table_name} [{self.name}]: {e}" - ) + logger.error(f"Failed to insert data into {table_name} [{self.name}]: {e}") raise - - def register_table_schema(self, table_name: str, schema: dict[str, Any]) -> None: + + def register_table_schema(self, table_name: str, schema: Dict[str, Any]) -> None: """Register a table schema for auto-creation on UNKNOWN_TABLE errors.""" self._table_schemas[table_name] = schema - + def _try_auto_create_table(self, table_name: str) -> bool: """Try to auto-create a table from registered schema. - + Returns True if table was created, False if no schema found or creation failed. """ schema = self._table_schemas.get(table_name) if not schema: return False - + try: create_sql = self._build_create_table_sql(table_name, schema) if create_sql: - logger.info( - f"Auto-creating table {table_name} from registered schema [{self.name}]" - ) + logger.info(f"Auto-creating table {table_name} from registered schema [{self.name}]") self._request(create_sql) - logger.info( - f"Table {table_name} auto-created successfully [{self.name}]" - ) + logger.info(f"Table {table_name} auto-created successfully [{self.name}]") return True except Exception as e: - logger.warning( - f"Failed to auto-create table {table_name} [{self.name}]: {e}" - ) + logger.warning(f"Failed to auto-create table {table_name} [{self.name}]: {e}") return False - + @staticmethod - def _build_create_table_sql(table_name: str, schema: dict[str, Any]) -> str | None: + def _build_create_table_sql(table_name: str, schema: Dict[str, Any]) -> Optional[str]: """Build CREATE TABLE IF NOT EXISTS SQL from schema dict.""" - columns = schema.get("columns", []) + columns = schema.get('columns', []) if not columns: return None - - engine = schema.get("engine", "MergeTree") - engine_params = schema.get("engine_params", []) - partition_by = schema.get("partition_by") - order_by = schema.get("order_by", []) - comment = schema.get("comment", "") - + + engine = schema.get('engine', 'MergeTree') + engine_params = schema.get('engine_params', []) + partition_by = schema.get('partition_by') + order_by = schema.get('order_by', []) + comment = schema.get('comment', '') + col_defs = [] for col in columns: - col_type = col.get("type") or col.get("data_type", "String") + col_type = col.get('type') or col.get('data_type', 'String') col_def = f"`{col['name']}` {col_type}" - if col.get("default"): + if col.get('default'): col_def += f" DEFAULT {col['default']}" - if col.get("comment"): + if col.get('comment'): col_def += f" COMMENT '{col['comment']}'" col_defs.append(col_def) - + create_sql = f"CREATE TABLE IF NOT EXISTS {table_name} (\n" create_sql += ",\n".join(f" {col}" for col in col_defs) - + if engine_params: engine_str = f"{engine}({', '.join(engine_params)})" - elif "(" not in engine: + elif '(' not in engine: engine_str = f"{engine}()" else: engine_str = engine create_sql += f"\n) ENGINE = {engine_str}" - + if partition_by and partition_by not in ("", "tuple()", "None"): create_sql += f"\nPARTITION BY {partition_by}" - + if order_by: if isinstance(order_by, list): create_sql += f"\nORDER BY ({', '.join(order_by)})" else: create_sql += f"\nORDER BY ({order_by})" - + if comment: create_sql += f"\nCOMMENT '{comment}'" - + return create_sql - + def table_exists(self, table_name: str) -> bool: """Check if table exists.""" query = "SELECT count() FROM system.tables WHERE database = %(database)s AND name = %(table_name)s" - result = self._request( - query, params={"database": self.database, "table_name": table_name} - ) + result = self._request(query, params={"database": self.database, "table_name": table_name}) return int(result.strip()) > 0 if result else False - + def close(self): """Close connection (no-op for HTTP).""" logger.info(f"ClickHouse HTTP connection closed [{self.name}]") @@ -381,20 +332,12 @@ def close(self): class ClickHouseClient: """ClickHouse database client with TCP/HTTP fallback support.""" - - def __init__( - self, - host: str = None, - port: int = None, - user: str = None, - password: str = None, - database: str = None, - name: str = "primary", - http_port: int = 8123, - prefer_http: bool = False, - ): + + def __init__(self, host: str = None, port: int = None, user: str = None, + password: str = None, database: str = None, name: str = "primary", + http_port: int = 8123, prefer_http: bool = False): """Initialize ClickHouse client. - + Args: host: ClickHouse host (default: from settings) port: ClickHouse TCP port (default: from settings) @@ -411,23 +354,19 @@ def __init__( self.password = password or settings.CLICKHOUSE_PASSWORD self.database = database or settings.CLICKHOUSE_DATABASE self.name = name - self.http_port = ( - http_port - if http_port != 8123 - else getattr(settings, "CLICKHOUSE_HTTP_PORT", 8123) - ) + self.http_port = http_port if http_port != 8123 else getattr(settings, 'CLICKHOUSE_HTTP_PORT', 8123) self.client = None self._http_client = None self._use_http = prefer_http self._lock = threading.Lock() # Registry of table schemas for auto-create on UNKNOWN_TABLE - self._table_schemas: dict[str, dict[str, Any]] = {} - + self._table_schemas: Dict[str, Dict[str, Any]] = {} + if prefer_http: self._init_http_client() else: self._connect() - + def _init_http_client(self): """Initialize HTTP client as fallback.""" self._http_client = ClickHouseHttpClient( @@ -436,17 +375,17 @@ def _init_http_client(self): user=self.user, password=self.password, database=self.database, - name=f"{self.name}-http", + name=f"{self.name}-http" ) self._use_http = True logger.info(f"Using HTTP fallback for ClickHouse [{self.name}]") - + def _connect(self): """Establish connection to ClickHouse via TCP, fallback to HTTP on failure.""" try: # Register Asia/Beijing as alias for Asia/Shanghai to handle non-standard timezone self._register_timezone_alias() - + self.client = Client( host=self.host, port=self.port, @@ -457,34 +396,29 @@ def _connect(self): send_receive_timeout=60, sync_request_timeout=60, settings={ - "use_numpy": True, - "enable_http_compression": 1, - "session_timezone": "Asia/Shanghai", - "max_memory_usage": 2000000000, - "max_bytes_before_external_group_by": 1000000000, - "max_threads": 4, - }, + 'use_numpy': True, + 'enable_http_compression': 1, + 'session_timezone': 'Asia/Shanghai', + 'max_memory_usage': 2000000000, + 'max_bytes_before_external_group_by': 1000000000, + 'max_threads': 4, + } ) # Test connection self.client.execute("SELECT 1") - logger.info( - f"Connected to ClickHouse via TCP [{self.name}]: {self.host}:{self.port}" - ) + logger.info(f"Connected to ClickHouse via TCP [{self.name}]: {self.host}:{self.port}") self._use_http = False except Exception as e: - logger.warning( - f"TCP connection failed [{self.name}]: {e}, falling back to HTTP" - ) + logger.warning(f"TCP connection failed [{self.name}]: {e}, falling back to HTTP") self._init_http_client() - + @staticmethod def _register_timezone_alias(): """Register Asia/Beijing as alias for Asia/Shanghai.""" try: import pytz - - if "Asia/Beijing" not in pytz.all_timezones_set: - pytz._tzinfo_cache["Asia/Beijing"] = pytz.timezone("Asia/Shanghai") + if 'Asia/Beijing' not in pytz.all_timezones_set: + pytz._tzinfo_cache['Asia/Beijing'] = pytz.timezone('Asia/Shanghai') except Exception: pass @@ -503,7 +437,7 @@ def _should_reconnect(exc: Exception) -> bool: or "Connection broken" in msg or "Connection reset" in msg ) - + def _reconnect(self): """Force reconnect to ClickHouse.""" try: @@ -514,11 +448,9 @@ def _reconnect(self): pass self._connect() except Exception as reconnect_err: - logger.error( - f"Reconnect to ClickHouse failed [{self.name}]: {reconnect_err}" - ) + logger.error(f"Reconnect to ClickHouse failed [{self.name}]: {reconnect_err}") raise - + def _ensure_connected(self): """Ensure we have a valid connection before executing query.""" if self._use_http: @@ -526,58 +458,48 @@ def _ensure_connected(self): self._init_http_client() elif self.client is None: self._connect() - - def register_table_schema(self, table_name: str, schema: dict[str, Any]) -> None: + + def register_table_schema(self, table_name: str, schema: Dict[str, Any]) -> None: """Register a table schema for auto-creation on UNKNOWN_TABLE errors.""" self._table_schemas[table_name] = schema # Also register on HTTP client if available if self._http_client: self._http_client.register_table_schema(table_name, schema) - + def _try_auto_create_table(self, table_name: str) -> bool: """Try to auto-create a table from registered schema. - + Returns True if table was created, False if no schema found or creation failed. """ schema = self._table_schemas.get(table_name) if not schema: return False - + try: - create_sql = ClickHouseHttpClient._build_create_table_sql( - table_name, schema - ) + create_sql = ClickHouseHttpClient._build_create_table_sql(table_name, schema) if create_sql: - logger.info( - f"Auto-creating table {table_name} from registered schema [{self.name}]" - ) + logger.info(f"Auto-creating table {table_name} from registered schema [{self.name}]") self.execute(create_sql) - logger.info( - f"Table {table_name} auto-created successfully [{self.name}]" - ) + logger.info(f"Table {table_name} auto-created successfully [{self.name}]") return True except Exception as e: - logger.warning( - f"Failed to auto-create table {table_name} [{self.name}]: {e}" - ) + logger.warning(f"Failed to auto-create table {table_name} [{self.name}]: {e}") return False - @retry( - stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3) - ) - def execute(self, query: str, params: dict | None = None) -> Any: + @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3)) + def execute(self, query: str, params: Optional[Dict] = None) -> Any: """Execute a query with retry logic and auto-reconnect on transport errors. - + On UNKNOWN_TABLE errors, automatically creates the table from registered schema and retries once. """ with self._lock: self._ensure_connected() - + # Use HTTP client if in HTTP mode (already has UNKNOWN_TABLE handling) if self._use_http and self._http_client: return self._http_client.execute(query, params) - + try: return self.client.execute(query, params) except Exception as e: @@ -594,7 +516,7 @@ def execute(self, query: str, params: dict | None = None) -> Any: return self.client.execute(query, params) except Exception: pass # Retry failed, fall through to original error - + if self._should_reconnect(e): logger.warning(f"Reconnect ClickHouse [{self.name}] due to: {e}") self._reconnect() @@ -604,19 +526,17 @@ def execute(self, query: str, params: dict | None = None) -> Any: return self.client.execute(query, params) logger.error(f"Query execution failed [{self.name}]: {e}") raise - - @retry( - stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3) - ) - def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: + + @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3)) + def execute_query(self, query: str, params: Optional[Dict] = None) -> pd.DataFrame: """Execute query and return results as DataFrame with auto-reconnect on transport errors.""" with self._lock: self._ensure_connected() - + # Use HTTP client if in HTTP mode if self._use_http and self._http_client: return self._http_client.execute_query(query, params) - + try: result = self.client.query_dataframe(query, params) return result @@ -627,9 +547,7 @@ def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: ) return pd.DataFrame() if self._should_reconnect(e): - logger.warning( - f"Reconnect ClickHouse during query_dataframe [{self.name}] due to: {e}" - ) + logger.warning(f"Reconnect ClickHouse during query_dataframe [{self.name}] due to: {e}") self._reconnect() # After reconnect, check if we switched to HTTP if self._use_http and self._http_client: @@ -637,77 +555,70 @@ def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: return self.client.query_dataframe(query, params) logger.error(f"Query execution failed [{self.name}]: {e}") raise - - @retry( - stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3) - ) - def insert_dataframe( - self, table_name: str, df: pd.DataFrame, settings: dict | None = None - ) -> None: + + @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=1, max=3)) + def insert_dataframe(self, table_name: str, df: pd.DataFrame, + settings: Optional[Dict] = None) -> None: """Insert DataFrame into table.""" with self._lock: self._ensure_connected() - + # Use HTTP client if in HTTP mode if self._use_http and self._http_client: self._http_client.insert_dataframe(table_name, df, settings) return - + try: self.client.insert_dataframe( - f"INSERT INTO {table_name} VALUES", df, settings=settings or {} + f"INSERT INTO {table_name} VALUES", + df, + settings=settings or {} ) logger.info(f"Inserted {len(df)} rows into {table_name} [{self.name}]") except Exception as e: if self._should_reconnect(e): - logger.warning( - f"Reconnect ClickHouse during insert [{self.name}] due to: {e}" - ) + logger.warning(f"Reconnect ClickHouse during insert [{self.name}] due to: {e}") self._reconnect() # After reconnect, check if we switched to HTTP if self._use_http and self._http_client: self._http_client.insert_dataframe(table_name, df, settings) return self.client.insert_dataframe( - f"INSERT INTO {table_name} VALUES", df, settings=settings or {} - ) - logger.info( - f"Inserted {len(df)} rows into {table_name} [{self.name}]" + f"INSERT INTO {table_name} VALUES", + df, + settings=settings or {} ) + logger.info(f"Inserted {len(df)} rows into {table_name} [{self.name}]") else: - logger.error( - f"Failed to insert data into {table_name} [{self.name}]: {e}" - ) + logger.error(f"Failed to insert data into {table_name} [{self.name}]: {e}") raise - + def create_database(self, database_name: str) -> None: """Create database if not exists.""" query = f"CREATE DATABASE IF NOT EXISTS {database_name}" self.execute(query) logger.info(f"Database {database_name} created or already exists [{self.name}]") - + def create_table(self, create_table_sql: str) -> None: """Create table from SQL definition.""" self.execute(create_table_sql) logger.info(f"Table created successfully [{self.name}]") - + def table_exists(self, table_name: str) -> bool: """Check if table exists.""" if self._use_http and self._http_client: return self._http_client.table_exists(table_name) - + query = """ SELECT count() FROM system.tables WHERE database = %(database)s AND name = %(table_name)s """ - result = self.execute( - query, params={"database": self.database, "table_name": table_name} - ) + result = self.execute(query, params={"database": self.database, "table_name": table_name}) return result[0][0] > 0 - - def get_table_schema(self, table_name: str) -> list[dict[str, Any]]: + + def get_table_schema(self, table_name: str) -> List[Dict[str, Any]]: """Get table schema information.""" query = """ SELECT @@ -720,26 +631,22 @@ def get_table_schema(self, table_name: str) -> list[dict[str, Any]]: AND table = %(table_name)s ORDER BY position """ - result = self.execute_query( - query, params={"database": self.database, "table_name": table_name} - ) - return result.to_dict("records") - + result = self.execute_query(query, params={"database": self.database, "table_name": table_name}) + return result.to_dict('records') + def add_column(self, table_name: str, column_def: str) -> None: """Add column to existing table.""" query = f"ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS {column_def}" self.execute(query) logger.info(f"Added column to {table_name}: {column_def} [{self.name}]") - + def modify_column(self, table_name: str, column_name: str, new_type: str) -> None: """Modify column type.""" query = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {new_type}" self.execute(query) - logger.info( - f"Modified column {column_name} in {table_name} to {new_type} [{self.name}]" - ) - - def get_partition_info(self, table_name: str) -> list[dict[str, Any]]: + logger.info(f"Modified column {column_name} in {table_name} to {new_type} [{self.name}]") + + def get_partition_info(self, table_name: str) -> List[Dict[str, Any]]: """Get partition information for table.""" query = """ SELECT @@ -752,11 +659,9 @@ def get_partition_info(self, table_name: str) -> list[dict[str, Any]]: GROUP BY partition ORDER BY partition """ - result = self.execute_query( - query, params={"database": self.database, "table_name": table_name} - ) - return result.to_dict("records") - + result = self.execute_query(query, params={"database": self.database, "table_name": table_name}) + return result.to_dict('records') + def optimize_table(self, table_name: str, final: bool = True) -> None: """Optimize table.""" query = f"OPTIMIZE TABLE {table_name}" @@ -764,7 +669,7 @@ def optimize_table(self, table_name: str, final: bool = True) -> None: query += " FINAL" self.execute(query) logger.info(f"Optimized table {table_name} [{self.name}]") - + def close(self): """Close database connection.""" if self._http_client: @@ -772,7 +677,7 @@ def close(self): if self.client: self.client.disconnect() logger.info(f"ClickHouse connection closed [{self.name}]") - + def is_using_http(self) -> bool: """Check if currently using HTTP fallback.""" return self._use_http @@ -780,12 +685,12 @@ def is_using_http(self) -> bool: class DualWriteClient: """Client that writes to both primary and backup ClickHouse databases.""" - + def __init__(self): """Initialize dual write client with primary and optional backup.""" self.primary = ClickHouseClient(name="primary", prefer_http=True) self.backup = None - + # Initialize backup client if configured if settings.BACKUP_CLICKHOUSE_HOST: try: @@ -795,23 +700,19 @@ def __init__(self): user=settings.BACKUP_CLICKHOUSE_USER, password=settings.BACKUP_CLICKHOUSE_PASSWORD, database=settings.BACKUP_CLICKHOUSE_DATABASE, - name="backup", - ) - logger.info( - f"Dual write enabled: backup at {settings.BACKUP_CLICKHOUSE_HOST}" + name="backup" ) + logger.info(f"Dual write enabled: backup at {settings.BACKUP_CLICKHOUSE_HOST}") except Exception as e: - logger.warning( - f"Failed to connect to backup ClickHouse, dual write disabled: {e}" - ) + logger.warning(f"Failed to connect to backup ClickHouse, dual write disabled: {e}") self.backup = None - + @property def client(self): """For backward compatibility - return primary client's underlying client.""" return self.primary.client - - def execute(self, query: str, params: dict | None = None) -> Any: + + def execute(self, query: str, params: Optional[Dict] = None) -> Any: """Execute query on primary, fallback to backup on failure.""" try: return self.primary.execute(query, params) @@ -821,7 +722,7 @@ def execute(self, query: str, params: dict | None = None) -> Any: return self.backup.execute(query, params) raise - def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: + def execute_query(self, query: str, params: Optional[Dict] = None) -> pd.DataFrame: """Execute query on primary, fallback to backup on failure.""" try: return self.primary.execute_query(query, params) @@ -831,17 +732,16 @@ def execute_query(self, query: str, params: dict | None = None) -> pd.DataFrame: return self.backup.execute_query(query, params) raise - def query(self, query: str, params: dict | None = None) -> pd.DataFrame: + def query(self, query: str, params: Optional[Dict] = None) -> pd.DataFrame: """Execute query on primary, fallback to backup (alias for execute_query).""" return self.execute_query(query, params) - - def insert_dataframe( - self, table_name: str, df: pd.DataFrame, settings: dict | None = None - ) -> None: + + def insert_dataframe(self, table_name: str, df: pd.DataFrame, + settings: Optional[Dict] = None) -> None: """Insert DataFrame into both primary and backup databases.""" # Always write to primary self.primary.insert_dataframe(table_name, df, settings) - + # Write to backup if available if self.backup: try: @@ -849,7 +749,7 @@ def insert_dataframe( except Exception as e: logger.error(f"Failed to write to backup database: {e}") # Don't raise - primary write succeeded - + def create_database(self, database_name: str) -> None: """Create database on both primary and backup.""" self.primary.create_database(database_name) @@ -858,7 +758,7 @@ def create_database(self, database_name: str) -> None: self.backup.create_database(database_name) except Exception as e: logger.warning(f"Failed to create database on backup: {e}") - + def create_table(self, create_table_sql: str) -> None: """Create table on both primary and backup.""" self.primary.create_table(create_table_sql) @@ -867,14 +767,14 @@ def create_table(self, create_table_sql: str) -> None: self.backup.create_table(create_table_sql) except Exception as e: logger.warning(f"Failed to create table on backup: {e}") - + def table_exists(self, table_name: str) -> bool: """Check if table exists on primary.""" return self.primary.table_exists(table_name) - - def register_table_schema(self, table_name: str, schema: dict[str, Any]) -> None: + + def register_table_schema(self, table_name: str, schema: Dict[str, Any]) -> None: """Register a table schema on primary (and backup if available). - + When a query hits UNKNOWN_TABLE for this table, the db client will automatically create it from the registered schema instead of raising. """ @@ -884,11 +784,11 @@ def register_table_schema(self, table_name: str, schema: dict[str, Any]) -> None self.backup.register_table_schema(table_name, schema) except Exception: pass - - def get_table_schema(self, table_name: str) -> list[dict[str, Any]]: + + def get_table_schema(self, table_name: str) -> List[Dict[str, Any]]: """Get table schema from primary.""" return self.primary.get_table_schema(table_name) - + def add_column(self, table_name: str, column_def: str) -> None: """Add column on both primary and backup.""" self.primary.add_column(table_name, column_def) @@ -897,7 +797,7 @@ def add_column(self, table_name: str, column_def: str) -> None: self.backup.add_column(table_name, column_def) except Exception as e: logger.warning(f"Failed to add column on backup: {e}") - + def modify_column(self, table_name: str, column_name: str, new_type: str) -> None: """Modify column on both primary and backup.""" self.primary.modify_column(table_name, column_name, new_type) @@ -906,11 +806,11 @@ def modify_column(self, table_name: str, column_name: str, new_type: str) -> Non self.backup.modify_column(table_name, column_name, new_type) except Exception as e: logger.warning(f"Failed to modify column on backup: {e}") - - def get_partition_info(self, table_name: str) -> list[dict[str, Any]]: + + def get_partition_info(self, table_name: str) -> List[Dict[str, Any]]: """Get partition info from primary.""" return self.primary.get_partition_info(table_name) - + def optimize_table(self, table_name: str, final: bool = True) -> None: """Optimize table on both primary and backup.""" self.primary.optimize_table(table_name, final) @@ -919,13 +819,13 @@ def optimize_table(self, table_name: str, final: bool = True) -> None: self.backup.optimize_table(table_name, final) except Exception as e: logger.warning(f"Failed to optimize table on backup: {e}") - + def close(self): """Close both connections.""" self.primary.close() if self.backup: self.backup.close() - + def is_dual_write_enabled(self) -> bool: """Check if dual write is enabled.""" return self.backup is not None diff --git a/src/stock_datasource/modules/portfolio/enhanced_service.py b/src/stock_datasource/modules/portfolio/enhanced_service.py index 677633b7..b6c84122 100644 --- a/src/stock_datasource/modules/portfolio/enhanced_service.py +++ b/src/stock_datasource/modules/portfolio/enhanced_service.py @@ -4,6 +4,7 @@ import uuid from dataclasses import asdict, dataclass from datetime import date, datetime +from enum import Enum from typing import Any import pandas as pd @@ -11,6 +12,32 @@ logger = logging.getLogger(__name__) +class TransactionType(Enum): + """Transaction type enum.""" + + BUY = "buy" + SELL = "sell" + + +@dataclass +class Transaction: + """Transaction record for buy/sell operations.""" + + id: str + user_id: str = "default_user" + ts_code: str = "" + stock_name: str = "" + transaction_type: str = "buy" # 'buy' or 'sell' + quantity: int = 0 + price: float = 0.0 + transaction_date: str = "" + position_id: str = "" + realized_pl: float | None = None + notes: str = "" + profile_id: str = "default" + created_at: datetime | None = None + + @dataclass class Position: """Enhanced position data model.""" @@ -34,6 +61,7 @@ class Position: industry: str | None = None last_price_update: datetime | None = None is_active: bool = True + profile_id: str = "default" created_at: datetime | None = None updated_at: datetime | None = None @@ -84,6 +112,7 @@ def __init__(self): self._db = None # In-memory storage for demo (should be replaced with database) self._positions: dict[str, Position] = {} + self._transactions: dict[str, Transaction] = {} self._alerts: dict[str, PositionAlert] = {} # Add some sample data @@ -272,6 +301,260 @@ async def add_position( logger.info(f"Position {position_id} added: {ts_code}") return position + async def record_buy_transaction( + self, + user_id: str, + ts_code: str, + quantity: int, + price: float, + transaction_date: str, + notes: str | None = None, + profile_id: str = "default", + ) -> Transaction: + """Record a buy transaction and update/create the corresponding position. + + If an active position for this user+ts_code+profile_id exists, the + position's cost_price is updated to the weighted average and quantity + is increased. Otherwise a new Position is created. + """ + # Find existing active position + existing_position = None + for pos in self._positions.values(): + if ( + pos.user_id == user_id + and pos.ts_code == ts_code + and pos.profile_id == profile_id + and pos.is_active + ): + existing_position = pos + break + + # Get stock name + stock_name, sector, industry = await self._get_stock_info(ts_code) + + if existing_position: + # Update existing position with weighted average cost + new_cost = self._calc_weighted_average_cost( + old_quantity=existing_position.quantity, + old_cost=existing_position.cost_price, + new_quantity=quantity, + new_price=price, + ) + existing_position.quantity += quantity + existing_position.cost_price = round(new_cost, 4) + existing_position.updated_at = datetime.now() + position_id = existing_position.id + + # Save to DB if available + await self._save_position(existing_position) + else: + # Create new position + position_id = str(uuid.uuid4()) + new_position = Position( + id=position_id, + user_id=user_id, + ts_code=ts_code, + stock_name=stock_name, + quantity=quantity, + cost_price=price, + buy_date=transaction_date, + notes=notes or "", + sector=sector, + industry=industry, + profile_id=profile_id, + is_active=True, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + await self._update_position_prices(new_position) + await self._save_position(new_position) + + # Create transaction record + txn_id = str(uuid.uuid4()) + transaction = Transaction( + id=txn_id, + user_id=user_id, + ts_code=ts_code, + stock_name=stock_name, + transaction_type="buy", + quantity=quantity, + price=price, + transaction_date=transaction_date, + position_id=position_id, + realized_pl=None, + notes=notes or "", + profile_id=profile_id, + created_at=datetime.now(), + ) + + # Persist transaction + self._transactions[txn_id] = transaction + await self._save_transaction(transaction) + + logger.info(f"Buy transaction {txn_id}: {ts_code} x{quantity} @ {price}") + return transaction + + async def record_sell_transaction( + self, + user_id: str, + ts_code: str, + quantity: int, + price: float, + transaction_date: str, + notes: str | None = None, + profile_id: str = "default", + ) -> Transaction: + """Record a sell transaction and update the corresponding position. + + Validates that the user has enough shares to sell, computes realized + P/L, and reduces the position quantity. If all shares are sold the + position is marked as inactive. + """ + # Find active position + existing_position = None + for pos in self._positions.values(): + if ( + pos.user_id == user_id + and pos.ts_code == ts_code + and pos.profile_id == profile_id + and pos.is_active + ): + existing_position = pos + break + + if not existing_position: + raise ValueError( + f"No active position found for {ts_code} " + f"(user={user_id}, profile={profile_id})" + ) + + # Validate sell quantity + self._validate_sell_quantity(existing_position.quantity, quantity) + + # Calculate realized P/L + realized_pl = self._calc_realized_pl( + quantity=quantity, sell_price=price, cost_price=existing_position.cost_price + ) + + # Update position + existing_position.quantity -= quantity + existing_position.updated_at = datetime.now() + + if existing_position.quantity == 0: + existing_position.is_active = False + + # Save position + await self._save_position(existing_position) + + # Create transaction record + txn_id = str(uuid.uuid4()) + transaction = Transaction( + id=txn_id, + user_id=user_id, + ts_code=ts_code, + stock_name=existing_position.stock_name, + transaction_type="sell", + quantity=quantity, + price=price, + transaction_date=transaction_date, + position_id=existing_position.id, + realized_pl=realized_pl, + notes=notes or "", + profile_id=profile_id, + created_at=datetime.now(), + ) + + # Persist transaction + self._transactions[txn_id] = transaction + await self._save_transaction(transaction) + + logger.info( + f"Sell transaction {txn_id}: {ts_code} x{quantity} @ {price}, " + f"realized_pl={realized_pl}" + ) + return transaction + + async def get_transactions( + self, + user_id: str, + ts_code: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + profile_id: str | None = None, + ) -> list[Transaction]: + """Get transaction history for a user with optional filters. + + Returns transactions ordered by transaction_date DESC. + """ + # Try database first + try: + if self.db is not None: + where_parts = ["user_id = %(user_id)s"] + params: dict[str, Any] = {"user_id": user_id} + + if ts_code: + where_parts.append("ts_code = %(ts_code)s") + params["ts_code"] = ts_code + if start_date: + where_parts.append("transaction_date >= %(start_date)s") + params["start_date"] = start_date + if end_date: + where_parts.append("transaction_date <= %(end_date)s") + params["end_date"] = end_date + if profile_id: + where_parts.append("profile_id = %(profile_id)s") + params["profile_id"] = profile_id + + where_clause = " AND ".join(where_parts) + query = f""" + SELECT id, user_id, ts_code, stock_name, transaction_type, + quantity, price, transaction_date, position_id, + realized_pl, notes, profile_id, created_at + FROM user_transactions + WHERE {where_clause} + ORDER BY transaction_date DESC + """ + df = self.db.execute_query(query, params) + if not df.empty: + transactions = [] + for _, row in df.iterrows(): + txn = Transaction( + id=str(row["id"]), + user_id=str(row["user_id"]), + ts_code=row["ts_code"], + stock_name=row["stock_name"], + transaction_type=str(row["transaction_type"]), + quantity=int(row["quantity"]), + price=float(row["price"]), + transaction_date=str(row["transaction_date"]), + position_id=str(row.get("position_id", "")), + realized_pl=float(row["realized_pl"]) + if pd.notna(row.get("realized_pl")) + else None, + notes=row.get("notes", ""), + profile_id=str(row.get("profile_id", "default")), + created_at=row["created_at"] + if pd.notna(row.get("created_at")) + else None, + ) + transactions.append(txn) + return transactions + except Exception as e: + logger.warning(f"Failed to get transactions from database: {e}") + + # Fallback to in-memory storage + results = [ + t + for t in self._transactions.values() + if t.user_id == user_id + and (ts_code is None or t.ts_code == ts_code) + and (start_date is None or t.transaction_date >= start_date) + and (end_date is None or t.transaction_date <= end_date) + and (profile_id is None or t.profile_id == profile_id) + ] + results.sort(key=lambda t: t.transaction_date, reverse=True) + return results + async def update_position( self, position_id: str, user_id: str, **updates ) -> Position | None: @@ -560,6 +843,79 @@ async def check_alerts(self, user_id: str = "default_user") -> list[PositionAler return triggered_alerts + async def get_kline_patterns( + self, ts_code: str, days: int = 60 + ) -> list[dict[str, str]]: + """Detect candlestick patterns in K-line data for a stock. + + Fetches OHLC data from ClickHouse, converts to Candle objects, + and runs detect_patterns() to find recognized patterns. + + Returns a list of dicts with keys: name, name_en, date, type, category. + """ + from .kline_patterns import Candle, detect_patterns + + candles = await self._fetch_kline_candles(ts_code, days) + if not candles: + return [] + return detect_patterns(candles) + + async def _fetch_kline_candles( + self, ts_code: str, days: int + ) -> list: + """Fetch OHLC data from ClickHouse and convert to Candle objects.""" + from .kline_patterns import Candle + + if self.db is None: + return [] + + try: + # Determine which table to query based on ts_code suffix + if ts_code.endswith(".HK"): + table = "ods_hk_daily" + elif any( + ts_code.startswith(p) + for p in ("51", "15", "56", "59", "16", "50", "52", "58") + ): + table = "ods_etf_fund_daily" + else: + table = "ods_daily" + + query = f""" + SELECT trade_date, open, high, low, close, vol + FROM {table} + WHERE ts_code = %(code)s + ORDER BY trade_date DESC + LIMIT %(limit)s + """ + df = self.db.execute_query(query, {"code": ts_code, "limit": days}) + if df.empty: + return [] + + # Convert to list of Candle objects (oldest first for pattern detection) + candles = [] + for _, row in df[::-1].iterrows(): + trade_date = row["trade_date"] + if hasattr(trade_date, "strftime"): + date_str = trade_date.strftime("%Y-%m-%d") + else: + date_str = str(trade_date) + + candles.append( + Candle( + date=date_str, + open=float(row["open"]), + close=float(row["close"]), + high=float(row["high"]), + low=float(row["low"]), + volume=float(row["vol"]) if pd.notna(row["vol"]) else 0, + ) + ) + return candles + except Exception as e: + logger.warning(f"Failed to fetch kline candles for {ts_code}: {e}") + return [] + # Private helper methods async def _get_stock_info(self, ts_code: str) -> tuple[str, str, str]: """Get stock name, sector and industry. Supports A-shares, ETFs and HK stocks.""" @@ -802,6 +1158,41 @@ def _calc_daily_change(position: Position): position.daily_change = position.current_price - position.prev_close position.daily_pct_chg = position.daily_change / position.prev_close * 100 + @staticmethod + def _calc_weighted_average_cost( + old_quantity: int, old_cost: float, new_quantity: int, new_price: float + ) -> float: + """Calculate weighted average cost after a new buy. + + Formula: (old_qty * old_cost + new_qty * new_price) / (old_qty + new_qty) + If old_quantity is 0, returns new_price directly. + """ + total_quantity = old_quantity + new_quantity + if total_quantity == 0: + return 0.0 + if old_quantity == 0: + return new_price + return (old_quantity * old_cost + new_quantity * new_price) / total_quantity + + @staticmethod + def _calc_realized_pl(quantity: int, sell_price: float, cost_price: float) -> float: + """Calculate realized profit/loss for a sell transaction. + + Formula: quantity * (sell_price - cost_price) + """ + return quantity * (sell_price - cost_price) + + @staticmethod + def _validate_sell_quantity(held_quantity: int, sell_quantity: int) -> None: + """Validate that sell quantity does not exceed held quantity. + + Raises ValueError if sell_quantity > held_quantity or sell_quantity == 0. + """ + if sell_quantity <= 0 or sell_quantity > held_quantity: + raise ValueError( + f"Cannot sell {sell_quantity} shares; only {held_quantity} held" + ) + @staticmethod def _parse_daily_trade_date(trade_date, market_type: str = "a_stock") -> datetime: """将日线 trade_date 转换为带收盘时间的 datetime。""" @@ -838,6 +1229,25 @@ async def _save_position(self, position: Position): # Always save to in-memory storage as backup self._positions[position.id] = position + async def _save_transaction(self, transaction: Transaction): + """Save transaction to database and memory.""" + try: + if self.db is not None: + query = """ + INSERT INTO user_transactions + (id, user_id, ts_code, stock_name, transaction_type, quantity, + price, transaction_date, position_id, realized_pl, notes, + profile_id, created_at) + VALUES (%(id)s, %(user_id)s, %(ts_code)s, %(stock_name)s, + %(transaction_type)s, %(quantity)s, %(price)s, + %(transaction_date)s, %(position_id)s, %(realized_pl)s, + %(notes)s, %(profile_id)s, %(created_at)s) + """ + params = asdict(transaction) + self.db.execute(query, params) + except Exception as e: + logger.warning(f"Failed to save transaction to database: {e}") + async def _record_position_history(self, position: Position, change_type: str): """Record position change in history table.""" try: diff --git a/src/stock_datasource/modules/portfolio/init.py b/src/stock_datasource/modules/portfolio/init.py index 18373281..7962ac7a 100644 --- a/src/stock_datasource/modules/portfolio/init.py +++ b/src/stock_datasource/modules/portfolio/init.py @@ -129,6 +129,28 @@ def ensure_portfolio_tables(): SETTINGS index_granularity = 8192 """) + # Create user_transactions table if not exists + client.execute(""" + CREATE TABLE IF NOT EXISTS user_transactions ( + id String, + user_id String DEFAULT 'default_user', + ts_code String, + stock_name String, + transaction_type Enum8('buy' = 1, 'sell' = 2), + quantity UInt32, + price Decimal(10, 3), + transaction_date Date, + position_id String DEFAULT '', + realized_pl Nullable(Decimal(15, 2)), + notes String DEFAULT '', + profile_id String DEFAULT 'default', + created_at DateTime DEFAULT now() + ) ENGINE = MergeTree() + ORDER BY (user_id, ts_code, transaction_date, id) + PARTITION BY toYYYYMM(transaction_date) + SETTINGS index_granularity = 8192 + """) + logger.info("Portfolio tables ensured successfully") except Exception as e: diff --git a/src/stock_datasource/modules/portfolio/kline_patterns.py b/src/stock_datasource/modules/portfolio/kline_patterns.py new file mode 100644 index 00000000..6b7fe9b6 --- /dev/null +++ b/src/stock_datasource/modules/portfolio/kline_patterns.py @@ -0,0 +1,381 @@ +"""K-line candlestick pattern recognition. + +Implements single, dual, and triple candlestick pattern detection +for technical analysis of stock price data. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Candle: + """Single OHLC candle data.""" + date: str + open: float + close: float + high: float + low: float + volume: float = 0 + + @property + def body(self) -> float: + return abs(self.close - self.open) + + @property + def upper_shadow(self) -> float: + return self.high - max(self.open, self.close) + + @property + def lower_shadow(self) -> float: + return min(self.open, self.close) - self.low + + @property + def range(self) -> float: + return self.high - self.low + + @property + def is_bullish(self) -> bool: + return self.close > self.open + + @property + def is_bearish(self) -> bool: + return self.close < self.open + + +# --------------------------------------------------------------------------- +# Single-Candle Patterns +# --------------------------------------------------------------------------- + + +def is_hammer(c: Candle) -> bool: + """Hammer: small body at top, long lower shadow (>= 2x body), small upper shadow.""" + if c.range == 0: + return False + if c.lower_shadow < c.body * 2: + return False + if c.upper_shadow > c.body: + return False + return True + + +def is_hanging_man(c: Candle) -> bool: + """Hanging Man: same shape as hammer but appears at top of uptrend. + + Note: We detect the shape; trend context is determined by the caller. + """ + return is_hammer(c) + + +def is_inverted_hammer(c: Candle) -> bool: + """Inverted Hammer: small body at bottom, long upper shadow (>= 2x body), small lower shadow.""" + if c.range == 0: + return False + if c.upper_shadow < c.body * 2: + return False + if c.lower_shadow > c.body: + return False + return True + + +def is_shooting_star(c: Candle) -> bool: + """Shooting Star: same shape as inverted hammer but at top of uptrend.""" + if c.range == 0: + return False + if c.upper_shadow < c.body * 2: + return False + if c.lower_shadow > c.body: + return False + return True + + +def is_doji(c: Candle) -> bool: + """Doji: body is very small relative to total range (<= 10% of range).""" + if c.range == 0: + return True # No movement at all + return c.body <= c.range * 0.1 + + +def is_dragonfly_doji(c: Candle) -> bool: + """Dragonfly Doji: doji with long lower shadow and no upper shadow.""" + if not is_doji(c): + return False + if c.upper_shadow > c.body: + return False + if c.lower_shadow < c.range * 0.3: + return False + return True + + +def is_gravestone_doji(c: Candle) -> bool: + """Gravestone Doji: doji with long upper shadow and no lower shadow.""" + if not is_doji(c): + return False + if c.lower_shadow > c.body: + return False + if c.upper_shadow < c.range * 0.3: + return False + return True + + +def is_marubozu(c: Candle) -> Optional[str]: + """Marubozu: no shadows (or very small), large body. + + Returns 'bullish' or 'bearish', or None if not a marubozu. + """ + if c.body == 0: + return None + # Shadows should be <= 5% of body + max_shadow = c.body * 0.05 + if c.upper_shadow > max_shadow or c.lower_shadow > max_shadow: + return None + # Body should be significant (>= 60% of range) + if c.range > 0 and c.body / c.range < 0.6: + return None + return "bullish" if c.is_bullish else "bearish" + + +# --------------------------------------------------------------------------- +# Dual-Candle Patterns +# --------------------------------------------------------------------------- + + +def is_bullish_engulfing(prev: Candle, curr: Candle) -> bool: + """Bullish Engulfing: previous bearish, current bullish, + current body completely engulfs previous body. + """ + if not prev.is_bearish: + return False + if not curr.is_bullish: + return False + # Current body engulfs previous body: + # curr.open < prev.close (opens below prev close) AND curr.close > prev.open (closes above prev open) + if curr.open >= prev.close or curr.close <= prev.open: + return False + return True + + +def is_bearish_engulfing(prev: Candle, curr: Candle) -> bool: + """Bearish Engulfing: previous bullish, current bearish, + current body completely engulfs previous body. + """ + if not prev.is_bullish: + return False + if not curr.is_bearish: + return False + # Current body engulfs previous body: + # curr.open > prev.close (opens above prev close) AND curr.close < prev.open (closes below prev open) + if curr.open <= prev.close or curr.close >= prev.open: + return False + return True + + +# --------------------------------------------------------------------------- +# Triple-Candle Patterns +# --------------------------------------------------------------------------- + + +def is_morning_star(c1: Candle, c2: Candle, c3: Candle) -> bool: + """Morning Star: large bearish, small star, large bullish closing above + midpoint of first candle. + """ + # First candle: large bearish + if not c1.is_bearish: + return False + # Second candle: small body (star) + if c2.body > c1.body * 0.5: + return False + # Third candle: bullish + if not c3.is_bullish: + return False + # Third candle closes above midpoint of first + midpoint = (c1.open + c1.close) / 2 + if c3.close < midpoint: + return False + return True + + +def is_evening_star(c1: Candle, c2: Candle, c3: Candle) -> bool: + """Evening Star: large bullish, small star, large bearish closing below + midpoint of first candle. + """ + # First candle: large bullish + if not c1.is_bullish: + return False + # Second candle: small body (star) + if c2.body > c1.body * 0.5: + return False + # Third candle: bearish + if not c3.is_bearish: + return False + # Third candle closes below midpoint of first + midpoint = (c1.open + c1.close) / 2 + if c3.close > midpoint: + return False + return True + + +def is_three_white_soldiers(c1: Candle, c2: Candle, c3: Candle) -> bool: + """Three White Soldiers: three consecutive bullish candles, + each opening within previous body and closing higher. + """ + if not (c1.is_bullish and c2.is_bullish and c3.is_bullish): + return False + # Each opens within previous body (c2.open between c1.close and c1.open) + if not (min(c1.open, c1.close) <= c2.open <= max(c1.open, c1.close)): + return False + if not (min(c2.open, c2.close) <= c3.open <= max(c2.open, c2.close)): + return False + # Each closes higher than previous + if c2.close <= c1.close or c3.close <= c2.close: + return False + return True + + +def is_three_black_crows(c1: Candle, c2: Candle, c3: Candle) -> bool: + """Three Black Crows: three consecutive bearish candles, + each opening within previous body and closing lower. + """ + if not (c1.is_bearish and c2.is_bearish and c3.is_bearish): + return False + # Each opens within previous body (c2.open between c1.close and c1.open) + if not (min(c1.open, c1.close) <= c2.open <= max(c1.open, c1.close)): + return False + if not (min(c2.open, c2.close) <= c3.open <= max(c2.open, c2.close)): + return False + # Each closes lower than previous + if c2.close >= c1.close or c3.close >= c2.close: + return False + return True + + +# --------------------------------------------------------------------------- +# Combined Pattern Detection +# --------------------------------------------------------------------------- + +# Pattern result structure +PATTERN_RESULT = { + "name": str, # Pattern name (Chinese) + "name_en": str, # Pattern name (English) + "date": str, # Date of the pattern (last candle) + "type": str, # 'bullish' or 'bearish' + "category": str, # 'single', 'dual', or 'triple' +} + + +def detect_patterns(candles: list[Candle]) -> list[dict]: + """Scan candle data and return all recognized patterns. + + Returns a list of dicts with keys: name, name_en, date, type, category. + """ + results: list[dict] = [] + + # Single-candle patterns + for c in candles: + if is_hammer(c): + results.append({ + "name": "锤子线", + "name_en": "Hammer", + "date": c.date, + "type": "bullish", + "category": "single", + }) + if is_shooting_star(c): + results.append({ + "name": "射击之星", + "name_en": "Shooting Star", + "date": c.date, + "type": "bearish", + "category": "single", + }) + if is_doji(c) and not is_dragonfly_doji(c) and not is_gravestone_doji(c): + results.append({ + "name": "十字星", + "name_en": "Doji", + "date": c.date, + "type": "neutral", + "category": "single", + }) + if is_dragonfly_doji(c): + results.append({ + "name": "蜻蜓十字", + "name_en": "Dragonfly Doji", + "date": c.date, + "type": "bullish", + "category": "single", + }) + if is_gravestone_doji(c): + results.append({ + "name": "墓碑十字", + "name_en": "Gravestone Doji", + "date": c.date, + "type": "bearish", + "category": "single", + }) + marubozu_type = is_marubozu(c) + if marubozu_type: + results.append({ + "name": "光头光脚阳线" if marubozu_type == "bullish" else "光头光脚阴线", + "name_en": f"{'Bullish' if marubozu_type == 'bullish' else 'Bearish'} Marubozu", + "date": c.date, + "type": marubozu_type, + "category": "single", + }) + + # Dual-candle patterns + for i in range(1, len(candles)): + prev, curr = candles[i - 1], candles[i] + if is_bullish_engulfing(prev, curr): + results.append({ + "name": "看涨吞没", + "name_en": "Bullish Engulfing", + "date": curr.date, + "type": "bullish", + "category": "dual", + }) + if is_bearish_engulfing(prev, curr): + results.append({ + "name": "看跌吞没", + "name_en": "Bearish Engulfing", + "date": curr.date, + "type": "bearish", + "category": "dual", + }) + + # Triple-candle patterns + for i in range(2, len(candles)): + c1, c2, c3 = candles[i - 2], candles[i - 1], candles[i] + if is_morning_star(c1, c2, c3): + results.append({ + "name": "启明星", + "name_en": "Morning Star", + "date": c3.date, + "type": "bullish", + "category": "triple", + }) + if is_evening_star(c1, c2, c3): + results.append({ + "name": "黄昏星", + "name_en": "Evening Star", + "date": c3.date, + "type": "bearish", + "category": "triple", + }) + if is_three_white_soldiers(c1, c2, c3): + results.append({ + "name": "红三兵", + "name_en": "Three White Soldiers", + "date": c3.date, + "type": "bullish", + "category": "triple", + }) + if is_three_black_crows(c1, c2, c3): + results.append({ + "name": "三只乌鸦", + "name_en": "Three Black Crows", + "date": c3.date, + "type": "bearish", + "category": "triple", + }) + + return results diff --git a/src/stock_datasource/modules/portfolio/router.py b/src/stock_datasource/modules/portfolio/router.py index 2a6f9f10..ce56b0d7 100644 --- a/src/stock_datasource/modules/portfolio/router.py +++ b/src/stock_datasource/modules/portfolio/router.py @@ -72,6 +72,46 @@ class UpdatePositionRequest(BaseModel): notes: str | None = Field(None, description="备注") +class BuyTransactionRequest(BaseModel): + """Request model for buy transaction.""" + + ts_code: str = Field(..., description="股票代码") + quantity: int = Field(..., gt=0, description="买入数量") + price: float = Field(..., gt=0, description="买入价格") + transaction_date: str = Field(..., description="交易日期") + notes: str | None = Field(None, description="备注") + profile_id: str | None = Field(None, description="账户ID") + + +class SellTransactionRequest(BaseModel): + """Request model for sell transaction.""" + + ts_code: str = Field(..., description="股票代码") + quantity: int = Field(..., gt=0, description="卖出数量") + price: float = Field(..., gt=0, description="卖出价格") + transaction_date: str = Field(..., description="交易日期") + notes: str | None = Field(None, description="备注") + profile_id: str | None = Field(None, description="账户ID") + + +class TransactionResponse(BaseModel): + """Response model for transaction.""" + + id: str + user_id: str + ts_code: str + stock_name: str + transaction_type: str + quantity: int + price: float + transaction_date: str + position_id: str + realized_pl: float | None = None + notes: str = "" + profile_id: str = "default" + created_at: str | None = None + + class PortfolioSummary(BaseModel): total_value: float total_cost: float @@ -542,3 +582,298 @@ async def batch_update_prices(current_user: dict = Depends(get_current_user)): "success": False, "error": str(e), } + + +# --------------------------------------------------------------------------- +# Transaction endpoints (buy/sell transaction history) +# --------------------------------------------------------------------------- + + +@router.post("/transactions/buy", response_model=TransactionResponse) +async def buy_transaction( + request: BuyTransactionRequest, current_user: dict = Depends(get_current_user) +): + """Record a buy transaction and update/create position. + + User isolation: Transaction is recorded under the authenticated user's account. + """ + try: + enhanced_service = get_enhanced_portfolio_service() + if not enhanced_service: + raise HTTPException(status_code=503, detail="Enhanced service not available") + + txn = await enhanced_service.record_buy_transaction( + user_id=current_user["id"], + ts_code=request.ts_code, + quantity=request.quantity, + price=request.price, + transaction_date=request.transaction_date, + notes=request.notes, + profile_id=request.profile_id or "default", + ) + + return TransactionResponse( + id=txn.id, + user_id=txn.user_id, + ts_code=txn.ts_code, + stock_name=txn.stock_name, + transaction_type=txn.transaction_type, + quantity=txn.quantity, + price=txn.price, + transaction_date=txn.transaction_date, + position_id=txn.position_id, + realized_pl=txn.realized_pl, + notes=txn.notes, + profile_id=txn.profile_id, + created_at=str(txn.created_at) if txn.created_at else None, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to record buy transaction: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/transactions/sell", response_model=TransactionResponse) +async def sell_transaction( + request: SellTransactionRequest, current_user: dict = Depends(get_current_user) +): + """Record a sell transaction and update position. + + User isolation: Transaction is recorded under the authenticated user's account. + Validates that the user has sufficient shares to sell. + """ + try: + enhanced_service = get_enhanced_portfolio_service() + if not enhanced_service: + raise HTTPException(status_code=503, detail="Enhanced service not available") + + txn = await enhanced_service.record_sell_transaction( + user_id=current_user["id"], + ts_code=request.ts_code, + quantity=request.quantity, + price=request.price, + transaction_date=request.transaction_date, + notes=request.notes, + profile_id=request.profile_id or "default", + ) + + return TransactionResponse( + id=txn.id, + user_id=txn.user_id, + ts_code=txn.ts_code, + stock_name=txn.stock_name, + transaction_type=txn.transaction_type, + quantity=txn.quantity, + price=txn.price, + transaction_date=txn.transaction_date, + position_id=txn.position_id, + realized_pl=txn.realized_pl, + notes=txn.notes, + profile_id=txn.profile_id, + created_at=str(txn.created_at) if txn.created_at else None, + ) + except ValueError as e: + logger.warning(f"Sell validation failed: {e}") + raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to record sell transaction: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/transactions", response_model=list[TransactionResponse]) +async def get_transactions( + ts_code: str | None = Query(None, description="Filter by stock code"), + start_date: str | None = Query(None, description="Start date (YYYY-MM-DD)"), + end_date: str | None = Query(None, description="End date (YYYY-MM-DD)"), + profile_id: str | None = Query(None, description="Filter by profile ID"), + current_user: dict = Depends(get_current_user), +): + """Get transaction history for the authenticated user. + + User isolation: Only returns transactions belonging to the authenticated user. + Supports filtering by stock code and date range. + """ + try: + enhanced_service = get_enhanced_portfolio_service() + if not enhanced_service: + raise HTTPException(status_code=503, detail="Enhanced service not available") + + transactions = await enhanced_service.get_transactions( + user_id=current_user["id"], + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + profile_id=profile_id, + ) + + return [ + TransactionResponse( + id=txn.id, + user_id=txn.user_id, + ts_code=txn.ts_code, + stock_name=txn.stock_name, + transaction_type=txn.transaction_type, + quantity=txn.quantity, + price=txn.price, + transaction_date=txn.transaction_date, + position_id=txn.position_id, + realized_pl=txn.realized_pl, + notes=txn.notes, + profile_id=txn.profile_id, + created_at=str(txn.created_at) if txn.created_at else None, + ) + for txn in transactions + ] + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +class TransactionSignalResponse(BaseModel): + """A buy/sell signal point for K-line chart markers.""" + id: str + ts_code: str + signal_type: str = Field(..., description="'buy' or 'sell'") + source: str = Field(..., description="'user' or 'strategy'") + signal_date: str + price: float + quantity: int | None = None + strategy_name: str | None = None + notes: str | None = None + + +@router.get("/transactions/signals", response_model=list[TransactionSignalResponse]) +async def get_transaction_signals( + ts_code: str = Query(..., description="Stock code (required)"), + start_date: str | None = Query(None, description="Start date (YYYY-MM-DD)"), + end_date: str | None = Query(None, description="End date (YYYY-MM-DD)"), + current_user: dict = Depends(get_current_user), +): + """Get buy/sell signal points for K-line chart markers. + + Combines: + - User transaction signals (actual buy/sell records) + - Technical strategy signals (from indicator analysis) + + Returns a unified list of signals suitable for rendering B/S markers + on K-line charts with different styles per source. + """ + try: + enhanced_service = get_enhanced_portfolio_service() + if not enhanced_service: + raise HTTPException(status_code=503, detail="Enhanced service not available") + + signals: list[TransactionSignalResponse] = [] + + # 1. User transaction signals + transactions = await enhanced_service.get_transactions( + user_id=current_user["id"], + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + ) + + for txn in transactions: + signals.append(TransactionSignalResponse( + id=f"user_{txn.id}", + ts_code=txn.ts_code, + signal_type=txn.transaction_type, + source="user", + signal_date=txn.transaction_date, + price=txn.price, + quantity=txn.quantity, + notes=f"{txn.transaction_type.upper()} {txn.quantity}@{txn.price}", + )) + + # 2. Strategy signals from technical indicators + try: + indicators = await enhanced_service.get_technical_indicators(ts_code, 180) + strategy_signals = _extract_strategy_signals(ts_code, indicators) + signals.extend(strategy_signals) + except Exception as e: + logger.debug(f"Strategy signals not available for {ts_code}: {e}") + + # Sort by signal_date + signals.sort(key=lambda s: s.signal_date) + + return signals + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get transaction signals: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +class KlinePatternResponse(BaseModel): + """A detected candlestick pattern.""" + name: str = Field(..., description="Pattern name (Chinese)") + name_en: str = Field(..., description="Pattern name (English)") + date: str = Field(..., description="Date of the pattern (last candle)") + type: str = Field(..., description="'bullish', 'bearish', or 'neutral'") + category: str = Field(..., description="'single', 'dual', or 'triple'") + + +@router.get("/kline-patterns/{ts_code}", response_model=list[KlinePatternResponse]) +async def get_kline_patterns( + ts_code: str = Path(..., description="Stock code"), + days: int = Query(default=60, description="Number of days to analyze"), + current_user: dict = Depends(get_current_user), +): + """Detect candlestick patterns in K-line data for a stock. + + Fetches OHLC data and runs pattern recognition to identify + single, dual, and triple candlestick patterns. + """ + try: + enhanced_service = get_enhanced_portfolio_service() + if not enhanced_service: + raise HTTPException(status_code=503, detail="Enhanced service not available") + + patterns = await enhanced_service.get_kline_patterns(ts_code, days) + return [ + KlinePatternResponse( + name=p["name"], + name_en=p["name_en"], + date=p["date"], + type=p["type"], + category=p["category"], + ) + for p in patterns + ] + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get kline patterns for {ts_code}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +def _extract_strategy_signals( + ts_code: str, indicators: dict +) -> list[TransactionSignalResponse]: + """Extract buy/sell signals from technical indicator data. + + Converts technical signals (MACD crossover, RSI overbought/oversold, etc.) + into TransactionSignalResponse items with source='strategy'. + """ + signals: list[TransactionSignalResponse] = [] + signals_data = indicators.get("signals", []) + + for idx, sig in enumerate(signals_data): + signal_type = "buy" if sig.get("type") in ("buy", "golden_cross", "oversold") else "sell" + signals.append(TransactionSignalResponse( + id=f"strategy_{idx}", + ts_code=ts_code, + signal_type=signal_type, + source="strategy", + signal_date=sig.get("date", ""), + price=float(sig.get("price", 0)), + strategy_name=sig.get("name", sig.get("type", "technical")), + notes=sig.get("message", ""), + )) + + return signals diff --git a/src/stock_datasource/modules/portfolio/schema.sql b/src/stock_datasource/modules/portfolio/schema.sql index 9eccf486..206b57c2 100644 --- a/src/stock_datasource/modules/portfolio/schema.sql +++ b/src/stock_datasource/modules/portfolio/schema.sql @@ -113,6 +113,26 @@ ORDER BY (user_id, ts_code, alert_type, id) PARTITION BY toYYYYMM(created_at) SETTINGS index_granularity = 8192; +-- User transactions table (buy/sell transaction history) +CREATE TABLE IF NOT EXISTS user_transactions ( + id String, + user_id String DEFAULT 'default_user', + ts_code String, + stock_name String, + transaction_type Enum8('buy' = 1, 'sell' = 2), + quantity UInt32, + price Decimal(10, 3), + transaction_date Date, + position_id String DEFAULT '', + realized_pl Nullable(Decimal(15, 2)), + notes String DEFAULT '', + profile_id String DEFAULT 'default', + created_at DateTime DEFAULT now() +) ENGINE = MergeTree() +ORDER BY (user_id, ts_code, transaction_date, id) +PARTITION BY toYYYYMM(transaction_date) +SETTINGS index_granularity = 8192; + -- Performance optimization views and materialized views -- Daily portfolio summary materialized view diff --git a/src/stock_datasource/modules/system_logs/log_parser.py b/src/stock_datasource/modules/system_logs/log_parser.py index cad6844b..4bf071b6 100644 --- a/src/stock_datasource/modules/system_logs/log_parser.py +++ b/src/stock_datasource/modules/system_logs/log_parser.py @@ -1,8 +1,10 @@ """Log parser for parsing various log formats.""" -import logging import re -from datetime import datetime, timedelta +import logging +from datetime import datetime +from datetime import timedelta +from typing import List, Optional from pathlib import Path logger = logging.getLogger(__name__) @@ -16,50 +18,50 @@ class LogParser: # NEW Loguru format with request_id, user_id & middleware_trace_id: # 2026-01-26 10:30:45.123 | INFO | abc123 | user1 | mwid123 | name:function:line - message re.compile( - r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*-\s*(.*)$" + r'^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*-\s*(.*)$' ), # OLD Loguru format with request_id & user_id (6 groups): # 2026-01-26 10:30:45.123 | INFO | abc123 | user1 | name:function:line - message re.compile( - r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*-\s*(.*)$" + r'^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:\.\d+)?)\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*\|\s*([^|]+?)\s*-\s*(.*)$' ), # OLDER Loguru format: 2026-01-26 10:30:45 | INFO | name:function:line - message re.compile( - r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*-\s*(.*)$" + r'^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})\s+\|\s*(\w+)\s*\|\s*([^|]+?)\s*-\s*(.*)$' ), # Python logging format: 2026-01-26 10:30:45,123 INFO module:message re.compile( - r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:,\d{3})?)\s+(\w+)\s+(\w+?):\s*(.*)$" + r'^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:,\d{3})?)\s+(\w+)\s+(\w+?):\s*(.*)$' ), # Standard format: 2026-01-26 10:30:45 [INFO] module - message re.compile( - r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})\s+\[(\w+)\]\s+(\w+)\s*-\s*(.*)$" + r'^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})\s+\[(\w+)\]\s+(\w+)\s*-\s*(.*)$' ), # Simple format: [INFO] message - re.compile(r"^\[(\w+)\]\s+(.*)$"), + re.compile(r'^\[(\w+)\]\s+(.*)$'), ] # Timestamp formats to try TIMESTAMP_FORMATS = [ "%Y-%m-%d %H:%M:%S,%f", # 2026-01-26 10:30:45,123 - "%Y-%m-%d %H:%M:%S", # 2026-01-26 10:30:45 + "%Y-%m-%d %H:%M:%S", # 2026-01-26 10:30:45 "%Y-%m-%d %H:%M:%S.%f", # 2026-01-26 10:30:45.123 ] ERROR_SIGNATURE_PATTERNS = [ - re.compile(r"(\w+(?:Error|Exception|Timeout|Refused|Unavailable))"), - re.compile(r"\b(Traceback)\b"), + re.compile(r'(\w+(?:Error|Exception|Timeout|Refused|Unavailable))'), + re.compile(r'\b(Traceback)\b'), ] def __init__(self): self.module_mapping = { - "backend.log": "backend", - "worker.log": "worker", - "server.log": "server", - "application.log": "application", + 'backend.log': 'backend', + 'worker.log': 'worker', + 'server.log': 'server', + 'application.log': 'application', } - def parse_line(self, line: str, filename: str = "unknown") -> dict | None: + def parse_line(self, line: str, filename: str = "unknown") -> Optional[dict]: """Parse a single log line. Args: @@ -78,17 +80,17 @@ def parse_line(self, line: str, filename: str = "unknown") -> dict | None: return None return { - "timestamp": datetime.now(), - "level": "INFO", - "module": self._get_module_from_filename(filename), - "message": line.strip(), - "raw_line": line, - "request_id": "-", - "user_id": "-", - "middleware_trace_id": "-", + 'timestamp': datetime.now(), + 'level': 'INFO', + 'module': self._get_module_from_filename(filename), + 'message': line.strip(), + 'raw_line': line, + 'request_id': '-', + 'user_id': '-', + 'middleware_trace_id': '-', } - def _parse_line_strict(self, line: str, filename: str = "unknown") -> dict | None: + def _parse_line_strict(self, line: str, filename: str = "unknown") -> Optional[dict]: """Parse line only when line matches known log patterns.""" if not line or not line.strip(): return None @@ -102,7 +104,11 @@ def _parse_line_strict(self, line: str, filename: str = "unknown") -> dict | Non return None - def parse_file(self, filepath: str, max_lines: int | None = None) -> list[dict]: + def parse_file( + self, + filepath: str, + max_lines: Optional[int] = None + ) -> List[dict]: """Parse entire log file. Args: @@ -112,7 +118,7 @@ def parse_file(self, filepath: str, max_lines: int | None = None) -> list[dict]: Returns: List of parsed log entries """ - entries: list[dict] = [] + entries: List[dict] = [] path = Path(filepath) if not path.exists(): @@ -120,7 +126,7 @@ def parse_file(self, filepath: str, max_lines: int | None = None) -> list[dict]: return entries try: - with open(path, encoding="utf-8", errors="ignore") as f: + with open(path, 'r', encoding='utf-8', errors='ignore') as f: for i, line in enumerate(f): if max_lines and i >= max_lines: break @@ -131,12 +137,8 @@ def parse_file(self, filepath: str, max_lines: int | None = None) -> list[dict]: continue if self._is_continuation_line(line) and entries: - entries[-1]["message"] = ( - f"{entries[-1]['message']}\n{line.rstrip()}" - ) - entries[-1]["raw_line"] = ( - f"{entries[-1]['raw_line']}\n{line.rstrip()}" - ) + entries[-1]['message'] = f"{entries[-1]['message']}\n{line.rstrip()}" + entries[-1]['raw_line'] = f"{entries[-1]['raw_line']}\n{line.rstrip()}" continue fallback_entry = self.parse_line(line, path.name) @@ -163,126 +165,92 @@ def _extract_fields(self, match: re.Match, line: str, filename: str) -> dict: # NEWEST format with request_id, user_id & middleware_trace_id: 7 groups # timestamp, level, request_id, user_id, middleware_trace_id, location, message - if len(groups) == 7 and groups[1].strip().upper() in ( - "INFO", - "WARNING", - "ERROR", - "DEBUG", - "TRACE", - "SUCCESS", - ): - ( - timestamp_str, - level, - request_id, - user_id, - middleware_trace_id, - location, - message, - ) = groups + if len(groups) == 7 and groups[1].strip().upper() in ('INFO', 'WARNING', 'ERROR', 'DEBUG', 'TRACE', 'SUCCESS'): + timestamp_str, level, request_id, user_id, middleware_trace_id, location, message = groups timestamp = self._parse_timestamp(timestamp_str) module = self._extract_module_from_location(location.strip()) return { - "timestamp": timestamp, - "level": level.strip().upper(), - "module": module, - "message": message, - "raw_line": line, - "request_id": request_id.strip(), - "user_id": user_id.strip(), - "middleware_trace_id": middleware_trace_id.strip(), + 'timestamp': timestamp, + 'level': level.strip().upper(), + 'module': module, + 'message': message, + 'raw_line': line, + 'request_id': request_id.strip(), + 'user_id': user_id.strip(), + 'middleware_trace_id': middleware_trace_id.strip(), } # OLD format with request_id & user_id: 6 groups # timestamp, level, request_id, user_id, location, message - if len(groups) == 6 and groups[1].strip().upper() in ( - "INFO", - "WARNING", - "ERROR", - "DEBUG", - "TRACE", - "SUCCESS", - ): + if len(groups) == 6 and groups[1].strip().upper() in ('INFO', 'WARNING', 'ERROR', 'DEBUG', 'TRACE', 'SUCCESS'): timestamp_str, level, request_id, user_id, location, message = groups timestamp = self._parse_timestamp(timestamp_str) module = self._extract_module_from_location(location.strip()) return { - "timestamp": timestamp, - "level": level.strip().upper(), - "module": module, - "message": message, - "raw_line": line, - "request_id": request_id.strip(), - "user_id": user_id.strip(), - "middleware_trace_id": "-", + 'timestamp': timestamp, + 'level': level.strip().upper(), + 'module': module, + 'message': message, + 'raw_line': line, + 'request_id': request_id.strip(), + 'user_id': user_id.strip(), + 'middleware_trace_id': '-', } # OLD Loguru format: 4 groups — timestamp, level, location, message - if len(groups) == 4 and groups[1].strip().upper() in ( - "INFO", - "WARNING", - "ERROR", - "DEBUG", - "TRACE", - "SUCCESS", - ): + if len(groups) == 4 and groups[1].strip().upper() in ('INFO', 'WARNING', 'ERROR', 'DEBUG', 'TRACE', 'SUCCESS'): timestamp_str, level, location, message = groups timestamp = self._parse_timestamp(timestamp_str) module = self._extract_module_from_location(location.strip()) return { - "timestamp": timestamp, - "level": level.strip().upper(), - "module": module, - "message": message, - "raw_line": line, - "request_id": "-", - "user_id": "-", - "middleware_trace_id": "-", + 'timestamp': timestamp, + 'level': level.strip().upper(), + 'module': module, + 'message': message, + 'raw_line': line, + 'request_id': '-', + 'user_id': '-', + 'middleware_trace_id': '-', } # Pattern with timestamp, level, module, message - if len(groups) == 4 and groups[1].upper() in ( - "INFO", - "WARNING", - "ERROR", - "DEBUG", - ): + if len(groups) == 4 and groups[1].upper() in ('INFO', 'WARNING', 'ERROR', 'DEBUG'): timestamp_str, level, module, message = groups timestamp = self._parse_timestamp(timestamp_str) return { - "timestamp": timestamp, - "level": level.upper(), - "module": module, - "message": message, - "raw_line": line, - "request_id": "-", - "user_id": "-", - "middleware_trace_id": "-", + 'timestamp': timestamp, + 'level': level.upper(), + 'module': module, + 'message': message, + 'raw_line': line, + 'request_id': '-', + 'user_id': '-', + 'middleware_trace_id': '-', } # Pattern with level, message only elif len(groups) == 2: level, message = groups return { - "timestamp": datetime.now(), - "level": level.upper(), - "module": self._get_module_from_filename(filename), - "message": message, - "raw_line": line, - "request_id": "-", - "user_id": "-", - "middleware_trace_id": "-", + 'timestamp': datetime.now(), + 'level': level.upper(), + 'module': self._get_module_from_filename(filename), + 'message': message, + 'raw_line': line, + 'request_id': '-', + 'user_id': '-', + 'middleware_trace_id': '-', } return { - "timestamp": datetime.now(), - "level": "INFO", - "module": self._get_module_from_filename(filename), - "message": line, - "raw_line": line, - "request_id": "-", - "user_id": "-", - "middleware_trace_id": "-", + 'timestamp': datetime.now(), + 'level': 'INFO', + 'module': self._get_module_from_filename(filename), + 'message': line, + 'raw_line': line, + 'request_id': '-', + 'user_id': '-', + 'middleware_trace_id': '-', } def _is_continuation_line(self, line: str) -> bool: @@ -294,11 +262,7 @@ def _is_continuation_line(self, line: str) -> bool: if line.startswith(" ") or line.startswith("\t"): return True - return ( - stripped.startswith("File ") - or stripped.startswith("Traceback") - or stripped.startswith("During handling") - ) + return stripped.startswith("File ") or stripped.startswith("Traceback") or stripped.startswith("During handling") def extract_error_signature(self, message: str) -> str: """Extract stable error signature for clustering.""" @@ -324,14 +288,14 @@ def _extract_module_from_location(self, location: str) -> str: Module name """ # Split by colon to get the name part - parts = location.split(":") + parts = location.split(':') if parts: name = parts[0] # Extract the last part (e.g., "stock_datasource.services.task_queue" -> "task_queue") - module_parts = name.split(".") + module_parts = name.split('.') if module_parts: return module_parts[-1] - return "unknown" + return 'unknown' def _remove_ansi_codes(self, text: str) -> str: """Remove ANSI color codes from text. @@ -343,8 +307,8 @@ def _remove_ansi_codes(self, text: str) -> str: Text without ANSI codes """ # ANSI escape sequences: \x1B[ followed by any characters until m - ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + return ansi_escape.sub('', text) def _parse_timestamp(self, timestamp_str: str) -> datetime: """Parse timestamp string. @@ -373,7 +337,7 @@ def _get_module_from_filename(self, filename: str) -> str: Returns: Module name or 'unknown' """ - return self.module_mapping.get(filename, "unknown") + return self.module_mapping.get(filename, 'unknown') class LogFileReader: @@ -392,14 +356,14 @@ def __init__(self, log_dir: str = "logs"): def read_logs( self, log_file: str = None, - start_time: datetime | None = None, - end_time: datetime | None = None, - level: str | None = None, - keyword: str | None = None, - request_id: str | None = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + level: Optional[str] = None, + keyword: Optional[str] = None, + request_id: Optional[str] = None, limit: int = 100, - offset: int = 0, - ) -> list[dict]: + offset: int = 0 + ) -> List[dict]: """Read and filter logs. Args: @@ -435,17 +399,17 @@ def read_logs( matched_logs.append(log) # Sort by timestamp descending - matched_logs.sort(key=lambda x: x["timestamp"], reverse=True) + matched_logs.sort(key=lambda x: x['timestamp'], reverse=True) # Apply pagination - return matched_logs[offset : offset + limit] + return matched_logs[offset:offset + limit] def _resolve_log_files( self, - log_file: str | None = None, - start_time: datetime | None = None, - end_time: datetime | None = None, - ) -> list[Path]: + log_file: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ) -> List[Path]: """Return candidate log files ordered from newest to oldest.""" if log_file: return [self.log_dir / log_file] @@ -475,39 +439,39 @@ def _resolve_log_files( def _matches_filters( self, log: dict, - start_time: datetime | None = None, - end_time: datetime | None = None, - level: str | None = None, - keyword: str | None = None, - request_id: str | None = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + level: Optional[str] = None, + keyword: Optional[str] = None, + request_id: Optional[str] = None, ) -> bool: """Check whether a single log entry matches the given filters.""" - if start_time and log["timestamp"] < start_time: + if start_time and log['timestamp'] < start_time: return False - if end_time and log["timestamp"] > end_time: + if end_time and log['timestamp'] > end_time: return False - if level and log["level"] != level.upper(): + if level and log['level'] != level.upper(): return False - if keyword and keyword.lower() not in log["message"].lower(): + if keyword and keyword.lower() not in log['message'].lower(): return False - if request_id and str(log.get("request_id", "-")).strip() != request_id.strip(): + if request_id and str(log.get('request_id', '-')).strip() != request_id.strip(): return False return True def _apply_filters( self, - logs: list[dict], - start_time: datetime | None = None, - end_time: datetime | None = None, - level: str | None = None, - keyword: str | None = None, - request_id: str | None = None, - ) -> list[dict]: + logs: List[dict], + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + level: Optional[str] = None, + keyword: Optional[str] = None, + request_id: Optional[str] = None, + ) -> List[dict]: """Apply filters to log list. Args: @@ -523,8 +487,7 @@ def _apply_filters( filtered = logs filtered = [ - log - for log in filtered + log for log in filtered if self._matches_filters( log, start_time=start_time, @@ -537,7 +500,7 @@ def _apply_filters( return filtered - def get_log_files(self) -> list[dict]: + def get_log_files(self) -> List[dict]: """Get list of all log files. Returns: @@ -555,23 +518,21 @@ def get_log_files(self) -> list[dict]: # Estimate line count (rough estimate: average 200 chars per line) estimated_lines = stat.st_size // 200 - files.append( - { - "name": filepath.name, - "size": stat.st_size, - "modified_time": datetime.fromtimestamp(stat.st_mtime), - "line_count": estimated_lines, - } - ) + files.append({ + 'name': filepath.name, + 'size': stat.st_size, + 'modified_time': datetime.fromtimestamp(stat.st_mtime), + 'line_count': estimated_lines + }) except Exception as e: logger.error(f"Error getting file info for {filepath}: {e}") # Sort by modified time descending - files.sort(key=lambda x: x["modified_time"], reverse=True) + files.sort(key=lambda x: x['modified_time'], reverse=True) return files - def get_archive_files(self) -> list[dict]: + def get_archive_files(self) -> List[dict]: """Get list of archived log files. Returns: @@ -587,18 +548,16 @@ def get_archive_files(self) -> list[dict]: if filepath.is_file(): try: stat = filepath.stat() - files.append( - { - "name": filepath.name, - "size": stat.st_size, - "modified_time": datetime.fromtimestamp(stat.st_mtime), - "line_count": 0, # Cannot estimate without decompressing - } - ) + files.append({ + 'name': filepath.name, + 'size': stat.st_size, + 'modified_time': datetime.fromtimestamp(stat.st_mtime), + 'line_count': 0 # Cannot estimate without decompressing + }) except Exception as e: logger.error(f"Error getting archive info for {filepath}: {e}") # Sort by modified time descending - files.sort(key=lambda x: x["modified_time"], reverse=True) + files.sort(key=lambda x: x['modified_time'], reverse=True) return files diff --git a/src/stock_datasource/modules/system_logs/schemas.py b/src/stock_datasource/modules/system_logs/schemas.py index 88e1d35f..3de0a39a 100644 --- a/src/stock_datasource/modules/system_logs/schemas.py +++ b/src/stock_datasource/modules/system_logs/schemas.py @@ -1,7 +1,7 @@ """Schemas for system logs module.""" from datetime import datetime - +from typing import Dict, List, Optional from pydantic import BaseModel, Field @@ -13,53 +13,43 @@ class LogEntry(BaseModel): module: str = Field(..., description="Module name (e.g., backend, worker, server)") message: str = Field(..., description="Log message") raw_line: str = Field(..., description="Original raw log line") - request_id: str | None = Field("-", description="Request ID for log correlation") - user_id: str | None = Field("-", description="User ID for log correlation") - middleware_trace_id: str | None = Field( - "-", description="Middleware trace ID for correlation" - ) + request_id: Optional[str] = Field("-", description="Request ID for log correlation") + user_id: Optional[str] = Field("-", description="User ID for log correlation") + middleware_trace_id: Optional[str] = Field("-", description="Middleware trace ID for correlation") class Config: - json_encoders = {datetime: lambda v: v.isoformat()} + json_encoders = { + datetime: lambda v: v.isoformat() + } class LogFilter(BaseModel): """Filter parameters for log queries.""" - level: str | None = Field(None, description="Filter by log level") - start_time: datetime | None = Field(None, description="Start time filter") - end_time: datetime | None = Field(None, description="End time filter") - keyword: str | None = Field( - None, max_length=200, description="Keyword search in message" - ) - request_id: str | None = Field( - None, max_length=32, description="Filter by request ID" - ) - middleware_trace_id: str | None = Field( - None, max_length=32, description="Filter by middleware trace ID" - ) + level: Optional[str] = Field(None, description="Filter by log level") + start_time: Optional[datetime] = Field(None, description="Start time filter") + end_time: Optional[datetime] = Field(None, description="End time filter") + keyword: Optional[str] = Field(None, max_length=200, description="Keyword search in message") + request_id: Optional[str] = Field(None, max_length=32, description="Filter by request ID") + middleware_trace_id: Optional[str] = Field(None, max_length=32, description="Filter by middleware trace ID") page: int = Field(1, ge=1, description="Page number") page_size: int = Field(50, ge=1, le=1000, description="Page size") class Config: - json_encoders = {datetime: lambda v: v.isoformat()} + json_encoders = { + datetime: lambda v: v.isoformat() + } class LogInsightFilter(BaseModel): """Filter params for stats/clusters/timeline insights.""" - level: str | None = Field(None, description="Filter by log level") - start_time: datetime | None = Field(None, description="Start time filter") - end_time: datetime | None = Field(None, description="End time filter") - keyword: str | None = Field( - None, max_length=200, description="Keyword search in message" - ) - request_id: str | None = Field( - None, max_length=32, description="Filter by request ID" - ) - window_hours: int = Field( - 2, ge=1, le=72, description="Fallback time window when start/end not provided" - ) + level: Optional[str] = Field(None, description="Filter by log level") + start_time: Optional[datetime] = Field(None, description="Start time filter") + end_time: Optional[datetime] = Field(None, description="End time filter") + keyword: Optional[str] = Field(None, max_length=200, description="Keyword search in message") + request_id: Optional[str] = Field(None, max_length=32, description="Filter by request ID") + window_hours: int = Field(2, ge=1, le=72, description="Fallback time window when start/end not provided") limit: int = Field(50, ge=1, le=500, description="Result limit") @@ -72,16 +62,18 @@ class LogFileInfo(BaseModel): line_count: int = Field(..., description="Estimated line count") class Config: - json_encoders = {datetime: lambda v: v.isoformat()} + json_encoders = { + datetime: lambda v: v.isoformat() + } class LogRootCause(BaseModel): """Structured root cause candidate.""" title: str - module: str | None = None - function: str | None = None - evidence: list[str] = Field(default_factory=list) + module: Optional[str] = None + function: Optional[str] = None + evidence: List[str] = Field(default_factory=list) confidence: float = Field(0.0, ge=0.0, le=1.0) @@ -89,58 +81,38 @@ class LogFixSuggestion(BaseModel): """Structured fix suggestion.""" title: str - steps: list[str] = Field(default_factory=list) + steps: List[str] = Field(default_factory=list) priority: str = Field("medium", description="low/medium/high") class LogAnalysisRequest(BaseModel): """Request for AI log analysis.""" - log_entries: list[LogEntry] = Field( - default_factory=list, description="Log entries to analyze" - ) - user_query: str | None = Field( - None, max_length=500, description="User's specific question" - ) - context: str | None = Field( - None, max_length=1000, description="Additional diagnosis context" - ) - start_time: datetime | None = Field(None, description="Start time filter") - end_time: datetime | None = Field(None, description="End time filter") - level: str | None = Field(None, description="Optional level filter") - query: str | None = Field( - None, max_length=200, description="Optional keyword query" - ) + log_entries: List[LogEntry] = Field(default_factory=list, description="Log entries to analyze") + user_query: Optional[str] = Field(None, max_length=500, description="User's specific question") + context: Optional[str] = Field(None, max_length=1000, description="Additional diagnosis context") + start_time: Optional[datetime] = Field(None, description="Start time filter") + end_time: Optional[datetime] = Field(None, description="End time filter") + level: Optional[str] = Field(None, description="Optional level filter") + query: Optional[str] = Field(None, max_length=200, description="Optional keyword query") default_window_hours: int = Field(2, ge=1, le=72) - include_code_context: bool = Field( - True, description="Whether diagnosis should include code hints" - ) - max_entries: int = Field( - 50, ge=5, le=500, description="Maximum log entries used for analysis" - ) + include_code_context: bool = Field(True, description="Whether diagnosis should include code hints") + max_entries: int = Field(50, ge=5, le=500, description="Maximum log entries used for analysis") class LogAnalysisResponse(BaseModel): """Response from AI log analysis.""" error_type: str = Field(..., description="Type of error") - possible_causes: list[str] = Field( - default_factory=list, description="Possible causes" - ) - suggested_fixes: list[str] = Field( - default_factory=list, description="Suggested fixes" - ) + possible_causes: List[str] = Field(default_factory=list, description="Possible causes") + suggested_fixes: List[str] = Field(default_factory=list, description="Suggested fixes") confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score") - related_logs: list[str] = Field( - default_factory=list, description="Related log entries" - ) + related_logs: List[str] = Field(default_factory=list, description="Related log entries") summary: str = Field(default="", description="Diagnosis summary") - analysis_source: str = Field( - default="rule_based", description="orchestrator/rule_based/hybrid" - ) - root_causes: list[LogRootCause] = Field(default_factory=list) - recent_operations: list[dict] = Field(default_factory=list) - fix_suggestions: list[LogFixSuggestion] = Field(default_factory=list) + analysis_source: str = Field(default="rule_based", description="orchestrator/rule_based/hybrid") + root_causes: List[LogRootCause] = Field(default_factory=list) + recent_operations: List[Dict] = Field(default_factory=list) + fix_suggestions: List[LogFixSuggestion] = Field(default_factory=list) risk_level: str = Field(default="low", description="low/medium/high/critical") impact_scope: str = Field(default="未识别明显影响范围") diagnosis_time: datetime = Field(default_factory=datetime.now) @@ -165,8 +137,8 @@ class LogStatsResponse(BaseModel): warning: int = 0 info: int = 0 debug: int = 0 - by_level: dict[str, int] = Field(default_factory=dict) - trend: list[LogStatsTrendPoint] = Field(default_factory=list) + by_level: Dict[str, int] = Field(default_factory=dict) + trend: List[LogStatsTrendPoint] = Field(default_factory=list) class ErrorClusterItem(BaseModel): @@ -183,7 +155,7 @@ class ErrorClusterItem(BaseModel): class ErrorClusterResponse(BaseModel): """Error cluster response.""" - clusters: list[ErrorClusterItem] = Field(default_factory=list) + clusters: List[ErrorClusterItem] = Field(default_factory=list) class OperationTimelineItem(BaseModel): @@ -194,22 +166,20 @@ class OperationTimelineItem(BaseModel): level: str module: str summary: str - detail: str | None = None - request_id: str | None = Field( - None, description="Request ID for timeline correlation" - ) + detail: Optional[str] = None + request_id: Optional[str] = Field(None, description="Request ID for timeline correlation") class OperationTimelineResponse(BaseModel): """Recent operation timeline response.""" - items: list[OperationTimelineItem] = Field(default_factory=list) + items: List[OperationTimelineItem] = Field(default_factory=list) class LogListResponse(BaseModel): """Response for log list query.""" - logs: list[LogEntry] = Field(..., description="Log entries") + logs: List[LogEntry] = Field(..., description="Log entries") total: int = Field(..., description="Total matching logs") page: int = Field(..., description="Current page number") page_size: int = Field(..., description="Page size") @@ -218,4 +188,4 @@ class LogListResponse(BaseModel): class ArchiveListResponse(BaseModel): """Response for archive list query.""" - archives: list[LogFileInfo] = Field(..., description="Archive files") + archives: List[LogFileInfo] = Field(..., description="Archive files") diff --git a/src/stock_datasource/modules/system_logs/service.py b/src/stock_datasource/modules/system_logs/service.py index 9282de55..82f67ed1 100644 --- a/src/stock_datasource/modules/system_logs/service.py +++ b/src/stock_datasource/modules/system_logs/service.py @@ -5,10 +5,8 @@ from collections import Counter, defaultdict from datetime import datetime, timedelta from pathlib import Path -from typing import Any +from typing import Any, Dict, List, Optional, Tuple -from .ai_diagnosis_service import get_log_ai_diagnosis_service -from .log_parser import LogFileReader from .schemas import ( ErrorClusterItem, ErrorClusterResponse, @@ -26,6 +24,8 @@ OperationTimelineItem, OperationTimelineResponse, ) +from .ai_diagnosis_service import get_log_ai_diagnosis_service +from .log_parser import LogFileReader logger = logging.getLogger(__name__) @@ -48,7 +48,6 @@ def _ch_client(self): """Lazy-access ClickHouse client.""" try: from stock_datasource.models.database import db_client - return db_client except Exception: return None @@ -88,7 +87,7 @@ def get_logs(self, filters: LogFilter) -> LogListResponse: page_size=filters.page_size, ) - def _get_logs_from_clickhouse(self, filters: LogFilter) -> LogListResponse | None: + def _get_logs_from_clickhouse(self, filters: LogFilter) -> Optional[LogListResponse]: """Query logs from ClickHouse system_structured_logs table. Returns None if ClickHouse is unavailable (caller should fall back). @@ -98,7 +97,7 @@ def _get_logs_from_clickhouse(self, filters: LogFilter) -> LogListResponse | Non return None try: conditions = [] - params: dict[str, Any] = {} + params: Dict[str, Any] = {} if filters.start_time: conditions.append("timestamp >= %(start_time)s") @@ -115,10 +114,7 @@ def _get_logs_from_clickhouse(self, filters: LogFilter) -> LogListResponse | Non if filters.request_id and filters.request_id != "-": conditions.append("request_id = %(request_id)s") params["request_id"] = filters.request_id - if ( - getattr(filters, "middleware_trace_id", None) - and filters.middleware_trace_id != "-" - ): + if getattr(filters, 'middleware_trace_id', None) and filters.middleware_trace_id != "-": conditions.append("middleware_trace_id = %(middleware_trace_id)s") params["middleware_trace_id"] = filters.middleware_trace_id @@ -144,18 +140,16 @@ def _get_logs_from_clickhouse(self, filters: LogFilter) -> LogListResponse | Non log_entries = [] for row in rows: ts, level, req_id, uid, mw_id, module, func, line_no, msg, exc = row - log_entries.append( - LogEntry( - timestamp=ts if isinstance(ts, datetime) else datetime.now(), - level=str(level), - module=str(module), - message=str(msg), - raw_line=f"{ts} | {level} | {req_id} | {uid} | {mw_id} | {module}:{func}:{line_no} - {msg}", - request_id=str(req_id), - user_id=str(uid), - middleware_trace_id=str(mw_id), - ) - ) + log_entries.append(LogEntry( + timestamp=ts if isinstance(ts, datetime) else datetime.now(), + level=str(level), + module=str(module), + message=str(msg), + raw_line=f"{ts} | {level} | {req_id} | {uid} | {mw_id} | {module}:{func}:{line_no} - {msg}", + request_id=str(req_id), + user_id=str(uid), + middleware_trace_id=str(mw_id), + )) return LogListResponse( logs=log_entries, @@ -187,15 +181,15 @@ def get_stats(self, filters: LogInsightFilter) -> LogStatsResponse: offset=0, ) - level_counter = Counter(log.get("level", "INFO").upper() for log in logs) - trend_bucket: dict[datetime, Counter] = defaultdict(Counter) + level_counter = Counter(log.get('level', 'INFO').upper() for log in logs) + trend_bucket: Dict[datetime, Counter] = defaultdict(Counter) for log in logs: - timestamp = log.get("timestamp") + timestamp = log.get('timestamp') if not isinstance(timestamp, datetime): continue bucket = timestamp.replace(minute=0, second=0, microsecond=0) - lvl = str(log.get("level", "INFO")).upper() + lvl = str(log.get('level', 'INFO')).upper() trend_bucket[bucket][lvl] += 1 trend = [] @@ -205,26 +199,24 @@ def get_stats(self, filters: LogInsightFilter) -> LogStatsResponse: LogStatsTrendPoint( timestamp=bucket, total=sum(counter.values()), - error=counter.get("ERROR", 0), - warning=counter.get("WARNING", 0), - info=counter.get("INFO", 0), - debug=counter.get("DEBUG", 0), + error=counter.get('ERROR', 0), + warning=counter.get('WARNING', 0), + info=counter.get('INFO', 0), + debug=counter.get('DEBUG', 0), ) ) return LogStatsResponse( total=len(logs), - error=level_counter.get("ERROR", 0), - warning=level_counter.get("WARNING", 0), - info=level_counter.get("INFO", 0), - debug=level_counter.get("DEBUG", 0), + error=level_counter.get('ERROR', 0), + warning=level_counter.get('WARNING', 0), + info=level_counter.get('INFO', 0), + debug=level_counter.get('DEBUG', 0), by_level=dict(level_counter), trend=trend, ) - def _get_stats_from_clickhouse( - self, filters: LogInsightFilter - ) -> LogStatsResponse | None: + def _get_stats_from_clickhouse(self, filters: LogInsightFilter) -> Optional[LogStatsResponse]: """Query stats from ClickHouse. Returns None on failure.""" ch = self._ch_client if ch is None: @@ -235,7 +227,7 @@ def _get_stats_from_clickhouse( "timestamp >= %(start_time)s", "timestamp <= %(end_time)s", ] - params: dict[str, Any] = { + params: Dict[str, Any] = { "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"), } @@ -266,7 +258,7 @@ def _get_stats_from_clickhouse( f"GROUP BY bucket, level ORDER BY bucket" ) trend_rows = ch.execute(trend_sql, params) - trend_bucket: dict[datetime, Counter] = defaultdict(Counter) + trend_bucket: Dict[datetime, Counter] = defaultdict(Counter) for row in trend_rows: bucket_dt = row[0] if isinstance(row[0], datetime) else start_time lvl = str(row[1]) @@ -276,24 +268,22 @@ def _get_stats_from_clickhouse( trend = [] for bucket in sorted(trend_bucket.keys()): counter = trend_bucket[bucket] - trend.append( - LogStatsTrendPoint( - timestamp=bucket, - total=sum(counter.values()), - error=counter.get("ERROR", 0), - warning=counter.get("WARNING", 0), - info=counter.get("INFO", 0), - debug=counter.get("DEBUG", 0), - ) - ) + trend.append(LogStatsTrendPoint( + timestamp=bucket, + total=sum(counter.values()), + error=counter.get('ERROR', 0), + warning=counter.get('WARNING', 0), + info=counter.get('INFO', 0), + debug=counter.get('DEBUG', 0), + )) total = sum(level_counter.values()) return LogStatsResponse( total=total, - error=level_counter.get("ERROR", 0), - warning=level_counter.get("WARNING", 0), - info=level_counter.get("INFO", 0), - debug=level_counter.get("DEBUG", 0), + error=level_counter.get('ERROR', 0), + warning=level_counter.get('WARNING', 0), + info=level_counter.get('INFO', 0), + debug=level_counter.get('DEBUG', 0), by_level=level_counter, trend=trend, ) @@ -306,9 +296,7 @@ def get_error_clusters(self, filters: LogInsightFilter) -> ErrorClusterResponse: start_time, end_time = self._resolve_time_window(filters) # Try ClickHouse path - ch_result = self._get_error_clusters_from_clickhouse( - filters, start_time, end_time - ) + ch_result = self._get_error_clusters_from_clickhouse(filters, start_time, end_time) if ch_result is not None: return ch_result @@ -324,47 +312,42 @@ def get_error_clusters(self, filters: LogInsightFilter) -> ErrorClusterResponse: offset=0, ) - grouped: dict[tuple[str, str], dict[str, Any]] = {} + grouped: Dict[Tuple[str, str], Dict[str, Any]] = {} for log in logs: - level = str(log.get("level", "")).upper() - if level not in ("ERROR", "WARNING"): + level = str(log.get('level', '')).upper() + if level not in ('ERROR', 'WARNING'): continue if filters.level and filters.level.upper() != level: continue - message = str(log.get("message", "")) + message = str(log.get('message', '')) signature = self.reader.parser.extract_error_signature(message) - module = str(log.get("module", "unknown")) + module = str(log.get('module', 'unknown')) key = (signature, module) if key not in grouped: grouped[key] = { - "signature": signature, - "count": 0, - "level": level, - "module": module, - "latest_time": log.get("timestamp") or datetime.now(), - "sample_message": message[:200], + 'signature': signature, + 'count': 0, + 'level': level, + 'module': module, + 'latest_time': log.get('timestamp') or datetime.now(), + 'sample_message': message[:200], } - grouped[key]["count"] += 1 - if (log.get("timestamp") or datetime.min) > grouped[key]["latest_time"]: - grouped[key]["latest_time"] = log.get("timestamp") - grouped[key]["sample_message"] = message[:200] + grouped[key]['count'] += 1 + if (log.get('timestamp') or datetime.min) > grouped[key]['latest_time']: + grouped[key]['latest_time'] = log.get('timestamp') + grouped[key]['sample_message'] = message[:200] - if level == "ERROR": - grouped[key]["level"] = "ERROR" + if level == 'ERROR': + grouped[key]['level'] = 'ERROR' clusters = [ErrorClusterItem(**item) for item in grouped.values()] - clusters.sort( - key=lambda item: (item.level != "ERROR", -item.count, item.latest_time), - reverse=False, - ) - return ErrorClusterResponse(clusters=clusters[: filters.limit]) + clusters.sort(key=lambda item: (item.level != 'ERROR', -item.count, item.latest_time), reverse=False) + return ErrorClusterResponse(clusters=clusters[:filters.limit]) - def _get_error_clusters_from_clickhouse( - self, filters: LogInsightFilter, start_time: datetime, end_time: datetime - ) -> ErrorClusterResponse | None: + def _get_error_clusters_from_clickhouse(self, filters: LogInsightFilter, start_time: datetime, end_time: datetime) -> Optional[ErrorClusterResponse]: """Query error clusters from ClickHouse. Returns None on failure.""" ch = self._ch_client if ch is None: @@ -381,7 +364,7 @@ def _get_error_clusters_from_clickhouse( "timestamp >= %(start_time)s", "timestamp <= %(end_time)s", ] - params: dict[str, Any] = { + params: Dict[str, Any] = { "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"), } @@ -414,26 +397,20 @@ def _get_error_clusters_from_clickhouse( clusters = [] for row in rows: level, module, signature, count, latest, sample = row - clusters.append( - ErrorClusterItem( - signature=str(signature), - count=int(count), - level=str(level), - module=str(module), - latest_time=latest - if isinstance(latest, datetime) - else datetime.now(), - sample_message=str(sample)[:200], - ) - ) + clusters.append(ErrorClusterItem( + signature=str(signature), + count=int(count), + level=str(level), + module=str(module), + latest_time=latest if isinstance(latest, datetime) else datetime.now(), + sample_message=str(sample)[:200], + )) return ErrorClusterResponse(clusters=clusters) except Exception as e: logger.warning(f"ClickHouse error clusters query failed, falling back: {e}") return None - def get_operation_timeline( - self, filters: LogInsightFilter - ) -> OperationTimelineResponse: + def get_operation_timeline(self, filters: LogInsightFilter) -> OperationTimelineResponse: """Build a mixed timeline from logs and schedule execution history.""" start_time, end_time = self._resolve_time_window(filters) @@ -442,12 +419,10 @@ def get_operation_timeline( if ch_items is not None: # Merge with schedule items - schedule_items = self._build_schedule_timeline_items( - start_time=start_time, end_time=end_time - ) + schedule_items = self._build_schedule_timeline_items(start_time=start_time, end_time=end_time) ch_items.extend(schedule_items) ch_items.sort(key=lambda item: item.timestamp, reverse=True) - return OperationTimelineResponse(items=ch_items[: filters.limit]) + return OperationTimelineResponse(items=ch_items[:filters.limit]) # Fallback: file parsing logs = self.reader.read_logs( @@ -461,33 +436,29 @@ def get_operation_timeline( offset=0, ) - items: list[OperationTimelineItem] = [] + items: List[OperationTimelineItem] = [] for log in logs[: max(filters.limit * 4, 120)]: - message = str(log.get("message", "")).strip() + message = str(log.get('message', '')).strip() # Classify event type: middleware.* → 'middleware', otherwise 'log' - event_type = "middleware" if message.startswith("middleware.") else "log" + event_type = 'middleware' if message.startswith('middleware.') else 'log' items.append( OperationTimelineItem( - timestamp=log.get("timestamp") or datetime.now(), + timestamp=log.get('timestamp') or datetime.now(), event_type=event_type, - level=str(log.get("level", "INFO")).upper(), - module=str(log.get("module", "unknown")), + level=str(log.get('level', 'INFO')).upper(), + module=str(log.get('module', 'unknown')), summary=message[:140], detail=message[:500], - request_id=str(log.get("request_id", "-") or "-"), + request_id=str(log.get('request_id', '-') or '-'), ) ) - schedule_items = self._build_schedule_timeline_items( - start_time=start_time, end_time=end_time - ) + schedule_items = self._build_schedule_timeline_items(start_time=start_time, end_time=end_time) items.extend(schedule_items) items.sort(key=lambda item: item.timestamp, reverse=True) - return OperationTimelineResponse(items=items[: filters.limit]) + return OperationTimelineResponse(items=items[:filters.limit]) - def _get_timeline_from_clickhouse( - self, filters: LogInsightFilter, start_time: datetime, end_time: datetime - ) -> list[OperationTimelineItem] | None: + def _get_timeline_from_clickhouse(self, filters: LogInsightFilter, start_time: datetime, end_time: datetime) -> Optional[List[OperationTimelineItem]]: """Query timeline items from ClickHouse. Returns None on failure.""" ch = self._ch_client if ch is None: @@ -497,7 +468,7 @@ def _get_timeline_from_clickhouse( "timestamp >= %(start_time)s", "timestamp <= %(end_time)s", ] - params: dict[str, Any] = { + params: Dict[str, Any] = { "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"), } @@ -525,26 +496,22 @@ def _get_timeline_from_clickhouse( for row in rows: ts, level, module, msg, request_id = row message = str(msg).strip() - event_type = ( - "middleware" if message.startswith("middleware.") else "log" - ) - items.append( - OperationTimelineItem( - timestamp=ts if isinstance(ts, datetime) else datetime.now(), - event_type=event_type, - level=str(level), - module=str(module), - summary=message[:140], - detail=message[:500], - request_id=str(request_id or "-"), - ) - ) + event_type = 'middleware' if message.startswith('middleware.') else 'log' + items.append(OperationTimelineItem( + timestamp=ts if isinstance(ts, datetime) else datetime.now(), + event_type=event_type, + level=str(level), + module=str(module), + summary=message[:140], + detail=message[:500], + request_id=str(request_id or '-'), + )) return items except Exception as e: logger.warning(f"ClickHouse timeline query failed, falling back: {e}") return None - def get_log_files(self) -> list[LogFileInfo]: + def get_log_files(self) -> List[LogFileInfo]: """Get list of all log files. Returns: @@ -554,19 +521,17 @@ def get_log_files(self) -> list[LogFileInfo]: return [ LogFileInfo( - name=f["name"], - size=f["size"], - modified_time=f["modified_time"], - line_count=f["line_count"], + name=f['name'], + size=f['size'], + modified_time=f['modified_time'], + line_count=f['line_count'] ) for f in files ] - async def analyze_logs( - self, request: LogAnalysisRequest, user_id: str | None = None - ) -> LogAnalysisResponse: + async def analyze_logs(self, request: LogAnalysisRequest, user_id: Optional[str] = None) -> LogAnalysisResponse: """Analyze logs with AI first and fallback to rule-based diagnosis.""" - source_entries: list[LogEntry] = request.log_entries + source_entries: List[LogEntry] = request.log_entries if not source_entries: insight_filters = LogInsightFilter( @@ -616,81 +581,62 @@ async def analyze_logs( logger.error(f"AI diagnosis failed, fallback to rule-based analysis: {e}") return self._analyze_rule_based(source_entries) - def _merge_ai_result( - self, source_entries: list[LogEntry], ai_result: dict[str, Any] - ) -> LogAnalysisResponse: + def _merge_ai_result(self, source_entries: List[LogEntry], ai_result: Dict[str, Any]) -> LogAnalysisResponse: """Normalize AI output into response schema.""" base = self._analyze_rule_based(source_entries) - root_causes_raw = ai_result.get("root_causes", []) + root_causes_raw = ai_result.get('root_causes', []) root_causes = [] for item in root_causes_raw: if isinstance(item, dict): root_causes.append( LogRootCause( - title=str(item.get("title", "未命名根因")), - module=item.get("module"), - function=item.get("function"), - evidence=[str(v) for v in item.get("evidence", [])[:6]], - confidence=float(item.get("confidence", 0.5)), + title=str(item.get('title', '未命名根因')), + module=item.get('module'), + function=item.get('function'), + evidence=[str(v) for v in item.get('evidence', [])[:6]], + confidence=float(item.get('confidence', 0.5)), ) ) - fix_suggestions_raw = ai_result.get("fix_suggestions", []) + fix_suggestions_raw = ai_result.get('fix_suggestions', []) fix_suggestions = [] for item in fix_suggestions_raw: if isinstance(item, dict): - priority = str(item.get("priority", "medium")).lower() - if priority not in ("low", "medium", "high"): - priority = "medium" + priority = str(item.get('priority', 'medium')).lower() + if priority not in ('low', 'medium', 'high'): + priority = 'medium' fix_suggestions.append( LogFixSuggestion( - title=str(item.get("title", "修复建议")), - steps=[str(v) for v in item.get("steps", [])[:10]], + title=str(item.get('title', '修复建议')), + steps=[str(v) for v in item.get('steps', [])[:10]], priority=priority, ) ) - risk_level = str(ai_result.get("risk_level", base.risk_level)).lower() - if risk_level not in ("low", "medium", "high", "critical"): + risk_level = str(ai_result.get('risk_level', base.risk_level)).lower() + if risk_level not in ('low', 'medium', 'high', 'critical'): risk_level = base.risk_level return LogAnalysisResponse( - error_type=str(ai_result.get("error_type", base.error_type)), - possible_causes=[ - str(v) - for v in ai_result.get("possible_causes", base.possible_causes)[:10] - ], - suggested_fixes=[ - str(v) - for v in ai_result.get("suggested_fixes", base.suggested_fixes)[:10] - ], - confidence=float(ai_result.get("confidence", base.confidence)), - related_logs=[ - str(v) for v in ai_result.get("related_logs", base.related_logs)[:10] - ] - or base.related_logs, - summary=str(ai_result.get("summary", base.summary)), - analysis_source="hybrid", + error_type=str(ai_result.get('error_type', base.error_type)), + possible_causes=[str(v) for v in ai_result.get('possible_causes', base.possible_causes)[:10]], + suggested_fixes=[str(v) for v in ai_result.get('suggested_fixes', base.suggested_fixes)[:10]], + confidence=float(ai_result.get('confidence', base.confidence)), + related_logs=[str(v) for v in ai_result.get('related_logs', base.related_logs)[:10]] or base.related_logs, + summary=str(ai_result.get('summary', base.summary)), + analysis_source='hybrid', root_causes=root_causes or base.root_causes, - recent_operations=ai_result.get( - "recent_operations", base.recent_operations - ), + recent_operations=ai_result.get('recent_operations', base.recent_operations), fix_suggestions=fix_suggestions or base.fix_suggestions, risk_level=risk_level, - impact_scope=str(ai_result.get("impact_scope", base.impact_scope)), + impact_scope=str(ai_result.get('impact_scope', base.impact_scope)), ) - def _analyze_rule_based( - self, source_entries: list[LogEntry] - ) -> LogAnalysisResponse: + def _analyze_rule_based(self, source_entries: List[LogEntry]) -> LogAnalysisResponse: """Rule-based diagnosis used as primary fallback.""" - error_logs = [ - log for log in source_entries if str(log.level).upper() == "ERROR" - ] - warning_logs = [ - log for log in source_entries if str(log.level).upper() == "WARNING" - ] + error_logs = [log for log in source_entries if str(log.level).upper() == 'ERROR'] + warning_logs = [log for log in source_entries if str(log.level).upper() == 'WARNING'] if not error_logs and not warning_logs: return LogAnalysisResponse( @@ -757,9 +703,7 @@ def _analyze_rule_based( "修复后观察同类错误频次是否下降", ], confidence=0.62, - related_logs=[ - log.message[:120] for log in (error_logs or warning_logs)[:8] - ], + related_logs=[log.message[:120] for log in (error_logs or warning_logs)[:8]], summary=f"最近窗口发现 {len(error_logs)} 条错误、{len(warning_logs)} 条告警,重点关注 {first_error.module} 模块。", analysis_source="rule_based", root_causes=[root_cause], @@ -798,8 +742,8 @@ def archive_logs(self, retention_days: int = 30) -> dict: archive_path = archive_dir / archive_name # Compress file - with open(filepath, "rb") as f_in: - with gzip.open(archive_path, "wb") as f_out: + with open(filepath, 'rb') as f_in: + with gzip.open(archive_path, 'wb') as f_out: f_out.writelines(f_in) # Delete original file @@ -812,12 +756,12 @@ def archive_logs(self, retention_days: int = 30) -> dict: logger.error(f"Error archiving {filepath}: {e}") return { - "status": "success", - "archived_count": len(archived_files), - "archived_files": archived_files, + 'status': 'success', + 'archived_count': len(archived_files), + 'archived_files': archived_files } - def get_archives(self) -> list[LogFileInfo]: + def get_archives(self) -> List[LogFileInfo]: """Get list of archived log files. Returns: @@ -827,15 +771,19 @@ def get_archives(self) -> list[LogFileInfo]: return [ LogFileInfo( - name=f["name"], - size=f["size"], - modified_time=f["modified_time"], - line_count=f["line_count"], + name=f['name'], + size=f['size'], + modified_time=f['modified_time'], + line_count=f['line_count'] ) for f in archives ] - def export_logs(self, filters: LogFilter, format: str = "csv") -> str: + def export_logs( + self, + filters: LogFilter, + format: str = "csv" + ) -> str: """Export filtered logs to file. Args: @@ -852,7 +800,7 @@ def export_logs(self, filters: LogFilter, format: str = "csv") -> str: end_time=filters.end_time, level=filters.level, keyword=filters.keyword, - limit=100000, # Get all matching logs + limit=100000 # Get all matching logs ) # Create export directory @@ -875,7 +823,7 @@ def export_logs(self, filters: LogFilter, format: str = "csv") -> str: logger.info(f"Exported logs to {filepath}") return str(filepath) - def _export_csv(self, logs: list[dict], filepath: Path): + def _export_csv(self, logs: List[dict], filepath: Path): """Export logs to CSV format. Args: @@ -884,21 +832,19 @@ def _export_csv(self, logs: list[dict], filepath: Path): """ import csv - with open(filepath, "w", newline="", encoding="utf-8") as f: + with open(filepath, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) - writer.writerow(["timestamp", "level", "module", "message"]) + writer.writerow(['timestamp', 'level', 'module', 'message']) for log in logs: - writer.writerow( - [ - log["timestamp"].isoformat(), - log["level"], - log["module"], - log["message"].replace("\n", " "), # Remove newlines - ] - ) - - def _export_json(self, logs: list[dict], filepath: Path): + writer.writerow([ + log['timestamp'].isoformat(), + log['level'], + log['module'], + log['message'].replace('\n', ' ') # Remove newlines + ]) + + def _export_json(self, logs: List[dict], filepath: Path): """Export logs to JSON format. Args: @@ -907,17 +853,13 @@ def _export_json(self, logs: list[dict], filepath: Path): """ import json - with open(filepath, "w", encoding="utf-8") as f: + with open(filepath, 'w', encoding='utf-8') as f: json.dump(logs, f, indent=2, default=str) - def _resolve_time_window( - self, filters: LogInsightFilter - ) -> tuple[datetime, datetime]: + def _resolve_time_window(self, filters: LogInsightFilter) -> Tuple[datetime, datetime]: """Resolve query window with fallback default window hours.""" end_time = filters.end_time or datetime.now() - start_time = filters.start_time or ( - end_time - timedelta(hours=filters.window_hours) - ) + start_time = filters.start_time or (end_time - timedelta(hours=filters.window_hours)) if start_time > end_time: start_time, end_time = end_time, start_time return start_time, end_time @@ -926,22 +868,20 @@ def _build_schedule_timeline_items( self, start_time: datetime, end_time: datetime, - ) -> list[OperationTimelineItem]: + ) -> List[OperationTimelineItem]: """Build timeline events from schedule execution history.""" try: - from stock_datasource.modules.datamanage.schedule_service import ( - schedule_service, - ) + from stock_datasource.modules.datamanage.schedule_service import schedule_service except Exception as e: logger.warning(f"Failed to import schedule_service for timeline: {e}") return [] total_days = max(1, min(30, (end_time.date() - start_time.date()).days + 1)) history = schedule_service.get_history(days=total_days, limit=200) - timeline: list[OperationTimelineItem] = [] + timeline: List[OperationTimelineItem] = [] for record in history: - started_at = record.get("started_at") + started_at = record.get('started_at') if isinstance(started_at, str): try: started_at = datetime.fromisoformat(started_at) @@ -953,13 +893,13 @@ def _build_schedule_timeline_items( if started_at < start_time or started_at > end_time: continue - status = str(record.get("status", "unknown")).lower() - trigger_type = str(record.get("trigger_type", "scheduled")) - level = "INFO" - if status in ("failed", "interrupted"): - level = "ERROR" - elif status in ("skipped", "stopping"): - level = "WARNING" + status = str(record.get('status', 'unknown')).lower() + trigger_type = str(record.get('trigger_type', 'scheduled')) + level = 'INFO' + if status in ('failed', 'interrupted'): + level = 'ERROR' + elif status in ('skipped', 'stopping'): + level = 'WARNING' summary = f"调度任务 {status}({trigger_type})" detail = ( @@ -970,9 +910,9 @@ def _build_schedule_timeline_items( timeline.append( OperationTimelineItem( timestamp=started_at, - event_type="schedule", + event_type='schedule', level=level, - module="scheduler", + module='scheduler', summary=summary, detail=detail, ) @@ -992,20 +932,20 @@ def _extract_error_type(self, error_message: str) -> str: # Simple heuristic: extract first line or common patterns message_lower = error_message.lower() - if "connection" in message_lower or "timeout" in message_lower: + if 'connection' in message_lower or 'timeout' in message_lower: return "ConnectionError" - elif "permission" in message_lower or "access" in message_lower: + elif 'permission' in message_lower or 'access' in message_lower: return "AccessError" - elif "not found" in message_lower or "missing" in message_lower: + elif 'not found' in message_lower or 'missing' in message_lower: return "NotFoundError" - elif "value" in message_lower or "type" in message_lower: + elif 'value' in message_lower or 'type' in message_lower: return "ValueError" else: return "GeneralError" # Global log service instance -_log_service: LogService | None = None +_log_service: Optional[LogService] = None def get_log_service(log_dir: str = "logs") -> LogService: diff --git a/src/stock_datasource/services/backend_container.py b/src/stock_datasource/services/backend_container.py index 093fbbdb..4600ab81 100644 --- a/src/stock_datasource/services/backend_container.py +++ b/src/stock_datasource/services/backend_container.py @@ -64,9 +64,7 @@ def _handle_signal(signum: int, _frame) -> None: signal.signal(signal.SIGINT, _handle_signal) for command in _build_commands(): - processes.append( - subprocess.Popen(command, cwd=str(Path(__file__).resolve().parents[3])) - ) + processes.append(subprocess.Popen(command, cwd=str(Path(__file__).resolve().parents[3]))) exit_code = 0 try: @@ -95,4 +93,4 @@ def _handle_signal(signum: int, _frame) -> None: if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) \ No newline at end of file diff --git a/src/stock_datasource/services/http_server.py b/src/stock_datasource/services/http_server.py index 491f2ba0..fe61ee86 100644 --- a/src/stock_datasource/services/http_server.py +++ b/src/stock_datasource/services/http_server.py @@ -1,17 +1,18 @@ """HTTP server for stock data service.""" +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +import logging import importlib import inspect -import logging -import multiprocessing -import os -from contextlib import asynccontextmanager from pathlib import Path +import os +import multiprocessing +import signal # Load environment variables at module import from dotenv import load_dotenv -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware # 优先从项目根目录 .env 加载,避免因启动目录不同导致读取失败 _PROJECT_ROOT = Path(__file__).resolve().parents[3] @@ -22,14 +23,13 @@ if not os.getenv("TUSHARE_TOKEN") and _DOTENV_PATH.exists(): load_dotenv(dotenv_path=_DOTENV_PATH, override=True) -from stock_datasource.core.base_service import BaseService from stock_datasource.core.service_generator import ServiceGenerator +from stock_datasource.core.base_service import BaseService logger = logging.getLogger(__name__) # Ensure unified logging is initialized on first import from stock_datasource.utils.logger import setup_logging as _setup_logging - _setup_logging() # Global cache for services @@ -50,14 +50,14 @@ def _is_local_dev() -> bool: True if not in Docker container, False otherwise """ # Check if running inside Docker - if os.path.exists("/.dockerenv"): + if os.path.exists('/.dockerenv'): return False # Check for container markers - if os.path.exists("/proc/1/cgroup"): - with open("/proc/1/cgroup") as f: + if os.path.exists('/proc/1/cgroup'): + with open('/proc/1/cgroup', 'r') as f: cgroup_content = f.read() - if "docker" in cgroup_content or "kubepods" in cgroup_content: + if 'docker' in cgroup_content or 'kubepods' in cgroup_content: return False # If not in container, assume local dev @@ -70,11 +70,11 @@ def _start_background_workers(num_workers: int = 10): Args: num_workers: Number of worker processes to start """ + from stock_datasource.services.task_worker import run_worker + import threading import time - from stock_datasource.services.task_worker import run_worker - global _worker_processes, _worker_meta, _supervisor_thread, _shutdown_event logger.info(f"Starting {num_workers} background workers (local dev mode)") @@ -91,7 +91,9 @@ def _start_background_workers(num_workers: int = 10): def _spawn_worker(worker_id: int) -> multiprocessing.Process: worker = multiprocessing.Process( - target=run_worker, args=(worker_id,), daemon=False + target=run_worker, + args=(worker_id,), + daemon=False ) worker.start() return worker @@ -121,9 +123,7 @@ def _supervise(): continue if not restart_enabled: - logger.warning( - f"Worker {worker_id} exited (code={exit_code}), restart disabled" - ) + logger.warning(f"Worker {worker_id} exited (code={exit_code}), restart disabled") continue if meta["restarts"] >= max_restarts: @@ -133,9 +133,7 @@ def _supervise(): continue meta["restarts"] += 1 - delay = min( - backoff_seconds * (2 ** (meta["restarts"] - 1)), backoff_max - ) + delay = min(backoff_seconds * (2 ** (meta["restarts"] - 1)), backoff_max) logger.warning( f"Worker {worker_id} exited (code={exit_code}), restarting in {delay:.1f}s " f"(attempt {meta['restarts']}/{max_restarts})" @@ -175,9 +173,7 @@ def _stop_background_workers(): worker.join(timeout=5) if worker.is_alive(): - logger.warning( - f"Worker (PID {worker.pid}) did not terminate, killing..." - ) + logger.warning(f"Worker (PID {worker.pid}) did not terminate, killing...") worker.kill() worker.join(timeout=5) @@ -200,7 +196,6 @@ async def lifespan(app: FastAPI): # Proxy should only be used in data extraction contexts via proxy_context() try: from stock_datasource.core.proxy import is_proxy_enabled - if is_proxy_enabled(): logger.info("Proxy configured (will be used only for data extraction)") else: @@ -217,9 +212,7 @@ async def lifespan(app: FastAPI): # Only import seed file when whitelist is enabled (avoids surprises for default open registration) if bool(getattr(settings, "AUTH_EMAIL_WHITELIST_ENABLED", False)): - email_file = Path( - getattr(settings, "AUTH_EMAIL_WHITELIST_FILE", "data/email.txt") - ) + email_file = Path(getattr(settings, "AUTH_EMAIL_WHITELIST_FILE", "data/email.txt")) if not email_file.is_absolute(): # Resolve relative path from current working directory (docker: /app) email_file = Path.cwd() / email_file @@ -232,12 +225,8 @@ async def lifespan(app: FastAPI): seed_file = next((p for p in fallback_candidates if p.exists()), None) if seed_file: - imported, skipped = auth_service.import_whitelist_from_file( - str(seed_file) - ) - logger.info( - f"Email whitelist imported: {imported} new, {skipped} existing" - ) + imported, skipped = auth_service.import_whitelist_from_file(str(seed_file)) + logger.info(f"Email whitelist imported: {imported} new, {skipped} existing") else: logger.warning( "Email whitelist enabled but no whitelist file found. " @@ -245,14 +234,10 @@ async def lifespan(app: FastAPI): ) except Exception as e: logger.warning(f"Auth initialization failed: {e}") - + # Initialize portfolio tables try: - from stock_datasource.modules.portfolio.init import ( - ensure_portfolio_tables, - ensure_profile_id_column, - ) - + from stock_datasource.modules.portfolio.init import ensure_portfolio_tables, ensure_profile_id_column ensure_portfolio_tables() ensure_profile_id_column() except Exception as e: @@ -261,35 +246,27 @@ async def lifespan(app: FastAPI): # Initialize profile tables try: from stock_datasource.modules.profile.service import get_profile_service - get_profile_service().ensure_table() except Exception as e: logger.warning(f"Profile table initialization failed: {e}") # Initialize financial analysis tables try: - from stock_datasource.modules.financial_analysis.tables import ( - ensure_financial_analysis_tables, - ) - + from stock_datasource.modules.financial_analysis.tables import ensure_financial_analysis_tables ensure_financial_analysis_tables() except Exception as e: logger.warning(f"Financial analysis table initialization failed: {e}") # Initialize Open API Gateway tables try: - from stock_datasource.modules.open_api.service import ( - _ensure_tables as _ensure_open_api_tables, - ) - + from stock_datasource.modules.open_api.service import _ensure_tables as _ensure_open_api_tables _ensure_open_api_tables() except Exception as e: logger.warning(f"Open API Gateway table initialization failed: {e}") - + # Initialize plugin manager try: from stock_datasource.core.plugin_manager import plugin_manager - plugin_manager.discover_plugins() logger.info(f"Discovered {len(plugin_manager.list_plugins())} plugins") except Exception as e: @@ -299,10 +276,10 @@ async def lifespan(app: FastAPI): # Keep it lightweight and synchronous to avoid long lock contention at runtime. try: from stock_datasource.config.settings import settings - from stock_datasource.core.plugin_manager import plugin_manager from stock_datasource.models.database import db_client from stock_datasource.models.schemas import PREDEFINED_SCHEMAS - from stock_datasource.utils.schema_manager import dict_to_schema, schema_manager + from stock_datasource.utils.schema_manager import schema_manager, dict_to_schema + from stock_datasource.core.plugin_manager import plugin_manager # Ensure database exists db_client.create_database(settings.CLICKHOUSE_DATABASE) @@ -312,22 +289,20 @@ async def lifespan(app: FastAPI): try: db_client.create_table(schema_manager._build_create_table_sql(schema)) except Exception as inner_e: - logger.warning( - f"Failed to ensure predefined table {schema.table_name}: {inner_e}" - ) + logger.warning(f"Failed to ensure predefined table {schema.table_name}: {inner_e}") # Create only configured plugin tables required by frontend pages. # You can override by env: REQUIRED_PLUGIN_TABLES=tushare_daily_basic,tushare_daily,... default_required_plugins = [ - "tushare_ths_daily", # 同花顺行情数据 - "tushare_ths_index", # 同花顺指数 - "tushare_idx_factor_pro", # 指数因子 - "tushare_index_basic", # 指数基础信息 - "tushare_etf_fund_daily", # ETF日线数据 - "tushare_etf_basic", # ETF基础信息 - "tushare_cyq_chips", # 筹码分布数据 - "tushare_stk_surv", # 机构调研数据 - "tushare_report_rc", # 研报覆盖数据 + "tushare_ths_daily", # 同花顺行情数据 + "tushare_ths_index", # 同花顺指数 + "tushare_idx_factor_pro", # 指数因子 + "tushare_index_basic", # 指数基础信息 + "tushare_etf_fund_daily", # ETF日线数据 + "tushare_etf_basic", # ETF基础信息 + "tushare_cyq_chips", # 筹码分布数据 + "tushare_stk_surv", # 机构调研数据 + "tushare_report_rc", # 研报覆盖数据 ] required_plugins_env = os.getenv("REQUIRED_PLUGIN_TABLES", "") required_plugins = [ @@ -349,9 +324,7 @@ async def lifespan(app: FastAPI): schema = dict_to_schema(schema_dict) db_client.create_table(schema_manager._build_create_table_sql(schema)) except Exception as inner_e: - logger.warning( - f"Failed to ensure table for plugin {plugin_name}: {inner_e}" - ) + logger.warning(f"Failed to ensure table for plugin {plugin_name}: {inner_e}") logger.info("ClickHouse table initialization completed") except Exception as e: @@ -359,8 +332,8 @@ async def lifespan(app: FastAPI): # Register ALL plugin schemas into db_client for auto-create on UNKNOWN_TABLE try: - from stock_datasource.core.plugin_manager import plugin_manager from stock_datasource.models.database import db_client + from stock_datasource.core.plugin_manager import plugin_manager registered_count = 0 for plugin_name in plugin_manager.list_plugins(): @@ -370,33 +343,25 @@ async def lifespan(app: FastAPI): continue schema_dict = plugin.get_schema() if schema_dict and schema_dict.get("table_name"): - db_client.register_table_schema( - schema_dict["table_name"], schema_dict - ) + db_client.register_table_schema(schema_dict["table_name"], schema_dict) registered_count += 1 except Exception: pass - logger.info( - f"Registered {registered_count} plugin schemas for auto-create on UNKNOWN_TABLE" - ) + logger.info(f"Registered {registered_count} plugin schemas for auto-create on UNKNOWN_TABLE") except Exception as e: logger.warning(f"Plugin schema registration failed: {e}") # Run ClickHouse migrations (incremental DDL tracked in _migrations table) try: from stock_datasource.utils.db_migrations import run_pending_migrations - run_pending_migrations() except Exception as e: logger.warning(f"ClickHouse migrations failed: {e}") # Start sync task manager(延迟启动,避免与初始化建表并发造成断连) try: - import threading - import time - from stock_datasource.modules.datamanage.service import sync_task_manager - + import threading, time def _delayed_start(): try: time.sleep(8) @@ -404,18 +369,14 @@ def _delayed_start(): logger.info("SyncTaskManager started (delayed)") except Exception as inner_e: logger.warning(f"SyncTaskManager delayed start failed: {inner_e}") - threading.Thread(target=_delayed_start, daemon=True).start() except Exception as e: logger.warning(f"SyncTaskManager start failed: {e}") # Start UnifiedScheduler (delayed to ensure SyncTaskManager is ready) try: - import threading - import time as _time - from stock_datasource.tasks.unified_scheduler import get_unified_scheduler - + import threading, time as _time def _delayed_scheduler_start(): try: _time.sleep(15) @@ -425,24 +386,18 @@ def _delayed_scheduler_start(): # Register realtime minute collection/sync jobs try: - from stock_datasource.modules.realtime_minute.scheduler import ( - register_realtime_jobs, - ) - + from stock_datasource.modules.realtime_minute.scheduler import register_realtime_jobs aps = scheduler.get_apscheduler() if aps is not None: register_realtime_jobs(aps) logger.info("Realtime minute jobs registered") else: - logger.warning( - "Realtime minute jobs registration skipped: scheduler unavailable" - ) + logger.warning("Realtime minute jobs registration skipped: scheduler unavailable") except Exception as rt_e: logger.warning(f"Realtime minute jobs registration failed: {rt_e}") except Exception as inner_e: logger.warning(f"UnifiedScheduler delayed start failed: {inner_e}") - threading.Thread(target=_delayed_scheduler_start, daemon=True).start() except Exception as e: logger.warning(f"UnifiedScheduler start failed: {e}") @@ -452,16 +407,13 @@ def _delayed_scheduler_start(): if _is_local_dev(): _start_background_workers(num_workers=10) else: - logger.info( - "Docker environment detected - workers run in separate container" - ) + logger.info("Docker environment detected - workers run in separate container") except Exception as e: logger.warning(f"Failed to start background workers: {e}") # Warm up Langfuse in a thread (avoid blocking event loop on first LLM call) try: import asyncio as _asyncio - from stock_datasource.llm.client import get_langfuse _asyncio.create_task(_asyncio.to_thread(get_langfuse)) @@ -469,16 +421,15 @@ def _delayed_scheduler_start(): logger.debug(f"Langfuse warmup skipped: {e}") logger.info("Application initialization completed") - + yield # Application runs here - + # Shutdown logger.info("Shutting down application...") # Stop UnifiedScheduler try: from stock_datasource.tasks.unified_scheduler import get_unified_scheduler - scheduler = get_unified_scheduler() scheduler.stop() logger.info("UnifiedScheduler stopped") @@ -495,7 +446,6 @@ def _delayed_scheduler_start(): # Stop sync task manager try: from stock_datasource.modules.datamanage.service import sync_task_manager - sync_task_manager.stop() logger.info("SyncTaskManager stopped") except Exception as e: @@ -504,7 +454,6 @@ def _delayed_scheduler_start(): # Flush Langfuse traces try: from stock_datasource.llm.client import flush_langfuse - flush_langfuse() logger.info("Langfuse traces flushed") except Exception as e: @@ -519,7 +468,7 @@ def create_app() -> FastAPI: version="2.0.0", lifespan=lifespan, ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -528,10 +477,10 @@ def create_app() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) - + # Register plugin service routes _register_services(app) - + # Register module routes (8 business modules) _register_module_routes(app) @@ -540,78 +489,61 @@ def create_app() -> FastAPI: # Register strategy routes _register_strategy_routes(app) - + # Register top list routes _register_toplist_routes(app) - + # Register workflow routes _register_workflow_routes(app) - + # Register cache routes _register_cache_routes(app) - + # Register Open API Gateway routes _register_open_api_routes(app) - + # Health check endpoint with cache stats @app.get("/health") async def health_check(): """Health check endpoint with service status.""" response = {"status": "ok"} - + # Check cache service try: from stock_datasource.services.cache_service import get_cache_service - cache_service = get_cache_service() cache_stats = cache_service.get_stats() response["cache"] = cache_stats except Exception as e: response["cache"] = {"available": False, "error": str(e)} - + # Check ClickHouse - use global client to avoid reconnection overhead try: from stock_datasource.models.database import db_client - db_client.execute("SELECT 1") response["clickhouse"] = "connected" except Exception as e: - response["clickhouse"] = f"error: {e!s}" - + response["clickhouse"] = f"error: {str(e)}" + return response - + # Root endpoint @app.get("/") async def root(): return { "name": "AI Stock Platform", "version": "2.0.0", - "modules": [ - "chat", - "market", - "screener", - "report", - "memory", - "datamanage", - "portfolio", - "backtest", - "toplist", - "system_logs", - ], + "modules": ["chat", "market", "screener", "report", "memory", "datamanage", "portfolio", "backtest", "toplist", "system_logs"] } - + # Custom access log middleware with request tracing @app.middleware("http") async def log_requests(request, call_next): """Log HTTP requests with request ID tracing.""" from datetime import datetime - from stock_datasource.utils.logger import logger as loguru_logger from stock_datasource.utils.request_context import ( - generate_request_id, - request_id_var, - reset_request_context, - user_id_var, + generate_request_id, request_id_var, user_id_var, reset_request_context, ) # Generate or accept request ID @@ -626,7 +558,6 @@ async def log_requests(request, call_next): if auth_header.startswith("Bearer "): token = auth_header[7:] from stock_datasource.modules.auth.service import get_auth_service - auth_svc = get_auth_service() payload = auth_svc.decode_token(token) if payload and payload.get("sub"): @@ -653,7 +584,7 @@ async def log_requests(request, call_next): return response finally: reset_request_context() - + return app @@ -676,10 +607,7 @@ def _register_module_routes(app: FastAPI) -> None: def _register_system_logs_routes(app: FastAPI) -> None: """Register system logs routes.""" try: - from stock_datasource.modules.system_logs.router import ( - router as system_logs_router, - ) - + from stock_datasource.modules.system_logs.router import router as system_logs_router app.include_router(system_logs_router) logger.info("Registered system logs routes") except Exception as e: @@ -690,7 +618,6 @@ def _register_strategy_routes(app: FastAPI) -> None: """Register strategy management routes.""" try: from stock_datasource.api.strategy_routes import router as strategy_router - app.include_router(strategy_router) logger.info("Registered strategy routes") except Exception as e: @@ -701,7 +628,6 @@ def _register_toplist_routes(app: FastAPI) -> None: """Register top list (龙虎榜) routes.""" try: from stock_datasource.api.toplist_routes import router as toplist_router - app.include_router(toplist_router) logger.info("Registered top list routes") except Exception as e: @@ -712,7 +638,6 @@ def _register_workflow_routes(app: FastAPI) -> None: """Register workflow management routes.""" try: from stock_datasource.api.workflow_routes import router as workflow_router - app.include_router(workflow_router) logger.info("Registered workflow routes") except Exception as e: @@ -723,7 +648,6 @@ def _register_cache_routes(app: FastAPI) -> None: """Register cache management routes.""" try: from stock_datasource.api.cache_routes import router as cache_router - app.include_router(cache_router) logger.info("Registered cache routes") except Exception as e: @@ -738,7 +662,6 @@ def _register_open_api_routes(app: FastAPI) -> None: """ try: from stock_datasource.modules.open_api.router import router as open_api_router - app.include_router( open_api_router, prefix="/api/open", @@ -750,7 +673,6 @@ def _register_open_api_routes(app: FastAPI) -> None: try: from stock_datasource.modules.open_api.admin_router import admin_router - app.include_router( admin_router, prefix="/api/open-api-admin", @@ -775,58 +697,57 @@ def _get_or_create_service(service_class, service_name: str): def _discover_services() -> list[tuple[str, type]]: """Dynamically discover all service classes from plugins directory. - + Returns: List of (service_name, service_class) tuples """ services = [] plugins_dir = Path(__file__).parent.parent / "plugins" - + if not plugins_dir.exists(): logger.warning(f"Plugins directory not found: {plugins_dir}") return services - + # Iterate through each plugin directory for plugin_dir in plugins_dir.iterdir(): if not plugin_dir.is_dir() or plugin_dir.name.startswith("_"): continue - + service_module_path = plugin_dir / "service.py" if not service_module_path.exists(): continue - + try: # Dynamically import the service module module_name = f"stock_datasource.plugins.{plugin_dir.name}.service" module = importlib.import_module(module_name) - + # Find all BaseService subclasses in the module for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and issubclass(obj, BaseService) - and obj is not BaseService - and obj.__module__ == module_name - ): + if (inspect.isclass(obj) and + issubclass(obj, BaseService) and + obj is not BaseService and + obj.__module__ == module_name): + # Use plugin directory name as service prefix service_name = plugin_dir.name services.append((service_name, obj)) logger.info(f"Discovered service: {service_name} -> {obj.__name__}") - + except Exception as e: logger.warning(f"Failed to discover services in {plugin_dir.name}: {e}") - + return services def _register_services(app: FastAPI) -> None: """Register all discovered service routes dynamically.""" service_configs = _discover_services() - + if not service_configs: logger.warning("No services discovered") return - + for prefix, service_class in service_configs: try: # Create service instance @@ -834,7 +755,7 @@ def _register_services(app: FastAPI) -> None: if service is None: logger.warning(f"Skipping service registration: {prefix}") continue - + generator = ServiceGenerator(service) router = generator.generate_http_routes() app.include_router( @@ -853,7 +774,7 @@ def _register_services(app: FastAPI) -> None: if __name__ == "__main__": import uvicorn - + uvicorn.run( app, host="0.0.0.0", diff --git a/src/stock_datasource/utils/log_sink_clickhouse.py b/src/stock_datasource/utils/log_sink_clickhouse.py index da8480e2..0d3b15c5 100644 --- a/src/stock_datasource/utils/log_sink_clickhouse.py +++ b/src/stock_datasource/utils/log_sink_clickhouse.py @@ -10,6 +10,7 @@ import time from datetime import datetime from pathlib import Path +from typing import Dict, List, Optional from stock_datasource.config.settings import settings @@ -20,7 +21,6 @@ def _get_db_client(): """Lazy import to avoid circular imports at module load time.""" try: from stock_datasource.models.database import db_client - return db_client except Exception: return None @@ -39,13 +39,12 @@ def _transform_record(record: dict) -> dict: "line": int(record.get("line", 0)), "message": record.get("message", ""), "exception": record.get("exception") or None, - "middleware_trace_id": record.get("middleware_trace_id") - or extra.get("middleware_trace_id", "-"), + "middleware_trace_id": record.get("middleware_trace_id") or extra.get("middleware_trace_id", "-"), "extra": json.dumps(extra, ensure_ascii=False, default=str), } -def _flush_batch(batch: list[dict]) -> None: +def _flush_batch(batch: List[Dict]) -> None: """Insert a batch of records into ClickHouse.""" if not batch: return @@ -54,7 +53,6 @@ def _flush_batch(batch: list[dict]) -> None: return try: import pandas as pd - df = pd.DataFrame(batch) db.insert_dataframe("system_structured_logs", df) except Exception: @@ -67,12 +65,12 @@ def _import_file(filepath: Path) -> bool: Returns True if file was successfully imported (and deleted), False otherwise. """ - batch: list[dict] = [] + batch: List[Dict] = [] batch_size = settings.LOG_CH_SINK_BATCH_SIZE imported = 0 try: - with open(filepath, encoding="utf-8", errors="ignore") as f: + with open(filepath, "r", encoding="utf-8", errors="ignore") as f: for line in f: line = line.strip() if not line: @@ -104,7 +102,7 @@ def _import_file(filepath: Path) -> bool: return False -def _rotate_active_jsonl(logs_dir: Path) -> Path | None: +def _rotate_active_jsonl(logs_dir: Path) -> Optional[Path]: """Rename the active JSONL file into an importable snapshot. The logger writes one line at a time and reopens the file per write, so an @@ -154,7 +152,6 @@ def start_ch_sink_watcher(logs_dir: Path, interval: float = 30.0) -> threading.T Returns: The started daemon thread. """ - def _watcher(): while True: try: diff --git a/tests/test_kline_patterns.py b/tests/test_kline_patterns.py new file mode 100644 index 00000000..06397ef5 --- /dev/null +++ b/tests/test_kline_patterns.py @@ -0,0 +1,612 @@ +"""Tests for K-line candlestick pattern recognition. + +TDD Cycle 3.1: Single-Candle Patterns +- Hammer (锤子线) +- Hanging Man (上吊线) +- Inverted Hammer (倒锤子线) +- Shooting Star (射击之星) +- Doji (十字星) +- Long-Legged Doji (长腿十字) +- Dragonfly Doji (蜻蜓十字) +- Gravestone Doji (墓碑十字) +- Marubozu (光头光脚) + +TDD Cycle 3.2: Dual-Candle Patterns +- Bullish Engulfing (看涨吞没) +- Bearish Engulfing (看跌吞没) +- Tweezer Top/Bottom (镊子顶/底) + +TDD Cycle 3.3: Triple-Candle Patterns +- Morning Star (启明星) +- Evening Star (黄昏星) +- Three White Soldiers (红三兵) +- Three Black Crows (三只乌鸦) + +TDD Cycle 3.4: Combined Pattern Detection +- detect_patterns() returns all patterns found in given data + +TDD Cycle 3.5: API Endpoint +- GET /api/portfolio/kline-patterns/{ts_code} returns patterns +""" + +import pytest +from dataclasses import dataclass + + +# --------------------------------------------------------------------------- +# Data model for a single candle +# --------------------------------------------------------------------------- + + +@dataclass +class Candle: + """Single OHLC candle data.""" + date: str + open: float + close: float + high: float + low: float + volume: float = 0 + + @property + def body(self) -> float: + """Absolute body size.""" + return abs(self.close - self.open) + + @property + def upper_shadow(self) -> float: + """Upper shadow length.""" + return self.high - max(self.open, self.close) + + @property + def lower_shadow(self) -> float: + """Lower shadow length.""" + return min(self.open, self.close) - self.low + + @property + def range(self) -> float: + """Full range (high - low).""" + return self.high - self.low + + @property + def is_bullish(self) -> bool: + return self.close > self.open + + @property + def is_bearish(self) -> bool: + return self.close < self.open + + +# --------------------------------------------------------------------------- +# TDD Cycle 3.1: Single-Candle Patterns +# --------------------------------------------------------------------------- + + +class TestCandleDataModel: + """Test the Candle dataclass and its computed properties.""" + + def test_candle_body_calculation(self): + c = Candle(date="2026-04-20", open=10.0, close=12.0, high=13.0, low=9.0) + assert c.body == 2.0 + assert c.is_bullish + + def test_candle_bearish_body(self): + c = Candle(date="2026-04-20", open=12.0, close=10.0, high=13.0, low=9.0) + assert c.body == 2.0 + assert c.is_bearish + + def test_candle_shadows(self): + c = Candle(date="2026-04-20", open=10.0, close=12.0, high=14.0, low=8.0) + assert c.upper_shadow == 2.0 # 14 - 12 + assert c.lower_shadow == 2.0 # 10 - 8 + assert c.range == 6.0 # 14 - 8 + + +class TestHammerPattern: + """Hammer: small body at top, long lower shadow (>= 2x body), little/no upper shadow. + + Bullish reversal pattern appearing at bottom of downtrend. + """ + + def test_hammer_detected(self): + """A candle with small body at top, long lower shadow, no upper shadow is a hammer.""" + from stock_datasource.modules.portfolio.kline_patterns import is_hammer + + # Hammer: body at top, long lower shadow, tiny upper shadow + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=10.6, low=8.0) + assert is_hammer(c) is True + + def test_hammer_with_zero_upper_shadow(self): + """Hammer with no upper shadow should be detected.""" + from stock_datasource.modules.portfolio.kline_patterns import is_hammer + + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=10.5, low=8.0) + assert is_hammer(c) is True + + def test_not_hammer_too_much_upper_shadow(self): + """A candle with large upper shadow is not a hammer.""" + from stock_datasource.modules.portfolio.kline_patterns import is_hammer + + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=13.0, low=8.0) + assert is_hammer(c) is False + + def test_not_hammer_short_lower_shadow(self): + """A candle with short lower shadow is not a hammer.""" + from stock_datasource.modules.portfolio.kline_patterns import is_hammer + + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=11.0, low=9.8) + assert is_hammer(c) is False + + +class TestShootingStarPattern: + """Shooting Star: small body at bottom, long upper shadow (>= 2x body), little/no lower shadow. + + Bearish reversal pattern appearing at top of uptrend. + """ + + def test_shooting_star_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_shooting_star + + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=13.0, low=10.0) + assert is_shooting_star(c) is True + + def test_not_shooting_star_with_lower_shadow(self): + from stock_datasource.modules.portfolio.kline_patterns import is_shooting_star + + c = Candle(date="2026-04-20", open=10.0, close=10.5, high=13.0, low=8.0) + assert is_shooting_star(c) is False + + +class TestDojiPattern: + """Doji: open and close are nearly equal (body is very small relative to range). + + Signals indecision in the market. + """ + + def test_doji_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_doji + + # Open and close very close relative to total range + c = Candle(date="2026-04-20", open=10.0, close=10.05, high=12.0, low=8.0) + assert is_doji(c) is True + + def test_not_doji_large_body(self): + from stock_datasource.modules.portfolio.kline_patterns import is_doji + + c = Candle(date="2026-04-20", open=10.0, close=12.0, high=13.0, low=9.0) + assert is_doji(c) is False + + +class TestMarubozuPattern: + """Marubozu: no or very small shadows, large body. + + Strong directional conviction. + """ + + def test_bullish_marubozu(self): + from stock_datasource.modules.portfolio.kline_patterns import is_marubozu + + # Open = low, Close = high, large body, no shadows + c = Candle(date="2026-04-20", open=10.0, close=15.0, high=15.0, low=10.0) + result = is_marubozu(c) + assert result is not None + assert result == "bullish" + + def test_bearish_marubozu(self): + from stock_datasource.modules.portfolio.kline_patterns import is_marubozu + + # Open = high, Close = low, large body, no shadows + c = Candle(date="2026-04-20", open=15.0, close=10.0, high=15.0, low=10.0) + result = is_marubozu(c) + assert result is not None + assert result == "bearish" + + def test_not_marubozu_with_shadows(self): + from stock_datasource.modules.portfolio.kline_patterns import is_marubozu + + c = Candle(date="2026-04-20", open=10.0, close=15.0, high=16.0, low=9.0) + assert is_marubozu(c) is None + + +# --------------------------------------------------------------------------- +# TDD Cycle 3.2: Dual-Candle Patterns +# --------------------------------------------------------------------------- + + +class TestBullishEngulfingPattern: + """Bullish Engulfing: previous bearish candle, current bullish candle + whose body completely engulfs the previous body. + + Bullish reversal pattern at bottom of downtrend. + """ + + def test_bullish_engulfing_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_bullish_engulfing + + prev = Candle(date="2026-04-19", open=12.0, close=10.0, high=12.5, low=9.5) + curr = Candle(date="2026-04-20", open=9.5, close=13.0, high=13.5, low=9.0) + assert is_bullish_engulfing(prev, curr) is True + + def test_not_engulfing_same_direction(self): + from stock_datasource.modules.portfolio.kline_patterns import is_bullish_engulfing + + # Both bullish - not a valid pattern + prev = Candle(date="2026-04-19", open=10.0, close=12.0, high=12.5, low=9.5) + curr = Candle(date="2026-04-20", open=9.5, close=13.0, high=13.5, low=9.0) + assert is_bullish_engulfing(prev, curr) is False + + def test_not_engulfing_does_not_engulf(self): + from stock_datasource.modules.portfolio.kline_patterns import is_bullish_engulfing + + # Current body does not completely engulf previous + prev = Candle(date="2026-04-19", open=12.0, close=10.0, high=12.5, low=9.5) + curr = Candle(date="2026-04-20", open=10.5, close=11.0, high=11.5, low=10.0) + assert is_bullish_engulfing(prev, curr) is False + + +class TestBearishEngulfingPattern: + """Bearish Engulfing: previous bullish candle, current bearish candle + whose body completely engulfs the previous body. + """ + + def test_bearish_engulfing_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_bearish_engulfing + + prev = Candle(date="2026-04-19", open=10.0, close=12.0, high=12.5, low=9.5) + curr = Candle(date="2026-04-20", open=13.0, close=9.5, high=13.5, low=9.0) + assert is_bearish_engulfing(prev, curr) is True + + def test_not_bearish_engulfing_wrong_direction(self): + from stock_datasource.modules.portfolio.kline_patterns import is_bearish_engulfing + + prev = Candle(date="2026-04-19", open=12.0, close=10.0, high=12.5, low=9.5) + curr = Candle(date="2026-04-20", open=13.0, close=9.5, high=13.5, low=9.0) + assert is_bearish_engulfing(prev, curr) is False + + +# --------------------------------------------------------------------------- +# TDD Cycle 3.3: Triple-Candle Patterns +# --------------------------------------------------------------------------- + + +class TestMorningStarPattern: + """Morning Star: + 1st candle: large bearish + 2nd candle: small body (star), gaps down + 3rd candle: large bullish, closes above midpoint of 1st candle + + Bullish reversal pattern at bottom of downtrend. + """ + + def test_morning_star_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_morning_star + + c1 = Candle(date="2026-04-18", open=15.0, close=12.0, high=15.2, low=11.8) + c2 = Candle(date="2026-04-19", open=11.5, close=11.8, high=12.0, low=11.0) + c3 = Candle(date="2026-04-20", open=12.0, close=14.5, high=14.8, low=11.8) + assert is_morning_star(c1, c2, c3) is True + + def test_not_morning_star_third_candle_weak(self): + from stock_datasource.modules.portfolio.kline_patterns import is_morning_star + + c1 = Candle(date="2026-04-18", open=15.0, close=12.0, high=15.2, low=11.8) + c2 = Candle(date="2026-04-19", open=11.5, close=11.8, high=12.0, low=11.0) + # Third candle does not close above midpoint of first + c3 = Candle(date="2026-04-20", open=12.0, close=12.5, high=12.8, low=11.8) + assert is_morning_star(c1, c2, c3) is False + + +class TestEveningStarPattern: + """Evening Star: + 1st candle: large bullish + 2nd candle: small body (star), gaps up + 3rd candle: large bearish, closes below midpoint of 1st candle + + Bearish reversal pattern at top of uptrend. + """ + + def test_evening_star_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_evening_star + + c1 = Candle(date="2026-04-18", open=12.0, close=15.0, high=15.2, low=11.8) + c2 = Candle(date="2026-04-19", open=15.5, close=15.2, high=16.0, low=15.0) + c3 = Candle(date="2026-04-20", open=15.0, close=12.5, high=15.2, low=12.0) + assert is_evening_star(c1, c2, c3) is True + + +class TestThreeWhiteSoldiers: + """Three White Soldiers: three consecutive bullish candles, each opening + within previous body and closing progressively higher. + + Strong bullish reversal. + """ + + def test_three_white_soldiers_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_three_white_soldiers + + c1 = Candle(date="2026-04-18", open=10.0, close=12.0, high=12.2, low=9.8) + c2 = Candle(date="2026-04-19", open=11.5, close=13.5, high=13.7, low=11.3) + c3 = Candle(date="2026-04-20", open=13.0, close=15.0, high=15.2, low=12.8) + assert is_three_white_soldiers(c1, c2, c3) is True + + def test_not_soldiers_one_bearish(self): + from stock_datasource.modules.portfolio.kline_patterns import is_three_white_soldiers + + c1 = Candle(date="2026-04-18", open=10.0, close=12.0, high=12.2, low=9.8) + c2 = Candle(date="2026-04-19", open=13.5, close=11.5, high=13.7, low=11.3) # bearish + c3 = Candle(date="2026-04-20", open=13.0, close=15.0, high=15.2, low=12.8) + assert is_three_white_soldiers(c1, c2, c3) is False + + +class TestThreeBlackCrows: + """Three Black Crows: three consecutive bearish candles, each opening + within previous body and closing progressively lower. + + Strong bearish reversal. + """ + + def test_three_black_crows_detected(self): + from stock_datasource.modules.portfolio.kline_patterns import is_three_black_crows + + c1 = Candle(date="2026-04-18", open=15.0, close=13.0, high=15.2, low=12.8) + c2 = Candle(date="2026-04-19", open=13.5, close=11.5, high=13.7, low=11.3) + c3 = Candle(date="2026-04-20", open=12.0, close=10.0, high=12.2, low=9.8) + assert is_three_black_crows(c1, c2, c3) is True + + +# --------------------------------------------------------------------------- +# TDD Cycle 3.4: Combined Pattern Detection +# --------------------------------------------------------------------------- + + +class TestDetectPatterns: + """Test the combined detect_patterns function that scans candle data + and returns all recognized patterns with their dates and types. + """ + + def test_detect_patterns_returns_list(self): + from stock_datasource.modules.portfolio.kline_patterns import detect_patterns + + candles = [ + Candle(date="2026-04-18", open=15.0, close=13.0, high=15.2, low=12.8), + Candle(date="2026-04-19", open=13.5, close=11.5, high=13.7, low=11.3), + Candle(date="2026-04-20", open=12.0, close=10.0, high=12.2, low=9.8), + ] + result = detect_patterns(candles) + assert isinstance(result, list) + + def test_detect_patterns_item_structure(self): + from stock_datasource.modules.portfolio.kline_patterns import detect_patterns + + candles = [ + Candle(date="2026-04-18", open=15.0, close=13.0, high=15.2, low=12.8), + Candle(date="2026-04-19", open=13.5, close=11.5, high=13.7, low=11.3), + Candle(date="2026-04-20", open=12.0, close=10.0, high=12.2, low=9.8), + ] + result = detect_patterns(candles) + if len(result) > 0: + pattern = result[0] + assert "name" in pattern + assert "date" in pattern + assert "type" in pattern # 'bullish' or 'bearish' + assert "category" in pattern # 'single', 'dual', or 'triple' + + def test_detect_patterns_finds_three_black_crows(self): + from stock_datasource.modules.portfolio.kline_patterns import detect_patterns + + candles = [ + Candle(date="2026-04-18", open=15.0, close=13.0, high=15.2, low=12.8), + Candle(date="2026-04-19", open=13.5, close=11.5, high=13.7, low=11.3), + Candle(date="2026-04-20", open=12.0, close=10.0, high=12.2, low=9.8), + ] + result = detect_patterns(candles) + names = [p["name"] for p in result] + assert "三只乌鸦" in names or "Three Black Crows" in names + + def test_detect_patterns_empty_input(self): + from stock_datasource.modules.portfolio.kline_patterns import detect_patterns + + result = detect_patterns([]) + assert result == [] + + def test_detect_patterns_hammer(self): + from stock_datasource.modules.portfolio.kline_patterns import detect_patterns + + candles = [ + Candle(date="2026-04-20", open=10.0, close=10.5, high=10.6, low=8.0), + ] + result = detect_patterns(candles) + names = [p["name"] for p in result] + assert "锤子线" in names or "Hammer" in names + + +# --------------------------------------------------------------------------- +# TDD Cycle 3.5: API Endpoint +# --------------------------------------------------------------------------- + + +class TestKlinePatternsAPI: + """Test GET /api/portfolio/kline-patterns/{ts_code} endpoint. + + The endpoint should: + 1. Fetch K-line OHLC data for the given ts_code + 2. Convert to Candle objects and run detect_patterns() + 3. Return the list of detected patterns + + Since this is a portfolio module endpoint, it requires authentication. + """ + + def test_kline_patterns_endpoint_returns_list(self): + """GET /api/portfolio/kline-patterns/{ts_code} returns a list.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import AsyncMock, patch + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return {"id": "test_user", "username": "tester"} + + app.dependency_overrides[get_current_user] = override_auth + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service" + ) as mock_svc_fn: + mock_svc = AsyncMock() + # get_kline_patterns returns a list of pattern dicts + mock_svc.get_kline_patterns = AsyncMock(return_value=[]) + mock_svc_fn.return_value = mock_svc + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/kline-patterns/600519.SH") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_kline_patterns_item_has_required_fields(self): + """Each pattern item should have: name, name_en, date, type, category.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import AsyncMock, patch + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return {"id": "test_user", "username": "tester"} + + app.dependency_overrides[get_current_user] = override_auth + + mock_patterns = [ + { + "name": "锤子线", + "name_en": "Hammer", + "date": "2026-04-20", + "type": "bullish", + "category": "single", + } + ] + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service" + ) as mock_svc_fn: + mock_svc = AsyncMock() + mock_svc.get_kline_patterns = AsyncMock(return_value=mock_patterns) + mock_svc_fn.return_value = mock_svc + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/kline-patterns/600519.SH?days=60") + assert response.status_code == 200 + data = response.json() + for item in data: + assert "name" in item + assert "name_en" in item + assert "date" in item + assert "type" in item + assert "category" in item + + def test_kline_patterns_endpoint_requires_auth(self): + """Endpoint should return 401 or 403 without authentication.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/kline-patterns/600519.SH") + # Should be 401, 403, or 422 (unauthenticated) + assert response.status_code in (401, 403, 422, 500) + + def test_kline_patterns_supports_days_param(self): + """Endpoint should accept a 'days' query parameter.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import AsyncMock, patch + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return {"id": "test_user", "username": "tester"} + + app.dependency_overrides[get_current_user] = override_auth + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service" + ) as mock_svc_fn: + mock_svc = AsyncMock() + mock_svc.get_kline_patterns = AsyncMock(return_value=[]) + mock_svc_fn.return_value = mock_svc + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/kline-patterns/600519.SH?days=7") + assert response.status_code == 200 + + def test_kline_patterns_type_values(self): + """Pattern type should be one of: bullish, bearish, neutral.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import AsyncMock, patch + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return {"id": "test_user", "username": "tester"} + + app.dependency_overrides[get_current_user] = override_auth + + mock_patterns = [ + { + "name": "锤子线", + "name_en": "Hammer", + "date": "2026-04-20", + "type": "bullish", + "category": "single", + }, + { + "name": "射击之星", + "name_en": "Shooting Star", + "date": "2026-04-18", + "type": "bearish", + "category": "single", + }, + { + "name": "十字星", + "name_en": "Doji", + "date": "2026-04-15", + "type": "neutral", + "category": "single", + }, + ] + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service" + ) as mock_svc_fn: + mock_svc = AsyncMock() + mock_svc.get_kline_patterns = AsyncMock(return_value=mock_patterns) + mock_svc_fn.return_value = mock_svc + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/kline-patterns/600519.SH?days=60") + assert response.status_code == 200 + data = response.json() + valid_types = {"bullish", "bearish", "neutral"} + for item in data: + assert item["type"] in valid_types diff --git a/tests/test_log_sink_clickhouse.py b/tests/test_log_sink_clickhouse.py index 7940016a..61e1fd77 100644 --- a/tests/test_log_sink_clickhouse.py +++ b/tests/test_log_sink_clickhouse.py @@ -1,13 +1,16 @@ """Tests for log_sink_clickhouse module.""" import json +import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock, call + +import pytest from stock_datasource.utils.log_sink_clickhouse import ( + _transform_record, _flush_batch, _import_file, - _transform_record, import_pending_files, ) @@ -80,20 +83,9 @@ def test_inserts_dataframe(self, mock_get_client): mock_db = MagicMock() mock_get_client.return_value = mock_db - batch = [ - { - "timestamp": "2026-01-01", - "level": "INFO", - "request_id": "-", - "user_id": "-", - "module": "test", - "function": "fn", - "line": 1, - "message": "msg", - "exception": None, - "extra": "{}", - } - ] + batch = [{"timestamp": "2026-01-01", "level": "INFO", "request_id": "-", "user_id": "-", + "module": "test", "function": "fn", "line": 1, "message": "msg", + "exception": None, "extra": "{}"}] _flush_batch(batch) mock_db.insert_dataframe.assert_called_once() args = mock_db.insert_dataframe.call_args @@ -121,30 +113,19 @@ class TestImportFile: def test_imports_and_deletes_file(self, tmp_path): jsonl_file = tmp_path / "test.jsonl.2026-04-09_15-30-00_123.jsonl" records = [ - { - "timestamp": "2026-04-09 15:30:00", - "level": "INFO", - "request_id": "-", - "user_id": "-", - "module": "mod", - "function": "fn", - "line": 1, - "message": "test msg", - "exception": None, - }, + {"timestamp": "2026-04-09 15:30:00", "level": "INFO", "request_id": "-", + "user_id": "-", "module": "mod", "function": "fn", "line": 1, + "message": "test msg", "exception": None}, ] with open(jsonl_file, "w") as f: for rec in records: f.write(json.dumps(rec) + "\n") from stock_datasource.config.settings import settings as _settings - original = getattr(_settings, "LOG_CH_SINK_BATCH_SIZE", 5000) try: _settings.LOG_CH_SINK_BATCH_SIZE = 5000 - with patch( - "stock_datasource.utils.log_sink_clickhouse._flush_batch" - ) as mock_flush: + with patch("stock_datasource.utils.log_sink_clickhouse._flush_batch") as mock_flush: result = _import_file(jsonl_file) assert result is True assert not jsonl_file.exists() @@ -162,7 +143,6 @@ def test_malformed_json_lines_skipped(self, tmp_path): f.write('{"level": "INFO", "line": 1}\n') from stock_datasource.config.settings import settings as _settings - original = getattr(_settings, "LOG_CH_SINK_BATCH_SIZE", 5000) try: _settings.LOG_CH_SINK_BATCH_SIZE = 5000 @@ -178,7 +158,6 @@ def test_empty_file_deleted(self, tmp_path): jsonl_file.write_text("") from stock_datasource.config.settings import settings as _settings - original = getattr(_settings, "LOG_CH_SINK_BATCH_SIZE", 5000) try: _settings.LOG_CH_SINK_BATCH_SIZE = 5000 @@ -198,9 +177,7 @@ def test_scans_jsonl_rotated_files(self, tmp_path): with open(rotated, "w") as f: f.write('{"level": "INFO", "line": 1}\n') - with patch( - "stock_datasource.utils.log_sink_clickhouse._import_file", return_value=True - ) as mock_import: + with patch("stock_datasource.utils.log_sink_clickhouse._import_file", return_value=True) as mock_import: count = import_pending_files(tmp_path) assert count == 1 mock_import.assert_called_once() @@ -209,9 +186,7 @@ def test_rotates_and_imports_active_jsonl(self, tmp_path): active = tmp_path / "stock_datasource.jsonl" active.write_text('{"level": "INFO", "line": 1}\n') - with patch( - "stock_datasource.utils.log_sink_clickhouse._import_file", return_value=True - ) as mock_import: + with patch("stock_datasource.utils.log_sink_clickhouse._import_file", return_value=True) as mock_import: count = import_pending_files(tmp_path) assert count == 1 assert not active.exists() @@ -221,9 +196,7 @@ def test_skips_empty_active_jsonl(self, tmp_path): active = tmp_path / "stock_datasource.jsonl" active.write_text("") - with patch( - "stock_datasource.utils.log_sink_clickhouse._import_file" - ) as mock_import: + with patch("stock_datasource.utils.log_sink_clickhouse._import_file") as mock_import: count = import_pending_files(tmp_path) assert count == 0 mock_import.assert_not_called() diff --git a/tests/test_portfolio_transactions.py b/tests/test_portfolio_transactions.py new file mode 100644 index 00000000..8d51e8db --- /dev/null +++ b/tests/test_portfolio_transactions.py @@ -0,0 +1,1221 @@ +"""Tests for portfolio transaction feature (buy/sell transaction history). + +TDD Cycle 1.1: Transaction Data Model & Schema +- TransactionType enum (BUY, SELL) +- Transaction dataclass +- Weighted average cost calculation +- Realized P/L calculation +- Sell validation (cannot sell more than held) + +TDD Cycle 1.2: Buy Transaction Operations +- record_buy_transaction creates Transaction with type=buy +- First buy creates new Position +- Second buy updates Position with weighted average cost +- Buy transaction persisted to user_transactions table +- Position auto-updated after buy transaction +- Multiple buys accumulate correctly + +TDD Cycle 1.3: Sell Transaction Operations +- record_sell_transaction creates Transaction with type=sell +- Partial sell reduces position quantity +- Full sell sets position is_active=False +- Sell more than held raises ValueError +- Realized P/L calculated correctly +- Position cost_price unchanged after sell +- Sell transaction persisted +- Sell on non-existent position raises ValueError + +TDD Cycle 1.4: Transaction History Query +- get_transactions returns list of Transaction for user +- get_transactions filtered by ts_code +- get_transactions ordered by transaction_date DESC +- get_transactions respects date range filter + +TDD Cycle 1.5: API Endpoints +- POST /api/portfolio/transactions/buy returns 200 +- POST /api/portfolio/transactions/sell returns 200 +- GET /api/portfolio/transactions returns list +- GET /api/portfolio/transactions?ts_code=600519.SH filters by stock +- POST sell with quantity > held returns 400 +- All endpoints require auth +""" + +from datetime import datetime +from unittest.mock import MagicMock, Mock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# 1. TransactionType Enum Tests +# --------------------------------------------------------------------------- + + +class TestTransactionType: + """Test TransactionType enum values.""" + + def test_has_buy_and_sell(self): + """TransactionType should have BUY and SELL members.""" + from stock_datasource.modules.portfolio.enhanced_service import TransactionType + + assert hasattr(TransactionType, "BUY") + assert hasattr(TransactionType, "SELL") + + def test_buy_value(self): + """TransactionType.BUY should have value 'buy'.""" + from stock_datasource.modules.portfolio.enhanced_service import TransactionType + + assert TransactionType.BUY.value == "buy" + + def test_sell_value(self): + """TransactionType.SELL should have value 'sell'.""" + from stock_datasource.modules.portfolio.enhanced_service import TransactionType + + assert TransactionType.SELL.value == "sell" + + +# --------------------------------------------------------------------------- +# 2. Transaction Dataclass Tests +# --------------------------------------------------------------------------- + + +class TestTransaction: + """Test Transaction dataclass fields and creation.""" + + def test_transaction_fields_exist(self): + """Transaction should have all required fields.""" + from stock_datasource.modules.portfolio.enhanced_service import Transaction + + txn = Transaction( + id="txn_001", + user_id="user_001", + ts_code="600519.SH", + stock_name="贵州茅台", + transaction_type="buy", + quantity=100, + price=1700.0, + transaction_date="2026-01-15", + position_id="pos_001", + realized_pl=None, + notes="首次建仓", + profile_id="default", + created_at=datetime.now(), + ) + + assert txn.id == "txn_001" + assert txn.user_id == "user_001" + assert txn.ts_code == "600519.SH" + assert txn.stock_name == "贵州茅台" + assert txn.transaction_type == "buy" + assert txn.quantity == 100 + assert txn.price == 1700.0 + assert txn.transaction_date == "2026-01-15" + assert txn.position_id == "pos_001" + assert txn.realized_pl is None + assert txn.notes == "首次建仓" + assert txn.profile_id == "default" + assert txn.created_at is not None + + def test_transaction_optional_fields_default(self): + """Transaction optional fields should have sensible defaults.""" + from stock_datasource.modules.portfolio.enhanced_service import Transaction + + txn = Transaction( + id="txn_002", + user_id="user_001", + ts_code="600519.SH", + stock_name="贵州茅台", + transaction_type="sell", + quantity=50, + price=1800.0, + transaction_date="2026-02-01", + position_id="pos_001", + ) + + assert txn.realized_pl is None + assert txn.notes == "" + assert txn.profile_id == "default" + + +# --------------------------------------------------------------------------- +# 3. Weighted Average Cost Calculation Tests +# --------------------------------------------------------------------------- + + +class TestWeightedAverageCost: + """Test weighted average cost calculation for buy transactions.""" + + def test_two_buys_weighted_average(self): + """Two buys at different prices should compute weighted average cost. + + buy 100 @ 10.00, buy 50 @ 12.00 -> avg = (1000+600)/150 = 10.667 + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + result = EnhancedPortfolioService._calc_weighted_average_cost( + old_quantity=100, old_cost=10.0, new_quantity=50, new_price=12.0 + ) + assert round(result, 2) == 10.67 + + def test_single_buy_no_average(self): + """First buy should return the buy price directly. + + old_quantity=0 means no previous position, cost = new_price. + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + result = EnhancedPortfolioService._calc_weighted_average_cost( + old_quantity=0, old_cost=0.0, new_quantity=100, new_price=15.0 + ) + assert result == 15.0 + + def test_three_buys_weighted_average(self): + """Three buys: 100@10, 50@12, 200@11. + + avg = (1000+600+2200) / 350 = 10.857 + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + # First two buys + avg1 = EnhancedPortfolioService._calc_weighted_average_cost( + old_quantity=100, old_cost=10.0, new_quantity=50, new_price=12.0 + ) + # Third buy uses the previously computed average + avg2 = EnhancedPortfolioService._calc_weighted_average_cost( + old_quantity=150, old_cost=avg1, new_quantity=200, new_price=11.0 + ) + assert round(avg2, 2) == 10.86 + + +# --------------------------------------------------------------------------- +# 4. Realized P/L Calculation Tests +# --------------------------------------------------------------------------- + + +class TestRealizedPL: + """Test realized profit/loss calculation for sell transactions.""" + + def test_sell_profit(self): + """Selling at higher price should give positive realized P/L. + + buy 100 @ 10, sell 50 @ 15 -> realized = 50 * (15-10) = 250 + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + result = EnhancedPortfolioService._calc_realized_pl( + quantity=50, sell_price=15.0, cost_price=10.0 + ) + assert result == 250.0 + + def test_sell_loss(self): + """Selling at lower price should give negative realized P/L. + + buy 100 @ 20, sell 30 @ 15 -> realized = 30 * (15-20) = -150 + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + result = EnhancedPortfolioService._calc_realized_pl( + quantity=30, sell_price=15.0, cost_price=20.0 + ) + assert result == -150.0 + + def test_sell_at_cost(self): + """Selling at cost price should give zero realized P/L.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + result = EnhancedPortfolioService._calc_realized_pl( + quantity=100, sell_price=10.0, cost_price=10.0 + ) + assert result == 0.0 + + +# --------------------------------------------------------------------------- +# 5. Sell Validation Tests +# --------------------------------------------------------------------------- + + +class TestSellValidation: + """Test sell quantity validation.""" + + def test_sell_more_than_held_raises_error(self): + """Selling more shares than held should raise ValueError.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + with pytest.raises(ValueError, match="Cannot sell"): + EnhancedPortfolioService._validate_sell_quantity( + held_quantity=50, sell_quantity=100 + ) + + def test_sell_exact_quantity_ok(self): + """Selling exactly the held quantity should not raise.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + # Should not raise + EnhancedPortfolioService._validate_sell_quantity( + held_quantity=100, sell_quantity=100 + ) + + def test_sell_partial_quantity_ok(self): + """Selling less than held should not raise.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + EnhancedPortfolioService._validate_sell_quantity( + held_quantity=100, sell_quantity=30 + ) + + def test_sell_zero_quantity_raises_error(self): + """Selling zero shares should raise ValueError.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + with pytest.raises(ValueError, match="Cannot sell"): + EnhancedPortfolioService._validate_sell_quantity( + held_quantity=100, sell_quantity=0 + ) + + +# --------------------------------------------------------------------------- +# 6. Buy Transaction Service Tests (Cycle 1.2) +# --------------------------------------------------------------------------- + + +class TestBuyTransaction: + """Test record_buy_transaction service method.""" + + @pytest.mark.asyncio + async def test_record_buy_creates_transaction_with_buy_type(self): + """record_buy_transaction should create a Transaction with type=buy.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + Transaction, + ) + + svc = EnhancedPortfolioService() + svc._db = None # Use in-memory only + + txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=1700.0, + transaction_date="2026-01-15", + ) + + assert isinstance(txn, Transaction) + assert txn.transaction_type == "buy" + assert txn.ts_code == "600519.SH" + assert txn.quantity == 100 + assert txn.price == 1700.0 + + @pytest.mark.asyncio + async def test_first_buy_creates_new_position(self): + """First buy for a stock should create a new Position.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=1700.0, + transaction_date="2026-01-15", + ) + + # Check position was created + assert txn.position_id != "" + position = svc._positions.get(txn.position_id) + assert position is not None + assert position.quantity == 100 + assert position.cost_price == 1700.0 + assert position.ts_code == "600519.SH" + + @pytest.mark.asyncio + async def test_second_buy_updates_weighted_average_cost(self): + """Second buy for same stock should update Position with weighted avg cost. + + buy 100 @ 10.00, buy 50 @ 12.00 -> position.quantity=150, position.cost_price≈10.67 + """ + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + # First buy + txn1 = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + position_id = txn1.position_id + + # Second buy + txn2 = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=12.0, + transaction_date="2026-01-20", + ) + + # Same position updated + assert txn2.position_id == position_id + position = svc._positions[position_id] + assert position.quantity == 150 + assert round(position.cost_price, 2) == 10.67 + + @pytest.mark.asyncio + async def test_buy_transaction_persisted_to_store(self): + """Buy transaction should be stored in the service's _transactions dict.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=1700.0, + transaction_date="2026-01-15", + ) + + assert txn.id in svc._transactions + assert svc._transactions[txn.id].transaction_type == "buy" + + @pytest.mark.asyncio + async def test_position_auto_updated_after_buy(self): + """Position should be auto-updated (is_active=True, updated_at set) after buy.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=1700.0, + transaction_date="2026-01-15", + ) + + position = svc._positions[txn.position_id] + assert position.is_active is True + assert position.updated_at is not None + + @pytest.mark.asyncio + async def test_multiple_buys_accumulate_correctly(self): + """Three buys should accumulate quantity and compute correct weighted avg cost.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=12.0, + transaction_date="2026-01-20", + ) + txn3 = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=200, + price=11.0, + transaction_date="2026-01-25", + ) + + position = svc._positions[txn3.position_id] + assert position.quantity == 350 + # avg = (100*10 + 50*12 + 200*11) / 350 = 3800/350 ≈ 10.86 + assert round(position.cost_price, 2) == 10.86 + + +# --------------------------------------------------------------------------- +# 7. Sell Transaction Service Tests (Cycle 1.3) +# --------------------------------------------------------------------------- + + +class TestSellTransaction: + """Test record_sell_transaction service method.""" + + @pytest.mark.asyncio + async def test_record_sell_creates_transaction_with_sell_type(self): + """record_sell_transaction should create a Transaction with type=sell.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + # First, create a position via buy + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + txn = await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=15.0, + transaction_date="2026-02-01", + ) + + assert txn.transaction_type == "sell" + assert txn.quantity == 50 + assert txn.price == 15.0 + + @pytest.mark.asyncio + async def test_partial_sell_reduces_position_quantity(self): + """Partial sell should reduce position quantity (100 -> sell 30 -> 70).""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + buy_txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=30, + price=12.0, + transaction_date="2026-02-01", + ) + + position = svc._positions[buy_txn.position_id] + assert position.quantity == 70 + + @pytest.mark.asyncio + async def test_full_sell_sets_position_inactive(self): + """Full sell (all shares) should set position is_active=False.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + buy_txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=15.0, + transaction_date="2026-02-01", + ) + + position = svc._positions[buy_txn.position_id] + assert position.is_active is False + assert position.quantity == 0 + + @pytest.mark.asyncio + async def test_sell_more_than_held_raises_valueerror(self): + """Selling more shares than held should raise ValueError.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + with pytest.raises(ValueError, match="Cannot sell"): + await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=200, + price=15.0, + transaction_date="2026-02-01", + ) + + @pytest.mark.asyncio + async def test_sell_realized_pl_calculated_correctly(self): + """Realized P/L should be calculated: sell 50 @ 15, cost_price=10 -> 250.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + txn = await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=15.0, + transaction_date="2026-02-01", + ) + + assert txn.realized_pl == 250.0 # 50 * (15 - 10) + + @pytest.mark.asyncio + async def test_sell_does_not_change_cost_price(self): + """Position cost_price should remain unchanged after sell (tracks avg buy cost).""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + buy_txn = await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + original_cost = svc._positions[buy_txn.position_id].cost_price + + await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=30, + price=15.0, + transaction_date="2026-02-01", + ) + + assert svc._positions[buy_txn.position_id].cost_price == original_cost + + @pytest.mark.asyncio + async def test_sell_transaction_persisted_to_store(self): + """Sell transaction should be stored in the service's _transactions dict.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + + txn = await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=15.0, + transaction_date="2026-02-01", + ) + + assert txn.id in svc._transactions + assert svc._transactions[txn.id].transaction_type == "sell" + + @pytest.mark.asyncio + async def test_sell_nonexistent_position_raises_valueerror(self): + """Selling a stock with no position should raise ValueError.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + with pytest.raises(ValueError, match="No active position"): + await svc.record_sell_transaction( + user_id="user_001", + ts_code="999999.SH", + quantity=10, + price=10.0, + transaction_date="2026-02-01", + ) + + +# --------------------------------------------------------------------------- +# 8. Transaction History Query Tests (Cycle 1.4) +# --------------------------------------------------------------------------- + + +class TestGetTransactions: + """Test get_transactions query method.""" + + @pytest.mark.asyncio + async def test_get_transactions_returns_list_for_user(self): + """get_transactions should return list of Transaction for user.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + await svc.record_sell_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=30, + price=12.0, + transaction_date="2026-02-01", + ) + + txns = await svc.get_transactions(user_id="user_001") + assert len(txns) >= 2 + assert all(t.user_id == "user_001" for t in txns) + + @pytest.mark.asyncio + async def test_get_transactions_filtered_by_ts_code(self): + """get_transactions should support filtering by ts_code.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + await svc.record_buy_transaction( + user_id="user_001", + ts_code="000001.SZ", + quantity=200, + price=15.0, + transaction_date="2026-01-16", + ) + + txns = await svc.get_transactions(user_id="user_001", ts_code="600519.SH") + assert len(txns) == 1 + assert txns[0].ts_code == "600519.SH" + + @pytest.mark.asyncio + async def test_get_transactions_ordered_by_date_desc(self): + """get_transactions should return transactions ordered by date DESC.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=12.0, + transaction_date="2026-02-01", + ) + + txns = await svc.get_transactions(user_id="user_001", ts_code="600519.SH") + assert len(txns) == 2 + # Most recent first + assert txns[0].transaction_date >= txns[1].transaction_date + + @pytest.mark.asyncio + async def test_get_transactions_date_range_filter(self): + """get_transactions should support date range filtering.""" + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + svc = EnhancedPortfolioService() + svc._db = None + + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=100, + price=10.0, + transaction_date="2026-01-15", + ) + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=50, + price=12.0, + transaction_date="2026-02-01", + ) + await svc.record_buy_transaction( + user_id="user_001", + ts_code="600519.SH", + quantity=30, + price=11.0, + transaction_date="2026-03-01", + ) + + # Filter to only February onwards + txns = await svc.get_transactions( + user_id="user_001", + ts_code="600519.SH", + start_date="2026-02-01", + ) + assert len(txns) == 2 + assert all(t.transaction_date >= "2026-02-01" for t in txns) + + # Filter to only January + txns_jan = await svc.get_transactions( + user_id="user_001", + ts_code="600519.SH", + start_date="2026-01-01", + end_date="2026-01-31", + ) + assert len(txns_jan) == 1 + + +# --------------------------------------------------------------------------- +# 9. Transaction API Endpoint Tests (Cycle 1.5) +# --------------------------------------------------------------------------- + + +def _make_user(user_id: str = "user_001") -> dict: + return { + "id": user_id, + "username": "testuser", + "email": "test@test.com", + "is_admin": False, + } + + +class TestTransactionAPIEndpoints: + """Test the transaction REST API endpoints.""" + + def test_buy_endpoint_returns_200(self): + """POST /api/portfolio/transactions/buy should return 200 with transaction data.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return _make_user("user_001") + + app.dependency_overrides[get_current_user] = override_auth + + with patch("stock_datasource.models.database.db_client") as mock_db: + mock_db.execute = Mock() + mock_db.execute_query = Mock( + return_value=MagicMock(empty=True, __bool__=lambda self: False) + ) + + client = TestClient(app, raise_server_exceptions=False) + response = client.post( + "/api/portfolio/transactions/buy", + json={ + "ts_code": "600519.SH", + "quantity": 100, + "price": 1700.0, + "transaction_date": "2026-01-15", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["transaction_type"] == "buy" + assert data["ts_code"] == "600519.SH" + assert data["quantity"] == 100 + + def test_sell_endpoint_returns_200(self): + """POST /api/portfolio/transactions/sell should return 200 with transaction data.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return _make_user("user_001") + + app.dependency_overrides[get_current_user] = override_auth + + # Use a single service instance so buy's position is visible to sell + shared_svc = EnhancedPortfolioService() + shared_svc._db = None + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service", + return_value=shared_svc, + ): + client = TestClient(app, raise_server_exceptions=False) + # First buy to create position + buy_resp = client.post( + "/api/portfolio/transactions/buy", + json={ + "ts_code": "600519.SH", + "quantity": 100, + "price": 1700.0, + "transaction_date": "2026-01-15", + }, + ) + assert buy_resp.status_code == 200, f"Buy failed: {buy_resp.text}" + + # Then sell + response = client.post( + "/api/portfolio/transactions/sell", + json={ + "ts_code": "600519.SH", + "quantity": 30, + "price": 1800.0, + "transaction_date": "2026-02-01", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["transaction_type"] == "sell" + assert data["quantity"] == 30 + + def test_get_transactions_returns_list(self): + """GET /api/portfolio/transactions should return list.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return _make_user("user_001") + + app.dependency_overrides[get_current_user] = override_auth + + with patch("stock_datasource.models.database.db_client") as mock_db: + mock_db.execute = Mock() + mock_db.execute_query = Mock( + return_value=MagicMock(empty=True, __bool__=lambda self: False) + ) + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/transactions") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_sell_more_than_held_returns_400(self): + """POST /api/portfolio/transactions/sell with quantity > held should return 400.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + async def override_auth(): + return _make_user("user_001") + + app.dependency_overrides[get_current_user] = override_auth + + with patch("stock_datasource.models.database.db_client") as mock_db: + mock_db.execute = Mock() + mock_db.execute_query = Mock( + return_value=MagicMock(empty=True, __bool__=lambda self: False) + ) + + client = TestClient(app, raise_server_exceptions=False) + # Buy some shares first + client.post( + "/api/portfolio/transactions/buy", + json={ + "ts_code": "600519.SH", + "quantity": 50, + "price": 1700.0, + "transaction_date": "2026-01-15", + }, + ) + + # Try to sell more than held + response = client.post( + "/api/portfolio/transactions/sell", + json={ + "ts_code": "600519.SH", + "quantity": 200, + "price": 1800.0, + "transaction_date": "2026-02-01", + }, + ) + + assert response.status_code == 400 + + def test_endpoints_require_auth(self): + """Transaction endpoints should require authentication.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + client = TestClient(app, raise_server_exceptions=False) + + # Without auth override, should get 401/403/500 + response = client.get("/api/portfolio/transactions") + assert response.status_code in (401, 403, 500) + + +# --------------------------------------------------------------------------- +# TDD Cycle 2.1: Transaction Signals Endpoint (for K-line B/S markers) +# --------------------------------------------------------------------------- + + +class TestTransactionSignals: + """Tests for the transaction signals endpoint. + + The signals endpoint returns buy/sell markers that combine: + - User transactions (actual buy/sell records) + - Technical strategy signals (from indicators API) + These are used to display B/S point markers on K-line charts. + """ + + def test_get_signals_returns_list(self): + """GET /transactions/signals with ts_code returns a list.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + def override_auth(): + return {"id": "user_signals_001", "username": "test_signals"} + + app.dependency_overrides[get_current_user] = override_auth + + client = TestClient(app, raise_server_exceptions=False) + response = client.get( + "/api/portfolio/transactions/signals", + params={"ts_code": "600519.SH"}, + ) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_signals_require_ts_code_param(self): + """GET /transactions/signals without ts_code should return 422.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + def override_auth(): + return {"id": "user_signals_002", "username": "test_signals"} + + app.dependency_overrides[get_current_user] = override_auth + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/portfolio/transactions/signals") + + assert response.status_code == 422 + + def test_signal_item_has_required_fields(self): + """Each signal item should have id, ts_code, signal_type, source, signal_date, price.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import patch + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + Transaction, + ) + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + def override_auth(): + return {"id": "user_signals_003", "username": "test_signals"} + + app.dependency_overrides[get_current_user] = override_auth + + # Create a service with a recorded buy transaction + svc = EnhancedPortfolioService() + svc._db = None + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service", + return_value=svc, + ): + # Record a buy so we have at least one signal + import asyncio + + txn = asyncio.get_event_loop().run_until_complete( + svc.record_buy_transaction( + user_id="user_signals_003", + ts_code="600519.SH", + quantity=100, + price=1800.0, + transaction_date="2026-04-20", + ) + ) + + client = TestClient(app, raise_server_exceptions=False) + response = client.get( + "/api/portfolio/transactions/signals", + params={"ts_code": "600519.SH"}, + ) + + assert response.status_code == 200 + data = response.json() + if len(data) > 0: + signal = data[0] + assert "id" in signal + assert "ts_code" in signal + assert "signal_type" in signal + assert "source" in signal + assert "signal_date" in signal + assert "price" in signal + assert signal["source"] == "user" + assert signal["signal_type"] in ("buy", "sell") + + def test_signals_include_both_user_and_strategy(self): + """Signals should include items with source='user' from transactions.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + from unittest.mock import patch, AsyncMock + + from stock_datasource.modules.auth.dependencies import get_current_user + from stock_datasource.modules.portfolio.router import router + from stock_datasource.modules.portfolio.enhanced_service import ( + EnhancedPortfolioService, + ) + + app = FastAPI() + app.include_router(router, prefix="/api/portfolio") + + def override_auth(): + return {"id": "user_signals_004", "username": "test_signals"} + + app.dependency_overrides[get_current_user] = override_auth + + svc = EnhancedPortfolioService() + svc._db = None + + with patch( + "stock_datasource.modules.portfolio.router.get_enhanced_portfolio_service", + return_value=svc, + ): + client = TestClient(app, raise_server_exceptions=False) + response = client.get( + "/api/portfolio/transactions/signals", + params={"ts_code": "600519.SH"}, + ) + + assert response.status_code == 200 + data = response.json() + # All signals from user transactions should have source='user' + user_signals = [s for s in data if s.get("source") == "user"] + # Strategy signals may or may not be present (depends on indicators) + # At minimum, the response format is correct + for s in user_signals: + assert s["signal_type"] in ("buy", "sell") + assert s["ts_code"] == "600519.SH" diff --git a/tests/test_system_logs_clickhouse.py b/tests/test_system_logs_clickhouse.py index b094a9c7..dcc6eb5d 100644 --- a/tests/test_system_logs_clickhouse.py +++ b/tests/test_system_logs_clickhouse.py @@ -1,7 +1,9 @@ """Tests for system_logs schemas and log_parser with ClickHouse context fields.""" from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock, PropertyMock + +import pytest from stock_datasource.modules.system_logs.schemas import ( LogEntry, @@ -147,17 +149,12 @@ class TestLogServiceClickHousePaths: def test_ch_client_property_returns_none_on_import_error(self): """_ch_client should return None when db_client is unavailable.""" # Import only the service module with heavy deps mocked out - with patch.dict( - "sys.modules", - { - "stock_datasource.modules.system_logs.ai_diagnosis_service": MagicMock(), - }, - ): + with patch.dict("sys.modules", { + "stock_datasource.modules.system_logs.ai_diagnosis_service": MagicMock(), + }): # Re-import to get fresh module import importlib - import stock_datasource.modules.system_logs.service as svc_mod - importlib.reload(svc_mod) service = svc_mod.LogService(log_dir="/tmp/test_logs") @@ -169,36 +166,28 @@ def test_ch_client_property_returns_none_on_import_error(self): def test_get_logs_fallback_uses_single_reader_call(self): """Fallback path should not rescan files just to compute totals.""" - with patch.dict( - "sys.modules", - { - "stock_datasource.modules.system_logs.ai_diagnosis_service": MagicMock(), - }, - ): + with patch.dict("sys.modules", { + "stock_datasource.modules.system_logs.ai_diagnosis_service": MagicMock(), + }): import importlib - import stock_datasource.modules.system_logs.service as svc_mod importlib.reload(svc_mod) service = svc_mod.LogService(log_dir="/tmp/test_logs") service.reader = MagicMock() - service.reader.read_logs.return_value = [ - { - "timestamp": datetime(2026, 4, 9, 15, 30, 0), - "level": "ERROR", - "module": "test_module", - "message": "boom", - "raw_line": "raw", - "request_id": "req-2", - "user_id": "user2", - } - ] + service.reader.read_logs.return_value = [{ + "timestamp": datetime(2026, 4, 9, 15, 30, 0), + "level": "ERROR", + "module": "test_module", + "message": "boom", + "raw_line": "raw", + "request_id": "req-2", + "user_id": "user2", + }] filters = LogFilter(request_id="req-2", page=1, page_size=50) - with patch.object( - svc_mod.LogService, "_get_logs_from_clickhouse", return_value=None - ): + with patch.object(svc_mod.LogService, "_get_logs_from_clickhouse", return_value=None): result = service.get_logs(filters) assert service.reader.read_logs.call_count == 1