Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions aworld/models/model_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Function(BaseModel):
Represents a function call made by a model
"""
name: str
arguments: str = None
arguments: Optional[str] = None


class ToolCall(BaseModel):
Expand All @@ -46,6 +46,7 @@ class ToolCall(BaseModel):
id: str
type: str = "function"
function: Function = None
extra_content: Optional[dict] = None

# name: str = None
# arguments: str = None
Expand Down Expand Up @@ -82,11 +83,18 @@ def from_dict(cls, data: Dict[str, Any]) -> 'ToolCall':
arguments = json.dumps(arguments, ensure_ascii=False)

function = Function(name=name, arguments=arguments)
if 'model_extra' in data and 'extra_content' in data['model_extra']:
extra_content = data['model_extra']['extra_content']
else:
extra_content = None
if 'extra_content' in data:
extra_content = data['extra_content']

return cls(
id=tool_id,
type=tool_type,
function=function,
extra_content=extra_content
# name=name,
# arguments=arguments,
)
Expand All @@ -104,7 +112,8 @@ def to_dict(self) -> Dict[str, Any]:
"function": {
"name": self.function.name,
"arguments": self.function.arguments
}
},
"extra_content": self.extra_content
}

def __repr__(self):
Expand Down Expand Up @@ -296,6 +305,10 @@ def from_openai_response(cls, response: Any) -> 'ModelResponse':
"name": function.name if hasattr(function, 'name') else None,
"arguments": function.arguments if hasattr(function, 'arguments') else None
}
if hasattr(tool_call, 'model_extra'):
model_extra = tool_call.model_extra
if model_extra:
tool_call_dict["model_extra"] = model_extra
processed_tool_calls.append(ToolCall.from_dict(tool_call_dict))

if message_dict and processed_tool_calls:
Expand Down