Improving my NBA stats SQL Agent Part 2

littlereddotdata
6 min readAug 16, 2024

--

https://icons8.com/illustrations/author/627444

Introduction

See Part 1 here

Code for Part 1 is here and code for Part 2 is here

Previously, we use a simple SQLDatabase Agent in Langchain so we can ask use natural language against our NBA stats database instead of having to write SQL. We also set up evaluation using MLflow Traces to give a baseline measurement of this agent’s performance.

Our work is not done though! Building Text to SQL agent means paying attention to many components, especially if we want a quality agent. So in part 2, we will go through some changes we made to our original implementation.

It’s also worth noting that making changes may not necessarily mean performance will improve. This is why we set up an evaluation process FIRST so that we can measure whether or not a change made actually leads to improvement.

1. Use SparkSQL agent and toolkit instead of SQLDatabase agent and toolkit

We originally used the SQLDatabaseToolkit with a SQLDatabase agent as a starting point. But looking more closely, we see that the SQLDatabase agent is built to cover a range of SQL dialects. But we are focused only on executing SparkSQL against our database.

So we can try using the SparkSQL agent and toolkit instead to see whether focusing on SparkSQL will increase the accuracy and relevancy of the SQL our agent generates.

spark_sql = SparkSQL(catalog="main", schema="nba_sql_agent")
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
tools = SparkSQLToolkit(db=spark_sql, llm=llm)
agent = create_spark_sql_agent(llm=llm, toolkit=tools, verbose=True, agent_executor_kwargs={"handle_parsing_errors":True})

2. Table metadata to include within the prompt

When we look at the implementation of the SparkSQL agent and of SQLDatabase agents we see that the agents have tools to list available database tables and their metadata.

For example with SQLDatabaseToolkit().get_tools() we have:

[QuerySQLDataBaseTool(),
InfoSQLDatabaseTool()
ListSQLDatabaseTool(),
QuerySQLCheckerTool()]

From looking at our MLflow Traces, we see our agent can take extra time to iterate on a SQL query error because everytime it does error handling, it needs to: 1. decide to use the InfoSQLDatabaseTool() and the ListSQLDatabaseTool() and 2. use the tool and evaluate the output.

Since we only have one table in our database, we can include the table description, column description and few-shot examples directly in our prompt. By doing this, we can reduce the number of tool calls our Agent has to make and reduce our latency.

table_desc = json.dumps({"description": "The 'nba_games' table contains data about NBA games, including details about the teams, game statistics. It includes information such as the game date, matchup, and win-loss record. This data can be used for various purposes, including game analysis. It can also be used to identify trends and patterns in player performance and team performance over time."})
column_comments = json.dumps({'season_id': "Unique identifier for the NBA season.",
'team_id_home': "Identifier for the home team in the game.",
'team_abbreviation_home': "Abbreviated representation of the home team's name.",
'team_name_home': "Full name of the home team in the game.",
'game_id': "Unique identifier for the game.",
'game_date': "Date when the game was played.",
'matchup_home': "Opposing team's abbreviation for the home team.",})

few_shot_examples = json.dumps([{'Question': 'How many away games did the Boston Celtics win in 2023',
'Answer': """SELECT count(season_id) as total_wins,year(game_date) as year
FROM nba_games
WHERE wl_away = "W" and team_name_away = "Boston Celtics" and year(game_date) = 2023
GROUP BY year(game_date)"""}])

3. Format prompt using ChatPromptTemplate

Now that our prompt is more complex, we can use the Langchain ChatPromptTemplate so we can organize the system prompt, user prompt and chat history separately

from langchain_core.prompts.chat import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.prompts import ChatPromptTemplate, PromptTemplate, MessagesPlaceholder
system_message_template = """"You are an expert in the American NBA. Your task is to answer user questions about games, teams and players. 
Think step-by-step what SQL query you need in order to answer the users question. Include the SQL query you finally used in your final output.
This is a description of the table: {table_desc}
This is an explanation of the column comments: {column_comments}
Here are some examples of the user question and the corresponding SQL query:
{few_shot_examples}"""
system_message = SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=["table_desc", "column_comments", "few_shot_examples"], template=system_message_template))
chat_prompt_str = ChatPromptTemplate([system_message,
MessagesPlaceholder(variable_name='chat_history', optional=True),
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}'))])

4. Helper functions to help us parse our user questions for the LLM Agent

Users may not ask questions in a way that directly translates into a valid SQL query. For example, sports fans might refer to Michael Jordan as “Air Jordan” or just “Jordan”. To help our Agent, we can translate these terms into the entities that are actually stored in our database (in this case the player’s full name). While it can seem feasible to use an LLM to parse colloquial player names into their full form, it’s also much faster and easier to start with a rule-based solution first.

# Parse player names if shortforms or nicknames are used
def parse_player_names(query):
known_nicknames = ["iceman", "air jordan"]
nicknames_in_query = [n for n in known_nicknames if n in query]
player_nickname_map = {"iceman": "George Gervin", "air jordan": "Michael Jordan"}

for n in nicknames_in_query:
try:
real_name = player_nickname_map[n]
query = query.replace(n, real_name)
except KeyError:
pass

Then, within our LLM chain, we can call this function on the user’s input as part of the Langchain Expression Language (LCEL)

chain = (
{
"input": itemgetter("messages")
| RunnableLambda(parse_player_names),})

5. Customise the SparkSQLToolkit

As we saw before, the toolkit that we pass to our Agent is really a class that encapsulates a list of tools. Instead of having a tool to list the tables and table information from our database, which takes time, we include this information directly in our prompt and leave out the original ListSparkSQLTool and InfoSparkSQLTool from the Toolkit.

class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL.
Parameters:
db: SparkSQL. The Spark SQL database.
llm: BaseLanguageModel. The language model.
"""
db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)

class Config:
arbitrary_types_allowed = True
  def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
NewQueryCheckerTool(db=self.db, llm=self.llm),
]

6. Customise tool prompts

Some tools used by our agent have prompts associated with them, for example, the QuerySparkSQLTool contains a prompt to check the validity of a sql query before it is executed. If we see persistent model errors because of certain syntax errors, we can update this prompt to be more relevant. Then, we can include this new prompt into the QueryCheckTool in the SparkSQLToolkit

# original prompt
'Double check the Spark SQL query above for common mistakes, including:\\n'
'- Using NOT IN with NULL values\\n'
'- Using UNION when UNION ALL should have been used\\n'
'- Using BETWEEN for exclusive ranges\\n'
'- Data type mismatch in predicates\\n'
'- Properly quoting identifiers\\n'
'- Using the correct number of arguments for functions\\n'
'- Casting to the correct data type\\n'
'- Using the proper columns for joins\\n'
'\\n'
'If there are any of the above mistakes, rewrite the query. If there are no '
'mistakes, just reproduce the original query.'
NEW_QUERY_CHECKER = """
{query}
Double check the Spark SQL query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Make sure your string matching formatting is correct
class NewQueryCheckerTool(BaseSparkSQLTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from <https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/>"""
template: str = NEW_QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: Any = Field(init=False)
name: str = "query_checker_sql_db"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
"""
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
from langchain.chains.llm import LLMChain
values["llm_chain"] = LLMChain(
llm=values.get("llm"), # type: ignore[arg-type]
prompt=PromptTemplate(
template=NEW_QUERY_CHECKER, input_variables=["query"]
),
)
if values["llm_chain"].prompt.input_variables != ["query"]:
raise ValueError(
"LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
"for the embedded prompt"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(
query=query, callbacks=run_manager.get_child() if run_manager else None
)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(
query=query, callbacks=run_manager.get_child() if run_manager else None
)

Conclusion

These are some ideas for customising and improving the results of a Text-to-SQL Agent. Tools and agents from libraries may benefit from customisation, and it’s helpful to be see how tools and agents really work so we can tune them as needed.

--

--

littlereddotdata
littlereddotdata

Written by littlereddotdata

I work with data in the little red dot

No responses yet