Skip to content

Commit 82f0c3c

Browse files
hwittenbornCopilot
andauthored
Support model names with slashes on Gemini endpoints (#17743)
* Support model names with slashes on Gemini endpoints * Fix test * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/proxy_unit_tests/test_google_endpoint_routing.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8b125ff commit 82f0c3c

File tree

2 files changed

+105
-6
lines changed

2 files changed

+105
-6
lines changed

litellm/proxy/google_endpoints/endpoints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313

1414

1515
@router.post(
16-
"/v1beta/models/{model_name}:generateContent",
16+
"/v1beta/models/{model_name:path}:generateContent",
1717
dependencies=[Depends(user_api_key_auth)],
1818
)
1919
@router.post(
20-
"/models/{model_name}:generateContent", dependencies=[Depends(user_api_key_auth)]
20+
"/models/{model_name:path}:generateContent", dependencies=[Depends(user_api_key_auth)]
2121
)
2222
async def google_generate_content(
2323
request: Request,
@@ -50,11 +50,11 @@ async def google_generate_content(
5050

5151

5252
@router.post(
53-
"/v1beta/models/{model_name}:streamGenerateContent",
53+
"/v1beta/models/{model_name:path}:streamGenerateContent",
5454
dependencies=[Depends(user_api_key_auth)],
5555
)
5656
@router.post(
57-
"/models/{model_name}:streamGenerateContent",
57+
"/models/{model_name:path}:streamGenerateContent",
5858
dependencies=[Depends(user_api_key_auth)],
5959
)
6060
async def google_stream_generate_content(
@@ -95,12 +95,12 @@ async def google_stream_generate_content(
9595

9696

9797
@router.post(
98-
"/v1beta/models/{model_name}:countTokens",
98+
"/v1beta/models/{model_name:path}:countTokens",
9999
dependencies=[Depends(user_api_key_auth)],
100100
response_model=TokenCountDetailsResponse,
101101
)
102102
@router.post(
103-
"/models/{model_name}:countTokens",
103+
"/models/{model_name:path}:countTokens",
104104
dependencies=[Depends(user_api_key_auth)],
105105
response_model=TokenCountDetailsResponse,
106106
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
import json
3+
import os
4+
import sys
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
import yaml
9+
10+
sys.path.insert(0, os.path.abspath("../.."))
11+
12+
from litellm.proxy._types import UserAPIKeyAuth
13+
from litellm.proxy.google_endpoints.endpoints import google_generate_content
14+
from fastapi import Request, Response
15+
from fastapi.datastructures import Headers
16+
from litellm.proxy.proxy_server import initialize
17+
from litellm.utils import ModelResponse
18+
19+
@pytest.fixture
20+
def mock_user_api_key_dict():
21+
"""Mock user API key dictionary."""
22+
return UserAPIKeyAuth(
23+
api_key="test_api_key",
24+
user_id="test_user_id",
25+
user_email="test@example.com",
26+
team_id="test_team_id",
27+
max_budget=100.0,
28+
spend=0.0,
29+
user_role="internal_user",
30+
allowed_cache_controls=[],
31+
metadata={},
32+
tpm_limit=None,
33+
rpm_limit=None,
34+
)
35+
36+
37+
@pytest.fixture
38+
def mock_request(request):
39+
"""Create a mock FastAPI request with the sample payload."""
40+
mock_req = MagicMock(spec=Request)
41+
mock_req.headers = Headers({"content-type": "application/json"})
42+
mock_req.method = "POST"
43+
mock_req.url.path = request.param.get("path")
44+
45+
async def mock_body():
46+
return json.dumps(request.param.get("payload", {})).encode('utf-8')
47+
48+
mock_req.body = mock_body
49+
return mock_req
50+
51+
52+
@pytest.fixture
53+
def mock_response():
54+
"""Create a mock FastAPI response."""
55+
return MagicMock(spec=Response)
56+
57+
58+
@pytest.mark.asyncio
59+
@pytest.mark.parametrize("mock_request", [{"path": "/v1beta/models/bedrock/claude-sonnet-3.7:generateContent", "payload": {"contents": [{"parts":[{"text": "The quick brown fox jumps over the lazy dog."}]}]}}], indirect=True)
60+
async def test_google_generate_content_with_slashes_in_model_name(
61+
mock_request, mock_response, mock_user_api_key_dict
62+
):
63+
"""
64+
Test that the google_generate_content endpoint correctly handles model names with slashes.
65+
"""
66+
config = {
67+
"model_list": [
68+
{
69+
"model_name": "bedrock/claude-sonnet-3.7",
70+
"litellm_params": {
71+
"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
72+
},
73+
}
74+
]
75+
}
76+
filepath = os.path.dirname(os.path.abspath(__file__))
77+
config_fp = f"{filepath}/test_config.yaml"
78+
with open(config_fp, "w") as f:
79+
yaml.dump(config, f)
80+
81+
try:
82+
await initialize(config=config_fp)
83+
84+
with patch("litellm.proxy.proxy_server.llm_router.agenerate_content", new_callable=AsyncMock) as mock_agenerate_content:
85+
mock_agenerate_content.return_value = ModelResponse()
86+
87+
await google_generate_content(
88+
request=mock_request,
89+
model_name="bedrock/claude-sonnet-3.7",
90+
fastapi_response=mock_response,
91+
user_api_key_dict=mock_user_api_key_dict,
92+
)
93+
94+
mock_agenerate_content.assert_called_once()
95+
_, call_kwargs = mock_agenerate_content.call_args
96+
assert call_kwargs["model"] == "bedrock/claude-sonnet-3.7"
97+
finally:
98+
if os.path.exists(config_fp):
99+
os.remove(config_fp)

0 commit comments

Comments
 (0)