Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions src/murfey/server/api/session_shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from typing import Dict, List

Expand Down Expand Up @@ -136,11 +137,48 @@ def get_foil_hole(session_id: int, fh_name: int, db) -> Dict[str, int]:
return {f[1].tag: f[0].id for f in foil_holes}


def find_upstream_visits(session_id: int, db: SQLModelSession):
def find_upstream_visits(session_id: int, db: SQLModelSession, max_depth: int = 2):
"""
Returns a nested dictionary, in which visits and the full paths to their directories
are further grouped by instrument name.
"""

def _recursive_search(
dirpath: str | Path,
search_string: str,
partial_match: bool = True,
max_depth: int = 1,
result: dict[str, Path] | None = None,
):
# If no dictionary was passed in, create a new dictionary
if result is None:
result = {}
# Stop recursing for this route once max depth hits 0
if max_depth == 0:
return result

# Walk through the directories
for entry in os.scandir(dirpath):
if entry.is_dir():
# Update dictionary with match and stop recursing for this route
if (
search_string in entry.name
if partial_match
else search_string == entry.name
):
if result is not None: # MyPy needs this 'is not None' check
result[entry.name] = Path(entry.path)
else:
# Continue searching down this route until max depth is reached
result = _recursive_search(
dirpath=entry.path,
search_string=search_string,
partial_match=partial_match,
max_depth=max_depth - 1,
result=result,
)
return result

murfey_session = db.exec(
select(MurfeySession).where(MurfeySession.id == session_id)
).one()
Expand All @@ -155,12 +193,13 @@ def find_upstream_visits(session_id: int, db: SQLModelSession):
upstream_instrument,
upstream_data_dir,
) in machine_config.upstream_data_directories.items():
# Looks for visit name in file path
current_upstream_visits = {}
for visit_path in Path(upstream_data_dir).glob(f"{visit_name.split('-')[0]}-*"):
if visit_path.is_dir():
current_upstream_visits[visit_path.name] = visit_path
upstream_visits[upstream_instrument] = current_upstream_visits
# Recursively look for matching visit names under current directory
upstream_visits[upstream_instrument] = _recursive_search(
dirpath=upstream_data_dir,
search_string=f"{visit_name.split('-')[0]}-",
partial_match=True,
max_depth=max_depth,
)
return upstream_visits


Expand Down
8 changes: 6 additions & 2 deletions tests/server/api/test_session_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from tests.conftest import ExampleVisit


@pytest.mark.parametrize("recurse", (True, False))
def test_find_upstream_visits(
mocker: MockerFixture,
tmp_path: Path,
# murfey_db_session,
recurse: bool,
):
# Get the visit, instrument name, and session ID
visit_name_root = f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}"
Expand Down Expand Up @@ -40,7 +41,10 @@ def test_find_upstream_visits(
# Only directories should be picked up
upstream_visit.mkdir(parents=True, exist_ok=True)
upstream_visits[upstream_instrument] = {upstream_visit.stem: upstream_visit}
upstream_data_dirs[upstream_instrument] = upstream_visit.parent
# Check that the function can cope with recursive searching
upstream_data_dirs[upstream_instrument] = (
upstream_visit.parent.parent if recurse else upstream_visit.parent
)
else:
upstream_visit.parent.mkdir(parents=True, exist_ok=True)
upstream_visit.touch(exist_ok=True)
Expand Down