-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_embeddings.py
49 lines (39 loc) · 1.32 KB
/
calc_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Embeddings calcuration for the chunks
# Date: 2024/10/06, 2025/01/26
# Author: [email protected]
import sys
sys.path.append("..")
sys.path.append("../cx")
import sqlite3
from rag import vector_db
from rag import embeddings
DB_PATH = "../db/documents.db"
VECTOR_DB_PATH = "../db/embeddings.db"
# Read all the chunks
with sqlite3.connect(DB_PATH) as conn:
cur = conn.cursor()
records = cur.execute("SELECT collection, context, chunk FROM chunks").fetchall()
data = {}
for r in records:
collection = r[0]
if collection not in data:
data[collection] = []
data[collection].append([r[1], r[2]])
N = 10 # Batch size for embeddings calculation
records = {}
# Calculate embeddings
for collection in data.keys():
print(f"Collection: {collection}")
items = data[collection]
for i in range(0, len(items), N):
contexts, chunks = zip(*items[i:i+N])
print(contexts[0])
vectors = embeddings.get_embedding(chunks)
if collection not in records:
records[collection] = []
records[collection].extend(zip(contexts, vectors, chunks))
# Save the embeddings in a vector database
print("Saving the data in the database...")
for collection, items in records.items():
col_db = vector_db.VectorDB(VECTOR_DB_PATH, collection, embeddings.DIMENSION)
col_db.save(items)