Skip to content

Commit 16c984e

Browse files
fix(langchain-classic): fix init_chat_model for HuggingFace models (#33943)
1 parent 13dd115 commit 16c984e

File tree

4 files changed

+163
-26
lines changed

4 files changed

+163
-26
lines changed

libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,54 @@ def from_model_id(
627627
HuggingFacePipeline,
628628
)
629629

630+
task = task if task is not None else "text-generation"
631+
632+
# Separate pipeline-specific kwargs from ChatHuggingFace kwargs
633+
# Parameters that should go to HuggingFacePipeline.from_model_id
634+
pipeline_specific_kwargs = {}
635+
636+
# Extract pipeline-specific parameters
637+
pipeline_keys = [
638+
"backend",
639+
"device",
640+
"device_map",
641+
"model_kwargs",
642+
"pipeline_kwargs",
643+
"batch_size",
644+
]
645+
for key in pipeline_keys:
646+
if key in kwargs:
647+
pipeline_specific_kwargs[key] = kwargs.pop(key)
648+
649+
# Remaining kwargs (temperature, max_tokens, etc.) should go to
650+
# pipeline_kwargs for generation parameters, which ChatHuggingFace
651+
# will inherit from the LLM
652+
if "pipeline_kwargs" not in pipeline_specific_kwargs:
653+
pipeline_specific_kwargs["pipeline_kwargs"] = {}
654+
655+
# Add generation parameters to pipeline_kwargs
656+
# Map max_tokens to max_new_tokens for HuggingFace pipeline
657+
generation_params = {}
658+
for k, v in list(kwargs.items()):
659+
if k == "max_tokens":
660+
generation_params["max_new_tokens"] = v
661+
kwargs.pop(k)
662+
elif k in (
663+
"temperature",
664+
"max_new_tokens",
665+
"top_p",
666+
"top_k",
667+
"repetition_penalty",
668+
"do_sample",
669+
):
670+
generation_params[k] = v
671+
kwargs.pop(k)
672+
673+
pipeline_specific_kwargs["pipeline_kwargs"].update(generation_params)
674+
675+
# Create the HuggingFacePipeline
630676
llm = HuggingFacePipeline.from_model_id(
631-
model_id=model_id, task=cast(str, task), **kwargs
677+
model_id=model_id, task=task, **pipeline_specific_kwargs
632678
)
633679
elif backend == "endpoint":
634680
from langchain_huggingface.llms.huggingface_endpoint import (

libs/partners/huggingface/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ test = [
4646
"langchain-core",
4747
"langchain-tests",
4848
"langchain-community",
49+
"langchain",
4950
]
5051
lint = ["ruff>=0.13.1,<0.14.0"]
5152
dev = [
@@ -61,6 +62,7 @@ typing = [
6162
[tool.uv.sources]
6263
langchain-core = { path = "../../core", editable = true }
6364
langchain-tests = { path = "../../standard-tests", editable = true }
65+
langchain = { path = "../../langchain_v1", editable = true }
6466

6567
[tool.mypy]
6668
disallow_untyped_defs = "True"

libs/partners/huggingface/tests/unit_tests/test_chat_models.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,55 @@ def test_profile() -> None:
337337
llm=empty_llm,
338338
)
339339
assert model.profile
340+
341+
342+
def test_init_chat_model_huggingface() -> None:
343+
"""Test that init_chat_model works with HuggingFace models.
344+
345+
This test verifies that the fix for issue #28226 works correctly.
346+
The issue was that init_chat_model didn't properly handle HuggingFace
347+
model initialization, particularly the required 'task' parameter and
348+
parameter separation between HuggingFacePipeline and ChatHuggingFace.
349+
"""
350+
from langchain.chat_models.base import init_chat_model
351+
352+
# Test basic initialization with default task
353+
# Note: This test may skip in CI if model download fails, but it verifies
354+
# that the initialization code path works correctly
355+
try:
356+
llm = init_chat_model(
357+
model="microsoft/Phi-3-mini-4k-instruct",
358+
model_provider="huggingface",
359+
temperature=0,
360+
max_tokens=1024,
361+
)
362+
363+
# Verify that ChatHuggingFace was created successfully
364+
assert llm is not None
365+
from langchain_huggingface import ChatHuggingFace
366+
367+
assert isinstance(llm, ChatHuggingFace)
368+
369+
# Verify that the llm attribute is set (this was the bug - it was missing)
370+
assert hasattr(llm, "llm")
371+
assert llm.llm is not None
372+
373+
# Test with explicit task parameter
374+
llm2 = init_chat_model(
375+
model="microsoft/Phi-3-mini-4k-instruct",
376+
model_provider="huggingface",
377+
task="text-generation",
378+
temperature=0.5,
379+
)
380+
assert isinstance(llm2, ChatHuggingFace)
381+
assert llm2.llm is not None
382+
except (
383+
ImportError,
384+
OSError,
385+
RuntimeError,
386+
ValueError,
387+
) as e:
388+
# If model download fails in CI, skip the test rather than failing
389+
# The important part is that the code path doesn't raise ValidationError
390+
# about missing 'llm' field, which was the original bug
391+
pytest.skip(f"Skipping test due to model download/initialization error: {e}")

libs/partners/huggingface/uv.lock

Lines changed: 62 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)