Skip to content

Commit c7e63fc

Browse files
authored
add Rubric and Rubric.evaluate (#155)
1 parent d030e3a commit c7e63fc

File tree

4 files changed

+119
-1
lines changed

4 files changed

+119
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
VERSION = "1.5.17"
3+
VERSION = "1.5.18"
44

55
with open("requirements.txt") as f:
66
requirements = f.read().splitlines()

surge/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from surge.tasks import Task
55
from surge.teams import Team
66
from surge.reports import Report
7+
from surge.rubrics import Rubric
78

89
api_key = os.environ.get("SURGE_API_KEY", None)
910
base_url = os.environ.get("SURGE_BASE_URL", "https://app.surgehq.ai/api")

surge/rubrics.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from surge.api_resource import APIResource
2+
3+
4+
class Rubric(APIResource):
5+
6+
@classmethod
7+
def evaluate(
8+
cls,
9+
text_for_grading: str,
10+
rubric_text: str,
11+
prompt: str = None,
12+
api_key: str = None,
13+
):
14+
"""
15+
Evaluate text against a rubric using AI grading.
16+
17+
Arguments:
18+
text_for_grading (str): The text content to be graded.
19+
rubric_text (str): The rubric or criteria to evaluate against.
20+
prompt (str, optional): Additional instructions for how to grade the text.
21+
api_key (str, optional): API key to use for this request.
22+
23+
Returns:
24+
dict: A dictionary containing:
25+
- answer (bool): Whether the text meets the rubric criteria.
26+
- explanation (str): An explanation of the grading decision.
27+
"""
28+
endpoint = "evaluate_rubric"
29+
params = {
30+
"text_for_grading": text_for_grading,
31+
"rubric_text": rubric_text,
32+
}
33+
if prompt is not None:
34+
params["prompt"] = prompt
35+
36+
response_json = cls.post(endpoint, params, api_key=api_key)
37+
return response_json

tests/test_rubrics.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from unittest.mock import patch
2+
3+
from surge.api_resource import APIResource
4+
from surge.rubrics import Rubric
5+
6+
7+
def test_evaluate_with_all_params():
8+
"""Test evaluate method with all parameters provided"""
9+
with patch.object(Rubric, "post") as mock_post:
10+
mock_post.return_value = {
11+
"answer": True,
12+
"explanation": 'The text explicitly mentions two animals: "fox" and "dog." Therefore, it contains an animal, satisfying the rubric.',
13+
}
14+
15+
result = Rubric.evaluate(
16+
text_for_grading="The quick brown fox jumps over the lazy dog",
17+
rubric_text="Check if the text contains an animal",
18+
prompt="Grade this text based on the rubric",
19+
api_key="test_key",
20+
)
21+
22+
mock_post.assert_called_once_with(
23+
"evaluate_rubric",
24+
{
25+
"text_for_grading": "The quick brown fox jumps over the lazy dog",
26+
"rubric_text": "Check if the text contains an animal",
27+
"prompt": "Grade this text based on the rubric",
28+
},
29+
api_key="test_key",
30+
)
31+
32+
assert result["answer"] == True
33+
assert "fox" in result["explanation"] or "dog" in result["explanation"]
34+
35+
36+
def test_evaluate_without_prompt():
37+
"""Test evaluate method without optional prompt parameter"""
38+
with patch.object(Rubric, "post") as mock_post:
39+
mock_post.return_value = {
40+
"answer": False,
41+
"explanation": "The text does not contain any animals.",
42+
}
43+
44+
result = Rubric.evaluate(
45+
text_for_grading="The quick brown car drives down the road",
46+
rubric_text="Check if the text contains an animal",
47+
)
48+
49+
mock_post.assert_called_once_with(
50+
"evaluate_rubric",
51+
{
52+
"text_for_grading": "The quick brown car drives down the road",
53+
"rubric_text": "Check if the text contains an animal",
54+
},
55+
api_key=None,
56+
)
57+
58+
assert result["answer"] == False
59+
assert "explanation" in result
60+
61+
62+
def test_evaluate_returns_dict():
63+
"""Test that evaluate returns a dictionary with expected keys"""
64+
with patch.object(Rubric, "post") as mock_post:
65+
mock_post.return_value = {"answer": True, "explanation": "Test explanation"}
66+
67+
result = Rubric.evaluate(
68+
text_for_grading="Sample text", rubric_text="Sample rubric"
69+
)
70+
71+
assert isinstance(result, dict)
72+
assert "answer" in result
73+
assert "explanation" in result
74+
assert isinstance(result["answer"], bool)
75+
assert isinstance(result["explanation"], str)
76+
77+
78+
def test_rubric_inherits_from_api_resource():
79+
"""Test that Rubric class inherits from APIResource"""
80+
assert issubclass(Rubric, APIResource)

0 commit comments

Comments
 (0)