diff --git a/src/murfey/server/api/session_shared.py b/src/murfey/server/api/session_shared.py index 644b9b35..3a9955fa 100644 --- a/src/murfey/server/api/session_shared.py +++ b/src/murfey/server/api/session_shared.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path from typing import Dict, List @@ -136,11 +137,48 @@ def get_foil_hole(session_id: int, fh_name: int, db) -> Dict[str, int]: return {f[1].tag: f[0].id for f in foil_holes} -def find_upstream_visits(session_id: int, db: SQLModelSession): +def find_upstream_visits(session_id: int, db: SQLModelSession, max_depth: int = 2): """ Returns a nested dictionary, in which visits and the full paths to their directories are further grouped by instrument name. """ + + def _recursive_search( + dirpath: str | Path, + search_string: str, + partial_match: bool = True, + max_depth: int = 1, + result: dict[str, Path] | None = None, + ): + # If no dictionary was passed in, create a new dictionary + if result is None: + result = {} + # Stop recursing for this route once max depth hits 0 + if max_depth == 0: + return result + + # Walk through the directories + for entry in os.scandir(dirpath): + if entry.is_dir(): + # Update dictionary with match and stop recursing for this route + if ( + search_string in entry.name + if partial_match + else search_string == entry.name + ): + if result is not None: # MyPy needs this 'is not None' check + result[entry.name] = Path(entry.path) + else: + # Continue searching down this route until max depth is reached + result = _recursive_search( + dirpath=entry.path, + search_string=search_string, + partial_match=partial_match, + max_depth=max_depth - 1, + result=result, + ) + return result + murfey_session = db.exec( select(MurfeySession).where(MurfeySession.id == session_id) ).one() @@ -155,12 +193,13 @@ def find_upstream_visits(session_id: int, db: SQLModelSession): upstream_instrument, upstream_data_dir, ) in machine_config.upstream_data_directories.items(): - # Looks for visit name in file path - current_upstream_visits = {} - for visit_path in Path(upstream_data_dir).glob(f"{visit_name.split('-')[0]}-*"): - if visit_path.is_dir(): - current_upstream_visits[visit_path.name] = visit_path - upstream_visits[upstream_instrument] = current_upstream_visits + # Recursively look for matching visit names under current directory + upstream_visits[upstream_instrument] = _recursive_search( + dirpath=upstream_data_dir, + search_string=f"{visit_name.split('-')[0]}-", + partial_match=True, + max_depth=max_depth, + ) return upstream_visits diff --git a/tests/server/api/test_session_shared.py b/tests/server/api/test_session_shared.py index c0f4c109..67db321d 100644 --- a/tests/server/api/test_session_shared.py +++ b/tests/server/api/test_session_shared.py @@ -9,10 +9,11 @@ from tests.conftest import ExampleVisit +@pytest.mark.parametrize("recurse", (True, False)) def test_find_upstream_visits( mocker: MockerFixture, tmp_path: Path, - # murfey_db_session, + recurse: bool, ): # Get the visit, instrument name, and session ID visit_name_root = f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}" @@ -40,7 +41,10 @@ def test_find_upstream_visits( # Only directories should be picked up upstream_visit.mkdir(parents=True, exist_ok=True) upstream_visits[upstream_instrument] = {upstream_visit.stem: upstream_visit} - upstream_data_dirs[upstream_instrument] = upstream_visit.parent + # Check that the function can cope with recursive searching + upstream_data_dirs[upstream_instrument] = ( + upstream_visit.parent.parent if recurse else upstream_visit.parent + ) else: upstream_visit.parent.mkdir(parents=True, exist_ok=True) upstream_visit.touch(exist_ok=True)