TharaKavin commited on
Commit
c42a70a
·
verified ·
1 Parent(s): 4ba3a74

Update embedder.py

Browse files
Files changed (1) hide show
  1. embedder.py +41 -40
embedder.py CHANGED
@@ -1,40 +1,41 @@
1
- from sentence_transformers import SentenceTransformer
2
- import faiss
3
- import numpy as np
4
-
5
- class VectorStore:
6
- def __init__(self):
7
- self.index = None
8
- self.chunks = []
9
- self.model = None # lazy load
10
-
11
- def load_model(self):
12
- if self.model is None:
13
- print("Loading model...")
14
- self.model = SentenceTransformer("all-MiniLM-L6-v2")
15
-
16
- def create_index(self, chunks):
17
- self.load_model()
18
-
19
- self.chunks = chunks
20
- embeddings = self.model.encode(chunks)
21
-
22
- if len(embeddings.shape) == 1:
23
- embeddings = np.array([embeddings])
24
- else:
25
- embeddings = np.array(embeddings)
26
-
27
- dim = embeddings.shape[1]
28
- self.index = faiss.IndexFlatL2(dim)
29
- self.index.add(embeddings)
30
-
31
- def retrieve(self, query, k=3):
32
- self.load_model()
33
-
34
- query_embedding = self.model.encode([query])
35
-
36
- if len(query_embedding.shape) == 1:
37
- query_embedding = np.array([query_embedding])
38
-
39
- distances, indices = self.index.search(query_embedding, k)
40
- return [self.chunks[i] for i in indices[0]]
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ class VectorStore:
6
+ def __init__(self):
7
+ self.index = None
8
+ self.chunks = []
9
+ self.model = None # lazy load
10
+
11
+ def load_model(self):
12
+ if self.model is None:
13
+ print("Loading model...")
14
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
15
+
16
+ def create_index(self, chunks):
17
+ self.load_model()
18
+
19
+ self.chunks = chunks
20
+ embeddings = self.model.encode(chunks)
21
+
22
+ if len(embeddings.shape) == 1:
23
+ embeddings = np.array([embeddings])
24
+ else:
25
+ embeddings = np.array(embeddings)
26
+
27
+ dim = embeddings.shape[1]
28
+ self.index = faiss.IndexFlatL2(dim)
29
+ self.index.add(embeddings)
30
+
31
+ def retrieve(self, query, k=3):
32
+ self.load_model()
33
+
34
+ query_embedding = self.model.encode([query])
35
+
36
+ if len(query_embedding.shape) == 1:
37
+ query_embedding = np.array([query_embedding])
38
+
39
+ distances, indices = self.index.search(query_embedding, k)
40
+ return [self.chunks[i] for i in indices[0]]
41
+