|
| 1 | +from unittest.mock import patch |
| 2 | + |
| 3 | +from surge.api_resource import APIResource |
| 4 | +from surge.rubrics import Rubric |
| 5 | + |
| 6 | + |
| 7 | +def test_evaluate_with_all_params(): |
| 8 | + """Test evaluate method with all parameters provided""" |
| 9 | + with patch.object(Rubric, "post") as mock_post: |
| 10 | + mock_post.return_value = { |
| 11 | + "answer": True, |
| 12 | + "explanation": 'The text explicitly mentions two animals: "fox" and "dog." Therefore, it contains an animal, satisfying the rubric.', |
| 13 | + } |
| 14 | + |
| 15 | + result = Rubric.evaluate( |
| 16 | + text_for_grading="The quick brown fox jumps over the lazy dog", |
| 17 | + rubric_text="Check if the text contains an animal", |
| 18 | + prompt="Grade this text based on the rubric", |
| 19 | + api_key="test_key", |
| 20 | + ) |
| 21 | + |
| 22 | + mock_post.assert_called_once_with( |
| 23 | + "evaluate_rubric", |
| 24 | + { |
| 25 | + "text_for_grading": "The quick brown fox jumps over the lazy dog", |
| 26 | + "rubric_text": "Check if the text contains an animal", |
| 27 | + "prompt": "Grade this text based on the rubric", |
| 28 | + }, |
| 29 | + api_key="test_key", |
| 30 | + ) |
| 31 | + |
| 32 | + assert result["answer"] == True |
| 33 | + assert "fox" in result["explanation"] or "dog" in result["explanation"] |
| 34 | + |
| 35 | + |
| 36 | +def test_evaluate_without_prompt(): |
| 37 | + """Test evaluate method without optional prompt parameter""" |
| 38 | + with patch.object(Rubric, "post") as mock_post: |
| 39 | + mock_post.return_value = { |
| 40 | + "answer": False, |
| 41 | + "explanation": "The text does not contain any animals.", |
| 42 | + } |
| 43 | + |
| 44 | + result = Rubric.evaluate( |
| 45 | + text_for_grading="The quick brown car drives down the road", |
| 46 | + rubric_text="Check if the text contains an animal", |
| 47 | + ) |
| 48 | + |
| 49 | + mock_post.assert_called_once_with( |
| 50 | + "evaluate_rubric", |
| 51 | + { |
| 52 | + "text_for_grading": "The quick brown car drives down the road", |
| 53 | + "rubric_text": "Check if the text contains an animal", |
| 54 | + }, |
| 55 | + api_key=None, |
| 56 | + ) |
| 57 | + |
| 58 | + assert result["answer"] == False |
| 59 | + assert "explanation" in result |
| 60 | + |
| 61 | + |
| 62 | +def test_evaluate_returns_dict(): |
| 63 | + """Test that evaluate returns a dictionary with expected keys""" |
| 64 | + with patch.object(Rubric, "post") as mock_post: |
| 65 | + mock_post.return_value = {"answer": True, "explanation": "Test explanation"} |
| 66 | + |
| 67 | + result = Rubric.evaluate( |
| 68 | + text_for_grading="Sample text", rubric_text="Sample rubric" |
| 69 | + ) |
| 70 | + |
| 71 | + assert isinstance(result, dict) |
| 72 | + assert "answer" in result |
| 73 | + assert "explanation" in result |
| 74 | + assert isinstance(result["answer"], bool) |
| 75 | + assert isinstance(result["explanation"], str) |
| 76 | + |
| 77 | + |
| 78 | +def test_rubric_inherits_from_api_resource(): |
| 79 | + """Test that Rubric class inherits from APIResource""" |
| 80 | + assert issubclass(Rubric, APIResource) |
0 commit comments