@@ -25,15 +25,15 @@ class BlobPipelineStorage(PipelineStorage):
2525
2626 _connection_string : str | None
2727 _container_name : str
28- _path_prefix : str
28+ _base_dir : str | None
2929 _encoding : str
3030 _storage_account_blob_url : str | None
3131
3232 def __init__ (self , ** kwargs : Any ) -> None :
3333 """Create a new BlobStorage instance."""
3434 connection_string = kwargs .get ("connection_string" )
3535 storage_account_blob_url = kwargs .get ("storage_account_blob_url" )
36- path_prefix = kwargs .get ("base_dir" )
36+ base_dir = kwargs .get ("base_dir" )
3737 container_name = kwargs ["container_name" ]
3838 if container_name is None :
3939 msg = "No container name provided for blob storage."
@@ -42,7 +42,9 @@ def __init__(self, **kwargs: Any) -> None:
4242 msg = "No storage account blob url provided for blob storage."
4343 raise ValueError (msg )
4444
45- logger .info ("Creating blob storage at %s" , container_name )
45+ logger .info (
46+ "Creating blob storage at [%s] and base_dir [%s]" , container_name , base_dir
47+ )
4648 if connection_string :
4749 self ._blob_service_client = BlobServiceClient .from_connection_string (
4850 connection_string
@@ -59,18 +61,13 @@ def __init__(self, **kwargs: Any) -> None:
5961 self ._encoding = kwargs .get ("encoding" , "utf-8" )
6062 self ._container_name = container_name
6163 self ._connection_string = connection_string
62- self ._path_prefix = path_prefix or ""
64+ self ._base_dir = base_dir
6365 self ._storage_account_blob_url = storage_account_blob_url
6466 self ._storage_account_name = (
6567 storage_account_blob_url .split ("//" )[1 ].split ("." )[0 ]
6668 if storage_account_blob_url
6769 else None
6870 )
69- logger .debug (
70- "creating blob storage at container=%s, path=%s" ,
71- self ._container_name ,
72- self ._path_prefix ,
73- )
7471 self ._create_container ()
7572
7673 def _create_container (self ) -> None :
@@ -82,6 +79,7 @@ def _create_container(self) -> None:
8279 for container in self ._blob_service_client .list_containers ()
8380 ]
8481 if container_name not in container_names :
82+ logger .debug ("Creating new container [%s]" , container_name )
8583 self ._blob_service_client .create_container (container_name )
8684
8785 def _delete_container (self ) -> None :
@@ -100,31 +98,26 @@ def _container_exists(self) -> bool:
10098 def find (
10199 self ,
102100 file_pattern : re .Pattern [str ],
103- base_dir : str | None = None ,
104- max_count = - 1 ,
105101 ) -> Iterator [str ]:
106102 """Find blobs in a container using a file pattern.
107103
108104 Params:
109- base_dir: The name of the base container.
110105 file_pattern: The file pattern to use.
111- max_count: The maximum number of blobs to return. If -1, all blobs are returned.
112106
113107 Returns
114108 -------
115109 An iterator of blob names and their corresponding regex matches.
116110 """
117- base_dir = base_dir or ""
118-
119111 logger .info (
120- "search container %s for files matching %s " ,
112+ "Search container [%s] in base_dir [%s] for files matching [%s] " ,
121113 self ._container_name ,
114+ self ._base_dir ,
122115 file_pattern .pattern ,
123116 )
124117
125118 def _blobname (blob_name : str ) -> str :
126- if blob_name .startswith (self ._path_prefix ):
127- blob_name = blob_name .replace (self ._path_prefix , "" , 1 )
119+ if self . _base_dir and blob_name .startswith (self ._base_dir ):
120+ blob_name = blob_name .replace (self ._base_dir , "" , 1 )
128121 if blob_name .startswith ("/" ):
129122 blob_name = blob_name [1 :]
130123 return blob_name
@@ -133,37 +126,35 @@ def _blobname(blob_name: str) -> str:
133126 container_client = self ._blob_service_client .get_container_client (
134127 self ._container_name
135128 )
136- all_blobs = list (container_client .list_blobs ())
137-
129+ all_blobs = list (container_client .list_blobs (self . _base_dir ))
130+ logger . debug ( "All blobs: %s" , [ blob . name for blob in all_blobs ])
138131 num_loaded = 0
139132 num_total = len (list (all_blobs ))
140133 num_filtered = 0
141134 for blob in all_blobs :
142135 match = file_pattern .search (blob .name )
143- if match and blob . name . startswith ( base_dir ) :
136+ if match :
144137 yield _blobname (blob .name )
145138 num_loaded += 1
146- if max_count > 0 and num_loaded >= max_count :
147- break
148139 else :
149140 num_filtered += 1
150- logger .debug (
151- "Blobs loaded: %d, filtered: %d, total: %d" ,
152- num_loaded ,
153- num_filtered ,
154- num_total ,
155- )
141+ logger .debug (
142+ "Blobs loaded: %d, filtered: %d, total: %d" ,
143+ num_loaded ,
144+ num_filtered ,
145+ num_total ,
146+ )
156147 except Exception : # noqa: BLE001
157148 logger .warning (
158149 "Error finding blobs: base_dir=%s, file_pattern=%s" ,
159- base_dir ,
150+ self . _base_dir ,
160151 file_pattern ,
161152 )
162153
163154 async def get (
164155 self , key : str , as_bytes : bool | None = False , encoding : str | None = None
165156 ) -> Any :
166- """Get a value from the cache ."""
157+ """Get a value from the blob ."""
167158 try :
168159 key = self ._keyname (key )
169160 container_client = self ._blob_service_client .get_container_client (
@@ -181,7 +172,7 @@ async def get(
181172 return blob_data
182173
183174 async def set (self , key : str , value : Any , encoding : str | None = None ) -> None :
184- """Set a value in the cache ."""
175+ """Set a value in the blob ."""
185176 try :
186177 key = self ._keyname (key )
187178 container_client = self ._blob_service_client .get_container_client (
@@ -196,46 +187,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
196187 except Exception :
197188 logger .exception ("Error setting key %s: %s" , key )
198189
199- def _set_df_json (self , key : str , dataframe : Any ) -> None :
200- """Set a json dataframe."""
201- if self ._connection_string is None and self ._storage_account_name :
202- dataframe .to_json (
203- self ._abfs_url (key ),
204- storage_options = {
205- "account_name" : self ._storage_account_name ,
206- "credential" : DefaultAzureCredential (),
207- },
208- orient = "records" ,
209- lines = True ,
210- force_ascii = False ,
211- )
212- else :
213- dataframe .to_json (
214- self ._abfs_url (key ),
215- storage_options = {"connection_string" : self ._connection_string },
216- orient = "records" ,
217- lines = True ,
218- force_ascii = False ,
219- )
220-
221- def _set_df_parquet (self , key : str , dataframe : Any ) -> None :
222- """Set a parquet dataframe."""
223- if self ._connection_string is None and self ._storage_account_name :
224- dataframe .to_parquet (
225- self ._abfs_url (key ),
226- storage_options = {
227- "account_name" : self ._storage_account_name ,
228- "credential" : DefaultAzureCredential (),
229- },
230- )
231- else :
232- dataframe .to_parquet (
233- self ._abfs_url (key ),
234- storage_options = {"connection_string" : self ._connection_string },
235- )
236-
237190 async def has (self , key : str ) -> bool :
238- """Check if a key exists in the cache ."""
191+ """Check if a key exists in the blob ."""
239192 key = self ._keyname (key )
240193 container_client = self ._blob_service_client .get_container_client (
241194 self ._container_name
@@ -244,7 +197,7 @@ async def has(self, key: str) -> bool:
244197 return blob_client .exists ()
245198
246199 async def delete (self , key : str ) -> None :
247- """Delete a key from the cache ."""
200+ """Delete a key from the blob ."""
248201 key = self ._keyname (key )
249202 container_client = self ._blob_service_client .get_container_client (
250203 self ._container_name
@@ -259,7 +212,7 @@ def child(self, name: str | None) -> "PipelineStorage":
259212 """Create a child storage instance."""
260213 if name is None :
261214 return self
262- path = str (Path (self ._path_prefix ) / name )
215+ path = str (Path (self ._base_dir ) / name ) if self . _base_dir else name
263216 return BlobPipelineStorage (
264217 connection_string = self ._connection_string ,
265218 container_name = self ._container_name ,
@@ -275,15 +228,10 @@ def keys(self) -> list[str]:
275228
276229 def _keyname (self , key : str ) -> str :
277230 """Get the key name."""
278- return str (Path (self ._path_prefix ) / key )
279-
280- def _abfs_url (self , key : str ) -> str :
281- """Get the ABFS URL."""
282- path = str (Path (self ._container_name ) / self ._path_prefix / key )
283- return f"abfs://{ path } "
231+ return str (Path (self ._base_dir ) / key ) if self ._base_dir else key
284232
285233 async def get_creation_date (self , key : str ) -> str :
286- """Get a value from the cache ."""
234+ """Get creation date for the blob ."""
287235 try :
288236 key = self ._keyname (key )
289237 container_client = self ._blob_service_client .get_container_client (
0 commit comments