Skip to content

Commit b089cd2

Browse files
frazanefloriankrb
andauthored
fix(grib-index): support querying float values (#520)
When indexing GRIB files using the `anemoi-datasets grib-index`, it might occur that the program fails due to duplicated messages even though they are not really duplicates. This occurs for instance with soil variables, where by default values for `level` are interpreted as integers which leads to values such as 0.01, 0.004, etc. to all be decoded as "0". To solve this, one needs to use the `key:type` syntax to decode a key in a specific type, e.g. `level:d`. However, this leads to another error because SQLite does not support column names that contain colons. As a simple fix, the proposed solution is to quote the columns. Extra: since I noticed the "grid definition" support was missing from this source (it's on other GRIB-based sources), I allowed myself to include this change as well, since it's small and should not impact the rest. --- By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: Florian Pinault <floriankrb@users.noreply.github.com>
1 parent 42117db commit b089cd2

1 file changed

Lines changed: 31 additions & 13 deletions

File tree

src/anemoi/datasets/create/sources/grib_index.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import earthkit.data as ekd
2020
import tqdm
21+
from anemoi.transform.fields import new_field_from_grid
2122
from anemoi.transform.flavour import RuleBasedFlavour
23+
from anemoi.transform.grids import grid_registry
2224
from cachetools import LRUCache
2325
from earthkit.data.indexing.fieldlist import FieldArray
2426

@@ -102,6 +104,10 @@ def __init__(
102104
self.warnings = {}
103105
self.cache = {}
104106

107+
def _quote_column(self, column: str) -> str:
108+
"""Quote a column name for use in SQL queries."""
109+
return f'"{column}"'
110+
105111
def _create_tables(self) -> None:
106112
"""Create the necessary tables in the database."""
107113
assert self.update
@@ -123,7 +129,7 @@ def _create_tables(self) -> None:
123129
_path_id INTEGER not null,
124130
_offset INTEGER not null,
125131
_length INTEGER not null,
126-
{', '.join(f"{key} TEXT not null default ''" for key in columns)},
132+
{', '.join(f"{self._quote_column(key)} TEXT not null default ''" for key in columns)},
127133
FOREIGN KEY(_path_id) REFERENCES paths(id))
128134
""") # ,
129135

@@ -134,13 +140,13 @@ def _create_tables(self) -> None:
134140

135141
self.cursor.execute(f"""
136142
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
137-
ON grib_index ({', '.join(columns)})
143+
ON grib_index ({', '.join(self._quote_column(col) for col in columns)})
138144
""")
139145

140146
for key in columns:
141147
self.cursor.execute(f"""
142-
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
143-
ON grib_index ({key})
148+
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
149+
ON grib_index ({self._quote_column(key)})
144150
""")
145151

146152
self._commit()
@@ -195,7 +201,7 @@ def _add_grib(self, **kwargs: Any) -> None:
195201

196202
self.cursor.execute(
197203
f"""
198-
INSERT INTO grib_index ({', '.join(kwargs.keys())})
204+
INSERT INTO grib_index ({', '.join(self._quote_column(k) for k in kwargs.keys())})
199205
VALUES ({', '.join('?' for _ in kwargs)})
200206
""",
201207
tuple(kwargs.values()),
@@ -208,7 +214,8 @@ def _add_grib(self, **kwargs: Any) -> None:
208214
for n in ("_path_id", "_offset", "_length"):
209215
kwargs.pop(n)
210216
self.cursor.execute(
211-
"SELECT * FROM grib_index WHERE " + " AND ".join(f"{key} = ?" for key in kwargs.keys()),
217+
"SELECT * FROM grib_index WHERE "
218+
+ " AND ".join(f"{self._quote_column(key)} = ?" for key in kwargs.keys()),
212219
tuple(kwargs.values()),
213220
)
214221
existing_record = self.cursor.fetchone()
@@ -252,20 +259,22 @@ def _ensure_columns(self, columns: list[str]) -> None:
252259
self._columns = None
253260

254261
for column in new_columns:
255-
self.cursor.execute(f"ALTER TABLE grib_index ADD COLUMN {column} TEXT not null default ''")
262+
self.cursor.execute(
263+
f"ALTER TABLE grib_index ADD COLUMN {self._quote_column(column)} TEXT not null default ''"
264+
)
256265

257266
self.cursor.execute("""DROP INDEX IF EXISTS idx_grib_index_all_keys""")
258267
all_columns = self._all_columns()
259268

260269
self.cursor.execute(f"""
261270
CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
262-
ON grib_index ({', '.join(all_columns)})
271+
ON grib_index ({', '.join(self._quote_column(col) for col in all_columns)})
263272
""")
264273

265274
for key in all_columns:
266275
self.cursor.execute(f"""
267-
CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
268-
ON grib_index ({key})
276+
CREATE INDEX IF NOT EXISTS idx_grib_index_{key.replace(':', '_')}
277+
ON grib_index ({self._quote_column(key)})
269278
""")
270279

271280
def add_grib_file(self, path: str) -> None:
@@ -301,6 +310,8 @@ def add_grib_file(self, path: str) -> None:
301310
self._unknown(path, field, i, param)
302311
self.warnings[param] = True
303312

313+
continue
314+
304315
self._ensure_columns(list(keys.keys()))
305316

306317
self._add_grib(
@@ -536,15 +547,14 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]:
536547
LOG.warning(f"Warning : {k} not in database columns, key discarded")
537548
continue
538549
if isinstance(v, list):
539-
query += f" AND {k} IN ({', '.join('?' for _ in v)})"
550+
query += f" AND {self._quote_column(k)} IN ({', '.join('?' for _ in v)})"
540551
params.extend([str(_) for _ in v])
541552
else:
542-
query += f" AND {k} = ?"
553+
query += f" AND {self._quote_column(k)} = ?"
543554
params.append(str(v))
544555

545556
print("SELECT (query)", query)
546557
print("SELECT (params)", params)
547-
548558
self.cursor.execute(query, params)
549559

550560
fetch = self.cursor.fetchall()
@@ -593,6 +603,11 @@ def _execute(
593603
FieldArray
594604
An array of retrieved GRIB fields.
595605
"""
606+
607+
grid_definition = kwargs.pop("grid_definition", None)
608+
if grid_definition:
609+
grid_definition = grid_registry.from_config(grid_definition)
610+
596611
index = GribIndex(indexdb)
597612

598613
if flavour is not None:
@@ -623,6 +638,9 @@ def _execute(
623638
field = flavour.apply(field)
624639
result.append(field)
625640

641+
if grid_definition is not None:
642+
result = [new_field_from_grid(field, grid_definition) for field in result]
643+
626644
return FieldArray(result)
627645

628646

0 commit comments

Comments
 (0)