DeepCritical / tests /unit /utils /test_hf_model_validator.py
Joseph Pollack
attempts fix 403 and settings
8fa2ce6 unverified
"""Unit tests for HuggingFace model and provider validator."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.utils.hf_model_validator import (
extract_oauth_token,
get_available_models,
get_available_providers,
get_models_for_provider,
validate_model_provider_combination,
validate_oauth_token,
)
class TestExtractOAuthToken:
"""Tests for extract_oauth_token function."""
def test_extract_from_oauth_token_object(self) -> None:
"""Should extract token from OAuthToken object with .token attribute."""
mock_oauth_token = MagicMock()
mock_oauth_token.token = "hf_test_token_123"
result = extract_oauth_token(mock_oauth_token)
assert result == "hf_test_token_123"
def test_extract_from_string(self) -> None:
"""Should return string token as-is."""
token = "hf_test_token_123"
result = extract_oauth_token(token)
assert result == token
def test_extract_none(self) -> None:
"""Should return None for None input."""
result = extract_oauth_token(None)
assert result is None
def test_extract_invalid_object(self) -> None:
"""Should return None for object without .token attribute."""
invalid_object = MagicMock()
del invalid_object.token # Remove token attribute
with patch("src.utils.hf_model_validator.logger") as mock_logger:
result = extract_oauth_token(invalid_object)
assert result is None
mock_logger.warning.assert_called_once()
class TestGetAvailableProviders:
"""Tests for get_available_providers function."""
@pytest.mark.asyncio
async def test_get_providers_with_cache(self) -> None:
"""Should return cached providers if available."""
# First call - should query API
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock model_info to return provider mapping
mock_model_info = MagicMock()
mock_model_info.inference_provider_mapping = {
"hf-inference": MagicMock(),
"nebius": MagicMock(),
}
mock_api.model_info.return_value = mock_model_info
# Mock settings
with patch("src.utils.hf_model_validator.settings") as mock_settings:
mock_settings.get_hf_fallback_models_list.return_value = [
"meta-llama/Llama-3.1-8B-Instruct"
]
providers = await get_available_providers(token="test_token")
assert "auto" in providers
assert len(providers) > 1
@pytest.mark.asyncio
async def test_get_providers_fallback_to_known(self) -> None:
"""Should fall back to known providers if discovery fails."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.model_info.side_effect = Exception("API error")
with patch("src.utils.hf_model_validator.settings") as mock_settings:
mock_settings.get_hf_fallback_models_list.return_value = [
"meta-llama/Llama-3.1-8B-Instruct"
]
providers = await get_available_providers(token="test_token")
# Should return known providers as fallback
assert "auto" in providers
assert len(providers) > 0
@pytest.mark.asyncio
async def test_get_providers_no_token(self) -> None:
"""Should work without token."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.model_info.side_effect = Exception("API error")
with patch("src.utils.hf_model_validator.settings") as mock_settings:
mock_settings.get_hf_fallback_models_list.return_value = [
"meta-llama/Llama-3.1-8B-Instruct"
]
providers = await get_available_providers(token=None)
# Should return known providers as fallback
assert "auto" in providers
class TestGetAvailableModels:
"""Tests for get_available_models function."""
@pytest.mark.asyncio
async def test_get_models_with_token(self) -> None:
"""Should fetch models with token."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock list_models to return model objects
mock_model1 = MagicMock()
mock_model1.id = "model1"
mock_model2 = MagicMock()
mock_model2.id = "model2"
mock_api.list_models.return_value = [mock_model1, mock_model2]
models = await get_available_models(token="test_token", limit=10)
assert len(models) == 2
assert "model1" in models
assert "model2" in models
mock_api.list_models.assert_called_once()
@pytest.mark.asyncio
async def test_get_models_with_provider_filter(self) -> None:
"""Should filter models by provider."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_model = MagicMock()
mock_model.id = "model1"
mock_api.list_models.return_value = [mock_model]
models = await get_available_models(
token="test_token",
inference_provider="nebius",
limit=10,
)
# Check that inference_provider was passed to list_models
call_kwargs = mock_api.list_models.call_args[1]
assert call_kwargs.get("inference_provider") == "nebius"
@pytest.mark.asyncio
async def test_get_models_fallback_on_error(self) -> None:
"""Should return fallback models on error."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.list_models.side_effect = Exception("API error")
models = await get_available_models(token="test_token", limit=10)
# Should return fallback models
assert len(models) > 0
assert "meta-llama/Llama-3.1-8B-Instruct" in models
class TestValidateModelProviderCombination:
"""Tests for validate_model_provider_combination function."""
@pytest.mark.asyncio
async def test_validate_auto_provider(self) -> None:
"""Should always validate 'auto' provider."""
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider="auto",
token="test_token",
)
assert is_valid is True
assert error_msg is None
@pytest.mark.asyncio
async def test_validate_none_provider(self) -> None:
"""Should validate None provider as auto."""
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider=None,
token="test_token",
)
assert is_valid is True
assert error_msg is None
@pytest.mark.asyncio
async def test_validate_valid_combination(self) -> None:
"""Should validate valid model/provider combination."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock model_info with provider mapping
mock_model_info = MagicMock()
mock_model_info.inference_provider_mapping = {
"nebius": MagicMock(),
"hf-inference": MagicMock(),
}
mock_api.model_info.return_value = mock_model_info
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider="nebius",
token="test_token",
)
assert is_valid is True
assert error_msg is None
@pytest.mark.asyncio
async def test_validate_invalid_combination(self) -> None:
"""Should reject invalid model/provider combination."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock model_info with provider mapping (without requested provider)
mock_model_info = MagicMock()
mock_model_info.inference_provider_mapping = {
"hf-inference": MagicMock(),
}
mock_api.model_info.return_value = mock_model_info
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider="nebius",
token="test_token",
)
assert is_valid is False
assert error_msg is not None
assert "nebius" in error_msg
@pytest.mark.asyncio
async def test_validate_fireworks_variants(self) -> None:
"""Should handle fireworks/fireworks-ai name variants."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock model_info with fireworks-ai in mapping
mock_model_info = MagicMock()
mock_model_info.inference_provider_mapping = {
"fireworks-ai": MagicMock(),
}
mock_api.model_info.return_value = mock_model_info
# Should accept "fireworks" when mapping has "fireworks-ai"
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider="fireworks",
token="test_token",
)
assert is_valid is True
assert error_msg is None
@pytest.mark.asyncio
async def test_validate_graceful_on_error(self) -> None:
"""Should return valid on error (graceful degradation)."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.model_info.side_effect = Exception("API error")
is_valid, error_msg = await validate_model_provider_combination(
model_id="test-model",
provider="nebius",
token="test_token",
)
# Should return True to allow actual request to determine validity
assert is_valid is True
class TestGetModelsForProvider:
"""Tests for get_models_for_provider function."""
@pytest.mark.asyncio
async def test_get_models_for_provider(self) -> None:
"""Should get models for specific provider."""
with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models:
mock_get_models.return_value = ["model1", "model2"]
models = await get_models_for_provider(
provider="nebius",
token="test_token",
limit=10,
)
assert len(models) == 2
mock_get_models.assert_called_once_with(
token="test_token",
task="text-generation",
limit=10,
inference_provider="nebius",
)
@pytest.mark.asyncio
async def test_get_models_normalize_fireworks(self) -> None:
"""Should normalize 'fireworks' to 'fireworks-ai'."""
with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models:
mock_get_models.return_value = ["model1"]
models = await get_models_for_provider(
provider="fireworks",
token="test_token",
)
# Should call with "fireworks-ai" not "fireworks"
call_kwargs = mock_get_models.call_args[1]
assert call_kwargs["inference_provider"] == "fireworks-ai"
class TestValidateOAuthToken:
"""Tests for validate_oauth_token function."""
@pytest.mark.asyncio
async def test_validate_none_token(self) -> None:
"""Should return invalid for None token."""
result = await validate_oauth_token(None)
assert result["is_valid"] is False
assert result["error"] == "No token provided"
@pytest.mark.asyncio
async def test_validate_invalid_format(self) -> None:
"""Should return invalid for malformed token."""
result = await validate_oauth_token("short")
assert result["is_valid"] is False
assert "Invalid token format" in result["error"]
@pytest.mark.asyncio
async def test_validate_valid_token(self) -> None:
"""Should validate valid token and return resources."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
# Mock whoami to return user info
mock_api.whoami.return_value = {"name": "testuser", "fullname": "Test User"}
# Mock get_available_models and get_available_providers
with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \
patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers:
mock_get_models.return_value = ["model1", "model2"]
mock_get_providers.return_value = ["auto", "nebius"]
result = await validate_oauth_token("hf_valid_token_123")
assert result["is_valid"] is True
assert result["username"] == "testuser"
assert result["has_inference_api_scope"] is True
assert len(result["available_models"]) == 2
assert len(result["available_providers"]) == 2
@pytest.mark.asyncio
async def test_validate_token_without_scope(self) -> None:
"""Should detect missing inference-api scope."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.whoami.return_value = {"name": "testuser"}
# Mock get_available_models to fail (no scope)
with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \
patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers:
mock_get_models.side_effect = Exception("403 Forbidden")
mock_get_providers.return_value = ["auto"]
result = await validate_oauth_token("hf_token_without_scope")
assert result["is_valid"] is True # Token is valid
assert result["has_inference_api_scope"] is False # But no scope
assert "inference-api scope" in result["error"]
@pytest.mark.asyncio
async def test_validate_invalid_token(self) -> None:
"""Should return invalid for token that fails authentication."""
with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api
mock_api.whoami.side_effect = Exception("401 Unauthorized")
result = await validate_oauth_token("hf_invalid_token")
assert result["is_valid"] is False
assert "could not authenticate" in result["error"]