From ef4dd5edc57581d8f2f55ddfacfbfe5208b69b6e Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:26:17 -0500 Subject: [PATCH 01/10] Refactor prediction logic for fixed effects model Allow only fixed effects in predict --- src/FixedEffectModel.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index dd87f92..c1b2433 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -138,14 +138,20 @@ function StatsAPI.predict(m::FixedEffectModel, data) has_cont_fe_interaction(m.formula) && throw(ArgumentError("Interaction of fixed effect and continuous variable detected in formula; this is currently not supported in `predict`")) + # only fixed effects cdata = StatsModels.columntable(data) - cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) - Xnew = modelmatrix(m.formula_schema, cols) - if all(nonmissings) - out = Xnew * m.coef + nrows = length(Tables.rows(cdata)) + if m.formula_schema.rhs == m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) + out = zeros(Float64, nrows) else - out = Vector{Union{Float64, Missing}}(missing, length(Tables.rows(cdata))) - out[nonmissings] = Xnew * m.coef + cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) + Xnew = modelmatrix(m.formula_schema, cols) + if all(nonmissings) + out = Xnew * m.coef + else + out = Vector{Union{Float64, Missing}}(missing, nrows) + out[nonmissings] = Xnew * m.coef + end end # Join FE estimates onto data and sum row-wise From a845746dd3177ec19534ef6b4960338a38734c23 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:28:23 -0500 Subject: [PATCH 02/10] Enhance predict.jl with fixed effects tests Add tests for fixed effects prediction in predict.jl --- test/predict.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/predict.jl b/test/predict.jl index 82bc1bf..076738a 100644 --- a/test/predict.jl +++ b/test/predict.jl @@ -162,6 +162,13 @@ end 4.0 .* (df.g4 .== "h") .* df.z m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&fe(g3) + fe(g4)&z)) @test_throws ArgumentError pred = predict(m, df) + + + # only fixed effects + f = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10)) + out1 = predict(reg(df, @formula(y ~ fe(id) + fe(t)), save = :fe), df) + out2 = predict(reg(df, @formula(y ~ 1 + fe(id) + fe(t)), save = :fe), df) + @test out1 .≈ out2 end @testset "Continuous/FE detection" begin From 8dadd3f8772681ea8ad872b19aac514eff62cf8d Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:29:59 -0500 Subject: [PATCH 03/10] Fix condition check for formula schema in FixedEffectModel --- src/FixedEffectModel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index c1b2433..4e9346f 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -141,7 +141,7 @@ function StatsAPI.predict(m::FixedEffectModel, data) # only fixed effects cdata = StatsModels.columntable(data) nrows = length(Tables.rows(cdata)) - if m.formula_schema.rhs == m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) + if m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) out = zeros(Float64, nrows) else cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) From 0534f2a0ecdd7e25bb21e94547426a5808cab7a7 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:38:07 -0500 Subject: [PATCH 04/10] Improve save argument handling for fixed effects Add error handling for save keyword argument when no fixed effects are present. --- src/fit.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fit.jl b/src/fit.jl index 9ad567f..ac683a5 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -124,7 +124,11 @@ function StatsAPI.fit(::Type{FixedEffectModel}, has_iv = formula_iv != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) formula, formula_fes = parse_fe(formula) has_fes = formula_fes != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) - save_fes = (save == :fe) | ((save == :all) & has_fes) + if save == :fe && !has_fes + throw("the save keyword argument is set to :fe but there are no fixed effects in the formula.") + end + + save_fes = save ∈ (:fe, :all) && has_fes has_weights = weights !== nothing # Compute feM, an AbstractFixedEffectSolver From 9aded469bf46b1665e8db9454435d0946b9dacb0 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:39:24 -0500 Subject: [PATCH 05/10] Refactor fixed effects saving condition Remove error throw for missing fixed effects when saving. --- src/fit.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/fit.jl b/src/fit.jl index ac683a5..b4d6dad 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -124,10 +124,7 @@ function StatsAPI.fit(::Type{FixedEffectModel}, has_iv = formula_iv != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) formula, formula_fes = parse_fe(formula) has_fes = formula_fes != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) - if save == :fe && !has_fes - throw("the save keyword argument is set to :fe but there are no fixed effects in the formula.") - end - + # when save = :fe but there are no fixed effects in the formula, don't save fixed effects save_fes = save ∈ (:fe, :all) && has_fes has_weights = weights !== nothing From 7ad7bb978ff0837b1ead8f89363f241b12b09204 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:41:41 -0500 Subject: [PATCH 06/10] Fix variable name for DataFrame in test case --- test/predict.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/predict.jl b/test/predict.jl index 076738a..2365a37 100644 --- a/test/predict.jl +++ b/test/predict.jl @@ -165,7 +165,7 @@ end # only fixed effects - f = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10)) + df = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10)) out1 = predict(reg(df, @formula(y ~ fe(id) + fe(t)), save = :fe), df) out2 = predict(reg(df, @formula(y ~ 1 + fe(id) + fe(t)), save = :fe), df) @test out1 .≈ out2 From e8d4697741c0e25736a05888f52124561a740df0 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 15:55:21 -0500 Subject: [PATCH 07/10] Update type check for interaction terms in formula --- src/FixedEffectModel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index 4e9346f..c0a78cc 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -122,7 +122,7 @@ end # Does the formula have InteractionTerms? function has_cont_fe_interaction(x::FormulaTerm) - if x.rhs isa Term # only one term + if x.rhs isa AbstractTerm # only one term is_cont_fe_int(x) elseif hasfield(typeof(x.rhs), :lhs) # Is an IV term false # Is this correct? From 2d7ed1a93bb1b0d7d371bf54a724ca8e4a025bd2 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 16:00:12 -0500 Subject: [PATCH 08/10] Add nonmissings initialization for FixedEffectModel --- src/FixedEffectModel.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index c0a78cc..b1d57a2 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -143,6 +143,7 @@ function StatsAPI.predict(m::FixedEffectModel, data) nrows = length(Tables.rows(cdata)) if m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) out = zeros(Float64, nrows) + nonmissings = trues(nrows) else cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) Xnew = modelmatrix(m.formula_schema, cols) From f2417ebd843f0f9a7c103f892fdfedbd12ff00f8 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 16:09:02 -0500 Subject: [PATCH 09/10] Fix error message in StatsAPI.predict function Corrected error message formatting in predict function. --- src/FixedEffectModel.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index b1d57a2..0eb67f9 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -133,7 +133,7 @@ end function StatsAPI.predict(m::FixedEffectModel, data) Tables.istable(data) || - throw(ArgumentError("expected second argument to be a Table, got $(typeof(data))")) + throw(ArgumentError("Expected second argument to be a Table, got $(typeof(data))")) has_cont_fe_interaction(m.formula) && throw(ArgumentError("Interaction of fixed effect and continuous variable detected in formula; this is currently not supported in `predict`")) @@ -142,6 +142,7 @@ function StatsAPI.predict(m::FixedEffectModel, data) cdata = StatsModels.columntable(data) nrows = length(Tables.rows(cdata)) if m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) + has_fe(m) || throw(ArgumentError("To be used with predict, a model requires regressors or fixed effects")) out = zeros(Float64, nrows) nonmissings = trues(nrows) else From 147bcc1255b0de400eae11ad8180c6f2dee7c3a5 Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Tue, 20 Jan 2026 16:09:55 -0500 Subject: [PATCH 10/10] Change test condition to use 'all' for comparison --- test/predict.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/predict.jl b/test/predict.jl index 2365a37..f1a85f7 100644 --- a/test/predict.jl +++ b/test/predict.jl @@ -168,7 +168,7 @@ end df = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10)) out1 = predict(reg(df, @formula(y ~ fe(id) + fe(t)), save = :fe), df) out2 = predict(reg(df, @formula(y ~ 1 + fe(id) + fe(t)), save = :fe), df) - @test out1 .≈ out2 + @test all(out1 .≈ out2) end @testset "Continuous/FE detection" begin