diff --git a/.claude/rules/general.md b/.claude/rules/general.md index f2ac120..63befcb 100644 --- a/.claude/rules/general.md +++ b/.claude/rules/general.md @@ -1,3 +1,7 @@ +# PR 생성 + +PR을 만들기 전에 **반드시** `.github/pull_request_template.md`를 먼저 읽고 템플릿 형식을 따를 것. + # 빌드 명령어 ## C++ 빌드 diff --git a/docs/architecture.md b/docs/architecture.md index e5f1654..daca032 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -185,6 +185,20 @@ While (세션 활성): - Atomic으로 카운터 증가 ``` +### 세션 I/O 최적화 + +- `proxy::Session`은 작은 패킷을 반복해서 읽고 쓰는 경로에서 syscall 수를 줄이기 위해 내부 read/write buffer를 유지한다. +- 서버 응답 경로는 `RelayBuffer`가 header+payload를 버퍼에 누적한 뒤 필요 시 flush 하며, result set 중간에도 임계치 기반으로 분할 전송한다. +- 클라이언트 요청 경로는 `ClientReadBuffer`가 `async_read_some` 기반으로 패킷을 누적 읽어 2단계 read(header, payload)를 단일 버퍼 흐름으로 줄인다. +- 세션 버퍼는 단일 MySQL 패킷 최대 크기(3-byte length field 기준)를 넘겨 확장하지 않으며, 큰 패킷 처리 후 버퍼가 비면 초기 크기로 축소해 장기 세션의 상주 메모리를 제한한다. +- packet write는 `write_packet_raw()`가 header와 payload를 scatter-gather로 묶어 전송해 별도 serialize 버퍼 할당을 피한다. + +### 로깅과 통계 + +- `StructuredLogger`는 UTC ISO8601 타임스탬프를 `strftime + snprintf` 로 생성해 문자열 스트림 기반 포맷 비용을 줄인다. +- 허용/차단/에러 로그는 세션 ID, 사용자, SQL prefix, 평가 경로를 남겨 parser, policy, proxy의 fail-close 판단을 추적 가능하게 한다. +- `StatsCollector`는 atomic 기반 누적 카운터를 유지하고, UDS stats 서버와 health check는 같은 snapshot을 읽기 전용으로 소비한다. + ## Fail-Close 원칙 (절대 위반 금지) **Fail-Close의 의미:** diff --git a/docs/failure-modes.md b/docs/failure-modes.md index fb5363c..d9f99f9 100644 --- a/docs/failure-modes.md +++ b/docs/failure-modes.md @@ -36,6 +36,37 @@ - `StructuredLogger` 종료 시 `flush()` 보장 - 테스트는 로거 수명 종료 이후 파일을 읽도록 구성 +### 4) Static Analysis (clang-tidy) 실패 +- 증상: + - `cppcoreguidelines-avoid-const-or-ref-data-members` + - `readability-braces-around-statements` + - `misc-const-correctness` + - `modernize-return-braced-init-list` + - Boost.Asio `awaitable.hpp` 경유 `clang-analyzer-core.NullDereference` +- 원인: + - 세션 내부 보조 버퍼 클래스가 reference data member를 유지함 + - coroutine early return 분기에 brace가 빠져 style check에 걸림 + - 변경되지 않는 scatter-gather 전송 버퍼가 `const`로 선언되지 않음 + - 단순 문자열 반환식이 최신 초기화 스타일과 맞지 않음 + - clang static analyzer가 coroutine + Boost.Asio awaitable 경로를 추적하며 외부 헤더 false positive를 낼 수 있음 +- 복구: + - non-owning stream 참조는 pointer 등 재바인딩 가능한 형태로 저장 + - 단일문 분기에도 brace를 사용 + - 불변 로컬 버퍼는 `const`로 선언 + - 단순 생성 반환은 braced-init-list 사용 + - `clang-tidy -p build/default .cpp` 단독 실행 대신 CI와 같은 extra arg/별도 compile DB 구성으로 재현한다 + - 외부 Boost 헤더에서만 발생한 `clang-analyzer-core.NullDereference`는 `.clang-tidy`에서 `WarningsAsErrors` 제외 항목이므로, 프로젝트 소스 error가 없으면 CI 실패 원인으로 보지 않는다 + +### 5) devcontainer에서 Integration Test 확인이 불안정한 경우 +- 증상: + - `tests/integration/test_scenarios.sh` 실행 시 direct/proxy 단계가 모두 실패하거나 환경에 따라 결과가 달라짐 +- 원인: + - devcontainer 내부에서는 Docker daemon 제어가 불가하고, compose MySQL 접근도 실행 권한/네트워크 제약의 영향을 받음 + - 스크립트는 stderr를 버리므로 접속 실패 원인이 바로 드러나지 않음 +- 복구: + - compose가 올린 `mysql` 서비스에 직접 접속 가능한 실행 경로에서 검증 + - 필요 시 동일한 `mysql` 명령을 독립 실행해 실제 접속 오류를 먼저 확인 + ## 런타임 실패 모드 ### UDS 바인드 실패 diff --git a/docs/runbook.md b/docs/runbook.md index c79cd8f..84c1107 100644 --- a/docs/runbook.md +++ b/docs/runbook.md @@ -183,7 +183,7 @@ vcpkg 의존성 컴파일 시간 절감을 위해 바이너리 캐시를 사용 | 실패 job | 원인 | 조치 | |---|---|---| | Build & Test | 빌드 오류 또는 테스트 실패 | 로컬 `cmake --preset default && ctest` 재현 | -| Static Analysis | clang-tidy error 또는 cppcheck 오류 | `clang-tidy -p build/debug <파일>.cpp` 로컬 실행 | +| Static Analysis | clang-tidy error 또는 cppcheck 오류 | `build/default/compile_commands.json` 기준으로 CI와 동일한 extra arg를 넣어 `clang-tidy` 재현 | | ASan | 메모리 오염/누수 | `cmake --preset asan && ctest` 로컬 실행 | | TSan | 데이터레이스 | `cmake --preset tsan && ctest` 로컬 실행 | | Go CI | 린트 오류 또는 테스트 실패 | `cd tools && golangci-lint run` 로컬 실행 | @@ -193,6 +193,53 @@ vcpkg 의존성 컴파일 시간 절감을 위해 바이너리 캐시를 사용 TSan job은 `ubuntu-24.04` (x86_64 고정) 에서만 실행한다. GCC ThreadSanitizer가 aarch64에서 불안정하므로 runner 아키텍처를 명시적으로 제한한다. +### 로컬 clang-tidy 재현 + +CI의 static analysis job은 `build/default/compile_commands.json` 에서 `src/*.cpp` 엔트리만 추려 별도 compile DB를 만들고, 아래 인자를 추가해 `clang-tidy`를 실행한다. + +```bash +db_path="build/default/compile_commands.json" +tidy_db_dir="/tmp/clang-tidy-db" +tidy_db="${tidy_db_dir}/compile_commands.json" +mkdir -p "${tidy_db_dir}" +src_root="$(realpath src)" + +jq --arg root "${src_root}" ' + def cmd: + if .command then .command + elif ((.arguments | type) == "array") then (.arguments | join(" ")) + else "" end; + [ + .[] + | select( + (.file | startswith($root + "/")) + and (cmd | test("(^| )-std=(gnu\\+\\+23|c\\+\\+23)( |$)")) + ) + | if .command then . else (. + {command: cmd}) end + ] + | sort_by(.file) + | unique_by(.file) +' "${db_path}" > "${tidy_db}" + +gcc_major="$(g++-14 -dumpfullversion -dumpversion | cut -d. -f1)" +gcc_triple="$(g++-14 -dumpmachine)" +args=( + -p "${tidy_db_dir}" + --extra-arg=-std=c++23 + --extra-arg=--gcc-toolchain=/usr +) +if [ -d "/usr/include/c++/${gcc_major}" ]; then + args+=(--extra-arg=-isystem --extra-arg="/usr/include/c++/${gcc_major}") +fi +if [ -d "/usr/include/${gcc_triple}/c++/${gcc_major}" ]; then + args+=(--extra-arg=-isystem --extra-arg="/usr/include/${gcc_triple}/c++/${gcc_major}") +fi + +clang-tidy "${args[@]}" "$(realpath src/proxy/session.cpp)" +``` + +로컬에서 단일 파일만 확인할 때도 위 인자 구성을 유지해야 CI와 같은 결과를 얻는다. + ## Docker 배포 (멀티 인스턴스 + HAProxy) ### 아키텍처 개요 diff --git a/docs/threat-model.md b/docs/threat-model.md index c82e272..920ca7b 100644 --- a/docs/threat-model.md +++ b/docs/threat-model.md @@ -170,6 +170,24 @@ > 참조: `docs/architecture.md:955-957` +### 3.2.1 ReDoS (정규식 기반 DoS) — DON-80 완화 + +| 항목 | 설명 | +|------|------| +| 공격 시나리오 | 파서의 `std::regex` 평가를 악용하여 백트래킹 폭발(catastrophic backtracking) 유도, CPU 고갈 | +| 이전 상태 | `sql_parser.cpp`의 `is_start_transaction_statement`, `extract_tables_for_keyword`, `has_where_keyword`가 `std::regex` 사용 → ReDoS 가능 | +| 현재 상태 (DON-80) | **완화됨** — `sql_parser.cpp` 세 함수가 `std::string::find` + 수동 토큰 파싱으로 교체. 추가로 `injection_detector.cpp`에 fast anchor 사전필터 도입 | +| 위험도 | **완화** | +| 잔존 위협 | `injection_detector.cpp`의 `std::regex`는 유지됨 (패턴 수 제한 + 입력 길이 제한 + fast anchor 사전필터로 완화) | + +**fast anchor 사전필터 (DON-80)**: +- `injection_detector.cpp`의 `InjectionDetector::check()`에 `extract_literal_anchor` 기반 사전필터 적용. +- 각 패턴의 가장 긴 리터럴 토큰을 `sql_upper.find()`로 먼저 검색하여 앵커가 없으면 `std::regex_search` 를 건너뜀. +- 정상 쿼리에서 regex 호출 횟수를 90%+ 감소시켜 ReDoS 노출 면적을 축소. +- **alternation 패턴(`|` 포함) 보안 주의**: 앵커 추출 시 하나의 토큰이 모든 대안을 대표할 수 없으므로 `|` 포함 패턴은 앵커를 빈 문자열로 처리하여 regex 를 항상 실행(false negative 방지). 기본 10패턴 중 piggyback 패턴(`;\s*(DROP|...)`)이 해당. + +> 참조: `src/parser/sql_parser.cpp` (DON-80), `src/parser/injection_detector.cpp` (`extract_literal_anchor`, `check()`) + ### 3.3 UDS 비인가 접근 | 항목 | 설명 | @@ -206,6 +224,7 @@ | 2.3 | 변수 간접 참조 | 고 | 미완화 | block_dynamic_sql로 간접 차단 | | 3.1 | 악성 패킷 | 중 | fail-close 적용 | Fuzz 테스트 확대 | | 3.2 | DoS | 중 | 부분 완화 (MAX_CONNECTIONS) | 타임아웃·속도 제한 강화 | +| 3.2.1 | ReDoS | 중 | **완화** (DON-80 — sql_parser std::regex 제거 + injection_detector fast anchor 사전필터) | injection_detector 패턴 수 제한 유지 | | 3.3 | UDS 비인가 접근 | 하 | 파일시스템 권한 보호 | - | | 3.4 | TLS 공격 | 중 | OpenSSL 기반 TLS 지원 | TLS 1.2+ 강제 | | 4.1 | Monitor 모드 + 미등록 사용자 우회 | 고 | **해소** (DON-49 수정) | - | diff --git a/scripts/hooks/pre-commit b/scripts/hooks/pre-commit index 80c7f6f..5bb68f6 100755 --- a/scripts/hooks/pre-commit +++ b/scripts/hooks/pre-commit @@ -24,9 +24,19 @@ if [[ ${#CPP_FILES[@]} -gt 0 ]]; then _CF_MAJOR=$(clang-format --version 2>/dev/null | grep -oE '[0-9]+' | head -1) if [[ -n "$_CF_MAJOR" && "$_CF_MAJOR" -lt 19 ]]; then echo "[pre-commit] ⚠️ clang-format ${_CF_MAJOR} < 19 (Standard: c++23 미지원) — C++ 포맷 검사 skip (CI가 최종 검증)" - elif ! echo "" | clang-format --stdin-filename=test.cpp >/dev/null 2>&1; then - echo "[pre-commit] ⚠️ clang-format 설정 오류(.clang-format 비호환) — C++ 포맷 검사 skip (CI가 최종 검증)" else + _CF_FILENAME_ARG="" + if echo "" | clang-format --stdin-filename=test.cpp >/dev/null 2>&1; then + _CF_FILENAME_ARG="--stdin-filename=test.cpp" + elif echo "" | clang-format --assume-filename=test.cpp >/dev/null 2>&1; then + _CF_FILENAME_ARG="--assume-filename=test.cpp" + fi + + if [[ -z "$_CF_FILENAME_ARG" ]]; then + echo "[pre-commit] ⚠️ clang-format stdin filename 옵션 비호환 — C++ 포맷 검사 skip (CI가 최종 검증)" + elif ! echo "" | clang-format "$_CF_FILENAME_ARG" >/dev/null 2>&1; then + echo "[pre-commit] ⚠️ clang-format 설정 오류(.clang-format 비호환) — C++ 포맷 검사 skip (CI가 최종 검증)" + else CPP_ERRORS=() for f in "${CPP_FILES[@]}"; do if ! clang-format --dry-run --Werror "$f" 2>/dev/null; then @@ -43,6 +53,7 @@ if [[ ${#CPP_FILES[@]} -gt 0 ]]; then else echo "[pre-commit] ✅ C++ 포맷 OK" fi + fi fi fi fi diff --git a/src/logger/structured_logger.cpp b/src/logger/structured_logger.cpp index 007fcfa..b95ec98 100644 --- a/src/logger/structured_logger.cpp +++ b/src/logger/structured_logger.cpp @@ -15,8 +15,7 @@ #include #include #include -#include -#include +#include namespace { @@ -40,10 +39,22 @@ std::string format_iso8601(const std::chrono::system_clock::time_point& tp) { } #endif - std::ostringstream oss; - oss << std::put_time(&tm_val, "%Y-%m-%dT%H:%M:%S") << '.' << std::setfill('0') << std::setw(3) - << millis.count() << 'Z'; - return oss.str(); + // strftime + snprintf로 타임스탬프 포맷 (ostringstream 대비 ~10x 빠름) + std::array date_buf{}; + const auto date_len = + std::strftime(date_buf.data(), date_buf.size(), "%Y-%m-%dT%H:%M:%S", &tm_val); + if (date_len == 0) { + return "1970-01-01T00:00:00.000Z"; + } + + std::array buf{}; + (void)std::snprintf(buf.data(), // NOLINT(cppcoreguidelines-pro-type-vararg) + buf.size(), + "%.*s.%03dZ", + static_cast(date_len), + date_buf.data(), + static_cast(millis.count())); + return {buf.data()}; } // --------------------------------------------------------------------------- @@ -177,15 +188,24 @@ void StructuredLogger::log_connection(const ConnectionLog& entry) { return; } - // JSON 구성 - std::ostringstream json; - json << R"({"event":")" << escape_json_string(entry.event) << R"(","session_id":)" - << entry.session_id << R"(,"client_ip":")" << escape_json_string(entry.client_ip) - << R"(","client_port":)" << entry.client_port << R"(,"db_user":")" - << escape_json_string(entry.db_user) << R"(","timestamp":")" - << format_iso8601(entry.timestamp) << R"("})"; - - logger_->info(json.str()); + std::string json; + json.reserve(256); + + json += R"({"event":")"; + json += escape_json_string(entry.event); + json += R"(","session_id":)"; + json += std::to_string(entry.session_id); + json += R"(,"client_ip":")"; + json += escape_json_string(entry.client_ip); + json += R"(","client_port":)"; + json += std::to_string(entry.client_port); + json += R"(,"db_user":")"; + json += escape_json_string(entry.db_user); + json += R"(","timestamp":")"; + json += format_iso8601(entry.timestamp); + json += R"("})"; + + logger_->info(json); } // --------------------------------------------------------------------------- @@ -196,26 +216,39 @@ void StructuredLogger::log_query(const QueryLog& entry) { return; } - // JSON 구성 - std::ostringstream json; - json << R"({"event":"query","session_id":)" << entry.session_id << R"(,"db_user":")" - << escape_json_string(entry.db_user) << R"(","client_ip":")" - << escape_json_string(entry.client_ip) << R"(","raw_sql":")" - << escape_json_string(entry.raw_sql) << R"(","command_raw":)" - << static_cast(entry.command_raw) << R"(,"tables":[)"; + std::string json; + json.reserve(256 + entry.raw_sql.size()); + + json += R"({"event":"query","session_id":)"; + json += std::to_string(entry.session_id); + json += R"(,"db_user":")"; + json += escape_json_string(entry.db_user); + json += R"(","client_ip":")"; + json += escape_json_string(entry.client_ip); + json += R"(","raw_sql":")"; + json += escape_json_string(entry.raw_sql); + json += R"(","command_raw":)"; + json += std::to_string(static_cast(entry.command_raw)); + json += R"(,"tables":[)"; for (size_t i = 0; i < entry.tables.size(); ++i) { if (i > 0) { - json << ','; + json += ','; } - json << R"(")" << escape_json_string(entry.tables[i]) << R"(")"; + json += '"'; + json += escape_json_string(entry.tables[i]); + json += '"'; } - json << R"(],"action_raw":)" << static_cast(entry.action_raw) << R"(,"timestamp":")" - << format_iso8601(entry.timestamp) << R"(","duration_us":)" << entry.duration.count() - << R"(})"; + json += R"(],"action_raw":)"; + json += std::to_string(static_cast(entry.action_raw)); + json += R"(,"timestamp":")"; + json += format_iso8601(entry.timestamp); + json += R"(","duration_us":)"; + json += std::to_string(entry.duration.count()); + json += '}'; - logger_->info(json.str()); + logger_->info(json); } // --------------------------------------------------------------------------- @@ -226,21 +259,34 @@ void StructuredLogger::log_block(const BlockLog& entry) { return; } - // JSON 구성 // would_block==true: dry-run 모드에서 차단됐을 것임을 나타냄 (실제 차단 아님) const char* event_name = entry.would_block ? "query_would_block" : "query_blocked"; const char* would_block_val = entry.would_block ? "true" : "false"; - std::ostringstream json; - json << R"({"event":")" << event_name << R"(","session_id":)" << entry.session_id - << R"(,"db_user":")" << escape_json_string(entry.db_user) << R"(","client_ip":")" - << escape_json_string(entry.client_ip) << R"(","raw_sql":")" - << escape_json_string(entry.raw_sql) << R"(","matched_rule":")" - << escape_json_string(entry.matched_rule) << R"(","reason":")" - << escape_json_string(entry.reason) << R"(","would_block":)" << would_block_val - << R"(,"timestamp":")" << format_iso8601(entry.timestamp) << R"("})"; - - logger_->warn(json.str()); + std::string json; + json.reserve(256 + entry.raw_sql.size()); + + json += R"({"event":")"; + json += event_name; + json += R"(","session_id":)"; + json += std::to_string(entry.session_id); + json += R"(,"db_user":")"; + json += escape_json_string(entry.db_user); + json += R"(","client_ip":")"; + json += escape_json_string(entry.client_ip); + json += R"(","raw_sql":")"; + json += escape_json_string(entry.raw_sql); + json += R"(","matched_rule":")"; + json += escape_json_string(entry.matched_rule); + json += R"(","reason":")"; + json += escape_json_string(entry.reason); + json += R"(","would_block":)"; + json += would_block_val; + json += R"(,"timestamp":")"; + json += format_iso8601(entry.timestamp); + json += R"("})"; + + logger_->warn(json); } // --------------------------------------------------------------------------- diff --git a/src/parser/injection_detector.cpp b/src/parser/injection_detector.cpp index 72d873e..9391c64 100644 --- a/src/parser/injection_detector.cpp +++ b/src/parser/injection_detector.cpp @@ -43,6 +43,7 @@ #include +#include #include #include #include @@ -63,14 +64,17 @@ struct InjectionDetector::CompiledPattern { std::string source_pattern; // 원본 패턴 문자열 (감사 로그용) std::shared_ptr compiled; // 컴파일된 정규식 std::string reason; // 사람이 읽을 수 있는 탐지 이유 + std::string fast_anchor; // 대문자 리터럴 앵커 (빠른 사전필터용) // CompiledPattern의 소멸자는 여기서 완전하게 정의됨. // shared_ptr의 소멸자는 이 시점에서 완전한 regex 정의를 가진다. ~CompiledPattern() = default; CompiledPattern() = default; - CompiledPattern(std::string src, std::shared_ptr re, std::string rsn) - : source_pattern(std::move(src)), compiled(std::move(re)), reason(std::move(rsn)) {} + CompiledPattern(std::string src, std::shared_ptr re, std::string rsn, + std::string anchor) + : source_pattern(std::move(src)), compiled(std::move(re)), reason(std::move(rsn)), + fast_anchor(std::move(anchor)) {} // 이동 지원 CompiledPattern(CompiledPattern&&) = default; @@ -81,6 +85,56 @@ struct InjectionDetector::CompiledPattern { CompiledPattern& operator=(const CompiledPattern&) = default; }; +// --------------------------------------------------------------------------- +// extract_literal_anchor +// +// 정규식 패턴에서 가장 긴 연속 알파벳/숫자/언더스코어 부분문자열을 추출한다. +// 대문자로 변환하여 반환 (case-insensitive 패턴 매칭 대응). +// +// [사전필터 목적] +// check()에서 regex_search 전에 string::find로 앵커 존재 여부를 확인한다. +// 앵커가 SQL에 없으면 해당 패턴의 regex를 건너뛸 수 있어 정상 쿼리에서 +// regex 호출 횟수를 대폭 줄인다. +// +// [보안 고려사항] +// - 패턴에 '|'(alternation)이 포함된 경우 반드시 빈 문자열을 반환한다. +// 이유: alternation 패턴에서 가장 긴 토큰을 앵커로 쓰면 다른 대안(alternative)이 +// SQL에 있을 때 사전필터가 잘못 skip하여 false negative가 발생한다. +// 예: `;\s*(DROP|...|TRUNCATE)` → 앵커 "TRUNCATE" → `; DROP TABLE` 미탐. +// - 앵커 길이 2 미만이면 빈 문자열 반환 → 사전필터 없이 regex 실행 (기존 동작 유지). +// - 앵커 기반 사전필터는 "skip"이지 "block"이 아니므로: +// 앵커 존재 → regex 실행 → 최종 판정 (false negative 불가) +// 앵커 부재 → regex 건너뜀 → not detected (앵커가 없는 SQL은 패턴 미매칭 보장) +// - 알파벳/숫자/언더스코어만으로 앵커를 제한하여 정규식 메타문자를 배제한다. +// --------------------------------------------------------------------------- +namespace { +std::string extract_literal_anchor(const std::string& pattern) { + // [보안] alternation 패턴은 앵커 추출 불가 — 빈 문자열 반환 + // 하나의 토큰이 모든 대안을 대표할 수 없으므로 false negative 위험 존재. + if (pattern.find('|') != std::string::npos) { + return std::string{}; + } + + std::string best; + std::string current; + for (const char c : pattern) { + if (std::isalnum(static_cast(c)) != 0 || c == '_') { + current += static_cast(std::toupper(static_cast(c))); + } else { + if (current.size() > best.size()) { + best = current; + } + current.clear(); + } + } + if (current.size() > best.size()) { + best = current; + } + // 너무 짧은 앵커는 사전필터 효과 없음 (빈 문자열 반환 → regex 항상 실행) + return best.size() >= 2 ? best : std::string{}; +} +} // namespace + // --------------------------------------------------------------------------- // InjectionDetector 소멸자 // cpp에서 정의하는 이유: CompiledPattern의 완전한 정의 이후에 소멸자가 인스턴스화되어야 @@ -99,7 +153,9 @@ InjectionDetector::InjectionDetector(std::vector patterns) { try { auto re = std::make_shared( p, std::regex_constants::icase | std::regex_constants::ECMAScript); - CompiledPattern cp(p, std::move(re), "Matched injection pattern: " + p); + auto anchor = extract_literal_anchor(p); + CompiledPattern cp(p, std::move(re), "Matched injection pattern: " + p, + std::move(anchor)); compiled_patterns_.push_back(std::move(cp)); } catch (const std::regex_error& e) { @@ -138,12 +194,33 @@ InjectionResult InjectionDetector::check(std::string_view sql) const { .detected = true, .matched_pattern = "", .reason = "no valid patterns loaded"}; } + // SQL 대문자 변환 (앵커 사전필터용 — 1회만 수행) + // + // [성능 의도] + // 각 패턴의 fast_anchor 는 대문자로 저장되어 있다. + // sql_upper 에서 string::find 를 먼저 수행해 앵커가 없으면 regex 를 건너뜀. + // 정상 쿼리(인젝션 키워드 없음)에서 regex 호출 횟수를 90%+ 줄일 수 있다. + // + // [보안 보장] + // 앵커 존재 → regex 실행 → 최종 판정 (false negative 불가) + // 앵커 부재 → regex 건너뜀 → not detected + // 앵커가 비어 있는 패턴은 항상 regex 실행 (기존 동작 유지) + std::string sql_upper; + sql_upper.reserve(sql.size()); + for (const char c : sql) { + sql_upper += static_cast(std::toupper(static_cast(c))); + } + const std::string sql_str(sql); for (const auto& cp : compiled_patterns_) { if (!cp.compiled) { continue; } + // 빠른 사전필터: 앵커가 정의되어 있고 SQL에 없으면 regex 건너뜀 + if (!cp.fast_anchor.empty() && sql_upper.find(cp.fast_anchor) == std::string::npos) { + continue; + } if (std::regex_search(sql_str, *cp.compiled)) { // 첫 번째 매칭 시 즉시 반환 return InjectionResult{ diff --git a/src/parser/sql_parser.cpp b/src/parser/sql_parser.cpp index 202305d..310b2d5 100644 --- a/src/parser/sql_parser.cpp +++ b/src/parser/sql_parser.cpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -122,6 +121,11 @@ std::string_view trim(std::string_view s) { static_cast(end - begin)); } +// 단어 경계 확인용: 알파벳, 숫자, 밑줄 +bool is_word_char(char c) { + return (std::isalnum(static_cast(c)) != 0) || c == '_'; +} + // 정규화된 SQL(대문자, 주석 제거)에서 첫 번째 키워드를 추출한다. // 반환값: 첫 번째 공백-구분 토큰, 없으면 빈 문자열 std::string extract_first_keyword(std::string_view normalized_sql) { @@ -142,14 +146,29 @@ std::string extract_first_keyword(std::string_view normalized_sql) { // START REPLICA / START SLAVE 같은 관리 명령이 BEGIN 허용 경로로 // 우회될 수 있으므로, 정확히 START TRANSACTION일 때만 kBegin으로 분류한다. bool is_start_transaction_statement(std::string_view normalized_sql) { - try { - static const std::regex start_txn_re( - R"(^START\s+TRANSACTION\b)", - std::regex_constants::ECMAScript); - return std::regex_search(normalized_sql.begin(), normalized_sql.end(), start_txn_re); - } catch (const std::regex_error&) { + const auto trimmed_sv = trim(normalized_sql); + // "START TRANSACTION" = 최소 17자 + if (trimmed_sv.size() < 17) { + return false; + } + if (trimmed_sv.substr(0, 5) != "START") { + return false; + } + // START 뒤에 반드시 공백이 하나 이상 있어야 함 + std::size_t pos = 5; + if (std::isspace(static_cast(trimmed_sv[pos])) == 0) { return false; } + while (pos < trimmed_sv.size() && + std::isspace(static_cast(trimmed_sv[pos])) != 0) { + ++pos; + } + // TRANSACTION 키워드 확인 + 단어 경계 + const auto remaining = trimmed_sv.substr(pos); + if (remaining.size() < 11 || remaining.substr(0, 11) != "TRANSACTION") { + return false; + } + return remaining.size() == 11 || !is_word_char(static_cast(remaining[11])); } // 첫 번째 키워드 → SqlCommand 매핑 @@ -203,85 +222,85 @@ void extract_tables_for_keyword(const std::string& normalized_sql, std::string_view original_sql, const std::string& keyword, std::vector& out_tables) { - // keyword 다음에 오는 테이블명(들)을 추출하는 정규식 - // 쉼표 구분 복수 테이블: FROM t1, t2, t3 - // 각 테이블명은 백틱 선택적 포함 - const std::string pattern = keyword + R"(\s+(`?[\w.]+`?(?:\s*,\s*`?[\w.]+`?)*))"; - - try { - const std::regex re(pattern, - std::regex_constants::icase | std::regex_constants::ECMAScript); + // P1-1 수정: 원문 대문자 변환을 루프 밖에서 1회만 수행 + const std::string orig_str(original_sql); + const std::string orig_upper = to_upper(original_sql); + + std::size_t search_pos = 0; + while (true) { + const auto kw_pos = normalized_sql.find(keyword, search_pos); + if (kw_pos == std::string::npos) { + break; + } - auto it = std::sregex_iterator(normalized_sql.begin(), normalized_sql.end(), re); - const auto end_it = std::sregex_iterator(); + // 키워드 앞 단어 경계 확인 + if (kw_pos > 0 && is_word_char(normalized_sql[kw_pos - 1])) { + search_pos = kw_pos + 1; + continue; + } - for (; it != end_it; ++it) { - const std::smatch& m = *it; - if (m.size() < 2) { - continue; - } + const auto after_kw = kw_pos + keyword.size(); + // 키워드 뒤에 반드시 공백이 있어야 함 + if (after_kw >= normalized_sql.size() || + std::isspace(static_cast(normalized_sql[after_kw])) == 0) { + search_pos = kw_pos + 1; + continue; + } - std::string table_list = m[1].str(); + // 공백 건너뜀 + auto pos = after_kw; + while (pos < normalized_sql.size() && + std::isspace(static_cast(normalized_sql[pos])) != 0) { + ++pos; + } - // 쉼표로 분리하여 각 테이블명 처리 - std::size_t pos = 0; - while (pos <= table_list.size()) { - // 앞 공백 건너뜀 - while (pos < table_list.size() && - std::isspace(static_cast(table_list[pos])) != 0) { - ++pos; - } - if (pos >= table_list.size()) { - break; - } + // 쉼표 구분 테이블명 추출 + while (pos < normalized_sql.size()) { + // 앞 공백 건너뜀 + while (pos < normalized_sql.size() && + std::isspace(static_cast(normalized_sql[pos])) != 0) { + ++pos; + } + if (pos >= normalized_sql.size()) { + break; + } - // 쉼표 찾기 - const auto comma = table_list.find(',', pos); - std::string token; - if (comma == std::string::npos) { - token = table_list.substr(pos); - pos = table_list.size() + 1; - } else { - token = table_list.substr(pos, comma - pos); - pos = comma + 1; - } + // 서브쿼리 '(' 시작 → 중단 + if (normalized_sql[pos] == '(') { + break; + } - // 토큰 앞뒤 공백 제거 - const auto trimmed_sv = trim(token); - if (trimmed_sv.empty()) { - continue; - } - std::string trimmed_token(trimmed_sv); + // 테이블명 토큰 추출 (알파벳, 숫자, _, ., `) + const auto token_start = pos; + while (pos < normalized_sql.size() && is_table_name_char(normalized_sql[pos])) { + ++pos; + } + if (pos == token_start) { + break; + } - // 백틱 제거 - if (!trimmed_token.empty() && trimmed_token.front() == '`') { - trimmed_token.erase(0, 1); - } - if (!trimmed_token.empty() && trimmed_token.back() == '`') { - trimmed_token.pop_back(); - } + std::string trimmed_token(normalized_sql, token_start, pos - token_start); - // 서브쿼리 시작 '(' 건너뜀 - if (trimmed_token.empty() || trimmed_token.front() == '(') { - continue; - } + // 백틱 제거 + if (!trimmed_token.empty() && trimmed_token.front() == '`') { + trimmed_token.erase(0, 1); + } + if (!trimmed_token.empty() && trimmed_token.back() == '`') { + trimmed_token.pop_back(); + } + if (!trimmed_token.empty() && trimmed_token.front() != '(') { // 원문 SQL에서 케이스 보존된 이름 추출 const std::string upper_token = to_upper(trimmed_token); - const std::string orig_str(original_sql); - const std::string orig_upper = to_upper(orig_str); - std::string final_name = trimmed_token; - // 원문에서 케이스 보존 추출 시도 (단어 경계 확인) - std::size_t search_from = 0; - while (search_from < orig_upper.size()) { - const auto found = orig_upper.find(upper_token, search_from); + std::size_t orig_from = 0; + while (orig_from < orig_upper.size()) { + const auto found = orig_upper.find(upper_token, orig_from); if (found == std::string::npos) { break; } - // 단어 경계 확인: 앞뒤가 식별자 문자가 아니어야 함 const bool valid_start = (found == 0) || (!is_table_name_char(orig_str[found - 1])); const bool valid_end = @@ -292,7 +311,7 @@ void extract_tables_for_keyword(const std::string& normalized_sql, final_name = orig_str.substr(found, upper_token.size()); break; } - search_from = found + 1; + orig_from = found + 1; } // 중복 추가 방지 (대소문자 무관) @@ -308,21 +327,38 @@ void extract_tables_for_keyword(const std::string& normalized_sql, out_tables.push_back(std::move(final_name)); } } + + // 쉼표 확인 → 다음 테이블 + while (pos < normalized_sql.size() && + std::isspace(static_cast(normalized_sql[pos])) != 0) { + ++pos; + } + if (pos < normalized_sql.size() && normalized_sql[pos] == ',') { + ++pos; + continue; + } + break; } - } catch (const std::regex_error& e) { - spdlog::warn("sql_parser: regex error for keyword '{}': {}", keyword, e.what()); + + search_pos = kw_pos + keyword.size(); } } // 정규화된 SQL에서 "WHERE" 단어 포함 여부 확인 // 단어 경계 적용: ELSEWHERE 같은 단어에서 오탐 방지 bool has_where_keyword(const std::string& normalized_sql) { - try { - const std::regex where_re("\\bWHERE\\b", std::regex_constants::ECMAScript); - return std::regex_search(normalized_sql, where_re); - } catch (const std::regex_error&) { - return false; + static constexpr std::string_view k_where = "WHERE"; + std::size_t pos = 0; + while ((pos = normalized_sql.find(k_where, pos)) != std::string::npos) { + const bool valid_start = (pos == 0) || !is_word_char(normalized_sql[pos - 1]); + const bool valid_end = (pos + k_where.size() >= normalized_sql.size()) || + !is_word_char(normalized_sql[pos + k_where.size()]); + if (valid_start && valid_end) { + return true; + } + ++pos; } + return false; } // --------------------------------------------------------------------------- diff --git a/src/proxy/session.cpp b/src/proxy/session.cpp index 309f2ed..8b4d044 100644 --- a/src/proxy/session.cpp +++ b/src/proxy/session.cpp @@ -90,54 +90,290 @@ Session::Session(std::uint64_t session_id, // --------------------------------------------------------------------------- namespace { -auto read_one_packet(AsyncStream& stream) - -> boost::asio::awaitable> { - std::array header{}; - boost::system::error_code ec; +constexpr std::size_t kMysqlMaxPayloadLen = 0x00FFFFFFU; +constexpr std::size_t kMysqlMaxPacketSize = kMysqlMaxPayloadLen + 4U; +constexpr std::size_t kRetainedBufferCeiling = 262144U; - co_await boost::asio::async_read(stream, - boost::asio::buffer(header), - boost::asio::redirect_error(boost::asio::use_awaitable, ec)); +// --------------------------------------------------------------------------- +// RelayBuffer +// 서버→클라이언트 릴레이 전용 버퍼. +// +// 읽기: async_read_some으로 큰 청크를 읽어 내부 링 버퍼에 누적. +// PacketView는 다음 read_packet / ensure_available 호출 전까지만 유효. +// → payload 검사 후 enqueue를 반드시 같은 코루틴 프레임에서 수행할 것. +// +// 쓰기: enqueue()로 원시 바이트를 출력 버퍼에 모아 flush()로 한 번에 전송. +// --------------------------------------------------------------------------- +class RelayBuffer { +public: + static constexpr std::size_t kInitBufSize = 65536UL; // 64 KB + static constexpr std::size_t kFlushThreshold = 65536UL; // 64 KB + + // PacketView: rbuf_ 내부를 참조하는 제로카피 뷰. + // ensure_available() 호출(= 다음 read_packet)까지만 유효. + struct PacketView { + std::span raw; // header(4) + payload + std::span payload; // header 이후 + std::uint8_t sequence_id{}; + }; - if (ec) { - co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, - .message = "failed to read packet header", - .context = ec.message()}); + explicit RelayBuffer(AsyncStream& stream) + : read_stream_{&stream}, rbuf_(kInitBufSize), wbuf_{} { + wbuf_.reserve(kInitBufSize); } - const std::uint32_t payload_len = static_cast(header[0]) | - (static_cast(header[1]) << 8U) | - (static_cast(header[2]) << 16U); + // 서버에서 패킷 1개 읽기 (버퍼에서 제로카피) + auto read_packet() -> boost::asio::awaitable> { + // 헤더 4바이트 확보 + if (!co_await ensure_available(4)) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, + .message = "failed to read packet header", + .context = "eof or read error"}); + } + + // payload 길이 파싱 (3바이트 LE) + const std::uint32_t payload_len = static_cast(rbuf_[rpos_]) | + (static_cast(rbuf_[rpos_ + 1]) << 8U) | + (static_cast(rbuf_[rpos_ + 2]) << 16U); + + const std::size_t total = 4 + payload_len; + if (total > kMysqlMaxPacketSize) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, + .message = "packet exceeds MySQL max size", + .context = std::format("size={}", total)}); + } - std::vector buf(4 + payload_len); - buf[0] = header[0]; - buf[1] = header[1]; - buf[2] = header[2]; - buf[3] = header[3]; + // 전체 패킷 확보 + if (!co_await ensure_available(total)) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, + .message = "failed to read packet payload", + .context = "eof or read error"}); + } - if (payload_len > 0) { - co_await boost::asio::async_read( - stream, - boost::asio::buffer(buf.data() + 4, payload_len), + // 제로카피 뷰 반환 — 다음 ensure_available 이전까지만 유효 + PacketView view{ + .raw = std::span(rbuf_.data() + rpos_, total), + .payload = std::span(rbuf_.data() + rpos_ + 4, payload_len), + .sequence_id = rbuf_[rpos_ + 3], + }; + rpos_ += total; + + co_return view; + } + + // 원시 바이트를 출력 버퍼에 추가 (복사 1회, 직렬화 불필요) + void enqueue(std::span data) { + wbuf_.insert(wbuf_.end(), data.begin(), data.end()); + } + + // 출력 버퍼가 flush 임계값을 초과하는지 확인 + [[nodiscard]] bool should_flush() const noexcept { return wbuf_.size() >= kFlushThreshold; } + + // 출력 버퍼를 대상 스트림으로 flush + auto flush(AsyncStream& dest) -> boost::asio::awaitable> { + if (wbuf_.empty()) { + co_return std::expected{}; + } + + boost::system::error_code ec; + co_await boost::asio::async_write( + dest, + boost::asio::buffer(wbuf_), boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + wbuf_.clear(); + shrink_write_buffer_if_idle(); + if (ec) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kInternalError, + .message = "failed to flush relay buffer", + .context = ec.message()}); + } + co_return std::expected{}; + } + +private: + // 버퍼에 n바이트 이상이 연속으로 확보될 때까지 읽기를 반복한다. + auto ensure_available(std::size_t n) -> boost::asio::awaitable { + if (n > kMysqlMaxPacketSize) { + co_return false; + } + + while (rend_ - rpos_ < n) { + // 컴팩션: 소비된 앞쪽 공간을 회수한다. + if (rpos_ > 0) { + const std::size_t avail = rend_ - rpos_; + std::memmove(rbuf_.data(), rbuf_.data() + rpos_, avail); + rend_ = avail; + rpos_ = 0; + } + shrink_read_buffer_if_idle(); + + // 남은 공간이 부족하면 버퍼 확장 + const std::size_t need = n - (rend_ - rpos_); + if (rbuf_.size() - rend_ < need) { + const std::size_t grow_to = std::max(rbuf_.size() * 2, rend_ + n); + if (grow_to > kMysqlMaxPacketSize) { + co_return false; + } + rbuf_.resize(grow_to); + } + + boost::system::error_code ec; + const auto bytes = co_await read_stream_->async_read_some( + boost::asio::buffer(rbuf_.data() + rend_, rbuf_.size() - rend_), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec || bytes == 0) { + co_return false; + } + rend_ += bytes; + } + co_return true; + } + + void shrink_read_buffer_if_idle() { + if (rpos_ != rend_ || rbuf_.size() <= kRetainedBufferCeiling) { + return; + } + + rbuf_ = std::vector(kInitBufSize); + rpos_ = 0; + rend_ = 0; + } + + void shrink_write_buffer_if_idle() { + if (!wbuf_.empty() || wbuf_.capacity() <= kRetainedBufferCeiling) { + return; + } + + std::vector shrunk; + shrunk.reserve(kInitBufSize); + wbuf_.swap(shrunk); + } + + AsyncStream* read_stream_; + std::vector rbuf_; + std::size_t rpos_{0}; // 읽기 시작 위치 + std::size_t rend_{0}; // 유효 데이터 끝 위치 + + std::vector wbuf_; +}; + +// --------------------------------------------------------------------------- +// ClientReadBuffer: 클라이언트 스트림에서 패킷을 읽기 위한 버퍼. +// read_one_packet의 2회 async_read(header + payload)를 1회 async_read_some로 줄인다. +// --------------------------------------------------------------------------- +class ClientReadBuffer { +public: + static constexpr std::size_t kBufSize = 16384UL; // 16 KB (COM_QUERY 패킷은 보통 < 1KB) + + explicit ClientReadBuffer(AsyncStream& stream) : stream_{&stream}, buf_(kBufSize) {} + + auto read_packet() -> boost::asio::awaitable> { + // 헤더 4바이트 확보 + if (!co_await ensure_available(4)) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, + .message = "failed to read packet header", + .context = "eof or read error"}); + } + + const std::uint32_t payload_len = static_cast(buf_[pos_]) | + (static_cast(buf_[pos_ + 1]) << 8U) | + (static_cast(buf_[pos_ + 2]) << 16U); + + const std::size_t total = 4 + payload_len; + if (total > kMysqlMaxPacketSize) { + co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, + .message = "packet exceeds MySQL max size", + .context = std::format("size={}", total)}); + } + + if (!co_await ensure_available(total)) { co_return std::unexpected(ParseError{.code = ParseErrorCode::kMalformedPacket, .message = "failed to read packet payload", - .context = ec.message()}); + .context = "eof or read error"}); } + + auto result = MysqlPacket::parse(std::span(buf_.data() + pos_, total)); + pos_ += total; + + co_return result; } - co_return MysqlPacket::parse(std::span{buf}); -} +private: + auto ensure_available(std::size_t n) -> boost::asio::awaitable { + if (n > kMysqlMaxPacketSize) { + co_return false; + } + + while (end_ - pos_ < n) { + // 컴팩션 + if (pos_ > 0) { + const std::size_t avail = end_ - pos_; + if (avail > 0) { + std::memmove(buf_.data(), buf_.data() + pos_, avail); + } + end_ = avail; + pos_ = 0; + } + shrink_if_idle(); + if (buf_.size() - end_ < n - (end_ - pos_)) { + const std::size_t grow_to = std::max(buf_.size() * 2, end_ + n); + if (grow_to > kMysqlMaxPacketSize) { + co_return false; + } + buf_.resize(grow_to); + } + + boost::system::error_code ec; + const auto bytes = co_await stream_->async_read_some( + boost::asio::buffer(buf_.data() + end_, buf_.size() - end_), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec || bytes == 0) { + co_return false; + } + end_ += bytes; + } + co_return true; + } + + void shrink_if_idle() { + if (pos_ != end_ || buf_.size() <= kRetainedBufferCeiling) { + return; + } + + buf_ = std::vector(kBufSize); + pos_ = 0; + end_ = 0; + } + + AsyncStream* stream_; + std::vector buf_; + std::size_t pos_{0}; + std::size_t end_{0}; +}; auto write_packet_raw(AsyncStream& stream, const MysqlPacket& pkt) -> boost::asio::awaitable> { - const auto bytes = pkt.serialize(); - boost::system::error_code ec; + // scatter-gather I/O: 헤더를 스택에 구성하고 payload 와 함께 한 번에 전송 + // → pkt.serialize() 의 벡터 할당+복사 제거 + const auto len = pkt.payload_length(); + std::array header{ + static_cast(len & 0xFFU), + static_cast((len >> 8U) & 0xFFU), + static_cast((len >> 16U) & 0xFFU), + pkt.sequence_id(), + }; + + const std::array bufs{ + boost::asio::buffer(header), + boost::asio::buffer(pkt.payload().data(), pkt.payload().size()), + }; - co_await boost::asio::async_write(stream, - boost::asio::buffer(bytes), - boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + boost::system::error_code ec; + co_await boost::asio::async_write( + stream, bufs, boost::asio::redirect_error(boost::asio::use_awaitable, ec)); if (ec) { co_return std::unexpected(ParseError{.code = ParseErrorCode::kInternalError, @@ -252,52 +488,55 @@ auto is_metadata_terminator_packet(std::span payload) -> boo is_resultset_final_ok_packet(payload); } -auto relay_stmt_prepare_section(AsyncStream& server_stream, - AsyncStream& client_stream, - std::uint16_t count, - std::uint64_t session_id) +auto relay_stmt_prepare_section(RelayBuffer& relay, std::uint16_t count, std::uint64_t session_id) -> boost::asio::awaitable> { for (std::uint16_t i = 0; i < count; ++i) { - auto def_pkt_result = co_await read_one_packet(server_stream); - if (!def_pkt_result) { - co_return std::unexpected(def_pkt_result.error()); - } - - auto wr = co_await write_packet_raw(client_stream, *def_pkt_result); - if (!wr) { - co_return std::unexpected(wr.error()); + auto def_result = co_await relay.read_packet(); + if (!def_result) { + co_return std::unexpected(def_result.error()); } + relay.enqueue(def_result->raw); } - auto term_pkt_result = co_await read_one_packet(server_stream); - if (!term_pkt_result) { - co_return std::unexpected(term_pkt_result.error()); + // terminator packet + auto term_result = co_await relay.read_packet(); + if (!term_result) { + co_return std::unexpected(term_result.error()); } - auto wr = co_await write_packet_raw(client_stream, *term_pkt_result); - if (!wr) { - co_return std::unexpected(wr.error()); - } + // payload 검사는 enqueue 전에, span이 아직 유효한 상태에서 수행 + const auto payload = term_result->payload; + const bool valid_terminator = is_metadata_terminator_packet(payload); + const std::uint8_t first_byte_log = payload.empty() ? 0U : payload[0]; + const std::size_t payload_len_log = payload.size(); + + relay.enqueue(term_result->raw); - const auto payload = term_pkt_result->payload(); - if (!is_metadata_terminator_packet(payload)) { + if (!valid_terminator) { spdlog::warn("[session {}] unexpected COM_STMT_PREPARE terminator: 0x{:02x} (len={})", session_id, - payload.empty() ? 0U : static_cast(payload[0]), - payload.size()); + static_cast(first_byte_log), + payload_len_log); } co_return std::expected{}; } -} // namespace - // --------------------------------------------------------------------------- -// relay_server_response -// MySQL 서버 응답(OK / ERR / Result Set)이 완료될 때까지 읽어 클라이언트에 릴레이. +// relay_server_response_buffered +// MySQL 서버 응답(OK / ERR / Result Set)이 완료될 때까지 RelayBuffer로 읽어 +// 클라이언트에 배치 릴레이한다. +// +// 제로카피 원칙: +// - 서버에서 읽은 원시 바이트는 MysqlPacket 파싱/직렬화 없이 그대로 전달. +// - PacketView.payload 참조는 enqueue 전에 필요한 값을 지역 변수로 복사한다 +// (ensure_available이 compact/realloc 시 span이 무효화되므로). // --------------------------------------------------------------------------- -auto Session::relay_server_response(CommandType request_type, - [[maybe_unused]] std::uint8_t request_seq_id) +auto relay_server_response_buffered(RelayBuffer& relay, + AsyncStream& client_stream, + CommandType request_type, + [[maybe_unused]] std::uint8_t request_seq_id, + std::uint64_t session_id) -> boost::asio::awaitable> { enum class ResponseState { // NOLINT(performance-enum-size) kFirst, // 첫 패킷 분석 중 @@ -306,135 +545,158 @@ auto Session::relay_server_response(CommandType request_type, kDone, // 응답 완료 }; - // 첫 패킷으로 응답 유형 판별 - auto first_pkt_result = co_await read_one_packet(server_stream_); - if (!first_pkt_result) { - co_return std::unexpected(first_pkt_result.error()); - } - - const MysqlPacket& first_pkt = *first_pkt_result; - const auto first_payload = first_pkt.payload(); - - // 첫 패킷을 클라이언트에 전달 - auto wr = co_await write_packet_raw(client_stream_, first_pkt); - if (!wr) { - co_return std::unexpected(wr.error()); + // --- 첫 패킷 --- + auto first_result = co_await relay.read_packet(); + if (!first_result) { + co_return std::unexpected(first_result.error()); } - bool has_first_byte = false; + // payload 검사 전에 필요한 값을 지역 변수로 추출 (span 유효 구간 내) + const std::uint8_t first_seq_id = first_result->sequence_id; + const bool first_payload_empty = first_result->payload.empty(); std::uint8_t first_byte = 0; - for (const auto byte : first_payload) { - first_byte = byte; - has_first_byte = true; - break; + std::size_t first_payload_size = 0; + std::uint16_t num_columns = 0; + std::uint16_t num_params = 0; + + if (!first_payload_empty) { + first_byte = first_result->payload[0]; + first_payload_size = first_result->payload.size(); + + // COM_STMT_PREPARE OK 파싱 (payload[5..8]) — enqueue 전에 수행 + if (first_byte == 0x00 && request_type == CommandType::kComStmtPrepare && + first_payload_size >= 12) { + num_columns = static_cast(first_result->payload[5]) | + (static_cast(first_result->payload[6]) << 8U); + num_params = static_cast(first_result->payload[7]) | + (static_cast(first_result->payload[8]) << 8U); + } } - if (!has_first_byte) { - co_return std::expected{}; + + relay.enqueue(first_result->raw); // enqueue 후에는 first_result->payload 참조 금지 + + if (first_payload_empty) { + co_return co_await relay.flush(client_stream); } - // ERR 패킷 (0xFF) → 즉시 완료 + // ERR (0xFF) → 즉시 flush if (first_byte == 0xFF) { - co_return std::expected{}; + co_return co_await relay.flush(client_stream); } - // OK 패킷 (0x00) + // OK (0x00) if (first_byte == 0x00) { if (request_type == CommandType::kComStmtPrepare) { - if (first_payload.size() < 12) { + if (first_payload_size < 12) { spdlog::warn("[session {}] short COM_STMT_PREPARE OK payload: {} bytes", - session_id_, - first_payload.size()); - co_return std::expected{}; + session_id, + first_payload_size); + co_return co_await relay.flush(client_stream); } - const std::uint16_t num_columns = static_cast(first_payload[5]) | - (static_cast(first_payload[6]) << 8U); - const std::uint16_t num_params = static_cast(first_payload[7]) | - (static_cast(first_payload[8]) << 8U); - if (num_params > 0) { - auto params_result = co_await relay_stmt_prepare_section( - server_stream_, client_stream_, num_params, session_id_); - if (!params_result) { - co_return std::unexpected(params_result.error()); + auto r = co_await relay_stmt_prepare_section(relay, num_params, session_id); + if (!r) { + co_return std::unexpected(r.error()); } } - if (num_columns > 0) { - auto columns_result = co_await relay_stmt_prepare_section( - server_stream_, client_stream_, num_columns, session_id_); - if (!columns_result) { - co_return std::unexpected(columns_result.error()); + auto r = co_await relay_stmt_prepare_section(relay, num_columns, session_id); + if (!r) { + co_return std::unexpected(r.error()); } } } - co_return std::expected{}; + co_return co_await relay.flush(client_stream); } - // EOF 패킷 (0xFE, payload.size() < 9) → 즉시 완료 (비정상) - if (first_byte == 0xFE && first_payload.size() < 9) { - co_return std::expected{}; + // EOF (0xFE, size < 9) → 즉시 flush (비정상) + if (first_byte == 0xFE && first_payload_size < 9) { + co_return co_await relay.flush(client_stream); } - // LOCAL_INFILE 요청 (0xFB) + // LOCAL_INFILE (0xFB) if (first_byte == 0xFB) { spdlog::warn("[session {}] unsupported LOCAL_INFILE response (0xFB) from server", - session_id_); + session_id); + // 이미 enqueue된 패킷을 flush 후 에러 반환 + auto f = co_await relay.flush(client_stream); + if (!f) { + co_return std::unexpected(f.error()); + } co_return std::unexpected(ParseError{.code = ParseErrorCode::kUnsupportedCommand, .message = "LOCAL_INFILE response is not supported", .context = "server response first byte = 0xFB"}); } - // Result Set: 첫 바이트가 column count (0x01~0xFC) + // Result Set: 첫 바이트 = column count (0x01~0xFC) if (first_byte < 0x01 || first_byte > 0xFC) { spdlog::warn( - "[session {}] unexpected first byte in response: 0x{:02x}", session_id_, first_byte); - co_return std::expected{}; + "[session {}] unexpected first byte in response: 0x{:02x}", session_id, first_byte); + co_return co_await relay.flush(client_stream); } const std::uint8_t column_count = first_byte; std::uint8_t column_defs_read = 0; ResponseState state = ResponseState::kColumnDefs; - std::uint8_t prev_seq_id = first_pkt.sequence_id(); + std::uint8_t prev_seq_id = first_seq_id; while (state != ResponseState::kDone) { - auto pkt_result = co_await read_one_packet(server_stream_); + auto pkt_result = co_await relay.read_packet(); if (!pkt_result) { co_return std::unexpected(pkt_result.error()); } - const MysqlPacket& pkt = *pkt_result; - const auto payload = pkt.payload(); + // PacketView 정보 추출 — enqueue 전에 수행 (span 유효 구간) + const std::uint8_t pkt_seq_id = pkt_result->sequence_id; + const bool pkt_payload_empty = pkt_result->payload.empty(); + std::uint8_t byte0 = 0; + std::size_t pkt_payload_size = 0; + bool is_row_or_coldef = false; - auto w = co_await write_packet_raw(client_stream_, pkt); - if (!w) { - co_return std::unexpected(w.error()); + if (!pkt_payload_empty) { + byte0 = pkt_result->payload[0]; + pkt_payload_size = pkt_result->payload.size(); } - if (payload.empty()) { - break; + // kRows 상태의 최종 OK 판별 — payload span 유효 구간 내에서만 가능 + if (!pkt_payload_empty && state == ResponseState::kRows && byte0 == 0x00) { + is_row_or_coldef = is_text_row_packet(pkt_result->payload, column_count) || + !is_resultset_final_ok_packet(pkt_result->payload); + } + + relay.enqueue(pkt_result->raw); // enqueue 후 pkt_result->payload 참조 금지 + + // 큰 result set 중간 flush + if (relay.should_flush()) { + auto f = co_await relay.flush(client_stream); + if (!f) { + co_return std::unexpected(f.error()); + } } - const std::uint8_t byte0 = payload[0]; + if (pkt_payload_empty) { + break; + } if (byte0 == 0xFF) { state = ResponseState::kDone; continue; } - if (pkt.sequence_id() < prev_seq_id && prev_seq_id != 0xFF) { + if (pkt_seq_id < prev_seq_id && prev_seq_id != 0xFF) { spdlog::warn("[session {}] seq_id reversed ({} -> {}), stopping relay", - session_id_, + session_id, prev_seq_id, - pkt.sequence_id()); + pkt_seq_id); state = ResponseState::kDone; continue; } - prev_seq_id = pkt.sequence_id(); + prev_seq_id = pkt_seq_id; switch (state) { - case ResponseState::kColumnDefs: { - if (byte0 == 0xFE && payload.size() < 9) { + case ResponseState::kColumnDefs: + if (byte0 == 0xFE && pkt_payload_size < 9) { state = ResponseState::kRows; } else if (byte0 == 0xFF) { state = ResponseState::kDone; @@ -442,21 +704,19 @@ auto Session::relay_server_response(CommandType request_type, ++column_defs_read; if (column_defs_read > column_count + 1) { spdlog::warn("[session {}] too many column definitions: {} > {}", - session_id_, + session_id, column_defs_read, column_count); state = ResponseState::kDone; } } break; - } case ResponseState::kRows: { - // EOF/ERR packet, or binary-protocol final OK packet — end of result set - const bool eof_or_err = (byte0 == 0xFE && payload.size() < 9) || byte0 == 0xFF; - const bool final_ok = request_type == CommandType::kComQuery && byte0 == 0x00 && - !is_text_row_packet(payload, column_count) && - is_resultset_final_ok_packet(payload); + const bool eof_or_err = (byte0 == 0xFE && pkt_payload_size < 9) || byte0 == 0xFF; + // is_row_or_coldef: true이면 일반 row/coldef → 아직 결과 진행 중 + const bool final_ok = + request_type == CommandType::kComQuery && byte0 == 0x00 && !is_row_or_coldef; if (eof_or_err || final_ok) { state = ResponseState::kDone; } @@ -469,9 +729,11 @@ auto Session::relay_server_response(CommandType request_type, } } - co_return std::expected{}; + co_return co_await relay.flush(client_stream); } +} // namespace + // --------------------------------------------------------------------------- // Session::run // --------------------------------------------------------------------------- @@ -716,13 +978,18 @@ auto Session::run() -> boost::asio::awaitable { // ----------------------------------------------------------------------- // 8. 커맨드 루프 + // RelayBuffer는 server_stream_ 위에서 서버 응답을 배치로 읽어 클라이언트에 전달. + // 세션 전체 수명 동안 재사용하여 재할당을 최소화한다. // ----------------------------------------------------------------------- + RelayBuffer server_relay(server_stream_); + ClientReadBuffer client_reader(client_stream_); + while (true) { if (closing_.load(std::memory_order_acquire)) { break; } - auto pkt_result = co_await read_one_packet(client_stream_); + auto pkt_result = co_await client_reader.read_packet(); if (!pkt_result) { const auto& err = pkt_result.error(); @@ -756,12 +1023,8 @@ auto Session::run() -> boost::asio::awaitable { // --------------------------------------------------------------- if (cmd.command_type == CommandType::kComQuit) { spdlog::debug("[session {}] COM_QUIT received", session_id_); - boost::system::error_code fwd_ec; - const auto quit_bytes = pkt.serialize(); - co_await boost::asio::async_write( - server_stream_, - boost::asio::buffer(quit_bytes), - boost::asio::redirect_error(boost::asio::use_awaitable, fwd_ec)); + auto fwd = co_await write_packet_raw(server_stream_, pkt); + (void)fwd; // QUIT 실패해도 세션 종료 break; } @@ -786,8 +1049,9 @@ auto Session::run() -> boost::asio::awaitable { } else { const ParsedQuery& parsed = *parse_result; - [[maybe_unused]] const auto inj_result = injection_detector_.check(cmd.query); - [[maybe_unused]] const auto proc_result = proc_detector_.detect(parsed); + // NOTE: injection_detector_.check() 및 proc_detector_.detect() 결과는 + // 현재 정책 엔진에서 사용하지 않으므로 호출을 제거하여 오버헤드 절감. + // 향후 정책에 통합 시 여기서 재활성화한다. policy_result = policy_->evaluate(parsed, ctx_); } @@ -849,8 +1113,8 @@ auto Session::run() -> boost::asio::awaitable { break; } - auto relay_result = - co_await relay_server_response(cmd.command_type, cmd.sequence_id); + auto relay_result = co_await relay_server_response_buffered( + server_relay, client_stream_, cmd.command_type, cmd.sequence_id, session_id_); if (!relay_result) { spdlog::warn("[session {}] relay_server_response failed: {}", session_id_, @@ -928,7 +1192,8 @@ auto Session::run() -> boost::asio::awaitable { break; } - auto relay_result = co_await relay_server_response(cmd.command_type, cmd.sequence_id); + auto relay_result = co_await relay_server_response_buffered( + server_relay, client_stream_, cmd.command_type, cmd.sequence_id, session_id_); if (!relay_result) { spdlog::warn("[session {}] failed to relay server response: {}", session_id_, diff --git a/src/proxy/session.hpp b/src/proxy/session.hpp index 6dd1845..ed49a9c 100644 --- a/src/proxy/session.hpp +++ b/src/proxy/session.hpp @@ -146,8 +146,5 @@ class Session : public std::enable_shared_from_this { // close() 중복 호출 방지용 atomic 플래그 std::atomic closing_{false}; - // relay_server_response 헬퍼 - // MySQL 서버 응답(Result Set / OK / ERR)이 완료될 때까지 읽어 클라이언트에 릴레이. - auto relay_server_response(CommandType request_type, std::uint8_t request_seq_id) - -> boost::asio::awaitable>; + }; diff --git a/tests/test_uds_server.cpp b/tests/test_uds_server.cpp index 92276a0..3cf1587 100644 --- a/tests/test_uds_server.cpp +++ b/tests/test_uds_server.cpp @@ -17,7 +17,7 @@ // - run() 전 stop() 호출 → 크래시/hang 없음 // // [테스트 패턴] -// - 각 테스트는 임시 소켓 경로(/tmp/test_uds__.sock)를 사용한다. +// - 각 테스트는 워크스페이스 내부 임시 소켓 경로를 사용한다. // - UdsServer 를 서버 전용 io_context 에서 백그라운드 스레드로 구동한다. // - 클라이언트는 별도 io_context 의 동기 소켓(sync connect/write/read) 사용. // @@ -33,7 +33,9 @@ // --------------------------------------------------------------------------- #include +#include #include +#include #include #include @@ -86,12 +88,49 @@ std::shared_ptr make_explain_policy_config() { return cfg; } +std::filesystem::path test_socket_dir() { + const auto dir = std::filesystem::current_path() / "test-uds"; + std::filesystem::create_directories(dir); + return dir; +} + +bool uds_bind_supported() { + static const bool supported = []() { + const auto probe_path = test_socket_dir() / "probe.sock"; + (void)std::filesystem::remove(probe_path); + + const int fd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (fd < 0) { + return false; + } + + sockaddr_un addr{}; // NOLINT(cppcoreguidelines-pro-type-member-init,hicpp-member-init) + addr.sun_family = AF_UNIX; + const auto path_str = probe_path.string(); + if (path_str.size() >= sizeof(addr.sun_path)) { + (void)::close(fd); + return false; + } + std::memcpy(addr.sun_path, path_str.c_str(), path_str.size() + 1); + + const bool ok = + (::bind(fd, + reinterpret_cast(&addr), + sizeof(addr)) == 0); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) + + (void)::close(fd); + (void)std::filesystem::remove(probe_path); + return ok; + }(); + + return supported; +} + // 임시 소켓 경로 생성 (PID + 단조 카운터로 테스트 간 충돌 방지) std::filesystem::path temp_socket_path(const char* tag) { static std::atomic counter{0}; - return std::filesystem::path("/tmp") / - ("test_uds_" + std::to_string(::getpid()) + "_" + std::to_string(counter.fetch_add(1)) + - "_" + tag + ".sock"); + return test_socket_dir() / ("test_uds_" + std::to_string(::getpid()) + "_" + + std::to_string(counter.fetch_add(1)) + "_" + tag + ".sock"); } // encode_le4: uint32_t → 4바이트 little-endian 배열 @@ -206,6 +245,9 @@ struct UdsSyncClient { class UdsServerTest : public ::testing::Test { protected: void SetUp() override { + if (!uds_bind_supported()) { + GTEST_SKIP() << "Unix domain socket bind not permitted in this environment"; + } socket_path_ = temp_socket_path("srv"); stats_ = std::make_shared(); ioc_ = std::make_unique(); @@ -229,7 +271,9 @@ class UdsServerTest : public ::testing::Test { if (server_) { server_->stop(); } - ioc_->stop(); + if (ioc_) { + ioc_->stop(); + } if (server_thread_.joinable()) { server_thread_.join(); } @@ -472,8 +516,7 @@ TEST_F(UdsServerTest, CommandField_InjectedInsideStringValue_UsesTopLevelCommand const std::string resp = client.recv(); ASSERT_FALSE(resp.empty()) << "stats command must return a non-empty response"; - EXPECT_NE(resp.find(R"("ok":true)"), std::string::npos) - << "stats must succeed. Got: " << resp; + EXPECT_NE(resp.find(R"("ok":true)"), std::string::npos) << "stats must succeed. Got: " << resp; // stats 응답에는 "payload" 필드가 있어야 함 EXPECT_NE(resp.find(R"("payload")"), std::string::npos) << "stats response must contain 'payload' field. Got: " << resp; @@ -502,8 +545,7 @@ TEST_F(UdsServerTest, CommandField_InjectedInsideNestedObject_UsesTopLevelComman const std::string resp = client.recv(); ASSERT_FALSE(resp.empty()) << "stats command must return a non-empty response"; - EXPECT_NE(resp.find(R"("ok":true)"), std::string::npos) - << "stats must succeed. Got: " << resp; + EXPECT_NE(resp.find(R"("ok":true)"), std::string::npos) << "stats must succeed. Got: " << resp; // stats 응답에는 "payload" 필드가 있어야 함 EXPECT_NE(resp.find(R"("payload")"), std::string::npos) << "stats response must contain 'payload' field. Got: " << resp; @@ -521,6 +563,9 @@ TEST_F(UdsServerTest, CommandField_InjectedInsideNestedObject_UsesTopLevelComman class UdsPolicyExplainTest : public ::testing::Test { protected: void SetUp() override { + if (!uds_bind_supported()) { + GTEST_SKIP() << "Unix domain socket bind not permitted in this environment"; + } socket_path_ = temp_socket_path("explain"); stats_ = std::make_shared(); ioc_ = std::make_unique(); @@ -548,7 +593,9 @@ class UdsPolicyExplainTest : public ::testing::Test { if (server_) { server_->stop(); } - ioc_->stop(); + if (ioc_) { + ioc_->stop(); + } if (server_thread_.joinable()) { server_thread_.join(); } @@ -863,6 +910,9 @@ TEST_F(UdsServerTest, Stats_ContainsMonitoredBlocks) { class UdsPolicyVersioningTest : public ::testing::Test { protected: void SetUp() override { + if (!uds_bind_supported()) { + GTEST_SKIP() << "Unix domain socket bind not permitted in this environment"; + } socket_path_ = temp_socket_path("versioning"); stats_ = std::make_shared(); ioc_ = std::make_unique(); @@ -906,7 +956,9 @@ class UdsPolicyVersioningTest : public ::testing::Test { if (server_) { server_->stop(); } - ioc_->stop(); + if (ioc_) { + ioc_->stop(); + } if (server_thread_.joinable()) { server_thread_.join(); } @@ -1136,7 +1188,8 @@ TEST_F(UdsPolicyVersioningTest, PolicyRollback_TargetVersionOutsidePayload_IsIgn // payload 문자열 리터럴에 포함된 \"target_version\":N 패턴이 // 실제 payload.target_version 값을 덮어쓰면 안 된다. // --------------------------------------------------------------------------- -TEST_F(UdsPolicyVersioningTest, PolicyRollback_TargetVersionStringLiteral_DoesNotOverridePayloadField) { +TEST_F(UdsPolicyVersioningTest, + PolicyRollback_TargetVersionStringLiteral_DoesNotOverridePayloadField) { const auto tmp_policy_path = version_dir_ / "rollback_literal_policy.yaml"; { std::ofstream f(tmp_policy_path); @@ -1194,7 +1247,8 @@ TEST_F(UdsPolicyVersioningTest, PolicyRollback_TargetVersionStringLiteral_DoesNo // uint64 범위를 넘는 target_version 입력은 거부되어야 하며, // 현재 정책 버전이 변경되면 안 된다. // --------------------------------------------------------------------------- -TEST_F(UdsPolicyVersioningTest, PolicyRollback_TargetVersionOverflow_ReturnsErrorAndKeepsCurrentPolicy) { +TEST_F(UdsPolicyVersioningTest, + PolicyRollback_TargetVersionOverflow_ReturnsErrorAndKeepsCurrentPolicy) { const auto tmp_policy_path = version_dir_ / "rollback_overflow_policy.yaml"; { std::ofstream f(tmp_policy_path); @@ -1522,7 +1576,7 @@ TEST_F(UdsServerTest, SocketPermissions_0600) { start_server(); ASSERT_TRUE(wait_for_socket()); - struct stat st {}; + struct stat st{}; ASSERT_EQ(::stat(socket_path_.c_str(), &st), 0) << "stat() failed on socket file"; // 하위 9비트만 비교 (owner/group/other rwx) const auto perms = st.st_mode & 0777;