diff --git a/.circleci/workflows.yml b/.circleci/workflows.yml index 609fa8bd90c..022bd1972e7 100644 --- a/.circleci/workflows.yml +++ b/.circleci/workflows.yml @@ -1246,6 +1246,7 @@ workflows: - validate-views - validate-metadata - dry-run-sql + - test-routines - test-routines: requires: - deploy-changes-to-stage diff --git a/bigquery_etl/cli/query.py b/bigquery_etl/cli/query.py index 9bda1a23fdc..7ed7f86a9d8 100644 --- a/bigquery_etl/cli/query.py +++ b/bigquery_etl/cli/query.py @@ -2300,6 +2300,9 @@ def _update_query_schema( query_schema = Schema.from_query_file( query_file_path, content=sql_content, + project=project_name, + dataset=dataset_name, + table=table_name, use_cloud_function=use_cloud_function, respect_skip=respect_dryrun_skip, sql_dir=sql_dir, diff --git a/bigquery_etl/dryrun.py b/bigquery_etl/dryrun.py index d8ba2a039bc..75bb4ff968b 100644 --- a/bigquery_etl/dryrun.py +++ b/bigquery_etl/dryrun.py @@ -12,10 +12,15 @@ """ import glob +import hashlib import json +import os +import pickle import random import re +import shutil import sys +import tempfile import time from enum import Enum from os.path import basename, dirname, exists @@ -106,10 +111,12 @@ def __init__( dataset=None, table=None, billing_project=None, + use_cache=True, ): """Instantiate DryRun class.""" self.sqlfile = sqlfile self.content = content + self.use_cache = use_cache self.query_parameters = query_parameters self.strip_dml = strip_dml self.use_cloud_function = use_cloud_function @@ -192,6 +199,17 @@ def skipped_files(sql_dir=ConfigLoader.get("default", "sql_dir")) -> Set[str]: return skip_files + @staticmethod + def clear_cache(): + """Clear dry run cache directory.""" + cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache" + if cache_dir.exists(): + try: + shutil.rmtree(cache_dir) + print(f"Cleared dry run cache at {cache_dir}") + except OSError as e: + print(f"Warning: Failed to clear dry run cache: {e}") + def skip(self): """Determine if dry run should be skipped.""" return self.respect_skip and self.sqlfile in self.skipped_files( @@ -225,6 +243,108 @@ def get_sql(self): return sql + def _get_cache_key(self, sql): + """Generate cache key based on SQL content and other parameters.""" + cache_input = f"{sql}|{self.project}|{self.dataset}|{self.table}" + return hashlib.sha256(cache_input.encode()).hexdigest() + + @staticmethod + def _get_cache_dir(): + """Get the cache directory path.""" + cache_dir = Path(tempfile.gettempdir()) / "bigquery_etl_dryrun_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + def _read_cache_file(self, cache_file, ttl_seconds): + """Read and return cached data from a pickle file with TTL check.""" + try: + if not cache_file.exists(): + return None + + # check if cache is expired + file_age = time.time() - cache_file.stat().st_mtime + if file_age > ttl_seconds: + try: + cache_file.unlink() + except OSError: + pass + return None + + cached_data = pickle.loads(cache_file.read_bytes()) + return cached_data + except (pickle.PickleError, EOFError, OSError, FileNotFoundError) as e: + print(f"[CACHE] Failed to load {cache_file}: {e}") + try: + if cache_file.exists(): + cache_file.unlink() + except OSError: + pass + return None + + @staticmethod + def _write_cache_file(cache_file, data): + """Write data to a cache file using atomic write.""" + try: + # write to temporary file first, then atomically rename + # this prevents race conditions where readers get partial files + # include random bytes to handle thread pool scenarios where threads share same PID + temp_file = Path( + str(cache_file) + f".tmp.{os.getpid()}.{os.urandom(4).hex()}" + ) + with open(temp_file, "wb") as f: + pickle.dump(data, f) + f.flush() + os.fsync(f.fileno()) # Ensure data is written to disk + + temp_file.replace(cache_file) + except (pickle.PickleError, OSError) as e: + print(f"[CACHE] Failed to save {cache_file}: {e}") + try: + if "temp_file" in locals() and temp_file.exists(): + temp_file.unlink() + except (OSError, NameError): + pass + + def _get_cached_result(self, cache_key, ttl_seconds=None): + """Load cached dry run result from disk.""" + if ttl_seconds is None: + ttl_seconds = ConfigLoader.get("dry_run", "cache_ttl_seconds", fallback=900) + + cache_file = self._get_cache_dir() / f"dryrun_{cache_key}.pkl" + return self._read_cache_file(cache_file, ttl_seconds) + + def _save_cached_result(self, cache_key, result): + """Save dry run result to disk cache using atomic write.""" + cache_file = self._get_cache_dir() / f"dryrun_{cache_key}.pkl" + self._write_cache_file(cache_file, result) + + # save table metadata separately if present + if ( + result + and "tableMetadata" in result + and self.project + and self.dataset + and self.table + ): + table_identifier = f"{self.project}.{self.dataset}.{self.table}" + self._save_cached_table_metadata(table_identifier, result["tableMetadata"]) + + def _get_cached_table_metadata(self, table_identifier, ttl_seconds=None): + """Load cached table metadata from disk based on table identifier.""" + if ttl_seconds is None: + ttl_seconds = ConfigLoader.get("dry_run", "cache_ttl_seconds", fallback=900) + + # table identifier as cache key + table_cache_key = hashlib.sha256(table_identifier.encode()).hexdigest() + cache_file = self._get_cache_dir() / f"table_metadata_{table_cache_key}.pkl" + return self._read_cache_file(cache_file, ttl_seconds) + + def _save_cached_table_metadata(self, table_identifier, metadata): + """Save table metadata to disk cache using atomic write.""" + table_cache_key = hashlib.sha256(table_identifier.encode()).hexdigest() + cache_file = self._get_cache_dir() / f"table_metadata_{table_cache_key}.pkl" + self._write_cache_file(cache_file, metadata) + @cached_property def dry_run_result(self): """Dry run the provided SQL file.""" @@ -233,6 +353,14 @@ def dry_run_result(self): else: sql = self.get_sql() + # check cache first (if caching is enabled) + if sql is not None and self.use_cache: + cache_key = self._get_cache_key(sql) + cached_result = self._get_cached_result(cache_key) + if cached_result is not None: + self.dry_run_duration = 0 # Cached result, no actual dry run + return cached_result + query_parameters = [] if self.query_parameters: for parameter_name, parameter_type in self.query_parameters.items(): @@ -351,6 +479,12 @@ def dry_run_result(self): } self.dry_run_duration = time.time() - start_time + + # Save to cache (if caching is enabled and result is valid) + # Don't cache errors to allow retries + if self.use_cache and result.get("valid"): + self._save_cached_result(cache_key, result) + return result except Exception as e: @@ -476,6 +610,13 @@ def get_table_schema(self): ): return self.dry_run_result["tableMetadata"]["schema"] + # Check if table metadata is cached (if caching is enabled) + if self.use_cache and self.project and self.dataset and self.table: + table_identifier = f"{self.project}.{self.dataset}.{self.table}" + cached_metadata = self._get_cached_table_metadata(table_identifier) + if cached_metadata: + return cached_metadata["schema"] + return [] def get_dataset_labels(self): @@ -565,6 +706,13 @@ def validate_schema(self): return True query_file_path = Path(self.sqlfile) + table_name = query_file_path.parent.name + dataset_name = query_file_path.parent.parent.name + project_name = query_file_path.parent.parent.parent.name + self.project = project_name + self.dataset = dataset_name + self.table = table_name + query_schema = Schema.from_json(self.get_schema()) if self.errors(): # ignore file when there are errors that self.get_schema() did not raise @@ -576,26 +724,7 @@ def validate_schema(self): click.echo(f"No schema file defined for {query_file_path}", err=True) return True - table_name = query_file_path.parent.name - dataset_name = query_file_path.parent.parent.name - project_name = query_file_path.parent.parent.parent.name - - partitioned_by = None - if ( - self.metadata - and self.metadata.bigquery - and self.metadata.bigquery.time_partitioning - ): - partitioned_by = self.metadata.bigquery.time_partitioning.field - - table_schema = Schema.for_table( - project_name, - dataset_name, - table_name, - client=self.client, - id_token=self.id_token, - partitioned_by=partitioned_by, - ) + table_schema = Schema.from_json(self.get_table_schema()) # This check relies on the new schema being deployed to prod if not query_schema.compatible(table_schema): diff --git a/bigquery_etl/schema/__init__.py b/bigquery_etl/schema/__init__.py index 0e1369109d4..65aa7e11570 100644 --- a/bigquery_etl/schema/__init__.py +++ b/bigquery_etl/schema/__init__.py @@ -13,6 +13,7 @@ from google.cloud.bigquery import SchemaField from .. import dryrun +from ..config import ConfigLoader SCHEMA_FILE = "schema.yaml" @@ -58,7 +59,16 @@ def from_json(cls, json_schema): return cls(json_schema) @classmethod - def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs): + def for_table( + cls, + project, + dataset, + table, + partitioned_by=None, + filename="query.sql", + *args, + **kwargs, + ): """Get the schema for a BigQuery table.""" query = f"SELECT * FROM `{project}.{dataset}.{table}`" @@ -66,16 +76,17 @@ def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')" try: + sql_dir = ConfigLoader.get("default", "sql_dir") return cls( dryrun.DryRun( - os.path.join(project, dataset, table, "query.sql"), + os.path.join(sql_dir, project, dataset, table, filename), query, project=project, dataset=dataset, table=table, *args, **kwargs, - ).get_schema() + ).get_table_schema() ) except Exception as e: print(f"Cannot get schema for {project}.{dataset}.{table}: {e}") diff --git a/bigquery_etl/schema/stable_table_schema.py b/bigquery_etl/schema/stable_table_schema.py index f2abbc6442b..86ca43fb156 100644 --- a/bigquery_etl/schema/stable_table_schema.py +++ b/bigquery_etl/schema/stable_table_schema.py @@ -59,7 +59,9 @@ def prod_schemas_uri(): with the most recent production schemas deploy. """ dryrun = DryRun( - "moz-fx-data-shared-prod/telemetry_derived/foo/query.sql", content="SELECT 1" + "moz-fx-data-shared-prod/telemetry_derived/foo/query.sql", + content="SELECT 1", + use_cache=False, ) build_id = dryrun.get_dataset_labels()["schemas_build_id"] commit_hash = build_id.split("_")[-1] @@ -88,6 +90,11 @@ def get_stable_table_schemas() -> List[SchemaFile]: print(f"Failed to load cached schemas: {e}, re-downloading...") print(f"Downloading schemas from {schemas_uri}") + + # Clear dry run cache when downloading new schemas + # Schema changes could affect dry run results + DryRun.clear_cache() + with urllib.request.urlopen(schemas_uri) as f: tarbytes = BytesIO(f.read()) diff --git a/bqetl_project.yaml b/bqetl_project.yaml index bd8dda21846..37e693d38bf 100644 --- a/bqetl_project.yaml +++ b/bqetl_project.yaml @@ -32,6 +32,7 @@ dry_run: function_accounts: - bigquery-etl-dryrun@moz-fx-data-shared-prod.iam.gserviceaccount.com - bigquery-etl-dryrun@moz-fx-data-shar-nonprod-efed.iam.gserviceaccount.com + cache_ttl_seconds: 900 # Cache dry run results for 15 minutes (900 seconds) skip: ## skip all data-observability-dev queries due to CI lacking permissions in that project. # TODO: once data observability platform assessment concludes this should be removed. diff --git a/sql_generators/README.md b/sql_generators/README.md index 3b2ceef8a28..751c9bf7a80 100644 --- a/sql_generators/README.md +++ b/sql_generators/README.md @@ -9,3 +9,4 @@ The directories in `sql_generators/` represent the generated queries and will co Each `__init__.py` file needs to implement a `generate()` method that is configured as a [click command](https://click.palletsprojects.com/en/8.0.x/). The `bqetl` CLI will automatically add these commands to the `./bqetl query generate` command group. After changes to a schema or adding new tables, the schema is automatically derived from the query and deployed the next day in DAG [bqetl_artifact_deployment](https://workflow.telemetry.mozilla.org/dags/bqetl_artifact_deployment/grid). Alternatively, it can be manually generated and deployed using `./bqetl generate all` and `./bqetl query schema deploy`. + diff --git a/tests/test_dryrun.py b/tests/test_dryrun.py index 21a9f7a0848..45636b190c2 100644 --- a/tests/test_dryrun.py +++ b/tests/test_dryrun.py @@ -192,3 +192,194 @@ def test_dryrun_metrics_query(self, tmp_query_path): dryrun = DryRun(sqlfile=str(query_file)) assert dryrun.is_valid() + + def test_cache_key_generation(self, tmp_query_path): + """Test that cache keys are generated consistently.""" + query_file = tmp_query_path / "query.sql" + sql_content = "SELECT 123" + query_file.write_text(sql_content) + + dryrun = DryRun(str(query_file)) + cache_key1 = dryrun._get_cache_key(sql_content) + cache_key2 = dryrun._get_cache_key(sql_content) + + # Same SQL should produce same cache key + assert cache_key1 == cache_key2 + assert len(cache_key1) == 64 # SHA256 hex digest length + + # Different SQL should produce different cache key + different_sql = "SELECT 456" + cache_key3 = dryrun._get_cache_key(different_sql) + assert cache_key1 != cache_key3 + + def test_cache_save_and_load(self, tmp_query_path, monkeypatch, tmp_path): + """Test that dry run results can be saved and loaded from cache.""" + # Use isolated cache directory for this test to avoid interference from other tests + monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path)) + + query_file = tmp_query_path / "query.sql" + query_file.write_text("SELECT 123") + + dryrun = DryRun(str(query_file)) + sql = dryrun.get_sql() + cache_key = dryrun._get_cache_key(sql) + + # Mock result data + test_result = { + "valid": True, + "schema": {"fields": [{"name": "test", "type": "STRING"}]}, + } + + # Save to cache + dryrun._save_cached_result(cache_key, test_result) + + # Load from cache + cached_result = dryrun._get_cached_result(cache_key) + + assert cached_result is not None + assert cached_result["valid"] is True + assert cached_result["schema"]["fields"][0]["name"] == "test" + + def test_cache_expiration(self, tmp_query_path, monkeypatch, tmp_path): + """Test that cache expires after TTL.""" + # Use isolated cache directory for this test to avoid interference from other tests + monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path)) + + query_file = tmp_query_path / "query.sql" + query_file.write_text("SELECT 123") + + dryrun = DryRun(str(query_file)) + sql = dryrun.get_sql() + cache_key = dryrun._get_cache_key(sql) + + test_result = {"valid": True} + dryrun._save_cached_result(cache_key, test_result) + + # Should be cached with short TTL + cached = dryrun._get_cached_result(cache_key, ttl_seconds=10) + assert cached is not None + + # Should be expired with very short TTL + expired = dryrun._get_cached_result(cache_key, ttl_seconds=0) + assert expired is None + + def test_cache_respects_sql_changes(self, tmp_query_path, monkeypatch, tmp_path): + """Test that changing SQL content creates a different cache entry.""" + # Use isolated cache directory for this test to avoid interference from other tests + monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path)) + + query_file = tmp_query_path / "query.sql" + + # First SQL + query_file.write_text("SELECT 123") + dryrun1 = DryRun(str(query_file)) + sql1 = dryrun1.get_sql() + cache_key1 = dryrun1._get_cache_key(sql1) + test_result1 = {"valid": True, "data": "first"} + dryrun1._save_cached_result(cache_key1, test_result1) + + # Second SQL + query_file.write_text("SELECT 456") + dryrun2 = DryRun(str(query_file)) + sql2 = dryrun2.get_sql() + cache_key2 = dryrun2._get_cache_key(sql2) + + # Cache keys should be different + assert cache_key1 != cache_key2 + + # First cache should still exist + cached1 = dryrun1._get_cached_result(cache_key1) + assert cached1["data"] == "first" + + # Second cache should not exist yet + cached2 = dryrun2._get_cached_result(cache_key2) + assert cached2 is None + + def test_table_metadata_cache(self, tmp_query_path, monkeypatch, tmp_path): + """Test that table metadata can be cached by table identifier.""" + # Use isolated cache directory for this test to avoid interference from other tests + monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path)) + + query_file = tmp_query_path / "query.sql" + query_file.write_text("SELECT 123") + + dryrun = DryRun( + str(query_file), + project="test-project", + dataset="test_dataset", + table="test_table", + ) + + table_identifier = f"{dryrun.project}.{dryrun.dataset}.{dryrun.table}" + test_metadata = { + "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "tableType": "TABLE", + } + + # Save table metadata + dryrun._save_cached_table_metadata(table_identifier, test_metadata) + + # Load table metadata + cached_metadata = dryrun._get_cached_table_metadata(table_identifier) + + assert cached_metadata is not None + assert cached_metadata["schema"]["fields"][0]["name"] == "col1" + assert cached_metadata["tableType"] == "TABLE" + + def test_table_metadata_cache_different_tables( + self, tmp_query_path, monkeypatch, tmp_path + ): + """Test that different tables have separate cache entries.""" + # Use isolated cache directory for this test to avoid interference from other tests + monkeypatch.setattr("tempfile.gettempdir", lambda: str(tmp_path)) + + query_file = tmp_query_path / "query.sql" + query_file.write_text("SELECT 123") + + # Table 1 + dryrun1 = DryRun( + str(query_file), project="test-project", dataset="dataset1", table="table1" + ) + table1_id = f"{dryrun1.project}.{dryrun1.dataset}.{dryrun1.table}" + metadata1 = {"schema": {"fields": [{"name": "table1_col"}]}} + dryrun1._save_cached_table_metadata(table1_id, metadata1) + + # Table 2 + dryrun2 = DryRun( + str(query_file), project="test-project", dataset="dataset2", table="table2" + ) + table2_id = f"{dryrun2.project}.{dryrun2.dataset}.{dryrun2.table}" + metadata2 = {"schema": {"fields": [{"name": "table2_col"}]}} + dryrun2._save_cached_table_metadata(table2_id, metadata2) + + # Both should be cached independently + cached1 = dryrun1._get_cached_table_metadata(table1_id) + cached2 = dryrun2._get_cached_table_metadata(table2_id) + + assert cached1["schema"]["fields"][0]["name"] == "table1_col" + assert cached2["schema"]["fields"][0]["name"] == "table2_col" + + def test_use_cache_false_disables_caching(self, tmp_query_path): + """Test that use_cache=False disables all caching functionality.""" + query_file = tmp_query_path / "query.sql" + query_file.write_text("SELECT 123") + + # First, create a cache entry with caching enabled + dryrun_with_cache = DryRun(str(query_file), use_cache=True) + result1 = dryrun_with_cache.dry_run_result + assert result1["valid"] + + # Verify cache was created + sql = dryrun_with_cache.get_sql() + cache_key = dryrun_with_cache._get_cache_key(sql) + cached = dryrun_with_cache._get_cached_result(cache_key) + assert cached is not None + + # Now create a new DryRun with use_cache=False + dryrun_no_cache = DryRun(str(query_file), use_cache=False) + + # Even though cache exists, it should not be used + # We can't easily verify this without mocking the API call, + # but we can verify the flag is set correctly + assert dryrun_no_cache.use_cache is False + assert dryrun_with_cache.use_cache is True