-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
189 lines (147 loc) · 5.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import inspect
import logging
from pathlib import Path
from typing import Dict, List
import mlflow
import numpy as np
from omegaconf import DictConfig, OmegaConf
from sentence_transformers import CrossEncoder, SentenceTransformer
file_query_embeddings = 'query_embeddings.npy'
def get_logger() -> logging.Logger:
caller = inspect.stack()[1][3]
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
return logging.getLogger(caller)
def save_query_embeddings(
config: DictConfig,
query_embeddings: np.ndarray
) -> None:
path_embeddings = get_path_of_query_embeddings(config)
if not path_embeddings.is_dir():
path_embeddings.parent.mkdir(parents=True)
with open(path_embeddings, 'wb') as file:
np.save(file, query_embeddings)
def load_query_embeddings(
config: DictConfig,
questions: List[str]
) -> np.ndarray:
path_embeddings = get_path_of_query_embeddings(config)
with open(path_embeddings, 'rb') as file:
query_embeddings = np.load(file)
assert len(questions) == len(query_embeddings)
return query_embeddings
def check_if_query_embeddings_exist(config) -> bool:
path_embeddings = get_path_of_query_embeddings(config)
return True if path_embeddings.is_file() else False
def get_path_of_query_embeddings(config: DictConfig) -> Path:
path_embeddings = Path().resolve().joinpath(
config.embeddings,
config.preprocess.dataset,
config.retrieval.model.model_name,
file_query_embeddings
)
return path_embeddings
def log_retrieval_experiment_mlflow(
config: DictConfig,
top_k: int,
metrics: Dict[str, float],
keys_to_remove: List[str]
) -> None:
logger = get_logger()
logger.info('Logging results with MLFlow...')
config = select_relevant_config(config, keys_to_remove)
cr = config.retrieval
cfg_mlflow = config.mlflow
mlflow.set_tracking_uri(Path().resolve().joinpath(config.experiments))
mlflow.set_experiment(cfg_mlflow.experiment_name)
# llama_index HuggingfaceEmbedding doesn't store number of parameters...
model_name = cr.model.model_name
model = SentenceTransformer(model_name, device='cpu')
params_to_log = {}
n_params = sum(p.numel() for p in model.parameters())
params_to_log.update({
'embedding_model_params': f'{n_params:,}',
'embedding_model_name': model_name,
'top_k': top_k,
'chunker': cr.chunker.name,
**cr.chunker.params
})
if cr.reranking:
model_name = cr.reranker.model_name
model = CrossEncoder(model_name, device='cpu')
n_params = sum(p.numel() for p in model.model.parameters())
params_to_log.update({
'reranker': model_name, 'reranker_params': f'{n_params:,}'
})
tags = {
**cfg_mlflow.tags,
'chunker': cr.chunker.name,
'dataset': config.preprocess.dataset
}
with mlflow.start_run(
run_name=cfg_mlflow.run_name,
description=cfg_mlflow.description
) as run:
mlflow.set_tags(tags)
mlflow.log_dict(OmegaConf.to_container(config), 'config.yaml')
mlflow.log_metrics(metrics)
mlflow.log_params(params_to_log)
def log_colbert_experiment_mlflow(
config: DictConfig,
top_k: int,
metrics: Dict[str, float],
keys_to_remove: List[str]
) -> None:
logger = get_logger()
logger.info('Logging results with MLFlow...')
config = select_relevant_config(config, keys_to_remove)
ccr = config.colbert_retrieval
cfg_mlflow = config.mlflow
mlflow.set_tracking_uri(Path().resolve().joinpath(config.experiments))
mlflow.set_experiment(cfg_mlflow.experiment_name)
params_to_log = {
'model_name': ccr.model_name,
'max_document_length': ccr.max_document_length,
'top_k': top_k
}
tags = {
**cfg_mlflow.tags,
'dataset': config.preprocess.dataset
}
with mlflow.start_run(
run_name=cfg_mlflow.run_name,
description=cfg_mlflow.description
) as run:
mlflow.set_tags(tags)
mlflow.log_dict(OmegaConf.to_container(config), 'config.yaml')
mlflow.log_metrics(metrics)
mlflow.log_params(params_to_log)
def select_relevant_config(
config: DictConfig,
keys_to_remove: List[str]
) -> DictConfig:
# remove irrelevant part related to an experiment
config._set_flag("struct", False)
if 'pipeline' in config.keys():
del config['pipeline']
for key in keys_to_remove:
if key in config.keys():
del config[key]
config._set_flag("struct", True)
return config
def recursive_config_update(
cfg_original: dict,
cfg_update: dict
) -> dict:
""""""
for key, value in cfg_update.items():
if (
isinstance(value, dict)
and key in cfg_original
and isinstance(cfg_original[key], dict)
):
recursive_config_update(cfg_original[key], value)
else:
cfg_original[key] = value