Skip to content
Merged
24 changes: 16 additions & 8 deletions src/FixedEffectModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ end

# Does the formula have InteractionTerms?
function has_cont_fe_interaction(x::FormulaTerm)
if x.rhs isa Term # only one term
if x.rhs isa AbstractTerm # only one term
is_cont_fe_int(x)
elseif hasfield(typeof(x.rhs), :lhs) # Is an IV term
false # Is this correct?
Expand All @@ -133,19 +133,27 @@ end

function StatsAPI.predict(m::FixedEffectModel, data)
Tables.istable(data) ||
throw(ArgumentError("expected second argument to be a Table, got $(typeof(data))"))
throw(ArgumentError("Expected second argument to be a Table, got $(typeof(data))"))

has_cont_fe_interaction(m.formula) &&
throw(ArgumentError("Interaction of fixed effect and continuous variable detected in formula; this is currently not supported in `predict`"))

# only fixed effects
cdata = StatsModels.columntable(data)
cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs)
Xnew = modelmatrix(m.formula_schema, cols)
if all(nonmissings)
out = Xnew * m.coef
nrows = length(Tables.rows(cdata))
if m.formula_schema.rhs == MatrixTerm((InterceptTerm{false}(),))
has_fe(m) || throw(ArgumentError("To be used with predict, a model requires regressors or fixed effects"))
out = zeros(Float64, nrows)
nonmissings = trues(nrows)
else
out = Vector{Union{Float64, Missing}}(missing, length(Tables.rows(cdata)))
out[nonmissings] = Xnew * m.coef
cols, nonmissings = StatsModels.missing_omit(cdata, m.formula_schema.rhs)
Xnew = modelmatrix(m.formula_schema, cols)
if all(nonmissings)
out = Xnew * m.coef
else
out = Vector{Union{Float64, Missing}}(missing, nrows)
out[nonmissings] = Xnew * m.coef
end
end

# Join FE estimates onto data and sum row-wise
Expand Down
3 changes: 2 additions & 1 deletion src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ function StatsAPI.fit(::Type{FixedEffectModel},
has_iv = formula_iv != FormulaTerm(ConstantTerm(0), ConstantTerm(0))
formula, formula_fes = parse_fe(formula)
has_fes = formula_fes != FormulaTerm(ConstantTerm(0), ConstantTerm(0))
save_fes = (save == :fe) | ((save == :all) & has_fes)
# when save = :fe but there are no fixed effects in the formula, don't save fixed effects
save_fes = save ∈ (:fe, :all) && has_fes
has_weights = weights !== nothing

# Compute feM, an AbstractFixedEffectSolver
Expand Down
7 changes: 7 additions & 0 deletions test/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ end
4.0 .* (df.g4 .== "h") .* df.z
m = reg(df, @formula(y ~ x + fe(g1) + fe(g2)&fe(g3) + fe(g4)&z))
@test_throws ArgumentError pred = predict(m, df)


# only fixed effects
df = DataFrame(y=rand(10), id = rand(1:2, 10), t = rand(1:2, 10))
out1 = predict(reg(df, @formula(y ~ fe(id) + fe(t)), save = :fe), df)
out2 = predict(reg(df, @formula(y ~ 1 + fe(id) + fe(t)), save = :fe), df)
@test all(out1 .≈ out2)
end

@testset "Continuous/FE detection" begin
Expand Down