Skip to content

Commit 851c6f5

Browse files
committed
Fix for_table to use caching
1 parent 50dcbf2 commit 851c6f5

File tree

4 files changed

+126
-62
lines changed

4 files changed

+126
-62
lines changed

bigquery_etl/dryrun.py

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pickle
1919
import random
2020
import re
21+
import shutil
2122
import sys
2223
import tempfile
2324
import time
@@ -198,6 +199,17 @@ def skipped_files(sql_dir=ConfigLoader.get("default", "sql_dir")) -> Set[str]:
198199

199200
return skip_files
200201

202+
@staticmethod
203+
def clear_cache():
204+
"""Clear dry run cache directory."""
205+
cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache"
206+
if cache_dir.exists():
207+
try:
208+
shutil.rmtree(cache_dir)
209+
print(f"Cleared dry run cache at {cache_dir}")
210+
except OSError as e:
211+
print(f"Warning: Failed to clear dry run cache: {e}")
212+
201213
def skip(self):
202214
"""Determine if dry run should be skipped."""
203215
return self.respect_skip and self.sqlfile in self.skipped_files(
@@ -241,41 +253,52 @@ def _get_cached_result(self, cache_key, ttl_seconds=None):
241253
if ttl_seconds is None:
242254
ttl_seconds = ConfigLoader.get("dry_run", "cache_ttl_seconds", fallback=900)
243255

244-
cache_dir = os.path.join(tempfile.gettempdir(), "bigquery_etl_dryrun_cache")
245-
os.makedirs(cache_dir, exist_ok=True)
246-
cache_file = os.path.join(cache_dir, f"dryrun_{cache_key}.pkl")
256+
cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache"
257+
cache_dir.mkdir(parents=True, exist_ok=True)
258+
cache_file = cache_dir / f"dryrun_{cache_key}.pkl"
259+
260+
try:
261+
if not cache_file.exists():
262+
return None
247263

248-
if os.path.exists(cache_file):
249264
# check if cache is expired
250-
file_age = time.time() - os.path.getmtime(cache_file)
265+
file_age = time.time() - cache_file.stat().st_mtime
251266
if file_age > ttl_seconds:
252267
try:
253-
os.remove(cache_file)
268+
cache_file.unlink()
254269
except OSError:
255270
pass
256271
return None
257272

273+
cached_data = pickle.loads(cache_file.read_bytes())
274+
cache_age = time.time() - cache_file.stat().st_mtime
275+
print(f"[DRYRUN CACHE HIT] {self.sqlfile} (age: {cache_age:.0f}s)")
276+
return cached_data
277+
except (pickle.PickleError, EOFError, OSError, FileNotFoundError) as e:
278+
print(f"[DRYRUN CACHE] Failed to load cache: {e}")
258279
try:
259-
with open(cache_file, "rb") as f:
260-
cached_data = pickle.load(f)
261-
cache_age = time.time() - os.path.getmtime(cache_file)
262-
print(f"[DRYRUN CACHE HIT] {self.sqlfile} (age: {cache_age:.0f}s)")
263-
return cached_data
264-
except (pickle.PickleError, EOFError, OSError) as e:
265-
print(f"[DRYRUN CACHE] Failed to load cache: {e}")
266-
return None
267-
268-
return None
280+
if cache_file.exists():
281+
cache_file.unlink()
282+
except OSError:
283+
pass
284+
return None
269285

270286
def _save_cached_result(self, cache_key, result):
271-
"""Save dry run result to disk cache."""
272-
cache_dir = os.path.join(tempfile.gettempdir(), "bigquery_etl_dryrun_cache")
273-
os.makedirs(cache_dir, exist_ok=True)
274-
cache_file = os.path.join(cache_dir, f"dryrun_{cache_key}.pkl")
287+
"""Save dry run result to disk cache using atomic write."""
288+
cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache"
289+
cache_dir.mkdir(parents=True, exist_ok=True)
290+
cache_file = cache_dir / f"dryrun_{cache_key}.pkl"
275291

276292
try:
277-
with open(cache_file, "wb") as f:
293+
# write to temporary file first, then atomically rename
294+
# this prevents race conditions where readers get partial files
295+
temp_file = Path(str(cache_file) + f".tmp.{os.getpid()}")
296+
with open(temp_file, "wb") as f:
278297
pickle.dump(result, f)
298+
f.flush()
299+
os.fsync(f.fileno()) # Ensure data is written to disk
300+
301+
temp_file.replace(cache_file)
279302

280303
# save table metadata separately if present
281304
if (
@@ -291,49 +314,73 @@ def _save_cached_result(self, cache_key, result):
291314
)
292315
except (pickle.PickleError, OSError) as e:
293316
print(f"[DRYRUN CACHE] Failed to save cache: {e}")
317+
try:
318+
temp_file = Path(str(cache_file) + f".tmp.{os.getpid()}")
319+
if temp_file.exists():
320+
temp_file.unlink()
321+
except OSError:
322+
pass
294323

295324
def _get_cached_table_metadata(self, table_identifier, ttl_seconds=None):
296325
"""Load cached table metadata from disk based on table identifier."""
297326
if ttl_seconds is None:
298327
ttl_seconds = ConfigLoader.get("dry_run", "cache_ttl_seconds", fallback=900)
299328

300-
cache_dir = os.path.join(tempfile.gettempdir(), "bigquery_etl_dryrun_cache")
301-
os.makedirs(cache_dir, exist_ok=True)
329+
cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache"
330+
cache_dir.mkdir(parents=True, exist_ok=True)
302331
# table identifier as cache key
303332
table_cache_key = hashlib.sha256(table_identifier.encode()).hexdigest()
304-
cache_file = os.path.join(cache_dir, f"table_metadata_{table_cache_key}.pkl")
333+
cache_file = cache_dir / f"table_metadata_{table_cache_key}.pkl"
334+
335+
try:
336+
if not cache_file.exists():
337+
return None
305338

306-
if os.path.exists(cache_file):
307339
# check if cache is expired
308-
file_age = time.time() - os.path.getmtime(cache_file)
340+
file_age = time.time() - cache_file.stat().st_mtime
309341

310342
if file_age > ttl_seconds:
311343
try:
312-
os.remove(cache_file)
344+
cache_file.unlink()
313345
except OSError:
314346
pass
315347
return None
316348

349+
cached_data = pickle.loads(cache_file.read_bytes())
350+
return cached_data
351+
except (pickle.PickleError, EOFError, OSError, FileNotFoundError) as e:
352+
print(f"[TABLE METADATA] Failed to load cache for {table_identifier}: {e}")
317353
try:
318-
with open(cache_file, "rb") as f:
319-
cached_data = pickle.load(f)
320-
return cached_data
321-
except (pickle.PickleError, EOFError, OSError):
322-
return None
323-
return None
354+
if cache_file.exists():
355+
cache_file.unlink()
356+
except OSError:
357+
pass
358+
return None
324359

325360
def _save_cached_table_metadata(self, table_identifier, metadata):
326-
"""Save table metadata to disk cache."""
327-
cache_dir = os.path.join(tempfile.gettempdir(), "bigquery_etl_dryrun_cache")
328-
os.makedirs(cache_dir, exist_ok=True)
361+
"""Save table metadata to disk cache using atomic write."""
362+
cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache"
363+
cache_dir.mkdir(parents=True, exist_ok=True)
329364
table_cache_key = hashlib.sha256(table_identifier.encode()).hexdigest()
330-
cache_file = os.path.join(cache_dir, f"table_metadata_{table_cache_key}.pkl")
365+
cache_file = cache_dir / f"table_metadata_{table_cache_key}.pkl"
331366

332367
try:
333-
with open(cache_file, "wb") as f:
368+
# write to temporary file first, then atomically rename
369+
temp_file = Path(str(cache_file) + f".tmp.{os.getpid()}")
370+
with open(temp_file, "wb") as f:
334371
pickle.dump(metadata, f)
372+
f.flush()
373+
os.fsync(f.fileno())
374+
375+
temp_file.replace(cache_file)
335376
except (pickle.PickleError, OSError) as e:
336377
print(f"[TABLE METADATA] Failed to save cache for {table_identifier}: {e}")
378+
try:
379+
temp_file = Path(str(cache_file) + f".tmp.{os.getpid()}")
380+
if temp_file.exists():
381+
temp_file.unlink()
382+
except OSError:
383+
pass
337384

338385
@cached_property
339386
def dry_run_result(self):
@@ -343,7 +390,7 @@ def dry_run_result(self):
343390
else:
344391
sql = self.get_sql()
345392

346-
# Check cache first (if caching is enabled)
393+
# check cache first (if caching is enabled)
347394
if sql is not None and self.use_cache:
348395
cache_key = self._get_cache_key(sql)
349396
cached_result = self._get_cached_result(cache_key)
@@ -470,8 +517,9 @@ def dry_run_result(self):
470517

471518
self.dry_run_duration = time.time() - start_time
472519

473-
# Save to cache (if caching is enabled)
474-
if self.use_cache:
520+
# Save to cache (if caching is enabled and result is valid)
521+
# Don't cache errors to allow retries
522+
if self.use_cache and result.get("valid"):
475523
self._save_cached_result(cache_key, result)
476524

477525
return result

bigquery_etl/schema/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from google.cloud.bigquery import SchemaField
1414

1515
from .. import dryrun
16+
from ..config import ConfigLoader
1617

1718
SCHEMA_FILE = "schema.yaml"
1819

@@ -58,24 +59,34 @@ def from_json(cls, json_schema):
5859
return cls(json_schema)
5960

6061
@classmethod
61-
def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs):
62+
def for_table(
63+
cls,
64+
project,
65+
dataset,
66+
table,
67+
partitioned_by=None,
68+
filename="query.sql",
69+
*args,
70+
**kwargs,
71+
):
6272
"""Get the schema for a BigQuery table."""
6373
query = f"SELECT * FROM `{project}.{dataset}.{table}`"
6474

6575
if partitioned_by:
6676
query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')"
6777

6878
try:
79+
sql_dir = ConfigLoader.get("default", "sql_dir")
6980
return cls(
7081
dryrun.DryRun(
71-
os.path.join(project, dataset, table, "query.sql"),
82+
os.path.join(sql_dir, project, dataset, table, filename),
7283
query,
7384
project=project,
7485
dataset=dataset,
7586
table=table,
7687
*args,
7788
**kwargs,
78-
).get_schema()
89+
).get_table_schema()
7990
)
8091
except Exception as e:
8192
print(f"Cannot get schema for {project}.{dataset}.{table}: {e}")

bigquery_etl/schema/stable_table_schema.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import os
55
import pickle
6-
import shutil
76
import tarfile
87
import tempfile
98
import urllib.request
@@ -52,17 +51,6 @@ def sortkey(self):
5251
)
5352

5453

55-
def _clear_dryrun_cache():
56-
"""Clear dry run cache when new schemas are downloaded."""
57-
cache_dir = os.path.join(tempfile.gettempdir(), "bigquery_etl_dryrun_cache")
58-
if os.path.exists(cache_dir):
59-
try:
60-
shutil.rmtree(cache_dir)
61-
print(f"Cleared dry run cache at {cache_dir}")
62-
except OSError as e:
63-
print(f"Warning: Failed to clear dry run cache: {e}")
64-
65-
6654
def prod_schemas_uri():
6755
"""Return URI for the schemas tarball deployed to shared-prod.
6856
@@ -105,7 +93,7 @@ def get_stable_table_schemas() -> List[SchemaFile]:
10593

10694
# Clear dry run cache when downloading new schemas
10795
# Schema changes could affect dry run results
108-
_clear_dryrun_cache()
96+
DryRun.clear_cache()
10997

11098
with urllib.request.urlopen(schemas_uri) as f:
11199
tarbytes = BytesIO(f.read())

tests/test_dryrun.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,11 @@ def test_cache_key_generation(self, tmp_query_path):
212212
cache_key3 = dryrun._get_cache_key(different_sql)
213213
assert cache_key1 != cache_key3
214214

215-
def test_cache_save_and_load(self, tmp_query_path):
215+
def test_cache_save_and_load(self, tmp_query_path, monkeypatch, tmp_path):
216216
"""Test that dry run results can be saved and loaded from cache."""
217+
# Use isolated cache directory for this test to avoid interference from other tests
218+
monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path))
219+
217220
query_file = tmp_query_path / "query.sql"
218221
query_file.write_text("SELECT 123")
219222

@@ -237,8 +240,11 @@ def test_cache_save_and_load(self, tmp_query_path):
237240
assert cached_result["valid"] is True
238241
assert cached_result["schema"]["fields"][0]["name"] == "test"
239242

240-
def test_cache_expiration(self, tmp_query_path):
243+
def test_cache_expiration(self, tmp_query_path, monkeypatch, tmp_path):
241244
"""Test that cache expires after TTL."""
245+
# Use isolated cache directory for this test to avoid interference from other tests
246+
monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path))
247+
242248
query_file = tmp_query_path / "query.sql"
243249
query_file.write_text("SELECT 123")
244250

@@ -257,8 +263,11 @@ def test_cache_expiration(self, tmp_query_path):
257263
expired = dryrun._get_cached_result(cache_key, ttl_seconds=0)
258264
assert expired is None
259265

260-
def test_cache_respects_sql_changes(self, tmp_query_path):
266+
def test_cache_respects_sql_changes(self, tmp_query_path, monkeypatch, tmp_path):
261267
"""Test that changing SQL content creates a different cache entry."""
268+
# Use isolated cache directory for this test to avoid interference from other tests
269+
monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path))
270+
262271
query_file = tmp_query_path / "query.sql"
263272

264273
# First SQL
@@ -286,8 +295,11 @@ def test_cache_respects_sql_changes(self, tmp_query_path):
286295
cached2 = dryrun2._get_cached_result(cache_key2)
287296
assert cached2 is None
288297

289-
def test_table_metadata_cache(self, tmp_query_path):
298+
def test_table_metadata_cache(self, tmp_query_path, monkeypatch, tmp_path):
290299
"""Test that table metadata can be cached by table identifier."""
300+
# Use isolated cache directory for this test to avoid interference from other tests
301+
monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path))
302+
291303
query_file = tmp_query_path / "query.sql"
292304
query_file.write_text("SELECT 123")
293305

@@ -314,8 +326,13 @@ def test_table_metadata_cache(self, tmp_query_path):
314326
assert cached_metadata["schema"]["fields"][0]["name"] == "col1"
315327
assert cached_metadata["tableType"] == "TABLE"
316328

317-
def test_table_metadata_cache_different_tables(self, tmp_query_path):
329+
def test_table_metadata_cache_different_tables(
330+
self, tmp_query_path, monkeypatch, tmp_path
331+
):
318332
"""Test that different tables have separate cache entries."""
333+
# Use isolated cache directory for this test to avoid interference from other tests
334+
monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path))
335+
319336
query_file = tmp_query_path / "query.sql"
320337
query_file.write_text("SELECT 123")
321338

0 commit comments

Comments
 (0)