Improving the accuracy of LLM responses with GraphRAG

Improving the accuracy of LLM responses with GraphRAG

·

13 min read

By Bilal Shareef

We explore GraphRAG by implementing it in a JavaScript application

Generative AI (GenAI) is emerging as a powerful technology across tech industries. It is used to generate text responses for user prompts or questions and is based on large language models (LLMs) trained on massive data resources.

Training LLMs is an expensive process, making regular training with up-to-date data impractical. Consequently, LLM knowledge is often outdated. With that being said, a common issue with LLMs is their occasional tendency to provide incorrect responses.

Retrieval-Augmented Generation (RAG) is an approach that can solve this problem by enabling LLMs to access additional data resources without requiring retraining, resulting in better answers. An even more effective approach is GraphRAG, which can provide more accurate answers and is essentially an extension of RAG.

In this blog post, we will explore GraphRAG by implementing it in a JavaScript application. We drew inspiration from a LangChain blog post, which explores GraphRAG with a Python application. LangChain is a framework for building applications based on LLMs. It provides tools and abstractions to improve the customisation, accuracy, and relevancy of the information generated by the models. We will be using LangChain JS framework in the application demonstrated in this blog post.

What is RAG?

Before delving into GraphRAG, let's first understand what RAG is. RAG is an AI technique that can enhance the quality of GenAI by enabling LLMs to access additional data without the need for retraining.

It offers a method for optimising the output of an LLM with targeted information without modifying the underlying model itself. This means that the RAG approach in GenAI can provide more contextually appropriate answers based on current data.

Typically, RAG implementations require vector databases that enable rapid storage and retrieval of data fed into the LLM.

What is GraphRAG?

GraphRAG is an extension of RAG where a graph database is used alongside a vector database to store and retrieve data. Vector databases excel in handling unstructured data, but they cannot efficiently store interconnected and structured information. That's where a graph database comes into play.

The concept behind GraphRAG is to harness the structured nature of data to enrich the depth and contextuality of retrieved information. Graph databases organise data into nodes and relationships, aiding LLMs in better understanding the data and providing more accurate answers.

As GraphRAG extends RAG, we can integrate structured data obtained from the graph database with unstructured data obtained from the vector database and then input it into the LLM. The model can thus utilise both structured and unstructured data to generate improved responses. In our demonstration, we will be using a Neo4j database to store both structured and unstructured data.

Now that we understand what RAG and GraphRAG are, let's build an application that demonstrates this.

The application consists of two parts: the loader and the retriever.

  1. The loader fetches the source content from a knowledge base and stores the data in the Neo4j database.

  2. The retriever fetches data from the Neo4j database and feeds it into the LLM to obtain responses. The retriever runs on demand whenever a question or prompt is sent.

The code for the application explained in this blog post is available on GitHub.

Environment setup

In our application, alongside LangChain and Neo4j, we will be using OpenAI. So let’s set these up first.

There are two options for setting up a Neo4j database. The easiest option is to use Neo4j Aura, which offers a free instance of the Neo4j database. The other option is to set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.

Once the Neo4j database instance is ready, we will have the URI, username, and password of the database instance. We also need an OpenAI API key to use their LLM models. Set this data in environment variables so that the application can read it from there.

const url = process.env.NEO4J_URI
const username = process.env.NEO4J_USERNAME
const password = process.env.NEO4J_PASSWORD
const openAIApiKey = process.env.OPENAI_API_KEY

Loader

Let's dive into the coding part of the loader.

We will first initialise the LLM by creating an instance of ChatOpenAI. We will pass a couple of parameters to set the temperature and the LLM model along with the OpenAI API key.

For our scenario, where we want the LLM to respond based on the stored knowledge base, there isn’t a need for the LLM to be creative. So let’s set the temperature to 0 to ensure we get deterministic outputs.

There are a bunch of LLMs offered by OpenAI. The gpt-4-turbo model is the most recent version at the time of writing this article. It is trained with data up to December 2023 and is more capable and reliable. However, we will be using gpt-3.5-turbo as it is relatively cheaper and sufficient for our demonstration.

Also, let’s initialise the Neo4jGraph database.

const llm = new ChatOpenAI({
  temperature: 0,
  modelName: 'gpt-3.5-turbo',
  openAIApiKey
})

const graph = await Neo4jGraph.initialize({ url, username, password })

Then, load the source text content from a knowledge base. In our application, we will use the IMSDBLoader provided by LangChain to retrieve text content. LangChain offers a variety of text loaders that allow us to easily fetch text content from various web resources.

The IMSDBLoader is one such loader provided by LangChain, enabling us to fetch content from IMSDb, a website where movie scripts can be found.

We will retrieve data from just one page (the script of the movie The Avengers) on IMSDB. The load method of the loader will return an array of documents (one document per page). Since we are loading only one page in our example below, the length of the returned array will be 1.

const loader = new IMSDBLoader(
  '<https://imsdb.com/scripts/Avengers,-The-(2012).html>'
)
const rawDocs = await loader.load()

Now that we have the raw text content loaded, let’s split it into multiple smaller chunks. LangChain provides several text splitters that enable us to divide long documents into smaller segments. We will utilise the TokenTextSplitter utility class for chunking, specifying only the chunk size and the chunk overlap size in bytes.

Next, we will iterate over the returned list of chunks and create a new list with Document objects for each chunk.

const textSplitter = new TokenTextSplitter({
  chunkSize: 512,
  chunkOverlap: 24
})

let documents = []
for (let i = 0; i < rawDocs.length; i++) {
  const chunks = await textSplitter.splitText(rawDocs[i].pageContent)
  const processedDocs = chunks.map(
    (chunk, index) =>
      new Document({
        pageContent: chunk,
        metadata: {
          a: index + 1,
          ...rawDocs[i].metadata
        }
      })
  )
  documents.push(...processedDocs)
}

The next step is to convert the list of documents into a list of graph documents. LangChain offers an LLMGraphTransformer class, which performs this conversion for us. Graph Documents comprise a blend of structured and unstructured data. Structured data is represented as nodes and relationships, while unstructured data is the actual source text content.

LLMGraphTransformer utilises the LLM to determine the nodes and relationships for a given text document.

const llmTransformer = new LLMGraphTransformer({ llm })
const graphDocuments = await llmTransformer.convertToGraphDocuments(documents)

Save the Graph Documents to the Neo4j database using the addGraphDocuments method.

await graph.addGraphDocuments(graphDocuments, {
  baseEntityLabel: true,
  includeSource: true
})

The loader part is pretty simple, as you can see above. We loaded the source content from the IMSDB loader and then split the content into smaller chunks. Next, we converted the text chunks into graph documents with the help of LLMGraphTransformer and stored them in the Neo4j database.

Now that the loader code is ready, give it a try by running it. You should see the graph documents for the passed source content saved successfully in the Neo4j database.

Retriever

Let's dive into the coding part of the retriever. We will start by initialising the LLM and the Neo4jGraph database, just as we did in the loader.

Next, we need to initialise the Neo4j vector store, which we will use to fetch the unstructured data stored in the loader part. We will accomplish this by invoking the fromExistingGraph static method of Neo4jVectorStore, which returns a handle to the database. We will use this handle later in the application to fetch the unstructured text data.

const neo4jVectorIndex = await Neo4jVectorStore.fromExistingGraph(
  new OpenAIEmbeddings({ openAIApiKey }),
  {
    url,
    username,
    password,
    searchType: 'hybrid',
    nodeLabel: 'Document',
    textNodeProperties: ['text'],
    embeddingNodeProperty: 'embedding'
  }
)

The next step is to define a chain that enables us to extract entities from a given string or question.

In LangChain, chains represent sequences of calls — whether to an LLM, a tool, or a data preprocessing step. In our application, we will use chains to deliver prompt templates and inputs to the LLM to obtain generated data.

An entity is essentially an object or thing that exists in the real world. For example, in the question "Who is Tony Stark?", "Tony Stark" is an entity. There could be more than one entity in a given string as well. For example, in the question "How are Thor and Loki related?", "Thor" and "Loki" are entities.

In the following code, we are defining a prompt template that instructs the LLM to extract entities from the passed question. Additionally, we are defining the output schema, which specifies the structure of the output we require — a list of entities.

The entity chain is created by passing the prompt, output schema, and the LLM to the createStructuredOutputRunnable function. When the entity chain is invoked, the prompt and the output schema are fed to the LLM, which extracts and returns the entities.

const entitiesSchema = z
  .object({
    names: z
      .array(z.string())
      .describe(
        'All the person, organization, or business entities that appear in the text'
      )
  })
  .describe('Identifying information about entities.')

const prompt = ChatPromptTemplate.fromMessages([
  [
    'system',
    'You are extracting organization and person entities from the text.'
  ],
  [
    'human',
    'Use the given format to extract information from the following input: {question}'
  ]
])

const entityChain = createStructuredOutputRunnable({
  outputSchema: entitiesSchema,
  prompt,
  llm
})

Next, let's execute a Neo4j query to create a full-text index — if it doesn’t already exist. A full-text index is utilised to index nodes and relationships, enabling querying of nodes later in the application.

This query will create the index when you run the retriever for the first time. It won’t perform any actions in subsequent runs.

await graph.query(
  'CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]'
)

Then, let’s define a structured retriever function that invokes the entity chain to extract entities from the question. For each entity, it executes a query to fetch matching nodes and relationships from the Neo4j database.

async function structuredRetriever(question) {
  let result = ''
  const entities = await entityChain.invoke({ question })

  for (const entity of entities.names) {
    const response = await graph.query(
      `CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
      YIELD node,score
      CALL {
        MATCH (node)-[r:!MENTIONS]->(neighbor)
        RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS
        output
        UNION
        MATCH (node)<-[r:!MENTIONS]-(neighbor)
        RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS
        output
      }
      RETURN output LIMIT 50`,
      { query: generateFullTextQuery(entity) }
    )

    result += response.map(el => el.output).join('\\n') + '\\n'
  }

  return result
}

Now we can define the final retriever function. It takes the question as input and then retrieves the structured and unstructured data for the passed question. The structured data is fetched from the structuredRetriever we have defined above, and the unstructured data is retrieved from the vector index.

async function retriever(question) {
  console.log('Standalone Question - ' + question)
  const structuredData = await structuredRetriever(question)

  const similaritySearchResults =
    await neo4jVectorIndex.similaritySearch(question)
  const unstructuredData = similaritySearchResults.map(el => el.pageContent)

  const finalData = `Structured data:
  ${structuredData}
  Unstructured data:
  ${unstructuredData.map(content => `#Document ${content}`).join('\\n')}
      `
  return finalData
}

The code we have written so far covers data retrieval from the database for a given question. Now let’s implement a few chains and link them together to make the entire process work.

In our application, we aim to maintain the context of the conversation to ensure continuity between questions. For instance, a user might ask, "Who is Tony Stark?" and then follow up with, "Does he own the Stark Tower?" In this scenario, the LLM needs to comprehend who "he" refers to in the context of the conversation to generate the correct response.

We will define a standalone chain that can understand the context based on the conversation history and then return a new question with the ambiguous parts resolved. In the example provided, when "Does he own the Stark Tower?" is asked, the chain will determine the context with the assistance of the conversation history and return a standalone question: "Does Tony Stark own the Stark Tower?".

const standaloneTemplate = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{conversationHistory}
Follow Up Input: {question}
Standalone question:`

const standalonePrompt = PromptTemplate.fromTemplate(standaloneTemplate)

const standaloneQuestionChain = standalonePrompt
  .pipe(llm)
  .pipe(new StringOutputParser())

Then, we will define a retriever chain that retrieves context (structured and unstructured data) from the database for the generated standalone question by invoking the retriever function we implemented above.

const retrieverChain = RunnableSequence.from([
  prevResult => prevResult.standaloneQuestion,
  retriever
])

Now, let’s define a chain that can generate the final response based on the context and the original question.

const answerTemplate = `You are a helpful and enthusiastic support bot who can answer any question based on the context provided and the conversation history. Try to find the answer in the context. If the answer is not given in the context, find the answer in the conversation history if possible. If you really don't know the answer, say "I am sorry, I don't know the answer to that.". And don't try to makeup the answer. Always speak as you are chatting to a friend

context:{context}
question: {question}
answer:`

const answerPrompt = PromptTemplate.fromTemplate(answerTemplate)

const answerChain = answerPrompt.pipe(llm).pipe(new StringOutputParser())

Finally, let’s connect the above three chains into a single sequential chain. The idea is to first invoke the standalone chain to obtain the standalone question, then invoke the retriever chain to acquire the context, and finally feed the question and context to the answer chain to generate the final response.

const chain = RunnableSequence.from([
  {
    standaloneQuestion: standaloneQuestionChain,
    orignalInput: new RunnablePassthrough()
  },
  {
    context: retrieverChain,
    question: ({ orignalInput }) => orignalInput.question,
    conversationHistory: ({ orignalInput }) => orignalInput.conversationHistory
  },
  answerChain
])

Now let's ask our questions one after the other by invoking the final chain. Additionally, we can capture each of the responses into an array to maintain the conversation history.

const conversationHistory = []

function logResult(result) {
  console.log(`Search Result - ${result}\\n`)
}

async function ask(question) {
  console.log(`Search Query - ${question}`)
  const answer = await chain.invoke({
    question,
    conversationHistory: formatChatHistory(conversationHistory)
  })
  conversationHistory.push(question)
  conversationHistory.push(answer)
  logResult(answer)
}

await ask('Loki is a native of which planet?')
await ask('What is the name of his brother?')
await ask('Who is the villain among the two?')
await ask('Who is Tony Stark?')
await ask('Does he own the Stark Tower?')

Below are the answers returned by the LLM for each of the questions asked.

Search Query - Loki is a native of which planet?
Standalone Question - What planet is Loki a native of?
Search Result - Loki is a native of Asgard.

Search Query - What is the name of his brother?
Standalone Question - What is the name of Loki's brother?
Search Result - His brother's name is Thor.

Search Query - Who is the villain among the two?
Standalone Question - Which of the two, Loki and Thor, is considered the villain?
Search Result - The villain is Loki.

Search Query - Who is Tony Stark?
Standalone Question - Who is Tony Stark?
Search Result - Tony Stark is a character who is also known as Iron Man. He is a genius engineer, industrialist, and philanthropist who uses his suit of armor to fight crime and protect the world.

Search Query - Does he own the Stark Tower?
Standalone Question - Does Tony Stark own the Stark Tower?
Search Result - Yes, Tony Stark owns the Stark Tower.

As you can see, our application was able to generate standalone questions based on the context of the conversation and also returned the correct responses for each of them.

Conclusion

To summarise what we have covered in this blog post, we implemented a loader module that fetches raw text data from a web source, converts it into graph documents, and stores it in the Graph database. Next, we implemented a retriever module that takes a user's question, extracts entities from it, retrieves both structured and unstructured data related to the entities and the question, and feeds this data to the LLM model to generate an answer.

GraphRAG greatly benefits businesses by improving customer service through precise and relevant interactions. For industries relying on extensive knowledge bases, the ability to generate accurate information can streamline processes, reduce errors, and enhance decision-making. Additionally, the adaptive nature of GraphRAG allows businesses to tailor responses based on specific contextual needs, providing a more personalised and efficient user experience. Ultimately, the integration of GraphRAG can give businesses a competitive advantage in an ever-evolving digital landscape.

References

Here are some interesting blog posts that will help expand your understanding of GraphRAG: