Skip to content

Commit

Permalink
Merge branch 'frontier'
Browse files Browse the repository at this point in the history
  • Loading branch information
binary-husky committed Sep 8, 2024
2 parents dd66ca2 + ab32c31 commit 8222f63
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ test.*
temp.*
objdump*
*.min.*.js
TODO
1 change: 0 additions & 1 deletion TODO

This file was deleted.

26 changes: 21 additions & 5 deletions crazy_functions/Rag_Interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker

VECTOR_STORE_TYPE = "Milvus"

if VECTOR_STORE_TYPE == "Simple":
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
if VECTOR_STORE_TYPE == "Milvus":
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker


RAG_WORKER_REGISTER = {}

Expand All @@ -14,16 +21,25 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u

# 1. we retrieve rag worker from global context
user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
user_name,
llm_kwargs,
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'),
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True)
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
if txt == "清空向量数据库":
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
rag_worker.purge()
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
return

chatbot.append([txt, '正在召回知识 ...'])
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

# 2. clip history to reduce token consumption
Expand Down Expand Up @@ -68,8 +84,8 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
)

# 5. remember what has been asked / answered
yield from update_ui_lastest_msg(model_say + '</br></br>' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面
yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
rag_worker.remember_qa(i_say_to_remember, model_say)
history.extend([i_say, model_say])

yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面
10 changes: 6 additions & 4 deletions crazy_functions/rag_fns/llama_index_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import llama_index
import os
import atexit
from typing import List
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
Expand Down Expand Up @@ -38,6 +41,7 @@ def does_checkpoint_exist(self, checkpoint_dir=None):
return True

def save_to_checkpoint(self, checkpoint_dir=None):
print(f'saving vector store to: {checkpoint_dir}')
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)

Expand Down Expand Up @@ -65,7 +69,8 @@ def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs()
self.vs_index = self.create_new_vs(checkpoint_dir)
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))

def assign_embedding_model(self):
pass
Expand Down Expand Up @@ -117,6 +122,3 @@ def generate_node_array_preview(self, nodes):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: print(buf)
return buf



107 changes: 107 additions & 0 deletions crazy_functions/rag_fns/milvus_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import llama_index
import os
import atexit
from typing import List
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker

DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
{context_str}
---------------------
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
---------------------
{query_str}
"""

QUESTION_ANSWER_RECORD = """\
{{
"type": "This is a previous conversation with the user",
"question": "{question}",
"answer": "{answer}",
}}
"""


class MilvusSaveLoad():

def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
return True

def save_to_checkpoint(self, checkpoint_dir=None):
print(f'saving vector store to: {checkpoint_dir}')
# if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
# self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)

def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
print('loading checkpoint from disk')
from llama_index.core import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
try:
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
return self.vs_index
except:
return self.create_new_vs(checkpoint_dir)
else:
return self.create_new_vs(checkpoint_dir)

def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
return index

def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)

class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):

def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.debug_mode = True
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs(checkpoint_dir)
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))

def inspect_vector_store(self):
# This function is for debugging
try:
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(
[f"{node.id_} | {node.text}" for node in dummy_retrieve_res]
)
print('\n++ --------inspect_vector_store begin--------')
print(vector_store_preview)
print('oo --------inspect_vector_store end--------')
return vector_store_preview
8 changes: 7 additions & 1 deletion request_llms/embed_models/openai_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ def compute_embedding(self, text="这是要计算嵌入的文本", llm_kwargs:di
embedding = res.data[0].embedding
return embedding

def embedding_dimension(self, llm_kwargs):
def embedding_dimension(self, llm_kwargs=None):
# load kwargs
if llm_kwargs is None:
llm_kwargs = self.llm_kwargs
if llm_kwargs is None:
raise RuntimeError("llm_kwargs is not provided!")

from .bridge_all_embed import embed_model_info
return embed_model_info[llm_kwargs['embed_model']]['embed_dimension']

Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ tiktoken>=0.3.3
requests[socks]
pydantic==2.5.2
llama-index==0.10.47
protobuf==3.18
llama-index-vector-stores-milvus==0.1.16
pymilvus==2.4.2
protobuf==3.20
transformers>=4.27.1,<4.42
scipdf_parser>=0.52
anthropic>=0.18.1
Expand Down
4 changes: 2 additions & 2 deletions toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): #
yield cookies, chatbot_gr, history, msg


def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1): # 刷新界面
def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面
"""
刷新用户界面
"""
if len(chatbot) == 0:
chatbot.append(["update_ui_last_msg", lastmsg])
chatbot[-1] = list(chatbot[-1])
chatbot[-1][-1] = lastmsg
yield from update_ui(chatbot=chatbot, history=history)
yield from update_ui(chatbot=chatbot, history=history, msg=msg)
time.sleep(delay)


Expand Down

0 comments on commit 8222f63

Please sign in to comment.