| | import csv |
| | import os |
| |
|
| | import pytest |
| |
|
| | output_path = 'regression_result' |
| | model = 'internlm2-chat-7b-hf' |
| | dataset = 'siqa' |
| |
|
| |
|
| | @pytest.fixture() |
| | def result_scores(): |
| | file = find_csv_files(output_path) |
| | if file is None: |
| | return None |
| | return read_csv_file(file) |
| |
|
| |
|
| | @pytest.mark.usefixtures('result_scores') |
| | class TestChatScore: |
| | """Test cases for chat model.""" |
| |
|
| | def test_model_dataset_score(self, result_scores): |
| | result_score = result_scores.get(model).get(dataset) |
| | assert_score(result_score, 79.53) |
| |
|
| |
|
| | def assert_score(score, baseline): |
| | if score is None or score == '-': |
| | assert False, 'value is none' |
| | if float(score) < (baseline * 1.03) and float(score) > (baseline * 0.97): |
| | print(score + ' between ' + str(baseline * 0.97) + ' and ' + |
| | str(baseline * 1.03)) |
| | assert True |
| | else: |
| | assert False, score + ' not between ' + str( |
| | baseline * 0.97) + ' and ' + str(baseline * 1.03) |
| |
|
| |
|
| | def find_csv_files(directory): |
| | csv_files = [] |
| | for root, dirs, files in os.walk(directory): |
| | for file in files: |
| | if file.endswith('.csv'): |
| | csv_files.append(os.path.join(root, file)) |
| | if len(csv_files) > 1: |
| | raise 'have more than 1 result file, please check the result manually' |
| | if len(csv_files) == 0: |
| | return None |
| | return csv_files[0] |
| |
|
| |
|
| | def read_csv_file(file_path): |
| | with open(file_path, 'r') as csvfile: |
| | reader = csv.DictReader(csvfile) |
| | filtered_data = [] |
| |
|
| | for row in reader: |
| | filtered_row = { |
| | k: v |
| | for k, v in row.items() |
| | if k not in ['version', 'metric', 'mode'] |
| | } |
| | filtered_data.append(filtered_row) |
| |
|
| | result = {} |
| | for data in filtered_data: |
| | dataset = data.get('dataset') |
| | for key in data.keys(): |
| | if key == 'dataset': |
| | continue |
| | else: |
| | if key in result.keys(): |
| | result.get(key)[dataset] = data.get(key) |
| | else: |
| | result[key] = {dataset: data.get(key)} |
| | return result |
| |
|