Skip to content

Commit

Permalink
feat: get message subgraph api support multiple kb (pingcap#619)
Browse files Browse the repository at this point in the history
part of pingcap#618

- Using langfuse_instrumentor instead of CallbackManager
  • Loading branch information
Mini256 committed Feb 11, 2025
1 parent 89f1fd4 commit a5fb4dc
Show file tree
Hide file tree
Showing 48 changed files with 2,204 additions and 2,012 deletions.
2 changes: 1 addition & 1 deletion backend/app/api/admin_routes/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi_pagination import Params, Page

from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.chat_config import ChatEngineConfig
from app.rag.chat.config import ChatEngineConfig
from app.repositories import chat_engine_repo
from app.models import ChatEngine, ChatEngineUpdate

Expand Down
6 changes: 3 additions & 3 deletions backend/app/api/admin_routes/evaluation/evaluation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_evaluation_task_summary(
evaluation_task_id: int, session: SessionDep, user: CurrentSuperuserDep
) -> EvaluationTaskSummary:
task = must_get(session, EvaluationTask, evaluation_task_id)
return get_evaluation_task_summary(task, session)
return get_summary_for_evaluation_task(task, session)


@router.get("/admin/evaluation/tasks")
Expand All @@ -135,7 +135,7 @@ def list_evaluation_task(
task_page: Page[EvaluationTask] = paginate(session, stmt, params)
summaries: List[EvaluationTaskSummary] = []
for task in task_page.items:
summaries.append(get_evaluation_task_summary(task, session))
summaries.append(get_summary_for_evaluation_task(task, session))

return Page[EvaluationTaskSummary](
items=summaries,
Expand Down Expand Up @@ -169,7 +169,7 @@ def list_evaluation_task_items(
return paginate(session, stmt, params)


def get_evaluation_task_summary(
def get_summary_for_evaluation_task(
evaluation_task: EvaluationTask, session: Session
) -> EvaluationTaskSummary:
status_counts = (
Expand Down
2 changes: 0 additions & 2 deletions backend/app/api/admin_routes/evaluation/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from http.client import HTTPException
from typing import TypeVar, Type

from fastapi import status, HTTPException
from sqlmodel import SQLModel, Session

Expand Down
4 changes: 2 additions & 2 deletions backend/app/api/admin_routes/knowledge_base/graph/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchReq
try:
kb = knowledge_base_repo.must_get(session, kb_id)
graph_store = get_kb_tidb_graph_store(session, kb)
entities, relations = graph_store.retrieve_with_weight(
entities, relationships = graph_store.retrieve_with_weight(
request.query,
[],
request.depth,
Expand All @@ -236,7 +236,7 @@ def legacy_search_graph(session: SessionDep, kb_id: int, request: GraphSearchReq
)
return {
"entities": entities,
"relationships": relations,
"relationships": relationships,
}
except KBNotFound as e:
raise e
Expand Down
133 changes: 133 additions & 0 deletions backend/app/api/admin_routes/legacy_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
from typing import Optional, List

from fastapi import APIRouter
from sqlmodel import Session
from app.models import Document
from app.api.admin_routes.models import ChatEngineBasedRetrieveRequest
from app.api.deps import SessionDep, CurrentSuperuserDep
from llama_index.core.schema import NodeWithScore

from app.exceptions import InternalServerError, KBNotFound
from app.rag.chat.config import ChatEngineConfig
from app.rag.chat.retrieve.retrieve_flow import RetrieveFlow

router = APIRouter()
logger = logging.getLogger(__name__)


def get_override_engine_config(
db_session: Session,
engine_name: str,
# Override chat engine config.
top_k: Optional[int] = None,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = None,
refine_question_with_kg: Optional[bool] = None,
) -> ChatEngineConfig:
engine_config = ChatEngineConfig.load_from_db(db_session, engine_name)
if similarity_top_k is not None:
engine_config.vector_search.similarity_top_k = similarity_top_k
if oversampling_factor is not None:
engine_config.vector_search.oversampling_factor = oversampling_factor
if top_k is not None:
engine_config.vector_search.top_k = top_k
if refine_question_with_kg is not None:
engine_config.refine_question_with_kg = refine_question_with_kg
return engine_config


@router.get("/admin/retrieve/documents", deprecated=True)
def legacy_retrieve_documents(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
chat_engine: str = "default",
# Override chat engine config.
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
refine_question_with_kg: Optional[bool] = True,
) -> List[Document]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=chat_engine,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
refine_question_with_kg=refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=chat_engine,
engine_config=engine_config,
)
return retriever.retrieve_documents(question)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.get("/admin/embedding_retrieve", deprecated=True)
def legacy_retrieve_chunks(
session: SessionDep,
user: CurrentSuperuserDep,
question: str,
chat_engine: str = "default",
# Override chat engine config.
top_k: Optional[int] = 5,
similarity_top_k: Optional[int] = None,
oversampling_factor: Optional[int] = 5,
refine_question_with_kg=False,
) -> List[NodeWithScore]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=chat_engine,
top_k=top_k,
similarity_top_k=similarity_top_k,
oversampling_factor=oversampling_factor,
refine_question_with_kg=refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=chat_engine,
engine_config=engine_config,
)
return retriever.retrieve(question)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()


@router.post("/admin/embedding_retrieve", deprecated=True)
def legacy_retrieve_chunks_2(
session: SessionDep,
user: CurrentSuperuserDep,
request: ChatEngineBasedRetrieveRequest,
) -> List[NodeWithScore]:
try:
engine_config = get_override_engine_config(
db_session=session,
engine_name=request.chat_engine,
top_k=request.top_k,
similarity_top_k=request.similarity_top_k,
oversampling_factor=request.oversampling_factor,
refine_question_with_kg=request.refine_question_with_kg,
)
retriever = RetrieveFlow(
db_session=session,
engine_name=request.chat_engine,
engine_config=engine_config,
)
return retriever.retrieve(request.query)
except KBNotFound as e:
raise e
except Exception as e:
logger.exception(e)
raise InternalServerError()
5 changes: 4 additions & 1 deletion backend/app/api/admin_routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class KnowledgeBaseDescriptor(BaseModel):
id: int
name: str

def __hash__(self):
return hash(self.id)


class DataSourceDescriptor(BaseModel):
id: int
Expand All @@ -48,4 +51,4 @@ class ChatEngineBasedRetrieveRequest(BaseModel):
top_k: Optional[int] = 5
similarity_top_k: Optional[int] = None
oversampling_factor: Optional[int] = 5
enable_kg_enhance_query_refine: Optional[bool] = False
refine_question_with_kg: Optional[bool] = False
93 changes: 0 additions & 93 deletions backend/app/api/admin_routes/retrieve_old.py

This file was deleted.

2 changes: 1 addition & 1 deletion backend/app/api/admin_routes/semantic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi import APIRouter, Body
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.chat_config import ChatEngineConfig
from app.rag.chat.config import ChatEngineConfig
from app.rag.semantic_cache import SemanticCacheManager, SemanticItem

router = APIRouter()
Expand Down
20 changes: 2 additions & 18 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
from app.api.admin_routes import (
chat_engine as admin_chat_engine,
feedback as admin_feedback,
legacy_retrieve as admin_legacy_retrieve,
site_setting as admin_site_settings,
upload as admin_upload,
retrieve_old as admin_retrieve_old,
stats as admin_stats,
semantic_cache as admin_semantic_cache,
langfuse as admin_langfuse,
Expand Down Expand Up @@ -85,7 +85,7 @@
api_router.include_router(admin_embedding_model_router, tags=["admin/embedding_model"])
api_router.include_router(admin_reranker_model_router, tags=["admin/reranker_model"])
api_router.include_router(admin_langfuse.router, tags=["admin/langfuse"])
api_router.include_router(admin_retrieve_old.router, tags=["admin/retrieve_old"])
api_router.include_router(admin_legacy_retrieve.router, tags=["admin/retrieve_old"])
api_router.include_router(admin_stats.router, tags=["admin/stats"])
api_router.include_router(admin_semantic_cache.router, tags=["admin/semantic_cache"])
api_router.include_router(admin_evaluation_task.router, tags=["admin/evaluation/task"])
Expand All @@ -98,19 +98,3 @@
api_router.include_router(
fastapi_users.get_auth_router(auth_backend), prefix="/auth", tags=["auth"]
)

# api_router.include_router(
# fastapi_users.get_register_router(UserRead, UserCreate),
# prefix="/auth",
# tags=["auth"],
# )
# api_router.include_router(
# fastapi_users.get_reset_password_router(),
# prefix="/auth",
# tags=["auth"],
# )
# api_router.include_router(
# fastapi_users.get_verify_router(UserRead),
# prefix="/auth",
# tags=["auth"],
# )
Loading

0 comments on commit a5fb4dc

Please sign in to comment.