Spaces:
Running
Running
| """Unit tests for HuggingFace model and provider validator.""" | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from src.utils.hf_model_validator import ( | |
| extract_oauth_token, | |
| get_available_models, | |
| get_available_providers, | |
| get_models_for_provider, | |
| validate_model_provider_combination, | |
| validate_oauth_token, | |
| ) | |
| class TestExtractOAuthToken: | |
| """Tests for extract_oauth_token function.""" | |
| def test_extract_from_oauth_token_object(self) -> None: | |
| """Should extract token from OAuthToken object with .token attribute.""" | |
| mock_oauth_token = MagicMock() | |
| mock_oauth_token.token = "hf_test_token_123" | |
| result = extract_oauth_token(mock_oauth_token) | |
| assert result == "hf_test_token_123" | |
| def test_extract_from_string(self) -> None: | |
| """Should return string token as-is.""" | |
| token = "hf_test_token_123" | |
| result = extract_oauth_token(token) | |
| assert result == token | |
| def test_extract_none(self) -> None: | |
| """Should return None for None input.""" | |
| result = extract_oauth_token(None) | |
| assert result is None | |
| def test_extract_invalid_object(self) -> None: | |
| """Should return None for object without .token attribute.""" | |
| invalid_object = MagicMock() | |
| del invalid_object.token # Remove token attribute | |
| with patch("src.utils.hf_model_validator.logger") as mock_logger: | |
| result = extract_oauth_token(invalid_object) | |
| assert result is None | |
| mock_logger.warning.assert_called_once() | |
| class TestGetAvailableProviders: | |
| """Tests for get_available_providers function.""" | |
| async def test_get_providers_with_cache(self) -> None: | |
| """Should return cached providers if available.""" | |
| # First call - should query API | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock model_info to return provider mapping | |
| mock_model_info = MagicMock() | |
| mock_model_info.inference_provider_mapping = { | |
| "hf-inference": MagicMock(), | |
| "nebius": MagicMock(), | |
| } | |
| mock_api.model_info.return_value = mock_model_info | |
| # Mock settings | |
| with patch("src.utils.hf_model_validator.settings") as mock_settings: | |
| mock_settings.get_hf_fallback_models_list.return_value = [ | |
| "meta-llama/Llama-3.1-8B-Instruct" | |
| ] | |
| providers = await get_available_providers(token="test_token") | |
| assert "auto" in providers | |
| assert len(providers) > 1 | |
| async def test_get_providers_fallback_to_known(self) -> None: | |
| """Should fall back to known providers if discovery fails.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.model_info.side_effect = Exception("API error") | |
| with patch("src.utils.hf_model_validator.settings") as mock_settings: | |
| mock_settings.get_hf_fallback_models_list.return_value = [ | |
| "meta-llama/Llama-3.1-8B-Instruct" | |
| ] | |
| providers = await get_available_providers(token="test_token") | |
| # Should return known providers as fallback | |
| assert "auto" in providers | |
| assert len(providers) > 0 | |
| async def test_get_providers_no_token(self) -> None: | |
| """Should work without token.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.model_info.side_effect = Exception("API error") | |
| with patch("src.utils.hf_model_validator.settings") as mock_settings: | |
| mock_settings.get_hf_fallback_models_list.return_value = [ | |
| "meta-llama/Llama-3.1-8B-Instruct" | |
| ] | |
| providers = await get_available_providers(token=None) | |
| # Should return known providers as fallback | |
| assert "auto" in providers | |
| class TestGetAvailableModels: | |
| """Tests for get_available_models function.""" | |
| async def test_get_models_with_token(self) -> None: | |
| """Should fetch models with token.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock list_models to return model objects | |
| mock_model1 = MagicMock() | |
| mock_model1.id = "model1" | |
| mock_model2 = MagicMock() | |
| mock_model2.id = "model2" | |
| mock_api.list_models.return_value = [mock_model1, mock_model2] | |
| models = await get_available_models(token="test_token", limit=10) | |
| assert len(models) == 2 | |
| assert "model1" in models | |
| assert "model2" in models | |
| mock_api.list_models.assert_called_once() | |
| async def test_get_models_with_provider_filter(self) -> None: | |
| """Should filter models by provider.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_model = MagicMock() | |
| mock_model.id = "model1" | |
| mock_api.list_models.return_value = [mock_model] | |
| models = await get_available_models( | |
| token="test_token", | |
| inference_provider="nebius", | |
| limit=10, | |
| ) | |
| # Check that inference_provider was passed to list_models | |
| call_kwargs = mock_api.list_models.call_args[1] | |
| assert call_kwargs.get("inference_provider") == "nebius" | |
| async def test_get_models_fallback_on_error(self) -> None: | |
| """Should return fallback models on error.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.list_models.side_effect = Exception("API error") | |
| models = await get_available_models(token="test_token", limit=10) | |
| # Should return fallback models | |
| assert len(models) > 0 | |
| assert "meta-llama/Llama-3.1-8B-Instruct" in models | |
| class TestValidateModelProviderCombination: | |
| """Tests for validate_model_provider_combination function.""" | |
| async def test_validate_auto_provider(self) -> None: | |
| """Should always validate 'auto' provider.""" | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider="auto", | |
| token="test_token", | |
| ) | |
| assert is_valid is True | |
| assert error_msg is None | |
| async def test_validate_none_provider(self) -> None: | |
| """Should validate None provider as auto.""" | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider=None, | |
| token="test_token", | |
| ) | |
| assert is_valid is True | |
| assert error_msg is None | |
| async def test_validate_valid_combination(self) -> None: | |
| """Should validate valid model/provider combination.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock model_info with provider mapping | |
| mock_model_info = MagicMock() | |
| mock_model_info.inference_provider_mapping = { | |
| "nebius": MagicMock(), | |
| "hf-inference": MagicMock(), | |
| } | |
| mock_api.model_info.return_value = mock_model_info | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider="nebius", | |
| token="test_token", | |
| ) | |
| assert is_valid is True | |
| assert error_msg is None | |
| async def test_validate_invalid_combination(self) -> None: | |
| """Should reject invalid model/provider combination.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock model_info with provider mapping (without requested provider) | |
| mock_model_info = MagicMock() | |
| mock_model_info.inference_provider_mapping = { | |
| "hf-inference": MagicMock(), | |
| } | |
| mock_api.model_info.return_value = mock_model_info | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider="nebius", | |
| token="test_token", | |
| ) | |
| assert is_valid is False | |
| assert error_msg is not None | |
| assert "nebius" in error_msg | |
| async def test_validate_fireworks_variants(self) -> None: | |
| """Should handle fireworks/fireworks-ai name variants.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock model_info with fireworks-ai in mapping | |
| mock_model_info = MagicMock() | |
| mock_model_info.inference_provider_mapping = { | |
| "fireworks-ai": MagicMock(), | |
| } | |
| mock_api.model_info.return_value = mock_model_info | |
| # Should accept "fireworks" when mapping has "fireworks-ai" | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider="fireworks", | |
| token="test_token", | |
| ) | |
| assert is_valid is True | |
| assert error_msg is None | |
| async def test_validate_graceful_on_error(self) -> None: | |
| """Should return valid on error (graceful degradation).""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.model_info.side_effect = Exception("API error") | |
| is_valid, error_msg = await validate_model_provider_combination( | |
| model_id="test-model", | |
| provider="nebius", | |
| token="test_token", | |
| ) | |
| # Should return True to allow actual request to determine validity | |
| assert is_valid is True | |
| class TestGetModelsForProvider: | |
| """Tests for get_models_for_provider function.""" | |
| async def test_get_models_for_provider(self) -> None: | |
| """Should get models for specific provider.""" | |
| with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models: | |
| mock_get_models.return_value = ["model1", "model2"] | |
| models = await get_models_for_provider( | |
| provider="nebius", | |
| token="test_token", | |
| limit=10, | |
| ) | |
| assert len(models) == 2 | |
| mock_get_models.assert_called_once_with( | |
| token="test_token", | |
| task="text-generation", | |
| limit=10, | |
| inference_provider="nebius", | |
| ) | |
| async def test_get_models_normalize_fireworks(self) -> None: | |
| """Should normalize 'fireworks' to 'fireworks-ai'.""" | |
| with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models: | |
| mock_get_models.return_value = ["model1"] | |
| models = await get_models_for_provider( | |
| provider="fireworks", | |
| token="test_token", | |
| ) | |
| # Should call with "fireworks-ai" not "fireworks" | |
| call_kwargs = mock_get_models.call_args[1] | |
| assert call_kwargs["inference_provider"] == "fireworks-ai" | |
| class TestValidateOAuthToken: | |
| """Tests for validate_oauth_token function.""" | |
| async def test_validate_none_token(self) -> None: | |
| """Should return invalid for None token.""" | |
| result = await validate_oauth_token(None) | |
| assert result["is_valid"] is False | |
| assert result["error"] == "No token provided" | |
| async def test_validate_invalid_format(self) -> None: | |
| """Should return invalid for malformed token.""" | |
| result = await validate_oauth_token("short") | |
| assert result["is_valid"] is False | |
| assert "Invalid token format" in result["error"] | |
| async def test_validate_valid_token(self) -> None: | |
| """Should validate valid token and return resources.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| # Mock whoami to return user info | |
| mock_api.whoami.return_value = {"name": "testuser", "fullname": "Test User"} | |
| # Mock get_available_models and get_available_providers | |
| with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \ | |
| patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers: | |
| mock_get_models.return_value = ["model1", "model2"] | |
| mock_get_providers.return_value = ["auto", "nebius"] | |
| result = await validate_oauth_token("hf_valid_token_123") | |
| assert result["is_valid"] is True | |
| assert result["username"] == "testuser" | |
| assert result["has_inference_api_scope"] is True | |
| assert len(result["available_models"]) == 2 | |
| assert len(result["available_providers"]) == 2 | |
| async def test_validate_token_without_scope(self) -> None: | |
| """Should detect missing inference-api scope.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.whoami.return_value = {"name": "testuser"} | |
| # Mock get_available_models to fail (no scope) | |
| with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \ | |
| patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers: | |
| mock_get_models.side_effect = Exception("403 Forbidden") | |
| mock_get_providers.return_value = ["auto"] | |
| result = await validate_oauth_token("hf_token_without_scope") | |
| assert result["is_valid"] is True # Token is valid | |
| assert result["has_inference_api_scope"] is False # But no scope | |
| assert "inference-api scope" in result["error"] | |
| async def test_validate_invalid_token(self) -> None: | |
| """Should return invalid for token that fails authentication.""" | |
| with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: | |
| mock_api = MagicMock() | |
| mock_api_class.return_value = mock_api | |
| mock_api.whoami.side_effect = Exception("401 Unauthorized") | |
| result = await validate_oauth_token("hf_invalid_token") | |
| assert result["is_valid"] is False | |
| assert "could not authenticate" in result["error"] | |