Skip to content

Commit f902504

Browse files
authored
feat: support multimodal embeddings (#237)
* Inline multimodal entities into existing models * apply ruff * bump to 0.7.0b1 * fix: remove tenant_id from invoke_multimodal_embedding * tests: add rerank * apply ruff * fix * fix: typing
1 parent 61c7481 commit f902504

File tree

13 files changed

+445
-10
lines changed

13 files changed

+445
-10
lines changed

python/dify_plugin/core/entities/invocation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ class InvokeType(Enum):
66
LLM = "llm"
77
LLMStructuredOutput = "llm_structured_output"
88
TextEmbedding = "text_embedding"
9+
MultimodalEmbedding = "multimodal_embedding"
910
Rerank = "rerank"
11+
MultimodalRerank = "multimodal_rerank"
1012
TTS = "tts"
1113
Speech2Text = "speech2text"
1214
Moderation = "moderation"

python/dify_plugin/core/entities/plugin/request.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Mapping
1+
from collections.abc import Mapping, Sequence
22
from enum import StrEnum
33
from typing import Any
44

@@ -9,7 +9,7 @@
99
OnlineDriveBrowseFilesRequest,
1010
OnlineDriveDownloadFileRequest,
1111
)
12-
from dify_plugin.entities.model import ModelType
12+
from dify_plugin.entities.model import EmbeddingInputType, ModelType
1313
from dify_plugin.entities.model.message import (
1414
AssistantPromptMessage,
1515
PromptMessage,
@@ -19,6 +19,7 @@
1919
ToolPromptMessage,
2020
UserPromptMessage,
2121
)
22+
from dify_plugin.entities.model.text_embedding import MultiModalContent
2223
from dify_plugin.entities.provider_config import CredentialType
2324
from dify_plugin.entities.trigger import Subscription
2425

@@ -59,8 +60,10 @@ class ModelActions(StrEnum):
5960
InvokeLLM = "invoke_llm"
6061
GetLLMNumTokens = "get_llm_num_tokens"
6162
InvokeTextEmbedding = "invoke_text_embedding"
63+
InvokeMultimodalEmbedding = "invoke_multimodal_embedding"
6264
GetTextEmbeddingNumTokens = "get_text_embedding_num_tokens"
6365
InvokeRerank = "invoke_rerank"
66+
InvokeMultimodalRerank = "invoke_multimodal_rerank"
6467
InvokeTTS = "invoke_tts"
6568
GetTTSVoices = "get_tts_model_voices"
6669
InvokeSpeech2Text = "invoke_speech2text"
@@ -202,6 +205,14 @@ class ModelInvokeTextEmbeddingRequest(PluginAccessModelRequest):
202205
texts: list[str]
203206

204207

208+
class ModelInvokeMultimodalEmbeddingRequest(PluginAccessModelRequest):
209+
action: ModelActions = ModelActions.InvokeMultimodalEmbedding
210+
model_type: ModelType = ModelType.TEXT_EMBEDDING
211+
212+
documents: list[MultiModalContent]
213+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
214+
215+
205216
class ModelGetTextEmbeddingNumTokens(PluginAccessModelRequest):
206217
action: ModelActions = ModelActions.GetTextEmbeddingNumTokens
207218

@@ -217,6 +228,16 @@ class ModelInvokeRerankRequest(PluginAccessModelRequest):
217228
top_n: int | None
218229

219230

231+
class ModelInvokeMultimodalRerankRequest(PluginAccessModelRequest):
232+
action: ModelActions = ModelActions.InvokeMultimodalRerank
233+
model_type: ModelType = ModelType.RERANK
234+
235+
query: MultiModalContent
236+
docs: Sequence[MultiModalContent]
237+
score_threshold: float | None
238+
top_n: int | None
239+
240+
220241
class ModelInvokeTTSRequest(PluginAccessModelRequest):
221242
action: ModelActions = ModelActions.InvokeTTS
222243

python/dify_plugin/core/plugin_executor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
ModelGetTTSVoices,
2323
ModelInvokeLLMRequest,
2424
ModelInvokeModerationRequest,
25+
ModelInvokeMultimodalEmbeddingRequest,
26+
ModelInvokeMultimodalRerankRequest,
2527
ModelInvokeRerankRequest,
2628
ModelInvokeSpeech2TextRequest,
2729
ModelInvokeTextEmbeddingRequest,
@@ -219,6 +221,18 @@ def invoke_text_embedding(self, session: Session, data: ModelInvokeTextEmbedding
219221
else:
220222
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")
221223

224+
def invoke_multimodal_embedding(self, session: Session, data: ModelInvokeMultimodalEmbeddingRequest):
225+
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
226+
if isinstance(model_instance, TextEmbeddingModel):
227+
return model_instance.invoke_multimodal(
228+
data.model,
229+
data.credentials,
230+
data.documents,
231+
user=data.user_id,
232+
input_type=data.input_type,
233+
)
234+
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")
235+
222236
def get_text_embedding_num_tokens(self, session: Session, data: ModelGetTextEmbeddingNumTokens):
223237
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
224238
if isinstance(model_instance, TextEmbeddingModel):
@@ -247,6 +261,20 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest):
247261
else:
248262
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")
249263

264+
def invoke_multimodal_rerank(self, session: Session, data: ModelInvokeMultimodalRerankRequest):
265+
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
266+
if isinstance(model_instance, RerankModel):
267+
return model_instance.invoke_multimodal(
268+
data.model,
269+
data.credentials,
270+
data.query,
271+
data.docs,
272+
score_threshold=data.score_threshold,
273+
top_n=data.top_n,
274+
user=data.user_id,
275+
)
276+
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")
277+
250278
def invoke_tts(self, session: Session, data: ModelInvokeTTSRequest):
251279
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
252280
if isinstance(model_instance, TTSModel):

python/dify_plugin/entities/model/rerank.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel, ConfigDict
1+
from pydantic import BaseModel, ConfigDict, Field
22

33
from dify_plugin.entities.model import BaseModelConfig, ModelType
44

@@ -32,3 +32,26 @@ class RerankModelConfig(BaseModelConfig):
3232
top_n: int
3333

3434
model_config = ConfigDict(protected_namespaces=())
35+
36+
37+
class MultiModalRerankResult(BaseModel):
38+
"""Rerank response produced by a multimodal rerank model."""
39+
40+
model: str = Field(..., description="Identifier of the model producing the reranked documents.")
41+
docs: list[RerankDocument] = Field(..., description="Reranked documents with scores.")
42+
43+
44+
class MultiModalRerankModelConfig(BaseModelConfig):
45+
"""Configuration payload for invoking a multimodal rerank model."""
46+
47+
model_type: ModelType = ModelType.RERANK
48+
score_threshold: float | None = Field(
49+
default=None,
50+
description="Optional threshold for filtering documents based on score.",
51+
)
52+
top_n: int | None = Field(
53+
default=None,
54+
description="Optional limit on the number of documents returned.",
55+
)
56+
57+
model_config = ConfigDict(protected_namespaces=())

python/dify_plugin/entities/model/text_embedding.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from decimal import Decimal
2+
from enum import StrEnum
23

3-
from pydantic import BaseModel, ConfigDict
4+
from pydantic import BaseModel, ConfigDict, Field
45

56
from dify_plugin.entities.model import BaseModelConfig, ModelType, ModelUsage
67

@@ -37,3 +38,34 @@ class TextEmbeddingModelConfig(BaseModelConfig):
3738
model_type: ModelType = ModelType.TEXT_EMBEDDING
3839

3940
model_config = ConfigDict(protected_namespaces=())
41+
42+
43+
class MultiModalContentType(StrEnum):
44+
"""Supported content types for multimodal inputs."""
45+
46+
TEXT = "text"
47+
IMAGE = "image"
48+
49+
50+
class MultiModalContent(BaseModel):
51+
"""A multimodal content payload provided by the caller."""
52+
53+
content: str = Field(..., description="The payload content, plain text or base64 encoded file data.")
54+
content_type: MultiModalContentType = Field(..., description="The modality of the provided content.")
55+
56+
57+
class MultiModalEmbeddingResult(BaseModel):
58+
"""Embedding response produced by a multimodal embedding model."""
59+
60+
model: str = Field(..., description="Identifier of the model generating embeddings.")
61+
embeddings: list[list[float]] = Field(..., description="Embedding vectors for provided contents.")
62+
usage: EmbeddingUsage = Field(..., description="Usage metrics associated with the inference.")
63+
64+
65+
class MultiModalEmbeddingModelConfig(BaseModelConfig):
66+
"""Configuration payload for invoking a multimodal embedding model."""
67+
68+
model_type: ModelType = ModelType.TEXT_EMBEDDING
69+
tenant_id: str = Field(..., description="Vendor tenant identifier associated with the dataset.")
70+
71+
model_config = ConfigDict(protected_namespaces=())

python/dify_plugin/interfaces/model/rerank_model.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from abc import abstractmethod
2+
from collections.abc import Sequence
23

34
from dify_plugin.entities.model import ModelType
4-
from dify_plugin.entities.model.rerank import RerankResult
5+
from dify_plugin.entities.model.rerank import MultiModalRerankResult, RerankResult
6+
from dify_plugin.entities.model.text_embedding import MultiModalContent
57
from dify_plugin.interfaces.model.ai_model import AIModel
68

79

@@ -41,6 +43,23 @@ def _invoke(
4143
"""
4244
raise NotImplementedError
4345

46+
def _invoke_multimodal(
47+
self,
48+
model: str,
49+
credentials: dict,
50+
query: MultiModalContent,
51+
docs: Sequence[MultiModalContent],
52+
score_threshold: float | None = None,
53+
top_n: int | None = None,
54+
user: str | None = None,
55+
) -> MultiModalRerankResult:
56+
"""Invoke a multimodal rerank model."""
57+
58+
raise NotImplementedError(
59+
f"{self.__class__.__name__} does not implement `_invoke_multimodal`. "
60+
"Implement this method to support multimodal rerank invocations."
61+
)
62+
4463
############################################################
4564
# For executor use only #
4665
############################################################
@@ -73,3 +92,31 @@ def invoke(
7392
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
7493
except Exception as e:
7594
raise self._transform_invoke_error(e) from e
95+
96+
def invoke_multimodal(
97+
self,
98+
model: str,
99+
credentials: dict,
100+
query: MultiModalContent,
101+
docs: Sequence[MultiModalContent],
102+
score_threshold: float | None = None,
103+
top_n: int | None = None,
104+
user: str | None = None,
105+
) -> MultiModalRerankResult:
106+
"""Invoke a multimodal rerank model."""
107+
108+
with self.timing_context():
109+
try:
110+
return self._invoke_multimodal(
111+
model,
112+
credentials,
113+
query,
114+
docs,
115+
score_threshold,
116+
top_n,
117+
user,
118+
)
119+
except NotImplementedError:
120+
raise
121+
except Exception as e:
122+
raise self._transform_invoke_error(e) from e

python/dify_plugin/interfaces/model/text_embedding_model.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from pydantic import ConfigDict
44

55
from dify_plugin.entities.model import EmbeddingInputType, ModelPropertyKey, ModelType
6-
from dify_plugin.entities.model.text_embedding import TextEmbeddingResult
6+
from dify_plugin.entities.model.text_embedding import (
7+
MultiModalContent,
8+
MultiModalEmbeddingResult,
9+
TextEmbeddingResult,
10+
)
711
from dify_plugin.interfaces.model.ai_model import AIModel
812

913

@@ -42,6 +46,21 @@ def _invoke(
4246
"""
4347
raise NotImplementedError
4448

49+
def _invoke_multimodal(
50+
self,
51+
model: str,
52+
credentials: dict,
53+
documents: list[MultiModalContent],
54+
user: str | None = None,
55+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
56+
) -> MultiModalEmbeddingResult:
57+
"""Invoke a multimodal embedding model."""
58+
59+
raise NotImplementedError(
60+
f"{self.__class__.__name__} does not implement `_invoke_multimodal`. "
61+
"Implement this method to support multimodal embeddings."
62+
)
63+
4564
@abstractmethod
4665
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
4766
"""
@@ -115,3 +134,27 @@ def invoke(
115134
return self._invoke(model, credentials, texts, user, input_type)
116135
except Exception as e:
117136
raise self._transform_invoke_error(e) from e
137+
138+
def invoke_multimodal(
139+
self,
140+
model: str,
141+
credentials: dict,
142+
documents: list[MultiModalContent],
143+
user: str | None = None,
144+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
145+
) -> MultiModalEmbeddingResult:
146+
"""Invoke a multimodal embedding model."""
147+
148+
with self.timing_context():
149+
try:
150+
return self._invoke_multimodal(
151+
model,
152+
credentials,
153+
documents,
154+
user,
155+
input_type,
156+
)
157+
except NotImplementedError:
158+
raise
159+
except Exception as e:
160+
raise self._transform_invoke_error(e) from e

python/dify_plugin/invocations/model/rerank.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from dify_plugin.core.entities.invocation import InvokeType
22
from dify_plugin.core.runtime import BackwardsInvocation
3-
from dify_plugin.entities.model.rerank import RerankModelConfig, RerankResult
3+
from dify_plugin.entities.model.rerank import (
4+
MultiModalRerankModelConfig,
5+
MultiModalRerankResult,
6+
RerankModelConfig,
7+
RerankResult,
8+
)
9+
from dify_plugin.entities.model.text_embedding import MultiModalContent
410

511

612
class RerankInvocation(BackwardsInvocation[RerankResult]):
@@ -20,3 +26,24 @@ def invoke(self, model_config: RerankModelConfig, docs: list[str], query: str) -
2026
return data
2127

2228
raise Exception("No response from rerank")
29+
30+
def invoke_multimodal(
31+
self,
32+
model_config: MultiModalRerankModelConfig,
33+
query: MultiModalContent,
34+
docs: list[MultiModalContent],
35+
) -> MultiModalRerankResult:
36+
payload = {
37+
**model_config.model_dump(),
38+
"query": query.model_dump() if isinstance(query, MultiModalContent) else query,
39+
"docs": [doc.model_dump() if isinstance(doc, MultiModalContent) else doc for doc in docs],
40+
}
41+
42+
for data in self._backwards_invoke(
43+
InvokeType.MultimodalRerank,
44+
MultiModalRerankResult,
45+
payload,
46+
):
47+
return data
48+
49+
raise Exception("No response from multimodal rerank")

0 commit comments

Comments
 (0)