Skip to content

Commit 6033e4f

Browse files
authored
Storage fixes and cleanup (#2118)
* Fix pipeline recursion * Remove base_dir from storage.find * Remove max_count from storage.find * Remove prefix on storage integ test * Add base_dir in creation_date test * Wrap base_dir in Path * Use constants for input/update directories
1 parent 6b03af6 commit 6033e4f

File tree

10 files changed

+86
-158
lines changed

10 files changed

+86
-158
lines changed

packages/graphrag/graphrag/config/defaults.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
EN_STOP_WORDS,
2525
)
2626

27+
DEFAULT_INPUT_BASE_DIR = "input"
2728
DEFAULT_OUTPUT_BASE_DIR = "output"
29+
DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output"
2830
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
2931
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
3032
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
@@ -229,7 +231,7 @@ class StorageDefaults:
229231
"""Default values for storage."""
230232

231233
type: ClassVar[StorageType] = StorageType.file
232-
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
234+
base_dir: str | None = None
233235
connection_string: None = None
234236
container_name: None = None
235237
storage_account_blob_url: None = None
@@ -240,7 +242,7 @@ class StorageDefaults:
240242
class InputStorageDefaults(StorageDefaults):
241243
"""Default values for input storage."""
242244

243-
base_dir: str = "input"
245+
base_dir: str | None = DEFAULT_INPUT_BASE_DIR
244246

245247

246248
@dataclass
@@ -310,7 +312,7 @@ class LocalSearchDefaults:
310312
class OutputDefaults(StorageDefaults):
311313
"""Default values for output."""
312314

313-
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
315+
base_dir: str | None = DEFAULT_OUTPUT_BASE_DIR
314316

315317

316318
@dataclass
@@ -362,7 +364,7 @@ class SummarizeDescriptionsDefaults:
362364
class UpdateIndexOutputDefaults(StorageDefaults):
363365
"""Default values for update index output."""
364366

365-
base_dir: str = "update_output"
367+
base_dir: str | None = DEFAULT_UPDATE_OUTPUT_BASE_DIR
366368

367369

368370
@dataclass

packages/graphrag/graphrag/config/models/graph_rag_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _validate_input_pattern(self) -> None:
152152
def _validate_input_base_dir(self) -> None:
153153
"""Validate the input base directory."""
154154
if self.input.storage.type == defs.StorageType.file:
155-
if self.input.storage.base_dir.strip() == "":
155+
if not self.input.storage.base_dir:
156156
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
157157
raise ValueError(msg)
158158
self.input.storage.base_dir = str(
@@ -167,14 +167,16 @@ def _validate_input_base_dir(self) -> None:
167167

168168
output: StorageConfig = Field(
169169
description="The output configuration.",
170-
default=StorageConfig(),
170+
default=StorageConfig(
171+
base_dir=graphrag_config_defaults.output.base_dir,
172+
),
171173
)
172174
"""The output configuration."""
173175

174176
def _validate_output_base_dir(self) -> None:
175177
"""Validate the output base directory."""
176178
if self.output.type == defs.StorageType.file:
177-
if self.output.base_dir.strip() == "":
179+
if not self.output.base_dir:
178180
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
179181
raise ValueError(msg)
180182
self.output.base_dir = str(
@@ -192,7 +194,7 @@ def _validate_output_base_dir(self) -> None:
192194
def _validate_update_index_output_base_dir(self) -> None:
193195
"""Validate the update index output base directory."""
194196
if self.update_index_output.type == defs.StorageType.file:
195-
if self.update_index_output.base_dir.strip() == "":
197+
if not self.update_index_output.base_dir:
196198
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
197199
raise ValueError(msg)
198200
self.update_index_output.base_dir = str(

packages/graphrag/graphrag/config/models/storage_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class StorageConfig(BaseModel):
1818
description="The storage type to use.",
1919
default=graphrag_config_defaults.storage.type,
2020
)
21-
base_dir: str = Field(
21+
base_dir: str | None = Field(
2222
description="The base directory for the output.",
2323
default=graphrag_config_defaults.storage.base_dir,
2424
)

packages/graphrag/graphrag/storage/blob_pipeline_storage.py

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ class BlobPipelineStorage(PipelineStorage):
2525

2626
_connection_string: str | None
2727
_container_name: str
28-
_path_prefix: str
28+
_base_dir: str | None
2929
_encoding: str
3030
_storage_account_blob_url: str | None
3131

3232
def __init__(self, **kwargs: Any) -> None:
3333
"""Create a new BlobStorage instance."""
3434
connection_string = kwargs.get("connection_string")
3535
storage_account_blob_url = kwargs.get("storage_account_blob_url")
36-
path_prefix = kwargs.get("base_dir")
36+
base_dir = kwargs.get("base_dir")
3737
container_name = kwargs["container_name"]
3838
if container_name is None:
3939
msg = "No container name provided for blob storage."
@@ -42,7 +42,9 @@ def __init__(self, **kwargs: Any) -> None:
4242
msg = "No storage account blob url provided for blob storage."
4343
raise ValueError(msg)
4444

45-
logger.info("Creating blob storage at %s", container_name)
45+
logger.info(
46+
"Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir
47+
)
4648
if connection_string:
4749
self._blob_service_client = BlobServiceClient.from_connection_string(
4850
connection_string
@@ -59,18 +61,13 @@ def __init__(self, **kwargs: Any) -> None:
5961
self._encoding = kwargs.get("encoding", "utf-8")
6062
self._container_name = container_name
6163
self._connection_string = connection_string
62-
self._path_prefix = path_prefix or ""
64+
self._base_dir = base_dir
6365
self._storage_account_blob_url = storage_account_blob_url
6466
self._storage_account_name = (
6567
storage_account_blob_url.split("//")[1].split(".")[0]
6668
if storage_account_blob_url
6769
else None
6870
)
69-
logger.debug(
70-
"creating blob storage at container=%s, path=%s",
71-
self._container_name,
72-
self._path_prefix,
73-
)
7471
self._create_container()
7572

7673
def _create_container(self) -> None:
@@ -82,6 +79,7 @@ def _create_container(self) -> None:
8279
for container in self._blob_service_client.list_containers()
8380
]
8481
if container_name not in container_names:
82+
logger.debug("Creating new container [%s]", container_name)
8583
self._blob_service_client.create_container(container_name)
8684

8785
def _delete_container(self) -> None:
@@ -100,31 +98,26 @@ def _container_exists(self) -> bool:
10098
def find(
10199
self,
102100
file_pattern: re.Pattern[str],
103-
base_dir: str | None = None,
104-
max_count=-1,
105101
) -> Iterator[str]:
106102
"""Find blobs in a container using a file pattern.
107103
108104
Params:
109-
base_dir: The name of the base container.
110105
file_pattern: The file pattern to use.
111-
max_count: The maximum number of blobs to return. If -1, all blobs are returned.
112106
113107
Returns
114108
-------
115109
An iterator of blob names and their corresponding regex matches.
116110
"""
117-
base_dir = base_dir or ""
118-
119111
logger.info(
120-
"search container %s for files matching %s",
112+
"Search container [%s] in base_dir [%s] for files matching [%s]",
121113
self._container_name,
114+
self._base_dir,
122115
file_pattern.pattern,
123116
)
124117

125118
def _blobname(blob_name: str) -> str:
126-
if blob_name.startswith(self._path_prefix):
127-
blob_name = blob_name.replace(self._path_prefix, "", 1)
119+
if self._base_dir and blob_name.startswith(self._base_dir):
120+
blob_name = blob_name.replace(self._base_dir, "", 1)
128121
if blob_name.startswith("/"):
129122
blob_name = blob_name[1:]
130123
return blob_name
@@ -133,37 +126,35 @@ def _blobname(blob_name: str) -> str:
133126
container_client = self._blob_service_client.get_container_client(
134127
self._container_name
135128
)
136-
all_blobs = list(container_client.list_blobs())
137-
129+
all_blobs = list(container_client.list_blobs(self._base_dir))
130+
logger.debug("All blobs: %s", [blob.name for blob in all_blobs])
138131
num_loaded = 0
139132
num_total = len(list(all_blobs))
140133
num_filtered = 0
141134
for blob in all_blobs:
142135
match = file_pattern.search(blob.name)
143-
if match and blob.name.startswith(base_dir):
136+
if match:
144137
yield _blobname(blob.name)
145138
num_loaded += 1
146-
if max_count > 0 and num_loaded >= max_count:
147-
break
148139
else:
149140
num_filtered += 1
150-
logger.debug(
151-
"Blobs loaded: %d, filtered: %d, total: %d",
152-
num_loaded,
153-
num_filtered,
154-
num_total,
155-
)
141+
logger.debug(
142+
"Blobs loaded: %d, filtered: %d, total: %d",
143+
num_loaded,
144+
num_filtered,
145+
num_total,
146+
)
156147
except Exception: # noqa: BLE001
157148
logger.warning(
158149
"Error finding blobs: base_dir=%s, file_pattern=%s",
159-
base_dir,
150+
self._base_dir,
160151
file_pattern,
161152
)
162153

163154
async def get(
164155
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
165156
) -> Any:
166-
"""Get a value from the cache."""
157+
"""Get a value from the blob."""
167158
try:
168159
key = self._keyname(key)
169160
container_client = self._blob_service_client.get_container_client(
@@ -181,7 +172,7 @@ async def get(
181172
return blob_data
182173

183174
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
184-
"""Set a value in the cache."""
175+
"""Set a value in the blob."""
185176
try:
186177
key = self._keyname(key)
187178
container_client = self._blob_service_client.get_container_client(
@@ -196,46 +187,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
196187
except Exception:
197188
logger.exception("Error setting key %s: %s", key)
198189

199-
def _set_df_json(self, key: str, dataframe: Any) -> None:
200-
"""Set a json dataframe."""
201-
if self._connection_string is None and self._storage_account_name:
202-
dataframe.to_json(
203-
self._abfs_url(key),
204-
storage_options={
205-
"account_name": self._storage_account_name,
206-
"credential": DefaultAzureCredential(),
207-
},
208-
orient="records",
209-
lines=True,
210-
force_ascii=False,
211-
)
212-
else:
213-
dataframe.to_json(
214-
self._abfs_url(key),
215-
storage_options={"connection_string": self._connection_string},
216-
orient="records",
217-
lines=True,
218-
force_ascii=False,
219-
)
220-
221-
def _set_df_parquet(self, key: str, dataframe: Any) -> None:
222-
"""Set a parquet dataframe."""
223-
if self._connection_string is None and self._storage_account_name:
224-
dataframe.to_parquet(
225-
self._abfs_url(key),
226-
storage_options={
227-
"account_name": self._storage_account_name,
228-
"credential": DefaultAzureCredential(),
229-
},
230-
)
231-
else:
232-
dataframe.to_parquet(
233-
self._abfs_url(key),
234-
storage_options={"connection_string": self._connection_string},
235-
)
236-
237190
async def has(self, key: str) -> bool:
238-
"""Check if a key exists in the cache."""
191+
"""Check if a key exists in the blob."""
239192
key = self._keyname(key)
240193
container_client = self._blob_service_client.get_container_client(
241194
self._container_name
@@ -244,7 +197,7 @@ async def has(self, key: str) -> bool:
244197
return blob_client.exists()
245198

246199
async def delete(self, key: str) -> None:
247-
"""Delete a key from the cache."""
200+
"""Delete a key from the blob."""
248201
key = self._keyname(key)
249202
container_client = self._blob_service_client.get_container_client(
250203
self._container_name
@@ -259,7 +212,7 @@ def child(self, name: str | None) -> "PipelineStorage":
259212
"""Create a child storage instance."""
260213
if name is None:
261214
return self
262-
path = str(Path(self._path_prefix) / name)
215+
path = str(Path(self._base_dir) / name) if self._base_dir else name
263216
return BlobPipelineStorage(
264217
connection_string=self._connection_string,
265218
container_name=self._container_name,
@@ -275,15 +228,10 @@ def keys(self) -> list[str]:
275228

276229
def _keyname(self, key: str) -> str:
277230
"""Get the key name."""
278-
return str(Path(self._path_prefix) / key)
279-
280-
def _abfs_url(self, key: str) -> str:
281-
"""Get the ABFS URL."""
282-
path = str(Path(self._container_name) / self._path_prefix / key)
283-
return f"abfs://{path}"
231+
return str(Path(self._base_dir) / key) if self._base_dir else key
284232

285233
async def get_creation_date(self, key: str) -> str:
286-
"""Get a value from the cache."""
234+
"""Get creation date for the blob."""
287235
try:
288236
key = self._keyname(key)
289237
container_client = self._blob_service_client.get_container_client(

packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, **kwargs: Any) -> None:
7777
)
7878
self._no_id_prefixes = []
7979
logger.debug(
80-
"creating cosmosdb storage with account: %s and database: %s and container: %s",
80+
"Creating cosmosdb storage with account [%s] and database [%s] and container [%s]",
8181
self._cosmosdb_account_name,
8282
self._database_name,
8383
self._container_name,
@@ -120,23 +120,18 @@ def _delete_container(self) -> None:
120120
def find(
121121
self,
122122
file_pattern: re.Pattern[str],
123-
base_dir: str | None = None,
124-
max_count=-1,
125123
) -> Iterator[str]:
126124
"""Find documents in a Cosmos DB container using a file pattern regex.
127125
128126
Params:
129-
base_dir: The name of the base directory (not used in Cosmos DB context).
130127
file_pattern: The file pattern to use.
131-
max_count: The maximum number of documents to return. If -1, all documents are returned.
132128
133129
Returns
134130
-------
135131
An iterator of document IDs and their corresponding regex matches.
136132
"""
137-
base_dir = base_dir or ""
138133
logger.info(
139-
"search container %s for documents matching %s",
134+
"Search container [%s] for documents matching [%s]",
140135
self._container_name,
141136
file_pattern.pattern,
142137
)
@@ -156,6 +151,7 @@ def find(
156151
enable_cross_partition_query=True,
157152
)
158153
)
154+
logger.debug("All items: %s", [item["id"] for item in items])
159155
num_loaded = 0
160156
num_total = len(items)
161157
if num_total == 0:
@@ -166,20 +162,18 @@ def find(
166162
if match:
167163
yield item["id"]
168164
num_loaded += 1
169-
if max_count > 0 and num_loaded >= max_count:
170-
break
171165
else:
172166
num_filtered += 1
173167

174-
progress_status = _create_progress_status(
175-
num_loaded, num_filtered, num_total
176-
)
177-
logger.debug(
178-
"Progress: %s (%d/%d completed)",
179-
progress_status.description,
180-
progress_status.completed_items,
181-
progress_status.total_items,
182-
)
168+
progress_status = _create_progress_status(
169+
num_loaded, num_filtered, num_total
170+
)
171+
logger.debug(
172+
"Progress: %s (%d/%d completed)",
173+
progress_status.description,
174+
progress_status.completed_items,
175+
progress_status.total_items,
176+
)
183177
except Exception: # noqa: BLE001
184178
logger.warning(
185179
"An error occurred while searching for documents in Cosmos DB."

0 commit comments

Comments
 (0)