Source code for src.embeddings.bert

from sentence_transformers import SentenceTransformer
import pandas as pd
from tqdm import tqdm
from src.utils.utils import check_path_exists, save
import os


[docs]class Bert(object): """ A class to create bert embeddings. Methods: transform(raw_texts: pd.Series, store: str = None): Transforms series of unpreprocessed strings to bert embeddings """ def __init__(self): """ Constructs bert object using a pretrained model. """ self.model = SentenceTransformer( "multi-qa-MiniLM-L6-cos-v1") # Dont Preprocess Texts beforehand!
[docs] def transform(self, raw_texts: pd.Series, store: str = None): """ Transform Series of unpreprocessed strings to bert embeddings. Args: raw_texts (pd.Series): Series of unpreprocessed strings Returns: bert_vec (list): List containing bert embeddings """ bert_vec = [] for text in tqdm(raw_texts): embedding = self.model.encode(text) bert_vec.append(embedding) if store is not None: check_path_exists(os.path.dirname(store)) save(bert_vec, store) return bert_vec