Skip to content

Commit 7b0a5f3

Browse files
authored
Merge pull request #29 from jkawamoto/structured-output
Refactor YouTube transcript retrieval to structured model
2 parents 2808172 + b1d6f02 commit 7b0a5f3

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

src/mcp_youtube_transcript/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
1010
from functools import lru_cache, partial
11-
from typing import AsyncIterator
11+
from typing import AsyncIterator, Tuple
1212
from typing import Final
1313
from urllib.parse import urlparse, parse_qs
1414

1515
import requests
1616
from bs4 import BeautifulSoup
1717
from mcp.server import FastMCP
1818
from mcp.server.fastmcp import Context
19-
from pydantic import Field
19+
from pydantic import Field, BaseModel
2020
from youtube_transcript_api import YouTubeTranscriptApi
2121
from youtube_transcript_api.proxies import WebshareProxyConfig, GenericProxyConfig, ProxyConfig
2222

@@ -35,7 +35,7 @@ async def _app_lifespan(_server: FastMCP, proxy_config: ProxyConfig | None) -> A
3535

3636

3737
@lru_cache
38-
def _get_transcript(ctx: AppContext, video_id: str, lang: str) -> str:
38+
def _get_transcript(ctx: AppContext, video_id: str, lang: str) -> Tuple[str, str]:
3939
if lang == "en":
4040
languages = ["en"]
4141
else:
@@ -46,11 +46,17 @@ def _get_transcript(ctx: AppContext, video_id: str, lang: str) -> str:
4646
)
4747
page.raise_for_status()
4848
soup = BeautifulSoup(page.text, "html.parser")
49-
title = soup.title.string if soup.title else "Transcript"
49+
title = soup.title.string if soup.title and soup.title.string else "Transcript"
5050

5151
transcripts = ctx.ytt_api.fetch(video_id, languages=languages)
52+
return title, "\n".join((item.text for item in transcripts))
5253

53-
return f"# {title}\n" + "\n".join((item.text for item in transcripts))
54+
55+
class Transcript(BaseModel):
56+
"""Transcript of a YouTube video."""
57+
58+
title: str = Field(description="Title of the video")
59+
transcript: str = Field(description="Transcript of the video")
5460

5561

5662
def server(
@@ -74,7 +80,7 @@ async def get_transcript(
7480
ctx: Context,
7581
url: str = Field(description="The URL of the YouTube video"),
7682
lang: str = Field(description="The preferred language for the transcript", default="en"),
77-
) -> str:
83+
) -> Transcript:
7884
"""Retrieves the transcript of a YouTube video."""
7985
parsed_url = urlparse(url)
8086
if parsed_url.hostname == "youtu.be":
@@ -86,9 +92,10 @@ async def get_transcript(
8692
video_id = q[0]
8793

8894
app_ctx: AppContext = ctx.request_context.lifespan_context # type: ignore
89-
return _get_transcript(app_ctx, video_id, lang)
95+
title, transcript = _get_transcript(app_ctx, video_id, lang)
96+
return Transcript(title=title, transcript=transcript)
9097

9198
return mcp
9299

93100

94-
__all__: Final = ["server"]
101+
__all__: Final = ["server", "Transcript"]

tests/test_mcp.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from mcp.types import TextContent
1616
from youtube_transcript_api import YouTubeTranscriptApi
1717

18+
from mcp_youtube_transcript import Transcript
19+
1820
params = StdioServerParameters(command="uv", args=["run", "mcp-youtube-transcript"])
1921

2022

@@ -44,14 +46,18 @@ async def test_get_transcript(mcp_client_session: ClientSession) -> None:
4446
video_id = "LPZh9BOjkQs"
4547

4648
title = fetch_title(video_id, "en")
47-
expect = f"# {title}\n" + "\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
49+
expect = Transcript(
50+
title=title, transcript="\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
51+
)
4852

4953
res = await mcp_client_session.call_tool(
5054
"get_transcript",
5155
arguments={"url": f"https//www.youtube.com/watch?v={video_id}"},
5256
)
5357
assert isinstance(res.content[0], TextContent)
54-
assert res.content[0].text == expect
58+
59+
transcript = Transcript.model_validate_json(res.content[0].text)
60+
assert transcript == expect
5561
assert not res.isError
5662

5763

@@ -61,14 +67,18 @@ async def test_get_transcript_with_language(mcp_client_session: ClientSession) -
6167
video_id = "WjAXZkQSE2U"
6268

6369
title = fetch_title(video_id, "ja")
64-
expect = f"# {title}\n" + "\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id, ["ja"])))
70+
expect = Transcript(
71+
title=title, transcript="\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id, ["ja"])))
72+
)
6573

6674
res = await mcp_client_session.call_tool(
6775
"get_transcript",
6876
arguments={"url": f"https//www.youtube.com/watch?v={video_id}", "lang": "ja"},
6977
)
7078
assert isinstance(res.content[0], TextContent)
71-
assert res.content[0].text == expect
79+
80+
transcript = Transcript.model_validate_json(res.content[0].text)
81+
assert transcript == expect
7282
assert not res.isError
7383

7484

@@ -80,7 +90,9 @@ async def test_get_transcript_fallback_language(
8090
video_id = "LPZh9BOjkQs"
8191

8292
title = fetch_title(video_id, "en")
83-
expect = f"# {title}\n" + "\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
93+
expect = Transcript(
94+
title=title, transcript="\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
95+
)
8496

8597
res = await mcp_client_session.call_tool(
8698
"get_transcript",
@@ -90,7 +102,9 @@ async def test_get_transcript_fallback_language(
90102
},
91103
)
92104
assert isinstance(res.content[0], TextContent)
93-
assert res.content[0].text == expect
105+
106+
transcript = Transcript.model_validate_json(res.content[0].text)
107+
assert transcript == expect
94108
assert not res.isError
95109

96110

@@ -115,12 +129,16 @@ async def test_get_transcript_with_short_url(mcp_client_session: ClientSession)
115129
video_id = "LPZh9BOjkQs"
116130

117131
title = fetch_title(video_id, "en")
118-
expect = f"# {title}\n" + "\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
132+
expect = Transcript(
133+
title=title, transcript="\n".join((item.text for item in YouTubeTranscriptApi().fetch(video_id)))
134+
)
119135

120136
res = await mcp_client_session.call_tool(
121137
"get_transcript",
122138
arguments={"url": f"https://youtu.be/{video_id}"},
123139
)
124140
assert isinstance(res.content[0], TextContent)
125-
assert res.content[0].text == expect
141+
142+
transcript = Transcript.model_validate_json(res.content[0].text)
143+
assert transcript == expect
126144
assert not res.isError

0 commit comments

Comments
 (0)