| |
| |
|
|
| |
| |
|
|
|
|
| import numpy as np |
| import pandas as pd |
| import pytest |
|
|
| from src.utils import train_test_split_and_feature_extraction |
|
|
| |
| |
| |
|
|
|
|
| @pytest.fixture |
| def big_fake_data(): |
| |
| num_rows = 100 |
| num_image_columns = 10 |
| num_text_columns = 11 |
|
|
| data = { |
| "id": np.arange(1, num_rows + 1), |
| "image": [f"path/{i}.jpg" for i in range(1, num_rows + 1)], |
| } |
|
|
| |
| for i in range(num_image_columns): |
| data[f"image_{i}"] = np.random.rand(num_rows) |
|
|
| |
| for i in range(num_text_columns): |
| data[f"text_{i}"] = np.random.rand(num_rows) |
|
|
| |
| data["class_id"] = np.random.choice(["label1", "label2", "label3"], size=num_rows) |
|
|
| return pd.DataFrame(data) |
|
|
|
|
| def test_train_test_split_and_feature_extraction(big_fake_data): |
| |
| train_df, test_df, text_columns, image_columns, label_columns = ( |
| train_test_split_and_feature_extraction( |
| big_fake_data, test_size=0.3, random_state=42 |
| ) |
| ) |
|
|
| |
| assert text_columns == [f"text_{i}" for i in range(11)], ( |
| "The text embedding columns extraction is incorrect" |
| ) |
| assert image_columns == [f"image_{i}" for i in range(10)], ( |
| "The image embedding columns extraction is incorrect" |
| ) |
| assert label_columns == ["class_id"], ( |
| "The label column extraction is incorrect, should be 'class_id'" |
| ) |
|
|
| |
| assert "image" not in image_columns, ( |
| "'image' column is not part of the embedding columns" |
| ) |
|
|
| |
| assert len(train_df) == 70, f"Train size should be 70%, but got {len(train_df)}%" |
| assert len(test_df) == 30, f"Test size should be 30%, but got {len(test_df)}%" |
|
|
| |
| expected_train_indices = train_df.index.tolist() |
| expected_test_indices = test_df.index.tolist() |
|
|
| |
| train_df_recheck, test_df_recheck, _, _, _ = ( |
| train_test_split_and_feature_extraction( |
| big_fake_data, test_size=0.3, random_state=42 |
| ) |
| ) |
|
|
| assert expected_train_indices == train_df_recheck.index.tolist(), ( |
| "Train set indices are not consistent with the random state" |
| ) |
| assert expected_test_indices == test_df_recheck.index.tolist(), ( |
| "Test set indices are not consistent with the random state" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main() |
|
|