|
7 | 7 | from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence |
8 | 8 | from dataclasses import dataclass |
9 | 9 | from operator import itemgetter |
10 | | -from typing import Any, Literal, cast |
| 10 | +from typing import TYPE_CHECKING, Any, Literal, cast |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint |
| 14 | + from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline |
11 | 15 |
|
12 | 16 | from langchain_core.callbacks.manager import ( |
13 | 17 | AsyncCallbackManagerForLLMRun, |
@@ -599,6 +603,51 @@ def _set_model_profile(self) -> Self: |
599 | 603 | self.profile = _get_default_model_profile(self.model_id) |
600 | 604 | return self |
601 | 605 |
|
| 606 | + @classmethod |
| 607 | + def from_model_id( |
| 608 | + cls, |
| 609 | + model_id: str, |
| 610 | + task: str | None = None, |
| 611 | + backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline", |
| 612 | + **kwargs: Any, |
| 613 | + ) -> ChatHuggingFace: |
| 614 | + """Construct a ChatHuggingFace model from a model_id. |
| 615 | +
|
| 616 | + Args: |
| 617 | + model_id: The model ID of the Hugging Face model. |
| 618 | + task: The task to perform (e.g., "text-generation"). |
| 619 | + backend: The backend to use. One of "pipeline", "endpoint", "text-gen". |
| 620 | + **kwargs: Additional arguments to pass to the backend or ChatHuggingFace. |
| 621 | + """ |
| 622 | + llm: ( |
| 623 | + Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference |
| 624 | + ) |
| 625 | + if backend == "pipeline": |
| 626 | + from langchain_huggingface.llms.huggingface_pipeline import ( |
| 627 | + HuggingFacePipeline, |
| 628 | + ) |
| 629 | + |
| 630 | + llm = HuggingFacePipeline.from_model_id( |
| 631 | + model_id=model_id, task=cast(str, task), **kwargs |
| 632 | + ) |
| 633 | + elif backend == "endpoint": |
| 634 | + from langchain_huggingface.llms.huggingface_endpoint import ( |
| 635 | + HuggingFaceEndpoint, |
| 636 | + ) |
| 637 | + |
| 638 | + llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs) |
| 639 | + elif backend == "text-gen": |
| 640 | + from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found] |
| 641 | + HuggingFaceTextGenInference, |
| 642 | + ) |
| 643 | + |
| 644 | + llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs) |
| 645 | + else: |
| 646 | + msg = f"Unknown backend: {backend}" |
| 647 | + raise ValueError(msg) |
| 648 | + |
| 649 | + return cls(llm=llm, **kwargs) |
| 650 | + |
602 | 651 | def _create_chat_result(self, response: dict) -> ChatResult: |
603 | 652 | generations = [] |
604 | 653 | token_usage = response.get("usage", {}) |
|
0 commit comments