使用Llama3和Ollama来增强RAG

Published: 07 May 2024 Category: llm

在这篇文章中,我们将探讨如何利用Meta新发布的最先进的开源大型语言模型Llama-3,实现在完全本地化基础设施上的进阶版RAG(检索增强生成)。这篇文章是使用Llama-3进行进阶RAG实施的实战指南。

简介: 在这篇文章中,我们将创建一个进阶版的RAG,它会根据输入到管道的研究论文来回答用户查询。构建这个pipeline所使用的技术栈如下:

  • Ollama嵌入模型(mxbai-embed-large)
  • Ollama量化Llama-3 8b模型
  • 本地托管的Qdrant向量数据库

从技术选型中可以明显看出两个优势:成本为0,并且信息高度安全和私密。

什么是HyDE?

HyDE(Hypothetical Document Embeddings, 假设文档嵌入)源自Gao等人在2022年发表的名为《Precise Zero-Shot Dense Retrieval without Relevance Labels》的论文中的创新工作。这项研究的主要目标是改进依赖语义嵌入相似性的零样本稠密检索(也就是基于嵌入向量的检索),HyDE由两个步骤来完成。

在第一步(步骤1)中,通过指令提示词引导大语言模型(以GPT-3为例)生成基于原始查询的假设文档。这个过程是针对查询的问题精心定制的,确保了“假设文档”的相关性。

进入第二步,通过一个“无监督对比编码器”的Contriever将生成的假设文档转化为嵌入向量。这个编码器将假设文档转换为向量表示,然后用于后续的相似性搜索和检索任务。

HyDE的基本功能是通过两个关键组件将文档转换为向量嵌入。第一步使用语言模型执行生成任务,旨在捕捉假设文档中的相关性,即使存在事实不准确。随后,由对比编码器管理的文档-文档相似性任务细化嵌入过程,过滤掉多余细节,提高效率。

值得注意的是,HyDE的表现超过了现有的无监督稠密检索器,如Contriever。此外,它在多样化的任务和语言上展示出与经过微调的检索器相当的性能表现。这种方法将稠密检索简化为两个连贯的任务,标志着基于语义嵌入的检索方法的重大进步。

实现:

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
import qdrant_client
import logging

初始化

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# load the local data directory and chunk the data for further processing
docs = SimpleDirectoryReader(input_dir="data", required_exts=[".pdf"]).load_data(show_progress=True)
text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)

text_chunks = []
doc_ids = []
nodes = []

创建向量数据库以便存储嵌入向量。

# Create a local Qdrant vector store
logger.info("initializing the vector store related objects")
client = qdrant_client.QdrantClient(host="localhost", port=6333)
vector_store = QdrantVectorStore(client=client, collection_name="research_papers")

加载嵌入及大语言模型。

# local vector embeddings model
logger.info("initializing the OllamaEmbedding")
embed_model = OllamaEmbedding(model_name='mxbai-embed-large', base_url='http://localhost:11434')
logger.info("initializing the global settings")
Settings.embed_model = embed_model
Settings.llm = Ollama(model="llama3", base_url='http://localhost:11434')
Settings.transformations = [text_parser]

创建节点、向量存储、HyDE transformer并进行查询。

logger.info("enumerating docs")
for doc_idx, doc in enumerate(docs):
    curr_text_chunks = text_parser.split_text(doc.text)
    text_chunks.extend(curr_text_chunks)
    doc_ids.extend([doc_idx] * len(curr_text_chunks))

logger.info("enumerating text_chunks")
for idx, text_chunk in enumerate(text_chunks):
    node = TextNode(text=text_chunk)
    src_doc = docs[doc_ids[idx]]
    node.metadata = src_doc.metadata
    nodes.append(node)

logger.info("enumerating nodes")
for node in nodes:
    node_embedding = embed_model.get_text_embedding(
        node.get_content(metadata_mode=MetadataMode.ALL)
    )
    node.embedding = node_embedding

logger.info("initializing the storage context")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("indexing the nodes in VectorStoreIndex")
index = VectorStoreIndex(
    nodes=nodes,
    storage_context=storage_context,
    transformations=Settings.transformations,
)

logger.info("initializing the VectorIndexRetriever with top_k as 5")
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
response_synthesizer = get_response_synthesizer()
logger.info("creating the RetrieverQueryEngine instance")
vector_query_engine = RetrieverQueryEngine(
    retriever=vector_retriever,
    response_synthesizer=response_synthesizer,
)
logger.info("creating the HyDEQueryTransform instance")
hyde = HyDEQueryTransform(include_original=True)
hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)

logger.info("retrieving the response to the query")
response = hyde_query_engine.query(
    str_or_query_bundle="what are all the data sets used in the experiment and told in the paper")
print(response)

client.close()

上述代码首先把日志级别配置成INFO,以便查看日志,然后从本地目录加载PDF数据,并将其分割成文本块。它设置了一个Qdrant向量存储来存储研究论文的嵌入向量,并初始化了一个Ollama文本嵌入模型,用来将文本生成嵌入。设置好全局配置,再处理文本块并将其与文档ID关联起来。再根据这些块创建文本节点,保留元数据,并使用Ollama模型为这些节点生成嵌入。然后设置存储上下文,以便在Qdrant向量存储中索引文本嵌入,以便后续进行检索。配置完向量检索器用来检索相似的嵌入向量,再初始化一个查询引擎用于处理查询。同时还设置了一个HyDE查询transformer,用于增强查询处理。最后,执行了一个查询,以检索有关论文实验中提到的数据集相关的信息,并输出响应。

输出:

结论: 通过利用Meta的Llama-3等先进的模型,结合HyDE的改进方法以及Ollama的能力,我们可以构建出优秀的RAG管道。通过细致地微调关键的超参数,如topk、chunksize和chunk_overlap,可以达到更高的准确性和效率。通过结合先进的工具的和精心的优化,能够释放系统的全部潜能,确保解决方案的创新性和领先性,同时还最大程度保障了隐私和安全。