diff --git a/src/FixedEffectModel.jl b/src/FixedEffectModel.jl index dd87f92..0eb67f9 100644 --- a/src/FixedEffectModel.jl +++ b/src/FixedEffectModel.jl @@ -122,7 +122,7 @@ end # Does the formula have InteractionTerms? function has_cont_fe_interaction(x::FormulaTerm) - if x.rhs isa Term # only one term + if x.rhs isa AbstractTerm # only one term is_cont_fe_int(x) elseif hasfield(typeof(x.rhs), :lhs) # Is an IV term false # Is this correct? @@ -133,19 +133,27 @@ end function StatsAPI.predict(m::FixedEffectModel, data) Tables.istable(data) || - throw(ArgumentError("expected second argument to be a Table, got $(typeof(data))")) + throw(ArgumentError("Expected second argument to be a Table, got $(typeof(data))")) has_cont_fe_interaction(m.formula) && throw(ArgumentError("Interaction of fixed effect and continuous variable detected in formula; this is currently not supported in `predict`")) + # only fixed effects cdata = StatsModels.columntable(data) - cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) - Xnew = modelmatrix(m.formula_schema, cols) - if all(nonmissings) - out = Xnew * m.coef + nrows = length(Tables.rows(cdata)) + if m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),)) + has_fe(m) || throw(ArgumentError("To be used with predict, a model requires regressors or fixed effects")) + out = zeros(Float64, nrows) + nonmissings = trues(nrows) else - out = Vector{Union{Float64, Missing}}(missing, length(Tables.rows(cdata))) - out[nonmissings] = Xnew * m.coef + cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs) + Xnew = modelmatrix(m.formula_schema, cols) + if all(nonmissings) + out = Xnew * m.coef + else + out = Vector{Union{Float64, Missing}}(missing, nrows) + out[nonmissings] = Xnew * m.coef + end end # Join FE estimates onto data and sum row-wise diff --git a/src/fit.jl b/src/fit.jl index 9ad567f..b4d6dad 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -124,7 +124,8 @@ function StatsAPI.fit(::Type{FixedEffectModel}, has_iv = formula_iv != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) formula, formula_fes = parse_fe(formula) has_fes = formula_fes != FormulaTerm(ConstantTerm(0), ConstantTerm(0)) - save_fes = (save == :fe) | ((save == :all) & has_fes) + # when save = :fe but there are no fixed effects in the formula, don't save fixed effects + save_fes = save ∈ (:fe, :all) && has_fes has_weights = weights !== nothing # Compute feM, an AbstractFixedEffectSolver diff --git a/test/predict.jl b/test/predict.jl index 82bc1bf..f1a85f7 100644 --- a/test/predict.jl +++ b/test/predict.jl @@ -162,6 +162,13 @@ end 4.0 .* (df.g4 .== "h") .* df.z m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&fe(g3) + fe(g4)&z)) @test_throws ArgumentError pred = predict(m, df) + + + # only fixed effects + df = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10)) + out1 = predict(reg(df, @formula(y ~ fe(id) + fe(t)), save = :fe), df) + out2 = predict(reg(df, @formula(y ~ 1 + fe(id) + fe(t)), save = :fe), df) + @test all(out1 .≈ out2) end @testset "Continuous/FE detection" begin