Skip to content

Commit eb7cc1e

Browse files
Kravchiemeta-codesync[bot]
authored andcommitted
Add Llama API endpoint
Reviewed By: laurendeason, kwiha Differential Revision: D86491153 fbshipit-source-id: 20e49c303358bcc230ea3003f9e53da6c31b55a6
1 parent 48a199c commit eb7cc1e

File tree

3 files changed

+298
-1
lines changed

3 files changed

+298
-1
lines changed

CybersecurityBenchmarks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ export DATASETS=$PWD/CybersecurityBenchmarks/datasets
119119

120120
Each benchmark can run tests for multiple LLMs. Our command line interface uses
121121
the format `<PROVIDER>::<MODEL>::<API KEY>` to specify an LLM to test. We
122-
currently support APIs from OPENAI, and TOGETHER. For OpenAI compatible endpoints,
122+
currently support APIs from OPENAI, GOOGLE, ANTHROPIC, LLAMA and TOGETHER. For OpenAI compatible endpoints,
123123
you can also specify a custom base URL by using this format: `<PROVIDER>::<MODEL>::<API KEY>::<BASE URL>`.
124124
The followings are a few examples:
125125

CybersecurityBenchmarks/benchmark/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
RetriesExceededException,
2525
UnretriableQueryException,
2626
)
27+
from .llms.meta import LLAMA
2728
from .llms.openai import OPENAI
2829
from .llms.together import TOGETHER
2930

@@ -84,5 +85,7 @@ def create(
8485
return ANTHROPIC(config)
8586
if provider == "GOOGLEGENAI":
8687
return GOOGLEGENAI(config)
88+
if provider == "LLAMA":
89+
return LLAMA(config)
8790

8891
raise ValueError(f"Unknown provider: {provider}")
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from __future__ import annotations
9+
10+
import json
11+
import logging
12+
from typing import Any, Dict, List, Optional
13+
14+
import openai
15+
from typing_extensions import override
16+
17+
from ..benchmark_utils import image_to_b64
18+
from .llm_base import (
19+
DEFAULT_MAX_TOKENS,
20+
DEFAULT_TEMPERATURE,
21+
DEFAULT_TOP_P,
22+
LLM,
23+
LLMConfig,
24+
UnretriableQueryException,
25+
)
26+
27+
LOG: logging.Logger = logging.getLogger(__name__)
28+
29+
30+
class LLAMA(LLM):
31+
"""Accessing Meta Llama API"""
32+
33+
def __init__(self, config: LLMConfig) -> None:
34+
super().__init__(config)
35+
self.client = openai.OpenAI(
36+
api_key=self.api_key,
37+
base_url="https://api.llama.com/v1",
38+
)
39+
40+
def _extract_response(self, response: Any) -> str:
41+
"""
42+
Helper method to extract content from Meta Llama API response.
43+
44+
Meta's API returns a different structure:
45+
response.completion_message = {
46+
(...)
47+
'content': {'type': 'text', 'text': 'actual response text'}
48+
}
49+
"""
50+
text = ""
51+
52+
try:
53+
if hasattr(response, "completion_message") and response.completion_message:
54+
completion_msg = response.completion_message
55+
56+
# Extract the text content
57+
content = completion_msg.get("content", {})
58+
text = content.get("text")
59+
60+
except AttributeError as e:
61+
raise UnretriableQueryException(
62+
f"Unexpected response structure from Meta Llama API: {e}. Response: {response}"
63+
)
64+
65+
if text is None or text == "":
66+
raise ValueError("Extracted response is empty.")
67+
68+
return text
69+
70+
def _build_response_format(
71+
self, guided_decode_json_schema: str
72+
) -> Optional[Dict[str, Any]]:
73+
"""
74+
Build response_format for Meta Llama API.
75+
76+
Meta uses a different format than OpenAI:
77+
{
78+
"type": "json_schema",
79+
"json_schema": {
80+
"name": "ResponseSchema",
81+
"schema": { ... actual JSON schema ... }
82+
}
83+
}
84+
"""
85+
86+
try:
87+
# Parse the JSON schema if it is a string
88+
if isinstance(guided_decode_json_schema, str):
89+
schema = json.loads(guided_decode_json_schema)
90+
else:
91+
schema = guided_decode_json_schema
92+
93+
# Build Meta's required format
94+
return {
95+
"type": "json_schema",
96+
"json_schema": {"name": "ResponseSchema", "schema": schema},
97+
}
98+
except json.JSONDecodeError as e:
99+
LOG.warning(
100+
f"Failed to parse JSON schema: {e}. Proceeding without response_format."
101+
)
102+
return None
103+
104+
@override
105+
def chat(
106+
self,
107+
prompt_with_history: List[str],
108+
guided_decode_json_schema: Optional[str] = None,
109+
temperature: float = DEFAULT_TEMPERATURE,
110+
top_p: float = DEFAULT_TOP_P,
111+
) -> str:
112+
messages = []
113+
for i in range(len(prompt_with_history)):
114+
if i % 2 == 0:
115+
messages.append({"role": "user", "content": prompt_with_history[i]})
116+
else:
117+
messages.append(
118+
{"role": "assistant", "content": prompt_with_history[i]}
119+
)
120+
121+
params: Dict[str, Any] = {
122+
"model": self.model,
123+
"messages": messages,
124+
"max_tokens": DEFAULT_MAX_TOKENS,
125+
"temperature": temperature,
126+
"top_p": top_p,
127+
}
128+
129+
if guided_decode_json_schema is not None:
130+
response_format = self._build_response_format(guided_decode_json_schema)
131+
if response_format is not None:
132+
params["response_format"] = response_format
133+
response = self.client.chat.completions.create(**params)
134+
135+
return self._extract_response(response)
136+
137+
@override
138+
def chat_with_system_prompt(
139+
self,
140+
system_prompt: str,
141+
prompt_with_history: List[str],
142+
guided_decode_json_schema: Optional[str] = None,
143+
temperature: float = DEFAULT_TEMPERATURE,
144+
top_p: float = DEFAULT_TOP_P,
145+
) -> str:
146+
messages = [{"role": "system", "content": system_prompt}]
147+
for i in range(len(prompt_with_history)):
148+
if i % 2 == 0:
149+
messages.append({"role": "user", "content": prompt_with_history[i]})
150+
else:
151+
messages.append(
152+
{"role": "assistant", "content": prompt_with_history[i]}
153+
)
154+
155+
level = logging.getLogger().level
156+
logging.getLogger().setLevel(logging.WARNING)
157+
158+
params: Dict[str, Any] = {
159+
"model": self.model,
160+
"messages": messages,
161+
"max_tokens": DEFAULT_MAX_TOKENS,
162+
"temperature": temperature,
163+
"top_p": top_p,
164+
}
165+
166+
if guided_decode_json_schema is not None:
167+
response_format = self._build_response_format(guided_decode_json_schema)
168+
if response_format is not None:
169+
params["response_format"] = response_format
170+
response = self.client.chat.completions.create(**params)
171+
172+
logging.getLogger().setLevel(level)
173+
174+
return self._extract_response(response)
175+
176+
@override
177+
def query(
178+
self,
179+
prompt: str,
180+
guided_decode_json_schema: Optional[str] = None,
181+
temperature: float = DEFAULT_TEMPERATURE,
182+
top_p: float = DEFAULT_TOP_P,
183+
) -> str:
184+
params: Dict[str, Any] = {
185+
"model": self.model,
186+
"messages": [
187+
{"role": "user", "content": prompt},
188+
],
189+
"max_tokens": DEFAULT_MAX_TOKENS,
190+
"temperature": temperature,
191+
"top_p": top_p,
192+
}
193+
194+
if guided_decode_json_schema is not None:
195+
response_format = self._build_response_format(guided_decode_json_schema)
196+
if response_format is not None:
197+
params["response_format"] = response_format
198+
response = self.client.chat.completions.create(**params)
199+
200+
return self._extract_response(response)
201+
202+
@override
203+
def query_with_system_prompt(
204+
self,
205+
system_prompt: str,
206+
prompt: str,
207+
guided_decode_json_schema: Optional[str] = None,
208+
temperature: float = DEFAULT_TEMPERATURE,
209+
top_p: float = DEFAULT_TOP_P,
210+
) -> str:
211+
params: Dict[str, Any] = {
212+
"model": self.model,
213+
"messages": [
214+
{"role": "system", "content": system_prompt},
215+
{"role": "user", "content": prompt},
216+
],
217+
"max_tokens": DEFAULT_MAX_TOKENS,
218+
"temperature": temperature,
219+
"top_p": top_p,
220+
}
221+
222+
if guided_decode_json_schema is not None:
223+
response_format = self._build_response_format(guided_decode_json_schema)
224+
if response_format is not None:
225+
params["response_format"] = response_format
226+
response = self.client.chat.completions.create(**params)
227+
228+
return self._extract_response(response)
229+
230+
@override
231+
def query_multimodal(
232+
self,
233+
system_prompt: Optional[str] = None,
234+
text_prompt: Optional[str] = None,
235+
image_paths: Optional[List[str]] = None,
236+
audio_paths: Optional[List[str]] = None,
237+
max_tokens: int = DEFAULT_MAX_TOKENS,
238+
temperature: float = DEFAULT_TEMPERATURE,
239+
top_p: float = DEFAULT_TOP_P,
240+
) -> str:
241+
if audio_paths and len(audio_paths) > 0:
242+
raise UnretriableQueryException("Audio inputs are not supported yet.")
243+
244+
if text_prompt is None and (image_paths is None or len(image_paths) == 0):
245+
raise ValueError(
246+
"At least one of text_prompt or image_paths must be given."
247+
)
248+
249+
# Build OpenAI‑compatible message list
250+
messages = []
251+
if system_prompt:
252+
messages.append({"role": "system", "content": system_prompt})
253+
254+
# Compose the user content as a list for multimodal prompts
255+
user_content: list[dict[str, str | dict[str, str]]] = []
256+
257+
if text_prompt:
258+
user_content.append({"type": "text", "text": text_prompt})
259+
260+
if image_paths and len(image_paths) > 0:
261+
# Llama models do not support more than 9 attachements
262+
if len(image_paths) > 9:
263+
LOG.warning(
264+
f"Found {len(image_paths)} image_paths, but only using the first 9."
265+
)
266+
image_paths = image_paths[:9]
267+
268+
for image_path in image_paths:
269+
image_data = image_to_b64(image_path)
270+
user_content.append(
271+
{
272+
"type": "image_url",
273+
"image_url": {"url": f"data:image/png;base64,{image_data}"},
274+
}
275+
)
276+
277+
messages.append({"role": "user", "content": user_content})
278+
279+
response = self.client.chat.completions.create(
280+
model=self.model,
281+
messages=messages,
282+
max_tokens=max_tokens,
283+
temperature=temperature,
284+
top_p=top_p,
285+
)
286+
287+
return self._extract_response(response)
288+
289+
@override
290+
def valid_models(self) -> list[str]:
291+
return [
292+
"Llama-4-Maverick-17B-128E-Instruct-FP8",
293+
"Llama-4-Scout-17B-16E-Instruct-FP8",
294+
]

0 commit comments

Comments
 (0)