| | """ |
| | This example loads the pre-trained bert-base-nli-mean-tokens models from the server. |
| | It then fine-tunes this model for some epochs on the STS benchmark dataset. |
| | """ |
| | from torch.utils.data import DataLoader |
| | import math |
| | from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses |
| | from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator |
| | from sentence_transformers.readers import STSDataReader |
| | import logging |
| | from datetime import datetime |
| |
|
| |
|
| | |
| | logging.basicConfig(format='%(asctime)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S', |
| | level=logging.INFO, |
| | handlers=[LoggingHandler()]) |
| | |
| |
|
| | |
| | |
| | model_name = "../saved_models" |
| | train_batch_size = 32 |
| | num_epochs = 4 |
| | model_save_path = 'output/quora_continue_training-'+model_name+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | sts_reader = STSDataReader('../data/quora', normalize_scores=True, s1_col_idx=4, s2_col_idx=5, score_col_idx=6, max_score=1) |
| |
|
| | |
| | model = SentenceTransformer(model_name) |
| |
|
| | |
| | logging.info("Read Quora train dataset") |
| | train_data = SentencesDataset(sts_reader.get_examples('train.csv'), model) |
| | train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) |
| | train_loss = losses.CosineSimilarityLoss(model=model) |
| |
|
| |
|
| | logging.info("Read Quora dev dataset") |
| | dev_data = SentencesDataset(examples=sts_reader.get_examples('dev.csv'), model=model) |
| | dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=train_batch_size) |
| | evaluator = EmbeddingSimilarityEvaluator(dev_dataloader) |
| |
|
| |
|
| | |
| | warmup_steps = math.ceil(len(train_data)*num_epochs/train_batch_size*0.1) |
| | logging.info("Warmup-steps: {}".format(warmup_steps)) |
| |
|
| |
|
| | |
| | model.fit(train_objectives=[(train_dataloader, train_loss)], |
| | evaluator=evaluator, |
| | epochs=num_epochs, |
| | evaluation_steps=1000, |
| | warmup_steps=warmup_steps, |
| | output_path=model_save_path) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|