RAG

For certain LLM tasks like answering questions, providing relevant context is essential. One common architecture is a two-stage RAG pipeline:

  1. Offline stage: Preprocess and index documents ("building the index").

  2. Online stage: Given a question, generate answers by retrieving the most relevant context.


Stage 1: Offline Indexing

We create three Nodes:

  1. ChunkDocschunks raw text.

  2. EmbedDocsembeds each chunk.

  3. StoreIndex – stores embeddings into a vector database.

import asyncio
import os # Assuming file operations
from brainyflow import Node, Flow, Memory, ParallelFlow

# Assume get_embedding, create_index, search_index are defined elsewhere
# async def get_embedding(text: str) -> list[float]: ...
# def create_index(embeddings: list[list[float]]) -> Any: ... # Returns index object
# def search_index(index: Any, query_embedding: list[float], top_k: int) -> tuple[list[list[int]], list[list[float]]]: ...

# --- Stage 1: Offline Indexing Nodes ---

# 1a. Node to trigger chunking for each file
class TriggerChunkingNode(Node):
    async def prep(self, memory):
        return memory.files or []

    async def exec(self, files: list):
         # Optional: could return file count or validate paths
         return len(files)

    async def post(self, memory, files: list, file_count: int):
        print(f"Triggering chunking for {file_count} files.")
        memory.all_chunks = [] # Initialize chunk store
        memory.chunk_metadata = [] # Store metadata like source file
        for index, filepath in enumerate(files):
            if os.path.exists(filepath): # Basic check
                 self.trigger('chunk_file', { "filepath": filepath, "file_index": index })
            else:
                 print(f"Warning: File not found {filepath}")
        # Trigger next major step after attempting all files
        self.trigger('embed_chunks')

# 1b. Node to chunk a single file
class ChunkFileNode(Node):
    async def prep(self, memory):
        # Read filepath from local memory
        return memory.filepath, memory.file_index

    async def exec(self, prep_res):
        filepath, file_index = prep_res
        print(f"Chunking {filepath}")
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                text = f.read()
            # Simple fixed-size chunking
            chunks = []
            size = 100
            for i in range(0, len(text), size):
                chunks.append(text[i : i + size])
            return chunks, filepath # Pass filepath for metadata
        except Exception as e:
            print(f"Error chunking {filepath}: {e}")
            return [], filepath # Return empty list on error

    async def post(self, memory, prep_res, exec_res):
        chunks, filepath = exec_res
        file_index = prep_res[1]
        # Append chunks and their source metadata to global lists
        # Note: If using ParallelFlow, direct append might lead to race conditions.
        # Consider storing per-file results then combining in the next step.
        start_index = len(memory.all_chunks)
        memory.all_chunks.extend(chunks)
        for i, chunk in enumerate(chunks):
             memory.chunk_metadata.append({"source": filepath, "chunk_index_in_file": i, "global_chunk_index": start_index + i})
        # This node doesn't trigger further processing for this specific file branch

# 1c. Node to trigger embedding for each chunk
class TriggerEmbeddingNode(Node):
     async def prep(self, memory):
         # This node runs after all 'chunk_file' triggers are processed by the Flow
         return memory.all_chunks or []

     async def exec(self, chunks: list):
         return len(chunks)

     async def post(self, memory, chunks: list, chunk_count: int):
         print(f"Triggering embedding for {chunk_count} chunks.")
         memory.all_embeds = [None] * chunk_count # Pre-allocate list for parallel writes
         for index, chunk in enumerate(chunks):
             # Pass chunk and its global index via forkingData
             self.trigger('embed_chunk', { "chunk": chunk, "global_index": index })
         # Trigger storing index after all embedding triggers are fired
         self.trigger('store_index')

# 1d. Node to embed a single chunk
class EmbedChunkNode(Node):
     async def prep(self, memory):
         # Read chunk and global index from local memory
         return memory.chunk, memory.global_index

     async def exec(self, prep_res):
         chunk, index = prep_res
         # print(f"Embedding chunk {index}") # Can be noisy
         return await get_embedding(chunk), index # Pass index through

     async def post(self, memory, prep_res, exec_res):
         embedding, index = exec_res
         # Store embedding at the correct index in the pre-allocated list
         memory.all_embeds[index] = embedding
         # This node doesn't trigger further processing for this chunk branch
# 1e. Node to store the final index
class StoreIndexNode(Node):
    async def prep(self, memory):
        # Read all embeddings from global memory
        # Filter out potential None values if embedding failed for some chunks
        embeddings = [emb for emb in (memory.all_embeds or []) if emb is not None]
        if len(embeddings) != len(memory.all_embeds or []):
             print(f"Warning: Some chunks failed to embed. Indexing {len(embeddings)} embeddings.")
        return embeddings

    async def exec(self, all_embeds: list):
        if not all_embeds:
             print("No embeddings to store.")
             return None
        print(f"Storing index for {len(all_embeds)} embeddings.")
        # Create a vector index (implementation depends on library)
        index = create_index(all_embeds)
        return index

    async def post(self, memory, prep_res, index):
        # Store the created index in global memory
        memory.index = index
        if index:
             print('Index created and stored.')
        else:
             print('Index creation skipped.')
        # End of offline flow
# --- Offline Flow Definition ---
trigger_chunking = TriggerChunkingNode()
chunk_file = ChunkFileNode()
trigger_embedding = TriggerEmbeddingNode()
embed_chunk = EmbedChunkNode()
store_index = StoreIndexNode()

# Define transitions using syntax sugar
trigger_chunking - 'chunk_file' >> chunk_file
trigger_chunking - 'embed_chunks' >> trigger_embedding
trigger_embedding - 'embed_chunk' >> embed_chunk
trigger_embedding - 'store_index' >> store_index

# Use ParallelFlow for potentially faster chunking and embedding
OfflineFlow = ParallelFlow(start=trigger_chunking)
# Or sequential: OfflineFlow = Flow(start=trigger_chunking)

Usage example:

# --- Offline Flow Execution ---
async def run_offline():
    # Create dummy files for example
    if not os.path.exists('doc1.txt'): fs.writeFileSync('doc1.txt', 'Alice was beginning to get very tired.')
    if not os.path.exists('doc2.txt'): fs.writeFileSync('doc2.txt', 'The quick brown fox jumps over the lazy dog.')

    initial_memory = {
        "files": ["doc1.txt", "doc2.txt"], # Example file paths
    }
    print('Starting offline indexing flow...')

    await OfflineFlow.run(initial_memory)

    print('Offline indexing complete.')
    # Clean up dummy files
    # os.remove('doc1.txt')
    # os.remove('doc2.txt')
    return initial_memory # Return memory containing index, chunks, embeds

# asyncio.run(run_offline()) # Example call

Stage 2: Online Query & Answer

We have 3 nodes:

  1. EmbedQuery – embeds the user’s question.

  2. RetrieveDocs – retrieves top chunk from the index.

  3. GenerateAnswer – calls the LLM with the question + chunk to produce the final answer.

# --- Stage 2: Online Query Nodes ---

# 2a. Embed Query Node
class EmbedQueryNode(Node):
    async def prep(self, memory):
        return memory.question # Read from memory

    async def exec(self, question):
        print(f"Embedding query: \"{question}\"")
        return await get_embedding(question)

    async def post(self, memory, prep_res, q_emb):
        memory.q_emb = q_emb # Write to memory
        self.trigger('retrieve_docs')
# 2b. Retrieve Docs Node
class RetrieveDocsNode(Node):
    async def prep(self, memory):
        # Need query embedding, index, and original chunks
        # Also retrieve metadata to know the source
        return memory.q_emb, memory.index, memory.all_chunks, memory.chunk_metadata

    async def exec(self, inputs):
        q_emb, index, chunks, metadata = inputs
        if not q_emb or not index or not chunks:
            raise ValueError("Missing data for retrieval in memory")
        print("Retrieving relevant chunk...")
        # Assuming search_index returns [[ids]], [[distances]]
        I, D = search_index(index, q_emb, top_k=1)
        if not I or not I[0]:
             return "Could not find relevant chunk.", None
        best_global_id = I[0][0]
        if best_global_id >= len(chunks):
             return "Index out of bounds.", None

        relevant_chunk = chunks[best_global_id]
        relevant_metadata = metadata[best_global_id] if metadata and best_global_id < len(metadata) else {}
        return relevant_chunk, relevant_metadata

    async def post(self, memory, prep_res, exec_res):
        relevant_chunk, relevant_metadata = exec_res
        memory.retrieved_chunk = relevant_chunk # Write to memory
        memory.retrieved_metadata = relevant_metadata # Write metadata too
        print(f"Retrieved chunk: {relevant_chunk[:60]}... (Source: {relevant_metadata.get('source', 'N/A')})")
        self.trigger('generate_answer')
# 2c. Generate Answer Node
class GenerateAnswerNode(Node):
    async def prep(self, memory):
        return memory.question, memory.retrieved_chunk # Read from memory

    async def exec(self, inputs):
        question, chunk = inputs
        if not chunk or chunk == "Could not find relevant chunk.":
             return "Sorry, I couldn't find relevant information to answer the question."
        prompt = f"Using the following context, answer the question.\nContext: {chunk}\nQuestion: {question}\nAnswer:"
        print("Generating final answer...")
        return await call_llm(prompt)

    async def post(self, memory, prep_res, answer):
        memory.answer = answer # Write to memory
        print("Answer:", answer)
        # End of online flow
# --- Online Flow Definition ---
embed_qnode = EmbedQueryNode()
retrieve_node = RetrieveDocsNode()
generate_node = GenerateAnswerNode()

# Define transitions using syntax sugar
embed_qnode - 'retrieve_docs' >> retrieve_node
retrieve_node - 'generate_answer' >> generate_node

OnlineFlow = Flow(start=embed_qnode)

Usage example:

# --- Online Flow Execution ---
async def run_online(memory_from_offline: dict):
    # Add the user's question to the memory from the offline stage
    memory_from_offline["question"] = "Why do people like cats?"

    print(f"\nStarting online RAG flow for question: \"{memory_from_offline['question']}\"")
    await OnlineFlow.run(memory_from_offline) # Pass memory object
    # final answer in memory_from_offline["answer"]
    print("Final Answer:", memory_from_offline.get("answer", "N/A")) # Read from memory
    return memory_from_offline

# Example usage combining both stages
async def main():
    # Mock external functions if not defined
    # global get_embedding, create_index, search_index, call_llm
    # get_embedding = ...
    # create_index = ...
    # search_index = ...
    # call_llm = ...

    memory_after_offline = await run_offline()
    if memory_after_offline.get("index"): # Only run online if index exists
        await run_online(memory_after_offline)
    else:
        print("Skipping online flow due to missing index.")

if __name__ == "__main__":
    # Note: Ensure dummy files exist or are created before running
    # For simplicity, file creation moved to run_offline
    asyncio.run(main())

Last updated