Skip to content

Commit bf6a5eb

Browse files
fix(huggingface): Helper logic for init_chat_model with HuggingFace backend (#34259)
1 parent 5720dea commit bf6a5eb

File tree

3 files changed

+53
-5
lines changed
  • libs
    • langchain_v1/langchain/chat_models
    • langchain/langchain_classic/chat_models
    • partners/huggingface/langchain_huggingface/chat_models

3 files changed

+53
-5
lines changed

libs/langchain/langchain_classic/chat_models/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,9 @@ def _init_chat_model_helper(
444444

445445
if model_provider == "huggingface":
446446
_check_pkg("langchain_huggingface")
447-
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
447+
from langchain_huggingface import ChatHuggingFace
448448

449-
llm = HuggingFacePipeline.from_model_id(model_id=model, **kwargs)
450-
return ChatHuggingFace(llm=llm)
449+
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
451450

452451
if model_provider == "groq":
453452
_check_pkg("langchain_groq")

libs/langchain_v1/langchain/chat_models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def _init_chat_model_helper(
405405
_check_pkg("langchain_huggingface")
406406
from langchain_huggingface import ChatHuggingFace
407407

408-
return ChatHuggingFace(model_id=model, **kwargs)
408+
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
409409
if model_provider == "groq":
410410
_check_pkg("langchain_groq")
411411
from langchain_groq import ChatGroq

libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
88
from dataclasses import dataclass
99
from operator import itemgetter
10-
from typing import Any, Literal, cast
10+
from typing import TYPE_CHECKING, Any, Literal, cast
11+
12+
if TYPE_CHECKING:
13+
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
14+
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
1115

1216
from langchain_core.callbacks.manager import (
1317
AsyncCallbackManagerForLLMRun,
@@ -599,6 +603,51 @@ def _set_model_profile(self) -> Self:
599603
self.profile = _get_default_model_profile(self.model_id)
600604
return self
601605

606+
@classmethod
607+
def from_model_id(
608+
cls,
609+
model_id: str,
610+
task: str | None = None,
611+
backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline",
612+
**kwargs: Any,
613+
) -> ChatHuggingFace:
614+
"""Construct a ChatHuggingFace model from a model_id.
615+
616+
Args:
617+
model_id: The model ID of the Hugging Face model.
618+
task: The task to perform (e.g., "text-generation").
619+
backend: The backend to use. One of "pipeline", "endpoint", "text-gen".
620+
**kwargs: Additional arguments to pass to the backend or ChatHuggingFace.
621+
"""
622+
llm: (
623+
Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference
624+
)
625+
if backend == "pipeline":
626+
from langchain_huggingface.llms.huggingface_pipeline import (
627+
HuggingFacePipeline,
628+
)
629+
630+
llm = HuggingFacePipeline.from_model_id(
631+
model_id=model_id, task=cast(str, task), **kwargs
632+
)
633+
elif backend == "endpoint":
634+
from langchain_huggingface.llms.huggingface_endpoint import (
635+
HuggingFaceEndpoint,
636+
)
637+
638+
llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs)
639+
elif backend == "text-gen":
640+
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
641+
HuggingFaceTextGenInference,
642+
)
643+
644+
llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs)
645+
else:
646+
msg = f"Unknown backend: {backend}"
647+
raise ValueError(msg)
648+
649+
return cls(llm=llm, **kwargs)
650+
602651
def _create_chat_result(self, response: dict) -> ChatResult:
603652
generations = []
604653
token_usage = response.get("usage", {})

0 commit comments

Comments
 (0)