diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index af7c344081..da5f7db91c 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -43,7 +43,6 @@ Executable, SqlValue, format_evaluated_code_exception, - prepare_env, ) if t.TYPE_CHECKING: @@ -178,7 +177,8 @@ def __init__( schema: t.Optional[MappingSchema] = None, runtime_stage: RuntimeStage = RuntimeStage.LOADING, resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None, - resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None, + resolve_tables: t.Optional[t.Callable[[ + exp.Expression], exp.Expression]] = None, snapshots: t.Optional[t.Dict[str, Snapshot]] = None, default_catalog: t.Optional[str] = None, path: t.Optional[Path] = None, @@ -198,7 +198,8 @@ def __init__( "MacroEvaluator": MacroEvaluator, } self.python_env = python_env or {} - self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()} + self.macros = {normalize_macro_name( + k): v.func for k, v in macro.get_registry().items()} self.columns_to_types_called = False self.default_catalog = default_catalog @@ -210,15 +211,44 @@ def __init__( self._environment_naming_info = environment_naming_info self._model_fqn = model_fqn - prepare_env(self.python_env, self.env) - for k, v in self.python_env.items(): - if v.is_definition: - self.macros[normalize_macro_name(k)] = self.env[v.name or k] - elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None): - self.macros[normalize_macro_name(k)] = self.env[k] - elif v.is_value: - value = self.env[k] - if k in ( + # Track executables not loaded yet for lazy loading + self._unloaded_executables: t.Dict[str, Executable] = {} + # Track failed imports to provide helpful error messages + self._failed_imports: t.Dict[str, Exception] = {} + + # Load python_env, defer imports that might fail + # Allows projects to share state without all dependencies + self._load_python_env() + + def _load_python_env(self) -> None: + """Load python environment with lazy import loading. + + This method implements lazy loading to allow projects to share + state databases without requiring all external dependencies: + + 1. VALUES: Loaded immediately (no dependencies) + 2. DEFINITIONS: Loaded immediately (backward compatibility) + - Won't fail at definition time even if imports missing + - Failures happen at call time + 3. IMPORTS: Deferred until first macro call + - Allows projects to share state without all dependencies + - Loaded automatically before any macro execution + + When a macro is called, all deferred imports are loaded first, + ensuring the macro has access to its dependencies. + """ + # Sort to process imports first, then values, then definitions + sorted_items = sorted( + self.python_env.items(), + key=lambda item: 0 if item[1].is_import else 1, + ) + + for name, executable in sorted_items: + if executable.is_value: + # Load values immediately + self.env[name] = eval(executable.payload) + value = self.env[name] + if name in ( c.SQLMESH_VARS, c.SQLMESH_VARS_METADATA, c.SQLMESH_BLUEPRINT_VARS, @@ -232,20 +262,95 @@ def __init__( ) for var_name, var_value in value.items() } + self.locals[name] = value + elif executable.is_definition: + # Load definitions immediately for backward compatibility + # They won't fail at definition time even if imports are missing + exec(executable.payload, self.env) + if executable.alias and executable.name: + self.env[executable.alias] = self.env[executable.name] + # Register as macro if it's a macro definition + func_name = executable.name or name + self.macros[normalize_macro_name( + func_name)] = self.env[func_name] + elif executable.is_import: + self._unloaded_executables[name] = executable + + def _ensure_executable_loaded(self, name: str) -> bool: + """Lazily load an executable if it hasn't been loaded yet. - self.locals[k] = value + Args: + name: The name of the executable to load + + Returns: + True if the executable was loaded successfully, False otherwise + """ + if name in self.env or name not in self._unloaded_executables: + return name in self.env + + if name in self._failed_imports: + return False + + executable = self._unloaded_executables[name] + + try: + exec(executable.payload, self.env) + if executable.alias and executable.name: + self.env[executable.alias] = self.env[executable.name] + + # If it's a macro import, register it + # For imports, the actual imported name might differ from the key + imported_name = executable.name or name + if executable.is_import and getattr( + self.env.get(imported_name), c.SQLMESH_MACRO, None + ): + self.macros[normalize_macro_name( + imported_name)] = self.env[imported_name] + + del self._unloaded_executables[name] + return True + except Exception as e: + self._failed_imports[name] = e + return False def send( self, name: str, *args: t.Any, **kwargs: t.Any ) -> t.Union[None, exp.Expression, t.List[exp.Expression]]: - func = self.macros.get(normalize_macro_name(name)) + normalized_name = normalize_macro_name(name) + func = self.macros.get(normalized_name) + + if not callable(func): + for exec_name in self._unloaded_executables: + if normalize_macro_name(exec_name) == normalized_name: + if self._ensure_executable_loaded(exec_name): + func = self.macros.get(normalized_name) + break if not callable(func): + for exec_name, error in self._failed_imports.items(): + if normalize_macro_name(exec_name) == normalized_name: + raise MacroEvalError( + f"Macro '{name}' could not be loaded due to a " + "missing dependency. This may be caused by a macro " + "from another project that shares the same state " + "database. If you don't use this macro, you can " + "safely ignore it by not calling it. " + f"Original error: {error}" + ) + raise MacroEvalError(f"Macro '{name}' does not exist.") + # Before calling the macro, load all deferred imports + # This ensures macros can reference imports even if they were deferred + for import_name in list(self._unloaded_executables.keys()): + import_exec = self._unloaded_executables.get(import_name) + if import_exec and import_exec.is_import: + self._ensure_executable_loaded(import_name) + try: return call_macro( - func, self.dialect, self._path, provided_args=(self, *args), provided_kwargs=kwargs + func, self.dialect, self._path, provided_args=( + self, *args), provided_kwargs=kwargs ) # type: ignore except Exception as e: raise MacroEvalError( @@ -273,7 +378,8 @@ def evaluate_macros( if var_name not in self.locals and var_name not in variables: if not isinstance(node.parent, StagedFilePath): - raise SQLMeshError(f"Macro variable '{node.name}' is undefined.") + raise SQLMeshError( + f"Macro variable '{node.name}' is undefined.") return node @@ -287,7 +393,8 @@ def evaluate_macros( ) return exp.convert( - self.transform(value) if isinstance(value, exp.Expression) else value + self.transform(value) if isinstance( + value, exp.Expression) else value ) if isinstance(node, exp.Identifier) and "@" in node.this: text = self.template(node.this, {}) @@ -344,11 +451,13 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | if isinstance(node.expression, exp.Lambda): _, fn = _norm_var_arg_lambda(self, node.expression) self.macros[normalize_macro_name(node.name)] = lambda _, *args: fn( - args[0] if len(args) == 1 else exp.Tuple(expressions=list(args)) + args[0] if len(args) == 1 else exp.Tuple( + expressions=list(args)) ) else: # Make variables defined through `@DEF` case-insensitive - self.locals[node.name.lower()] = self.transform(node.expression) + self.locals[node.name.lower()] = self.transform( + node.expression) return node @@ -379,7 +488,8 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | return None if isinstance(result, (tuple, list)): - result = [self.parse_one(item) for item in result if item is not None] + result = [self.parse_one(item) + for item in result if item is not None] if ( len(result) == 1 @@ -392,13 +502,13 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | - and that output is something that _norm_var_arg_lambda() will unpack into varargs > (a list containing a single item of type exp.Tuple/exp.Array) then we will get inconsistent behaviour depending on if this node emits a list with a single item vs multiple items. - + In the first case, emitting a list containing a single array item will cause that array to get unpacked and its *members* passed to the calling macro In the second case, emitting a list containing multiple array items will cause each item to get passed as-is to the calling macro - + To prevent this inconsistency, we wrap this node output in an exp.Array so that _norm_var_arg_lambda() can "unpack" that into the actual argument we want to pass to the parent macro function - + Note we only do this for evaluation results that get passed as an argument to another macro, because when the final result is given to something like SELECT, we still want that to be unpacked into a list of items like: - SELECT ARRAY(1), ARRAY(2) @@ -467,7 +577,8 @@ def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, ex model_name = exp.to_table(normalized_model_name) columns_to_types = ( - self._schema.find(model_name, ensure_data_types=True) if self._schema else None + self._schema.find( + model_name, ensure_data_types=True) if self._schema else None ) if columns_to_types is None: snapshot = self.get_snapshot(model_name) @@ -475,7 +586,8 @@ def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, ex columns_to_types = snapshot.node.columns_to_types # type: ignore if columns_to_types is None: - raise SQLMeshError(f"Schema for model '{model_name}' can't be statically determined.") + raise SQLMeshError( + f"Schema for model '{model_name}' can't be statically determined.") return columns_to_types @@ -515,13 +627,15 @@ def this_model(self) -> str: """Returns the resolved name of the surrounding model.""" this_model = self.locals.get("this_model") if not this_model: - raise SQLMeshError("Model name is not available in the macro evaluator.") + raise SQLMeshError( + "Model name is not available in the macro evaluator.") return this_model.sql(dialect=self.dialect, identify=True, comments=False) @property def this_model_fqn(self) -> str: if self._model_fqn is None: - raise SQLMeshError("Model name is not available in the macro evaluator.") + raise SQLMeshError( + "Model name is not available in the macro evaluator.") return self._model_fqn @property @@ -548,21 +662,24 @@ def snapshots(self) -> t.Dict[str, Snapshot]: def this_env(self) -> str: """Returns the name of the current environment in before after all.""" if "this_env" not in self.locals: - raise SQLMeshError("Environment name is only available in before_all and after_all") + raise SQLMeshError( + "Environment name is only available in before_all and after_all") return self.locals["this_env"] @property def schemas(self) -> t.List[str]: """Returns the schemas of the current environment in before after all macros.""" if "schemas" not in self.locals: - raise SQLMeshError("Schemas are only available in before_all and after_all") + raise SQLMeshError( + "Schemas are only available in before_all and after_all") return self.locals["schemas"] @property def views(self) -> t.List[str]: """Returns the views of the current environment in before after all macros.""" if "views" not in self.locals: - raise SQLMeshError("Views are only available in before_all and after_all") + raise SQLMeshError( + "Views are only available in before_all and after_all") return self.locals["views"] def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: @@ -659,7 +776,8 @@ def substitute( return exp.convert(evaluator.locals[name]) if SQLMESH_MACRO_PREFIX in node.name: return node.__class__( - this=evaluator.template(node.name, {k: v.name for k, v in args.items()}) + this=evaluator.template( + node.name, {k: v.name for k, v in args.items()}) ) elif isinstance(node, MacroFunc): local_copy = evaluator.locals.copy() @@ -687,7 +805,8 @@ def substitute( { expression.name.lower(): arg for expression, arg in zip( - func.expressions, args.expressions if isinstance(args, exp.Tuple) else [args] + func.expressions, args.expressions if isinstance( + args, exp.Tuple) else [args] ) }, ) @@ -899,13 +1018,15 @@ def star( "The 'except_' argument in @STAR will soon be deprecated. Use 'exclude' instead." ) if not isinstance(exclude, (exp.Array, exp.Tuple)): - raise SQLMeshError(f"Invalid exclude_ '{exclude}'. Expected an array.") + raise SQLMeshError( + f"Invalid exclude_ '{exclude}'. Expected an array.") if prefix and not isinstance(prefix, exp.Literal): raise SQLMeshError(f"Invalid prefix '{prefix}'. Expected a literal.") if suffix and not isinstance(suffix, exp.Literal): raise SQLMeshError(f"Invalid suffix '{suffix}'. Expected a literal.") if not isinstance(quote_identifiers, exp.Boolean): - raise SQLMeshError(f"Invalid quote_identifiers '{quote_identifiers}'. Expected a boolean.") + raise SQLMeshError( + f"Invalid quote_identifiers '{quote_identifiers}'. Expected a boolean.") excluded_names = { normalize_identifiers(excluded, dialect=evaluator.dialect).name @@ -993,8 +1114,8 @@ def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: return ( exp.Case() .when(exp.and_(*(field.is_(exp.null()) for field in fields)), exp.null()) - .else_(reduce(lambda a, b: a + b, [exp.func("COALESCE", field, 0) for field in fields])) # type: ignore - ) + .else_(reduce(lambda a, b: a + b, [exp.func("COALESCE", field, 0) for field in fields])) + ) # type: ignore @macro() @@ -1011,7 +1132,8 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: return ( exp.Case() .when(exp.and_(*(field.is_(exp.null()) for field in fields)), exp.null()) - .else_(reduce(lambda a, b: a - b, [exp.func("COALESCE", field, 0) for field in fields])) # type: ignore + # type: ignore + .else_(reduce(lambda a, b: a - b, [exp.func("COALESCE", field, 0) for field in fields])) ) @@ -1056,7 +1178,8 @@ def union( """ if not args: - raise SQLMeshError("At least one table is required for the @UNION macro.") + raise SQLMeshError( + "At least one table is required for the @UNION macro.") arg_idx = 0 # Check for condition @@ -1064,7 +1187,8 @@ def union( if isinstance(condition, bool): arg_idx += 1 if arg_idx >= len(args): - raise SQLMeshError("Expected more arguments after the condition of the `@UNION` macro.") + raise SQLMeshError( + "Expected more arguments after the condition of the `@UNION` macro.") # Check for union type type_ = exp.Literal.string("ALL") @@ -1073,7 +1197,8 @@ def union( arg_idx += 1 kind = type_.name.upper() if kind not in ("ALL", "DISTINCT"): - raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.") + raise SQLMeshError( + f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.") # Remaining args should be tables tables = [ @@ -1136,10 +1261,12 @@ def haversine_distance( "ASIN", exp.func( "SQRT", - exp.func("POWER", exp.func("SIN", exp.func("RADIANS", (lat2 - lat1) / 2)), 2) + exp.func("POWER", exp.func("SIN", exp.func( + "RADIANS", (lat2 - lat1) / 2)), 2) + exp.func("COS", exp.func("RADIANS", lat1)) * exp.func("COS", exp.func("RADIANS", lat2)) - * exp.func("POWER", exp.func("SIN", exp.func("RADIANS", (lon2 - lon1) / 2)), 2), + * exp.func("POWER", exp.func("SIN", + exp.func("RADIANS", (lon2 - lon1) / 2)), 2), ), ) * conversion_rate @@ -1223,7 +1350,8 @@ def var( ) -> exp.Expression: """Returns the value of a variable or the default value if the variable is not set.""" if not var_name.is_string: - raise SQLMeshError(f"Invalid variable name '{var_name.sql()}'. Expected a string literal.") + raise SQLMeshError( + f"Invalid variable name '{var_name.sql()}'. Expected a string literal.") return exp.convert(evaluator.var(var_name.this, default)) @@ -1276,7 +1404,8 @@ def deduplicate( partition_clause = exp.tuple_(*partition_by) order_expressions = [ - evaluator.transform(parse_one(order_item, into=exp.Ordered, dialect=evaluator.dialect)) + evaluator.transform( + parse_one(order_item, into=exp.Ordered, dialect=evaluator.dialect)) for order_item in order_by ] @@ -1331,7 +1460,8 @@ def date_spine( try: if start_date.is_string and end_date.is_string: - start_date_obj = datetime.strptime(start_date_name, "%Y-%m-%d").date() + start_date_obj = datetime.strptime( + start_date_name, "%Y-%m-%d").date() end_date_obj = datetime.strptime(end_date_name, "%Y-%m-%d").date() else: start_date_obj = None @@ -1356,9 +1486,11 @@ def date_spine( "databricks", "postgres", ): - date_interval = exp.Interval(this=exp.Literal.number(3), unit=exp.var("month")) + date_interval = exp.Interval( + this=exp.Literal.number(3), unit=exp.var("month")) else: - date_interval = exp.Interval(this=exp.Literal.number(1), unit=exp.var(datepart_name)) + date_interval = exp.Interval( + this=exp.Literal.number(1), unit=exp.var(datepart_name)) generate_date_array = exp.func( "GENERATE_DATE_ARRAY", @@ -1368,7 +1500,8 @@ def date_spine( ) alias_name = f"date_{datepart_name}" - exploded = exp.alias_(exp.func("unnest", generate_date_array), "_exploded", table=[alias_name]) + exploded = exp.alias_( + exp.func("unnest", generate_date_array), "_exploded", table=[alias_name]) return exp.select(alias_name).from_(exploded) @@ -1405,7 +1538,8 @@ def resolve_template( "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" """ if "this_model" in evaluator.locals: - this_model = exp.to_table(evaluator.locals["this_model"], dialect=evaluator.dialect) + this_model = exp.to_table( + evaluator.locals["this_model"], dialect=evaluator.dialect) template_str: str = template.this result = ( template_str.replace("@{catalog_name}", this_model.catalog) @@ -1472,9 +1606,11 @@ def call_macro( # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments param = sig.parameters[arg] if param.kind is inspect.Parameter.VAR_POSITIONAL: - bound.arguments[arg] = tuple(_coerce(v, typ, dialect, path) for v in value) + bound.arguments[arg] = tuple( + _coerce(v, typ, dialect, path) for v in value) elif param.kind is inspect.Parameter.VAR_KEYWORD: - bound.arguments[arg] = {k: _coerce(v, typ, dialect, path) for k, v in value.items()} + bound.arguments[arg] = {k: _coerce( + v, typ, dialect, path) for k, v in value.items()} else: bound.arguments[arg] = _coerce(value, typ, dialect, path) @@ -1567,7 +1703,8 @@ def _coerce( _coerce(expr, generic[i], dialect, path) for i, expr in enumerate(expr.expressions) ) - raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") + raise SQLMeshError( + f"{base_err_msg} Expected {len(generic)} items.") if base is list and isinstance(expr, (exp.Array, exp.Tuple)): generic = t.get_args(typ) if not generic: diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index fb10f64b27..531304c3af 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -65,7 +65,8 @@ def cte_tag_name(evaluator: MacroEvaluator, with_: exp.Select): for cte in with_.find_all(exp.CTE): name = cte.alias_or_name for query in cte.find_all(exp.Select): - query.select(exp.Literal.string(name).as_("source"), copy=False) + query.select(exp.Literal.string( + name).as_("source"), copy=False) return with_ @macro() @@ -104,7 +105,8 @@ def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b" return MacroEvaluator( "hive", - {"test": Executable(name="test", payload="def test(_):\n return 'test'")}, + {"test": Executable( + name="test", payload="def test(_):\n return 'test'")}, ) @@ -121,7 +123,8 @@ def test_star(assert_exp_eq) -> None: dialect="tsql", ) evaluator = MacroEvaluator(schema=schema, dialect="tsql") - assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), expected_sql, dialect="tsql") + assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), + expected_sql, dialect="tsql") sql = "SELECT @STAR(foo, exclude := [SomeColumn]) FROM foo" expected_sql = "SELECT CAST(`foo`.`a` AS STRING) AS `a` FROM foo" @@ -153,7 +156,8 @@ def test_star(assert_exp_eq) -> None: dialect="tsql", ) evaluator = MacroEvaluator(schema=schema, dialect="tsql") - assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), expected_sql, dialect="tsql") + assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), + expected_sql, dialect="tsql") sql = """SELECT @STAR(foo) FROM foo""" expected_sql = ( @@ -211,9 +215,11 @@ def test_star(assert_exp_eq) -> None: def test_start_no_column_types(assert_exp_eq) -> None: sql = """SELECT @STAR(foo) FROM foo""" expected_sql = """SELECT [foo].[a] AS [a] FROM foo""" - schema = MappingSchema({"foo": {"a": exp.DataType.build("UNKNOWN")}}, dialect="tsql") + schema = MappingSchema( + {"foo": {"a": exp.DataType.build("UNKNOWN")}}, dialect="tsql") evaluator = MacroEvaluator(schema=schema, dialect="tsql") - assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), expected_sql, dialect="tsql") + assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), + expected_sql, dialect="tsql") def test_case(macro_evaluator: MacroEvaluator) -> None: @@ -239,7 +245,8 @@ def test_macro_var(macro_evaluator): macro_evaluator.dialect = "snowflake" assert e.find(StagedFilePath) is not None - assert macro_evaluator.transform(e).sql(dialect="snowflake") == "SELECT a FROM @path, t2" + assert macro_evaluator.transform(e).sql( + dialect="snowflake") == "SELECT a FROM @path, t2" # Referencing a var that doesn't exist in the evaluator's scope should raise macro_evaluator.locals = {} @@ -250,17 +257,21 @@ def test_macro_var(macro_evaluator): assert "Macro variable 'y' is undefined." in str(ex.value) # Parsing a "parameter" like Snowflake's $1 should not produce a MacroVar expression - e = parse_one("select $1 from @path (file_format => bla.foo)", read="snowflake") + e = parse_one("select $1 from @path (file_format => bla.foo)", + read="snowflake") assert e.find(exp.Parameter) is e.selects[0] assert e.find(StagedFilePath) # test no space - e = parse_one("select $1 from @path(file_format => bla.foo)", read="snowflake") + e = parse_one("select $1 from @path(file_format => bla.foo)", + read="snowflake") assert e.find(StagedFilePath) - assert e.sql(dialect="snowflake") == "SELECT $1 FROM @path (FILE_FORMAT => bla.foo)" + assert e.sql( + dialect="snowflake") == "SELECT $1 FROM @path (FILE_FORMAT => bla.foo)" macro_evaluator.locals = {"x": 1} macro_evaluator.dialect = "snowflake" - e = parse_one("COPY INTO @'s3://example/foo_@{x}.csv' FROM a.b.c", read="snowflake") + e = parse_one( + "COPY INTO @'s3://example/foo_@{x}.csv' FROM a.b.c", read="snowflake") assert ( macro_evaluator.transform(e).sql(dialect="snowflake") == "COPY INTO 's3://example/foo_1.csv' FROM a.b.c" @@ -274,7 +285,8 @@ def test_macro_str_replace(macro_evaluator): def test_macro_custom(macro_evaluator, assert_exp_eq): - assert_exp_eq(macro_evaluator.transform(parse_one("SELECT @TEST()")), "SELECT test") + assert_exp_eq(macro_evaluator.transform( + parse_one("SELECT @TEST()")), "SELECT test") def test_ast_correctness(macro_evaluator): @@ -376,7 +388,8 @@ def test_ast_correctness(macro_evaluator): ("SELECT @REDUCE([1], (x, y) -> x + y)", "SELECT 1", {}), ("SELECT @REDUCE([1, 2], (x, y) -> x + y)", "SELECT 1 + 2", {}), ("SELECT @REDUCE([[1]], (x, y) -> x + y)", "SELECT ARRAY(1)", {}), - ("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", "SELECT ARRAY(1, 2)", {}), + ("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", + "SELECT ARRAY(1, 2)", {}), ( """select @EACH([a, b, c], x -> column like x AS @SQL('@{x}_y', 'Identifier')), @x""", "SELECT column LIKE a AS a_y, column LIKE b AS b_y, column LIKE c AS c_y, '3'", @@ -384,13 +397,15 @@ def test_ast_correctness(macro_evaluator): ), ("SELECT @EACH([1], a -> [@a])", "SELECT ARRAY(1)", {}), ("SELECT @EACH([1, 2], a -> [@a])", "SELECT ARRAY(1), ARRAY(2)", {}), - ("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1)", {}), + ("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", + "SELECT ARRAY(1)", {}), ( "SELECT @REDUCE(@EACH([1, 2], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}, ), - ("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}), + ("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", + "SELECT ARRAY(1) + ARRAY(2)", {}), ( """@WITH(@do_with) all_cities as (select * from city) select all_cities""", "WITH all_cities AS (SELECT * FROM city) SELECT all_cities", @@ -641,27 +656,35 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq): assert coerce(exp.Literal.number(1.1), float) == 1.1 assert coerce(exp.Literal.string("Hi mom"), str) == "Hi mom" assert coerce(exp.true(), bool) is True - assert coerce(exp.Literal.string("2020-01-01"), datetime) == to_datetime("2020-01-01") - assert coerce(exp.Literal.string("2020-01-01"), date) == to_date("2020-01-01") + assert coerce(exp.Literal.string("2020-01-01"), + datetime) == to_datetime("2020-01-01") + assert coerce(exp.Literal.string("2020-01-01"), + date) == to_date("2020-01-01") # Coercing a string literal to a column should return a column with the same name - assert_exp_eq(coerce(exp.Literal.string("order"), exp.Column), exp.column("order")) + assert_exp_eq(coerce(exp.Literal.string("order"), + exp.Column), exp.column("order")) # Not possible to coerce this string literal Cast to an exp.Column node -- so it should just return the input assert_exp_eq( - coerce(exp.Literal.string("order::date"), exp.Column), exp.Literal.string("order::date") + coerce(exp.Literal.string("order::date"), + exp.Column), exp.Literal.string("order::date") ) # This however, is correctly coercible since it's a cast assert_exp_eq( - coerce(exp.Literal.string("order::date"), exp.Cast), exp.cast(exp.column("order"), "DATE") + coerce(exp.Literal.string("order::date"), exp.Cast), exp.cast( + exp.column("order"), "DATE") ) # Here we resolve ambiguity via the user type hint - assert_exp_eq(coerce(exp.Literal.string("order"), exp.Identifier), exp.to_identifier("order")) - assert_exp_eq(coerce(exp.Literal.string("order"), exp.Table), exp.table_("order")) + assert_exp_eq(coerce(exp.Literal.string("order"), + exp.Identifier), exp.to_identifier("order")) + assert_exp_eq(coerce(exp.Literal.string("order"), + exp.Table), exp.table_("order")) # Resolve a union type hint by choosing the first one that works assert_exp_eq( - coerce(exp.Literal.string("order::date"), t.Union[exp.Column, exp.Cast]), + coerce(exp.Literal.string("order::date"), + t.Union[exp.Column, exp.Cast]), exp.cast(exp.column("order"), "DATE"), ) @@ -673,7 +696,8 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq): # From a string literal to a Select should parse the string literal, and the inverse operation works as well assert_exp_eq( - coerce(exp.Literal.string("SELECT 1 FROM a"), exp.Select), parse_one("SELECT 1 FROM a") + coerce(exp.Literal.string("SELECT 1 FROM a"), + exp.Select), parse_one("SELECT 1 FROM a") ) assert coerce(parse_one("SELECT 1 FROM a"), SQL) == "SELECT 1 FROM a" @@ -686,16 +710,21 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq): # Generics work as well, recursively resolving inner types assert coerce(parse_one("[1, 2, 3]"), t.List[int]) == [1, 2, 3] - assert coerce(parse_one("[1, 2, 3]"), t.Tuple[int, int, float]) == (1, 2, 3.0) + assert coerce(parse_one("[1, 2, 3]"), + t.Tuple[int, int, float]) == (1, 2, 3.0) assert coerce(parse_one("[1, 2, 3]"), t.Tuple[int, ...]) == (1, 2, 3) - assert coerce(parse_one("[1, 2, 3]"), t.Tuple[int, str, float]) == (1, "2", 3.0) + assert coerce(parse_one("[1, 2, 3]"), + t.Tuple[int, str, float]) == (1, "2", 3.0) assert coerce( - parse_one("[1, 2, [3]]"), t.Tuple[int, str, t.Union[float, t.Tuple[float, ...]]] + parse_one("[1, 2, [3]]"), t.Tuple[int, str, + t.Union[float, t.Tuple[float, ...]]] ) == (1, "2", (3.0,)) # Using exp.Expression will always return the input expression - assert coerce(parse_one("order", into=exp.Column), exp.Expression) == exp.column("order") - assert coerce(exp.Literal.string("OK"), exp.Expression) == exp.Literal.string("OK") + assert coerce(parse_one("order", into=exp.Column), + exp.Expression) == exp.column("order") + assert coerce(exp.Literal.string("OK"), + exp.Expression) == exp.Literal.string("OK") # Strict flag allows raising errors and is used when recursively coercing expressions # otherwise, in general, we want to be lenient and just warn the user when something is not possible @@ -739,7 +768,8 @@ def test_macro_parameter_resolution(macro_evaluator): MacroEvalError, match=".*'pos_only' parameter is positional only, but was passed as a keyword|.*missing a required positional-only argument: 'pos_only'|.*missing a required argument: 'a1'", ): - macro_evaluator.evaluate(parse_one("@test_arg_resolution(pos_only := 1)")) + macro_evaluator.evaluate( + parse_one("@test_arg_resolution(pos_only := 1)")) with pytest.raises(MacroEvalError, match=".*too many positional arguments"): macro_evaluator.evaluate(parse_one("@test_arg_resolution(1, 2, 3)")) @@ -771,7 +801,8 @@ def test_macro_first_value_ignore_respect_nulls(assert_exp_eq) -> None: actual_expr = d.parse_one( "SELECT FIRST_VALUE(@test(x) IGNORE NULLS) OVER (ORDER BY y) AS column_test" ) - assert_exp_eq(evaluator.transform(actual_expr), expected_sql, dialect="duckdb") + assert_exp_eq(evaluator.transform(actual_expr), + expected_sql, dialect="duckdb") expected_sql = ( "SELECT FIRST_VALUE(x RESPECT NULLS) OVER (ORDER BY y NULLS FIRST) AS column_test" @@ -779,7 +810,8 @@ def test_macro_first_value_ignore_respect_nulls(assert_exp_eq) -> None: actual_expr = d.parse_one( "SELECT FIRST_VALUE(@test(x) RESPECT NULLS) OVER (ORDER BY y) AS column_test" ) - assert_exp_eq(evaluator.transform(actual_expr), expected_sql, dialect="duckdb") + assert_exp_eq(evaluator.transform(actual_expr), + expected_sql, dialect="duckdb") DEDUPLICATE_SQL = """ @@ -844,7 +876,8 @@ def test_macro_first_value_ignore_respect_nulls(assert_exp_eq) -> None: def test_deduplicate(assert_exp_eq, dialect, sql, expected_sql): schema = MappingSchema({}, dialect=dialect) evaluator = MacroEvaluator(schema=schema, dialect=dialect) - assert_exp_eq(evaluator.transform(parse_one(sql)), expected_sql, dialect=dialect) + assert_exp_eq(evaluator.transform(parse_one(sql)), + expected_sql, dialect=dialect) def test_deduplicate_error_handling(macro_evaluator): @@ -853,21 +886,24 @@ def test_deduplicate_error_handling(macro_evaluator): SQLMeshError, match="partition_by must be a list of columns: \\[, cast\\( as \\)\\]", ): - macro_evaluator.evaluate(parse_one("@deduplicate(my_table, user_id, ['timestamp DESC'])")) + macro_evaluator.evaluate( + parse_one("@deduplicate(my_table, user_id, ['timestamp DESC'])")) # Test error handling: non-list order_by with pytest.raises( SQLMeshError, match="order_by must be a list of strings, optional - nulls ordering: \\[' nulls '\\]", ): - macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], 'timestamp DESC')")) + macro_evaluator.evaluate( + parse_one("@deduplicate(my_table, [user_id], 'timestamp DESC')")) # Test error handling: empty order_by with pytest.raises( SQLMeshError, match="order_by must be a list of strings, optional - nulls ordering: \\[' nulls '\\]", ): - macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], [])")) + macro_evaluator.evaluate( + parse_one("@deduplicate(my_table, [user_id], [])")) @pytest.mark.parametrize( @@ -1019,7 +1055,8 @@ def test_date_spine(assert_exp_eq, dialect, date_part): FROM _generated_dates ) AS _generated_dates """ - assert_exp_eq(evaluator.transform(parse_one(date_spine_macro)), expected_sql, dialect=dialect) + assert_exp_eq(evaluator.transform(parse_one(date_spine_macro)), + expected_sql, dialect=dialect) def test_date_spine_error_handling(macro_evaluator): @@ -1028,28 +1065,32 @@ def test_date_spine_error_handling(macro_evaluator): MacroEvalError, match=".*Invalid datepart 'invalid'. Expected: 'day', 'week', 'month', 'quarter', or 'year'", ): - macro_evaluator.evaluate(parse_one("@date_spine('invalid', '2022-01-01', '2024-12-31')")) + macro_evaluator.evaluate( + parse_one("@date_spine('invalid', '2022-01-01', '2024-12-31')")) # Test error handling: invalid start_date format with pytest.raises( MacroEvalError, match=".*Invalid date format - start_date and end_date must be in format: YYYY-MM-DD", ): - macro_evaluator.evaluate(parse_one("@date_spine('day', '2022/01/01', '2024-12-31')")) + macro_evaluator.evaluate( + parse_one("@date_spine('day', '2022/01/01', '2024-12-31')")) # Test error handling: invalid end_date format with pytest.raises( MacroEvalError, match=".*Invalid date format - start_date and end_date must be in format: YYYY-MM-DD", ): - macro_evaluator.evaluate(parse_one("@date_spine('day', '2022-01-01', '2024/12/31')")) + macro_evaluator.evaluate( + parse_one("@date_spine('day', '2022-01-01', '2024/12/31')")) # Test error handling: start_date after end_date with pytest.raises( MacroEvalError, match=".*Invalid date range - start_date '2024-12-31' is after end_date '2022-01-01'.", ): - macro_evaluator.evaluate(parse_one("@date_spine('day', '2024-12-31', '2022-01-01')")) + macro_evaluator.evaluate( + parse_one("@date_spine('day', '2024-12-31', '2022-01-01')")) def test_macro_union(assert_exp_eq, macro_evaluator: MacroEvaluator): @@ -1079,7 +1120,8 @@ def test_resolve_template_literal(): evaluator.transform(parsed_sql) evaluator.locals.update( - {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + {"this_model": exp.to_table( + "test_catalog.sqlmesh__test.test__test_model__2517971505")} ) assert ( @@ -1090,7 +1132,8 @@ def test_resolve_template_literal(): # Evaluating evaluator = MacroEvaluator(runtime_stage=RuntimeStage.EVALUATING) evaluator.locals.update( - {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + {"this_model": exp.to_table( + "test_catalog.sqlmesh__test.test__test_model__2517971505")} ) assert ( evaluator.transform(parsed_sql).sql() @@ -1105,7 +1148,8 @@ def test_resolve_template_table(): evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) evaluator.locals.update( - {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + {"this_model": exp.to_table( + "test_catalog.sqlmesh__test.test__test_model__2517971505")} ) assert ( @@ -1166,3 +1210,231 @@ def test_macro_coerce_literal_type(macro_evaluator): expression = d.parse_one("@TEST_LITERAL_TYPE(1.0)") with pytest.raises(MacroEvalError, match=".*Coercion failed"): macro_evaluator.transform(expression) + + +def test_lazy_macro_loading_with_missing_dependency(): + """Test that MacroEvaluator can be created even when python_env contains + imports with missing dependencies, as long as those macros aren't used.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + # Create python_env with a macro that imports a non-existent module + python_env = { + "missing_module_macro": Executable( + payload="from nonexistent_module import helper", + kind=ExecutableKind.IMPORT, + ), + "valid_value": Executable.value(42), + } + + # Should not raise during initialization + evaluator = MacroEvaluator(python_env=python_env) + + # Valid values should be accessible + assert evaluator.locals["valid_value"] == 42 + + # Trying to use the macro with missing dependency should raise helpful error + with pytest.raises(MacroEvalError) as exc_info: + evaluator.send("missing_module_macro") + + error_msg = str(exc_info.value).lower() + assert "missing dependency" in error_msg + assert "another project" in error_msg + + +def test_lazy_macro_loading_success(): + """Test that macros are successfully loaded when dependencies are available.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + python_env = { + "math_module": Executable( + payload="import math", + kind=ExecutableKind.IMPORT, + ), + } + + evaluator = MacroEvaluator(python_env=python_env) + + assert "math_module" in evaluator._unloaded_executables + + assert evaluator._ensure_executable_loaded("math_module") + + assert "math" in evaluator.env + assert "math_module" not in evaluator._unloaded_executables + + +def test_lazy_macro_loading_definition_with_dependency(): + """Test that macro definitions work even when their imports are deferred.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + python_env = { + "bad_import": Executable( + payload="from fake_package import fake_func", + kind=ExecutableKind.IMPORT, + ), + "macro_using_bad_import": Executable( + payload=""" +def my_macro(evaluator): + return fake_func() +""", + kind=ExecutableKind.DEFINITION, + name="my_macro", + ), + } + + # Should initialize without error (import is deferred) + evaluator = MacroEvaluator(python_env=python_env) + + # The macro definition should be loaded immediately + assert normalize_macro_name("my_macro") in evaluator.macros + + # The bad import should be deferred + assert "bad_import" in evaluator._unloaded_executables + + # Calling the macro should fail because bad_import can't be loaded + # When the macro is called, we try to load all imports first + with pytest.raises(MacroEvalError): # type: ignore + evaluator.send("my_macro") + + +def test_mixed_python_env_with_partial_loading(): + """Test environment with mix of loadable and unloadable executables.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + python_env = { + "good_import": Executable( + payload="import json", + kind=ExecutableKind.IMPORT, + ), + "bad_import": Executable( + payload="from fake_package import fake_func", + kind=ExecutableKind.IMPORT, + ), + "value": Executable.value({"key": "value"}), + } + + # Should initialize without error + evaluator = MacroEvaluator(python_env=python_env) + + # Good import should be deferred (not loaded yet) + assert "good_import" in evaluator._unloaded_executables + + # Bad import should also be deferred + assert "bad_import" in evaluator._unloaded_executables + + # Values should be loaded immediately + assert evaluator.locals["value"] == {"key": "value"} + + # Can successfully load good import on demand + assert evaluator._ensure_executable_loaded("good_import") + assert "json" in evaluator.env + + # Cannot load bad import + assert not evaluator._ensure_executable_loaded("bad_import") + assert "bad_import" in evaluator._failed_imports + + +def test_lazy_loading_with_macro_decorator(): + """Test that @macro() decorated functions are loaded correctly with lazy loading.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + # Simulate a macro definition that would be in python_env + python_env = { + "test_lazy_macro": Executable( + payload=""" +def test_lazy_macro(evaluator): + return 'lazy_loaded' +""", + kind=ExecutableKind.DEFINITION, + name="test_lazy_macro", + ), + } + + evaluator = MacroEvaluator(python_env=python_env) + + # The macro should have been loaded successfully + assert "test_lazy_macro" in evaluator.env + # And registered as a macro + assert normalize_macro_name("test_lazy_macro") in evaluator.macros + + +def test_lazy_loading_error_is_cached(): + """Test that failed imports are cached to avoid repeated attempts.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + python_env = { + "failing_import": Executable( + payload="from nonexistent_package import something", + kind=ExecutableKind.IMPORT, + ), + } + + evaluator = MacroEvaluator(python_env=python_env) + + # Try to load it (will fail) + assert not evaluator._ensure_executable_loaded("failing_import") + + # Should be in failed imports + assert "failing_import" in evaluator._failed_imports + + # Second attempt should use cached failure (not try to load again) + assert not evaluator._ensure_executable_loaded("failing_import") + + +def test_lazy_loading_preserves_values(): + """Test that SQLMESH_VARS and other special values are loaded correctly.""" + from sqlmesh.utils.metaprogramming import ExecutableKind, SqlValue + + python_env = { + c.SQLMESH_VARS: Executable.value({ + "my_var": "test_value", + "sql_var": SqlValue(sql="SELECT 1"), + }), + c.SQLMESH_BLUEPRINT_VARS: Executable.value({ + "blueprint_var": "blueprint_value", + }), + } + + evaluator = MacroEvaluator(python_env=python_env) + + # Special variables should be loaded and processed + assert c.SQLMESH_VARS in evaluator.locals + assert evaluator.locals[c.SQLMESH_VARS]["my_var"] == "test_value" + + # SQL values should be parsed + assert isinstance( + evaluator.locals[c.SQLMESH_VARS]["sql_var"], exp.Expression) + + assert c.SQLMESH_BLUEPRINT_VARS in evaluator.locals + assert evaluator.locals[c.SQLMESH_BLUEPRINT_VARS]["blueprint_var"] == "blueprint_value" + + +def test_macro_evaluator_backward_compatibility(): + """Ensure existing macro behavior is unchanged with lazy loading.""" + from sqlmesh.utils.metaprogramming import ExecutableKind + + # Test with traditional python_env (no lazy loading needed) + python_env = { + "var": Executable.value(100), + "simple_macro": Executable( + payload=""" +def simple_macro(evaluator): + return evaluator.locals.get('var', 0) * 2 +""", + kind=ExecutableKind.DEFINITION, + name="simple_macro", + ), + } + + evaluator = MacroEvaluator(python_env=python_env) + + # Macro should be loaded and available + assert normalize_macro_name("simple_macro") in evaluator.macros + + # Should work exactly as before + result = evaluator.send("simple_macro") + assert result == 200 + + +def normalize_macro_name(name: str) -> str: + """Helper to normalize macro names for testing.""" + return name.lower()