Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/elasticsearch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ A caching layer for LLMs that uses Elasticsearch.
Simple example:

```python
from langchain.globals import set_llm_cache
from langchain_core.globals import set_llm_cache

from langchain_elasticsearch import ElasticsearchCache

Expand Down Expand Up @@ -151,7 +151,7 @@ The new cache class can be applied also to a pre-existing cache index:
import json
from typing import Any, Dict, List

from langchain.globals import set_llm_cache
from langchain_core.globals import set_llm_cache
from langchain_core.caches import RETURN_VAL_TYPE

from langchain_elasticsearch import ElasticsearchCache
Expand Down
22 changes: 6 additions & 16 deletions libs/elasticsearch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-elasticsearch"
version = "0.4.0"
version = "0.5.0"
description = "An integration package connecting Elasticsearch and LangChain"
authors = []
readme = "README.md"
Expand All @@ -12,8 +12,8 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
langchain-core = "^0.3.0"
elasticsearch = {version = ">=8.16.0,<9.0.0", extras = ["vectorstore_mmr"]}
langchain-core = ">=0.3.0,<2.0.0"
elasticsearch = {version = ">=8.16.0,<9.0.0", extras = ["vectorstore_mmr"]} # Pin <9.0 to match ES 8.x server

[tool.poetry.group.test]
optional = true
Expand All @@ -24,8 +24,9 @@ freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain = ">=0.3.10,<1.0.0"
pytest-asyncio = "^0.23.0"
langchain = ">=1.0.0,<2.0.0"
langchain-classic = ">=1.0.0,<2.0.0"
aiohttp = "^3.8.3"

[tool.poetry.group.codespell]
Expand Down Expand Up @@ -74,18 +75,7 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
Expand Down
49 changes: 46 additions & 3 deletions libs/elasticsearch/tests/integration_tests/_async/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import AsyncGenerator, Dict, Union
import json
from typing import Any, AsyncGenerator, Dict, List, Union

import pytest
from elasticsearch.helpers import BulkIndexError
from langchain.embeddings.cache import _value_serializer
from langchain.globals import set_llm_cache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models import BaseChatModel

from langchain_elasticsearch import (
Expand All @@ -14,6 +14,49 @@
from ._test_utilities import clear_test_indices, create_es_client, read_env


def _value_serializer(value: List[float]) -> bytes:
"""Serialize embedding values to bytes (replaces private langchain function)."""
return json.dumps(value).encode()


@pytest.fixture(autouse=True)
async def _close_async_caches(
monkeypatch: pytest.MonkeyPatch,
) -> AsyncGenerator[None, None]:
"""Ensure cache clients close cleanly to avoid aiohttp warnings."""
created_clients: List = []

original_cache_init = AsyncElasticsearchCache.__init__
original_store_init = AsyncElasticsearchEmbeddingsCache.__init__

def wrapped_cache_init(self, *args: Any, **kwargs: Any) -> None:
original_cache_init(self, *args, **kwargs)
created_clients.append(self._es_client)

def wrapped_store_init(self, *args: Any, **kwargs: Any) -> None:
original_store_init(self, *args, **kwargs)
created_clients.append(self._es_client)

monkeypatch.setattr(AsyncElasticsearchCache, "__init__", wrapped_cache_init)
monkeypatch.setattr(
AsyncElasticsearchEmbeddingsCache, "__init__", wrapped_store_init
)
try:
yield
finally:
for client in created_clients:
close = getattr(client, "close", None)
if close:
try:
await close()
except Exception:
pass
monkeypatch.setattr(AsyncElasticsearchCache, "__init__", original_cache_init)
monkeypatch.setattr(
AsyncElasticsearchEmbeddingsCache, "__init__", original_store_init
)


@pytest.fixture
async def es_env_fx() -> Union[dict, AsyncGenerator]:
params = read_env()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from typing import AsyncIterator

import pytest
from langchain.memory import ConversationBufferMemory
from elasticsearch import AsyncElasticsearch
from langchain_core.messages import AIMessage, HumanMessage, message_to_dict
from langchain_classic.memory import ConversationBufferMemory

pytestmark = [
pytest.mark.filterwarnings(
"ignore:Please see the migration guide.*:langchain_core._api.deprecation.LangChainDeprecationWarning"
)
]

from langchain_elasticsearch.chat_history import AsyncElasticsearchChatMessageHistory

Expand All @@ -23,11 +30,11 @@

class TestElasticsearch:
@pytest.fixture
async def elasticsearch_connection(self) -> AsyncIterator[dict]:
async def elasticsearch_connection(self) -> AsyncIterator[AsyncElasticsearch]:
params = read_env()
es = create_es_client(params)

yield params
yield es

await clear_test_indices(es)
await es.close()
Expand All @@ -38,12 +45,14 @@ def index_name(self) -> str:
return f"test_{uuid.uuid4().hex}"

async def test_memory_with_message_store(
self, elasticsearch_connection: dict, index_name: str
self, elasticsearch_connection: AsyncElasticsearch, index_name: str
) -> None:
"""Test the memory with a message store."""
# setup Elasticsearch as a message store
message_history = AsyncElasticsearchChatMessageHistory(
**elasticsearch_connection, index=index_name, session_id="test-session"
es_connection=elasticsearch_connection,
index=index_name,
session_id="test-session",
)

memory = ConversationBufferMemory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def test_user_agent_header(
), f"The string '{user_agent}' does not match the expected pattern."

await index_test_data(es_client, index_name, "text")
await retriever.aget_relevant_documents("foo")
await retriever.ainvoke("foo")

search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
user_agent = search_request["headers"]["User-Agent"]
Expand Down Expand Up @@ -133,7 +133,7 @@ def body_func(query: str) -> Dict:
)

await index_test_data(es_client, index_name, text_field)
result = await retriever.aget_relevant_documents("foo")
result = await retriever.ainvoke("foo")

assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
Expand Down Expand Up @@ -171,7 +171,7 @@ def body_func(query: str) -> Dict:

await index_test_data(es_client, index_name_1, text_field_1)
await index_test_data(es_client, index_name_2, text_field_2)
result = await retriever.aget_relevant_documents("foo")
result = await retriever.ainvoke("foo")

# matches from both indices
assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [
Expand Down Expand Up @@ -206,7 +206,7 @@ def id_as_content(hit: Mapping[str, Any]) -> Document:
)

await index_test_data(es_client, index_name, text_field)
result = await retriever.aget_relevant_documents("foo")
result = await retriever.ainvoke("foo")

assert [r.page_content for r in result] == ["3", "1", "5"]
assert [r.metadata for r in result] == [meta, meta, meta]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from ._test_utilities import clear_test_indices, create_es_client, read_env

logging.basicConfig(level=logging.DEBUG)
pytestmark = [
pytest.mark.filterwarnings(
"ignore:Deprecated field \\[rank\\] used, replaced by \\[retriever\\]:elasticsearch.ElasticsearchWarning"
)
]

"""
cd tests/integration_tests
Expand All @@ -27,6 +32,28 @@


class TestElasticsearch:
@pytest.fixture(autouse=True)
async def _close_async_stores(self, monkeypatch: pytest.MonkeyPatch) -> AsyncIterator[None]:
created: list[AsyncElasticsearchStore] = []
original_init = AsyncElasticsearchStore.__init__

def wrapped_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[misc]
original_init(self, *args, **kwargs)
created.append(self)

monkeypatch.setattr(AsyncElasticsearchStore, "__init__", wrapped_init)
try:
yield
finally:
for store in created:
aclose = getattr(store, "aclose", None)
if aclose:
try:
await aclose()
except Exception:
pass
monkeypatch.setattr(AsyncElasticsearchStore, "__init__", original_init)

@pytest.fixture
async def es_params(self) -> AsyncIterator[dict]:
params = read_env()
Expand Down Expand Up @@ -104,7 +131,7 @@ async def test_search_with_relevance_threshold(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": similarity_of_second_ranked},
)
output = await retriever.aget_relevant_documents(query=query_string)
output = await retriever.ainvoke(query_string)

assert output == [
top3[0][0],
Expand Down Expand Up @@ -145,7 +172,7 @@ async def test_search_by_vector_with_relevance_threshold(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": similarity_of_second_ranked},
)
output = await retriever.aget_relevant_documents(query=query_string)
output = await retriever.ainvoke(query_string)

assert output == [
top3[0][0],
Expand Down Expand Up @@ -1081,7 +1108,7 @@ async def test_elasticsearch_with_relevance_threshold(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": similarity_of_second_ranked},
)
output = await retriever.aget_relevant_documents(query=query_string)
output = await retriever.ainvoke(query_string)

assert output == [
top3[0][0],
Expand Down
12 changes: 9 additions & 3 deletions libs/elasticsearch/tests/integration_tests/_sync/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from typing import Dict, Generator, Union
import json
from typing import Dict, Generator, List, Union

import pytest
from elasticsearch.helpers import BulkIndexError
from langchain.embeddings.cache import _value_serializer
from langchain.globals import set_llm_cache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models import BaseChatModel

from langchain_elasticsearch import (
ElasticsearchCache,
ElasticsearchEmbeddingsCache,
)


from ._test_utilities import clear_test_indices, create_es_client, read_env


def _value_serializer(value: List[float]) -> bytes:
"""Serialize embedding values to bytes (replaces private langchain function)."""
return json.dumps(value).encode()


@pytest.fixture
def es_env_fx() -> Union[dict, Generator]:
params = read_env()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from typing import Iterator

import pytest
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage, message_to_dict
from langchain_classic.memory import ConversationBufferMemory

pytestmark = [
pytest.mark.filterwarnings(
"ignore:Please see the migration guide.*:langchain_core._api.deprecation.LangChainDeprecationWarning"
)
]

from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> N
), f"The string '{user_agent}' does not match the expected pattern."

index_test_data(es_client, index_name, "text")
retriever.get_relevant_documents("foo")
retriever.invoke("foo")

search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
user_agent = search_request["headers"]["User-Agent"]
Expand Down Expand Up @@ -127,7 +127,7 @@ def body_func(query: str) -> Dict:
)

index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
result = retriever.invoke("foo")

assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
Expand Down Expand Up @@ -165,7 +165,7 @@ def body_func(query: str) -> Dict:

index_test_data(es_client, index_name_1, text_field_1)
index_test_data(es_client, index_name_2, text_field_2)
result = retriever.get_relevant_documents("foo")
result = retriever.invoke("foo")

# matches from both indices
assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [
Expand Down Expand Up @@ -198,7 +198,7 @@ def id_as_content(hit: Mapping[str, Any]) -> Document:
)

index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
result = retriever.invoke("foo")

assert [r.page_content for r in result] == ["3", "1", "5"]
assert [r.metadata for r in result] == [meta, meta, meta]
Expand Down
Loading