|
9 | 9 | Note on DNS resolution: |
10 | 10 | Some networks have issues with gRPC's default c-ares DNS resolver. |
11 | 11 | The adapter sets GRPC_DNS_RESOLVER=native if not already set. |
| 12 | +
|
| 13 | +Note on dialect support: |
| 14 | + Spanner supports two SQL dialects: GoogleSQL and PostgreSQL. |
| 15 | + The adapter detects the dialect on connect and adjusts identifier |
| 16 | + quoting accordingly (backticks for GoogleSQL, double quotes for PostgreSQL). |
| 17 | + PostgreSQL dialect support is experimental and untested. |
12 | 18 | """ |
13 | 19 |
|
14 | 20 | from __future__ import annotations |
|
27 | 33 | if TYPE_CHECKING: |
28 | 34 | from sqlit.domains.connections.domain.config import ConnectionConfig |
29 | 35 |
|
| 36 | +# Dialect constants |
| 37 | +DIALECT_GOOGLESQL = "GOOGLE_STANDARD_SQL" |
| 38 | +DIALECT_POSTGRESQL = "POSTGRESQL" |
| 39 | + |
30 | 40 |
|
31 | 41 | class SpannerAdapter(CursorBasedAdapter): |
32 | 42 | """Adapter for Google Cloud Spanner.""" |
@@ -134,8 +144,37 @@ def connect(self, config: ConnectionConfig) -> Any: |
134 | 144 | # Store config for later use |
135 | 145 | conn._sqlit_spanner_database = database |
136 | 146 |
|
| 147 | + # Detect and store the database dialect (GoogleSQL or PostgreSQL) |
| 148 | + conn._sqlit_spanner_dialect = self._detect_dialect(conn) |
| 149 | + |
137 | 150 | return conn |
138 | 151 |
|
| 152 | + def _detect_dialect(self, conn: Any) -> str: |
| 153 | + """Detect the database dialect (GoogleSQL or PostgreSQL). |
| 154 | +
|
| 155 | + Queries INFORMATION_SCHEMA.DATABASE_OPTIONS to determine which SQL |
| 156 | + dialect the database uses. This affects identifier quoting. |
| 157 | + """ |
| 158 | + query = """ |
| 159 | + SELECT OPTION_VALUE |
| 160 | + FROM INFORMATION_SCHEMA.DATABASE_OPTIONS |
| 161 | + WHERE OPTION_NAME = 'database_dialect' |
| 162 | + """ |
| 163 | + rows = self._execute_readonly(conn, query) |
| 164 | + if rows and rows[0]: |
| 165 | + return str(rows[0][0]) |
| 166 | + # If we can't detect, raise an error (no fallback) |
| 167 | + msg = "Could not detect Spanner database dialect" |
| 168 | + raise ValueError(msg) |
| 169 | + |
| 170 | + def _get_dialect(self, conn: Any) -> str: |
| 171 | + """Get the cached dialect for a connection.""" |
| 172 | + dialect = getattr(conn, "_sqlit_spanner_dialect", None) |
| 173 | + if dialect is None: |
| 174 | + msg = "Spanner dialect not detected on connection" |
| 175 | + raise ValueError(msg) |
| 176 | + return dialect |
| 177 | + |
139 | 178 | def get_databases(self, conn: Any) -> list[str]: |
140 | 179 | """Return the connected database (Spanner is single-database per connection).""" |
141 | 180 | database = getattr(conn, "_sqlit_spanner_database", None) |
@@ -292,12 +331,48 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence |
292 | 331 | """Spanner doesn't support traditional sequences.""" |
293 | 332 | return [] |
294 | 333 |
|
| 334 | + def _quote_identifier_for_dialect(self, dialect: str, name: str) -> str: |
| 335 | + """Quote an identifier based on dialect. |
| 336 | +
|
| 337 | + - GoogleSQL: `identifier` (backticks) |
| 338 | + - PostgreSQL: "identifier" (double quotes) |
| 339 | + """ |
| 340 | + if dialect == DIALECT_POSTGRESQL: |
| 341 | + # PostgreSQL dialect uses double quotes |
| 342 | + escaped = name.replace('"', '""') |
| 343 | + return f'"{escaped}"' |
| 344 | + # GoogleSQL uses backticks |
| 345 | + escaped = name.replace("`", "\\`") |
| 346 | + return f"`{escaped}`" |
| 347 | + |
| 348 | + def _quote_identifier_for_conn(self, conn: Any, name: str) -> str: |
| 349 | + """Quote an identifier using the connection's dialect.""" |
| 350 | + dialect = self._get_dialect(conn) |
| 351 | + return self._quote_identifier_for_dialect(dialect, name) |
| 352 | + |
295 | 353 | def quote_identifier(self, name: str) -> str: |
296 | | - """Quote an identifier for GoogleSQL (backticks).""" |
297 | | - return f"`{name}`" |
| 354 | + """Quote an identifier for GoogleSQL (backticks). |
| 355 | +
|
| 356 | + Note: This method doesn't have access to the connection, so it always |
| 357 | + uses GoogleSQL syntax. For connection-aware quoting, use |
| 358 | + _quote_identifier_for_conn() instead. |
| 359 | + """ |
| 360 | + return self._quote_identifier_for_dialect(DIALECT_GOOGLESQL, name) |
298 | 361 |
|
299 | 362 | def build_select_query( |
300 | 363 | self, table: str, limit: int, database: str | None = None, schema: str | None = None |
301 | 364 | ) -> str: |
302 | | - """Build SELECT query with LIMIT.""" |
303 | | - return f"SELECT * FROM `{table}` LIMIT {limit}" |
| 365 | + """Build SELECT query with LIMIT. |
| 366 | +
|
| 367 | + Note: This method doesn't have access to the connection, so it always |
| 368 | + uses GoogleSQL syntax for identifier quoting. |
| 369 | + """ |
| 370 | + quoted = self._quote_identifier_for_dialect(DIALECT_GOOGLESQL, table) |
| 371 | + return f"SELECT * FROM {quoted} LIMIT {limit}" |
| 372 | + |
| 373 | + def build_select_query_for_conn( |
| 374 | + self, conn: Any, table: str, limit: int, database: str | None = None, schema: str | None = None |
| 375 | + ) -> str: |
| 376 | + """Build SELECT query with LIMIT using connection-aware quoting.""" |
| 377 | + quoted = self._quote_identifier_for_conn(conn, table) |
| 378 | + return f"SELECT * FROM {quoted} LIMIT {limit}" |
0 commit comments