Artigo
· Ago. 21 29min de leitura

Texto para IRIS SQL com LangChain

Um experimento sobre como usar a framework LangChain, Busca por Vetor IRIS e LLMs para gerar SQL compatível com IRIS de prompts de usuários.

Esse artigo foi baseado neste notebook. Você pode rodar com um ambiente pronto para uso com esta aplicação no OpenExchange.

Setup

Primeiro, precisamos instalar as livrarias necessárias:

!pip install --upgrade --quiet langchain langchain-openai langchain-iris pandas

Em seguida, importamos os módulos requeridos e definimos o ambiente:

import os
import datetime
import hashlib
from copy import deepcopy
from sqlalchemy import create_engine
import getpass
import pandas as pd
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache
from langchain_iris import IRISVector

Vamos usar SQLiteCache para manter as chamadas de LLM em cache:

# Cache for LLM calls
set_llm_cache(SQLiteCache(database_path=".langchain.db"))

Defina os parâmetros da conexão da base de dados IRIS:

# IRIS database connection parameters
os.environ["ISC_LOCAL_SQL_HOSTNAME"] = "localhost"
os.environ["ISC_LOCAL_SQL_PORT"] = "1972"
os.environ["ISC_LOCAL_SQL_NAMESPACE"] = "IRISAPP"
os.environ["ISC_LOCAL_SQL_USER"] = "_system"
os.environ["ISC_LOCAL_SQL_PWD"] = "SYS"

Se a chave da OpenAI API não está definida, peça ao usuário para fornecê-la:

if not "OPENAI_API_KEY" in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass()

Crie a string de conxão para a base de dados IRIS:

# IRIS database connection string
args = {
    'hostname': os.getenv("ISC_LOCAL_SQL_HOSTNAME"), 
    'port': os.getenv("ISC_LOCAL_SQL_PORT"), 
    'namespace': os.getenv("ISC_LOCAL_SQL_NAMESPACE"), 
    'username': os.getenv("ISC_LOCAL_SQL_USER"), 
    'password': os.getenv("ISC_LOCAL_SQL_PWD")
}
iris_conn_str = f"iris://{args['username']}:{args['password']}@{args['hostname']}:{args['port']}/{args['namespace']}"

Estabeleça a conexão para a base de dados IRIS:

# Connection to IRIS database
engine = create_engine(iris_conn_str)
cnx = engine.connect().connection

Prepare um dicionário para guardar informações de contexto para o prompt do sistema:

# Dict for context information for system prompt
context = {}
context["top_k"] = 3

Criação de Prompt

Para transformar inputs de usuário em consultas SQL compatíveis com a base de dados IRIS, precisamos criar um prompt efetivo para o modelo de linguagem. Começamos com um prompt inicial que fornece instruções básicas para gerar queries SQL. O template é derivado de prompts default do LangChain's para MSSQL e customizado para a base de dados IRIS.

# Basic prompt template with IRIS database SQL instructions
iris_sql_template = """
Você é um expert em InterSystems IRIS. Dada uma questão como entrada, crie primeiro uma consulta com sintaxe correta para rodar e retorne a resposta para a questão de entrada.
A não ser que o usuário especifique na questão um número específico de exemplos a obter, pesquise por no máximo {top_k} resultados usando a cláusula TOP para o InterSystems IRIS. Você pode ordenar resultados para os dados mais informativos na base de dados.
Nunca pesquise todas as colunas de uma tabela. Você deve consultar apenas as colunas necessárias para responder a questão. Envolva cada nome de coluna em aspas simples para deontar que são identificadores delimitados.
Preste atenção para usar apenas nomes de colunas que você pode ver nas tabelas abaixo. Cuidado para não consultar colunas que não existem. Além disso, preste atenção em qual coluna está em qual tabela.
Atente-se em usar a função CAST(CURRENT_DATE as date) para buscar a data atual, se a questão envolve "hoje".
Use aspas duplas para delimitar identificadores de colunas.
Retorne apenas SQL puro; não aplique nenhum tipo de formatação.
"""

Esse prompt básico configura o modelo de linguagem (LLM) para funcionar como um expert em SQL com guia específico para a base de dados IRIS. Em seguida, fornecemos um prompt auxiliar com informação sobre a base de dados para evitar exageros.

# SQL template extension for including tables context information
tables_prompt_template = """
Only use the following tables:
{table_info}
"""

Para melhorar a acurácia das respostas LLM, usamos uma técnica chamada few-shot prompting. Isso envolve apresentar alguns exemplos para a LLM.

# SQL template extension for including few shots
prompt_sql_few_shots_template = """
Below are a number of examples of questions and their corresponding SQL queries.

{examples_value}
"""

Definimos o template para exemplos few-shot:

# Few shots prompt template
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)

Construímos o prompt de usuário usando o template few-shot:

# User prompt template
user_prompt = "\n" + example_prompt.invoke({"input": "{input}", "query": ""}).to_string()

Finalmente, podemos compor todos os prompts para criar o prompt final:

# Complete prompt template
prompt = (
    ChatPromptTemplate.from_messages([("system", iris_sql_template)])
    + ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
    + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
    + ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt

Esse prompt espera as variáveis examples_value, input, table_info, etop_k.

É assim que o prompt é estruturado:

ChatPromptTemplate(
    input_variables=['examples_value', 'input', 'table_info', 'top_k'], 
    messages=[
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['top_k'], 
                template=iris_sql_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['table_info'], 
                template=tables_prompt_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['examples_value'], 
                template=prompt_sql_few_shots_template
            )
        ), 
        HumanMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['input'], 
                template=user_prompt
            )
        )
    ]
)

Para visualizar como o prompt será enviado a LLM, podemos usar valores no espaço reservado para as variáveis requeridas:

prompt_value = prompt.invoke({
    "top_k": "<top_k>",
    "table_info": "<table_info>",
    "examples_value": "<examples_value>",
    "input": "<input>"
})
print(prompt_value.to_string())
Sistema: 
Você é um expert em InterSystems IRIS. Dada uma questão como entrada, crie primeiro uma consulta com sintaxe correta para rodar e retorne a resposta para a questão de entrada.
A não ser que o usuário especifique na questão um número específico de exemplos a obter, pesquise por no máximo {top_k} resultados usando a cláusula TOP para o InterSystems IRIS. Você pode ordenar resultados para os dados mais informativos na base de dados.
Nunca pesquise todas as colunas de uma tabela. Você deve consultar apenas as colunas necessárias para responder a questão. Envolva cada nome de coluna em aspas simples para deontar que são identificadores delimitados.
Preste atenção para usar apenas nomes de colunas que você pode ver nas tabelas abaixo. Cuidado para não consultar colunas que não existem. Além disso, preste atenção em qual coluna está em qual tabela.
Atente-se em usar a função CAST(CURRENT_DATE as date) para buscar a data atual, se a questão envolve "hoje".
Use aspas duplas para delimitar identificadores de colunas.
Retorne apenas SQL puro; não aplique nenhum tipo de formatação.

Sistema: 
Use  apenas as tabelas a seguir:
<table_info>

Sistema: 
Abaixo, seguem alguns exemplos de questões e suas consultas SQL correspondentes.

<examples_value>

Humano: 
Input do usuário: <input>
Consulta SQL: 

Agora estamos prontos para enviar esse prompt para o LLM providenciando as variáveis necessárias. Vamos seguir para o próximo passo quando estiver pronto.

Fornecendo informações de tabela

Para criar consultas SQL precisas, precisamos fornecer ao modelo de linguagem (LLM) informações detalhadas sobre as tabelas das bases de dados. Sem essa informação, o LLM deve gerar consultas que parecem plausíveis mas são incorretas devido a exageros. Então, nosso primeiro passo é criar uma função que retorna definições de tabelas da base de dados IRIS.

Função para retornar definições de tabelas

A função a seguir consulta a variável INFORMATION_SCHEMA para buscar as definições de tabela para um schema (esquema) específico. Se uma tabela específica for fornecida, ela retorna a definição daquela tabela; caso contrário, retorna definições para todas as tabelas naquele esquema.

def get_table_definitions_array(cnx, schema, table=None):
    cursor = cnx.cursor()

    # Base query to get columns information
    query = """
    SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA = %s
    """

    # Parameters for the query
    params = [schema]

    # Adding optional filters
    if table:
        query += " AND TABLE_NAME = %s"
        params.append(table)

    # Execute the query
    cursor.execute(query, params)

    # Fetch the results
    rows = cursor.fetchall()

    # Process the results to generate the table definition(s)
    table_definitions = {}
    for row in rows:
        table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row
        if table_name not in table_definitions:
            table_definitions[table_name] = []
        table_definitions[table_name].append({
            "column_name": column_name,
            "column_type": column_type,
            "is_nullable": is_nullable,
            "column_default": column_default,
            "column_key": column_key,
            "extra": extra
        })

    primary_keys = {}

    # Build the output string
    result = []
    for table_name, columns in table_definitions.items():
        table_def = f"CREATE TABLE {schema}.{table_name} (\n"
        column_definitions = []
        for column in columns:
            column_def = f"  {column['column_name']} {column['column_type']}"
            if column['is_nullable'] == "NO":
                column_def += " NOT NULL"
            if column['column_default'] is not None:
                column_def += f" DEFAULT {column['column_default']}"
            if column['extra']:
                column_def += f" {column['extra']}"
            column_definitions.append(column_def)
        if table_name in primary_keys:
            pk_def = f"  PRIMARY KEY ({', '.join(primary_keys[table_name])})"
            column_definitions.append(pk_def)
        table_def += ",\n".join(column_definitions)
        table_def += "\n);"
        result.append(table_def)

    return result

Retornar definições de tabela para um Esquema

Para esse exemplo, usamos o esquema Aviation (aviação), que está disponível aqui.

# Retrieve table definitions for the Aviation schema
tables = get_table_definitions_array(cnx, "Aviation")
print(tables)

Essa função retorna os predicados CREATE TABLE para todas as tabelas no esquema Aviation:

[
    'CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  AccidentExplosion varchar,\n  AccidentFire varchar,\n  AirFrameHours varchar,\n  AirFrameHoursSince varchar,\n  AirFrameHoursSinceLastInspection varchar,\n  AircraftCategory varchar,\n  AircraftCertMaxGrossWeight integer,\n  AircraftHomeBuilt varchar,\n  AircraftKey integer NOT NULL,\n  AircraftManufacturer varchar,\n  AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,\n  DestinationCountry varchar,\n  DestinationSameAsLocal varchar,\n  DestinationState varchar,\n  EngineCount integer,\n  EvacuationOccurred varchar,\n  EventId varchar NOT NULL,\n  FlightMedical varchar,\n  FlightMedicalType varchar,\n  FlightPhase integer,\n  FlightPlan varchar,\n  FlightPlanActivated varchar,\n  FlightSiteSeeing varchar,\n  FlightType varchar,\n  GearType varchar,\n  LastInspectionDate timestamp,\n  LastInspectionType varchar,\n  Missing varchar,\n  OperationDomestic varchar,\n  OperationScheduled varchar,\n  OperationType varchar,\n  OperatorCertificate varchar,\n  OperatorCertificateNum varchar,\n  OperatorCode varchar,\n  OperatorCountry varchar,\n  OperatorIndividual varchar,\n  OperatorName varchar,\n  OperatorState varchar,\n  Owner varchar,\n  OwnerCertified varchar,\n  OwnerCountry varchar,\n  OwnerState varchar,\n  RegistrationNumber varchar,\n  ReportedToICAO varchar,\n  SeatsCabinCrew integer,\n  SeatsFlightCrew integer,\n  SeatsPassengers integer,\n  SeatsTotal integer,\n  SecondPilot varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.EventC("Aircraft"))\n);',
    'CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  Age integer,\n  AircraftKey integer NOT NULL,\n  Category varchar,\n  CrewNumber integer NOT NULL,\n  EventId varchar NOT NULL,\n  Injury varchar,\n  MedicalCertification varchar,\n  MedicalCertificationDate timestamp,\n  MedicalCertificationValid varchar,\n  Seat varchar,\n  SeatbeltUsed varchar,\n  Sex varchar,\n  ShoulderHarnessUsed varchar,\n  ToxicologyTestPerformed varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.AircraftC("Crew"))\n);',
    'CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  AirportDirection integer,\n  AirportDistance varchar,\n  AirportElevation integer,\n  AirportLocation varchar,\n  AirportName varchar,\n  Altimeter varchar,\n  EventDate timestamp,\n  EventId varchar NOT NULL,\n  EventTime integer,\n  FAADistrictOffice varchar,\n  InjuriesGroundFatal integer,\n  InjuriesGroundMinor integer,\n  InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  InjuriesTotal integer,\n  InjuriesTotalFatal integer,\n  InjuriesTotalMinor integer,\n  InjuriesTotalNone integer,\n  InjuriesTotalSerious integer,\n  InvestigatingAgency varchar,\n  LightConditions varchar,\n  LocationCity varchar,\n  LocationCoordsLatitude double,\n  LocationCoordsLongitude double,\n  LocationCountry varchar,\n  LocationSiteZipCode varchar,\n  LocationState varchar,\n  MidAir varchar,\n  NTSBId varchar,\n  NarrativeCause varchar,\n  NarrativeFull varchar,\n  NarrativeSummary varchar,\n  OnGroundCollision varchar,\n  SkyConditionCeiling varchar,\n  SkyConditionCeilingHeight integer,\n  SkyConditionNonCeiling varchar,\n  SkyConditionNonCeilingHeight integer,\n  TimeZone varchar,\n  Type varchar,\n  Visibility varchar,\n  WeatherAirTemperature integer,\n  WeatherPrecipitation varchar,\n  WindDirection integer,\n  WindDirectionIndicator varchar,\n  WindGust integer,\n  WindGustIndicator varchar,\n  WindVelocity integer,\n  WindVelocityIndicator varchar\n);'
]

Com essas definições de tabela, podemos seguir para o próximo passo, que é integrá-las ao nosso prompt para o LLM. Isso assegura que o LLM tem informações precisas e compreensivas sobre o esquema da base de dados ao gerar consultas SQL.

Selecionando as tabelas mais relevantes

Ao trabalhar com bases de dados, especialmente as maiores, enviar a linguagem de definição de dados (Data Definition Language ou DDL) para todas as tabelas em um prompt pode ser impraticável. Ao passo que essa abordagem pode funcionar para bases de dados menores, as bases de dados reais frequentemente contém centenas ou milhares de tabelas, sendo ineficiente processá-las por completo.

No entanto, é improvável que um modelo de linguagem precise saber cada tabela numa base de dados para gerar consultas SQL efetivamente. Para endereçar esse desafio, podemos nivelar as capacidades da procura semântica para selecionar apenas tabelas relevantes baseadas na consulta do usuário.

Abordagem

Nós podemos conseguir isso usando a procura semântica com a Procura de Vetor IRIS. Note que este método é mais efetivo se os seus identificadores de elementos SQL (como tabelas, campos e chaves) têm nomes significativos. Se seus identificadores forem códigos abstratos, consiere usar um dicionário de dados no lugar.

Passos

  1. Retornar informação de tabela

Primeiro, vamos extrair as definições de tabelas para um DataFrame pandas:

# Retrieve table definitions into a pandas DataFrame
table_def = get_table_definitions_array(cnx=cnx, schema='Aviation')
table_df = pd.DataFrame(data=table_def, columns=["col_def"])
table_df["id"] = table_df.index + 1
table_df

O DataFrame (table_df) deve se parecer com algo assim:

col_def id
0 CREATE TABLE Aviation.Aircraft (\n Event bigi... 1
1 CREATE TABLE Aviation.Crew (\n Aircraft varch... 2
2 CREATE TABLE Aviation.Event (\n ID bigint NOT... 3
  1. Separar as definições em documentos:

Em seguida, divida as definições de tabelas em documentos LangChain. Esse passo é crucial para lidar com grandes conjuntos de texto e extrair textos embutidos:

loader = DataFrameLoader(table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
tables_docs = text_splitter.split_documents(documents)
tables_docs

A lista resultante tables_docs contém documentos divididos com metadados, como o seguinte:

[Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 2}, page_content='CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 3}, page_content='CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  ...')]
  1. Extrair embutidos e guardar em IRIS

Agora, use a classe IRISVector do langchain-iris para extrair vetores embutidos e guardá-los:

tables_vector_store = IRISVector.from_documents(
    embedding=OpenAIEmbeddings(), 
    documents=tables_docs,
    connection_string=iris_conn_str,
    collection_name="sql_tables",
    pre_delete_collection=True
)

Nota: A bandeira pre_delete_collection é definida como True (verdadeiro) para propósitos de demonstração, para assegurar uma coleção nova em cada rodada de teste. Em um ambinte produtivo, essa bandeira geralmente deve ser definida como False (falso).

  1. Achar documentos relevantes

Com tabela de embutidos guardada, você agora pode consultar tabelas relevantes baseadas no input do usuário:

input_query = "List the first 2 manufacturers"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

Por exemplo, consultar por manufaturas pode retornar:

[Document(metadata={'id': 1}, page_content='GearType varchar,\n  LastInspectionDate timestamp,\n  ...'),
 Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...')]

Para os metadados, você pode ver que apenas o ID 1 da tabela (Aviation.Aircraft) é relevante, que linha com a consulta.

  1. Lidando com casos extremos

Apesar desta abordagem ser geralmente efetiva, pode não ser sempre perfeita. Por exemplo, consultar por áreas de queda pode retornar tabelas menos relevantes:

input_query = "List the top 10 most crash sites"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

Os resultados podem incluir:

[Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  ...'),
 Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...')]

Apesar de retornar a tabela correta Aviation.Event duas vezes, a tabela Aviation.Aircraft também pode aparecer, o que pode ser melhorado com filtros adicionais ou limitações. Isso está além do escopo deste exemplo e vai ser deixado para implementações futuras.

  1. Defina uma função para retornar tabelas relevantes

Para automatizar este processo, defina uma função para filtrar e retornar as tabelas relevantes baseado no input do usuário:

def get_relevant_tables(user_input, tables_vector_store, table_df):
    relevant_tables_docs = tables_vector_store.similarity_search(user_input)
    relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
    indices = table_df["id"].isin(relevant_tables_docs_indices)
    relevant_tables_array = [x for x in table_df[indices]["col_def"]]
    return relevant_tables_array

Essa função vai ajudar em retornar de maneira eficiente apenas as tabelas relevantes a enviar ao LLM, reduzindo o tamanho do prompt e melhorando a performance da consulta num geral.

Selecionando os exemplos mais relevantes (Few-Shot Prompting)

Ao trabalhar com modelos de linguagem (LLMs), fornecer exemplos relevantes ajuda a assegurar respostas precisas e contextualmente apropriadas. Esses exemplos, referidos como exemplos "few-shot", guiam o LLM para entender a estrutura e contexto das consultas que deve manusear.
No nosso caso, precisamos popular a variável examples_value com um conjunto diverso de consultas SQL que cobrem um espectro vasto de sinitaxe IRIS SQL e as tabelas disponíveis na base de dados. Isso ajuda a prevenir que o LLM gere consultas incorretas ou irrelevantes.

Definindo Consultas de Exemplo

Abaixo, uma lista de queries de exemplo desenhadas para ilustrar várias operações SQL:

examples = [
    {"input": "List all aircrafts.", "query": "SELECT * FROM Aviation.Aircraft"},
    {"input": "Find all incidents for the aircraft with ID 'N12345'.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
    {"input": "List all incidents in the 'Commercial' operation type.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"},
    {"input": "Ache o número total de incidentes.", "query": "SELECT COUNT(*) FROM Aviation.Event"},
    {"input": "Liste todos os incidentes que ocorreram no 'Canadá'.", "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"},
    {"input": "Quantos incidentes estão associados com a aeronave com AircraftKey 5?", "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"},
    {"input": "Ache o número total de diferentes aeronaves envolvidas em incidentes.", "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"},
    {"input": "Liste todos os incidentes que ocorreram após 5 PM.", "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"},
    {"input": "Quem são os top 5 operadores por número de incidentes?", "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"},
    {"input": "Quais incidentes ocorreram no ano 2020?", "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"},
    {"input": "Qual foi o mês com mais incidentes no ano 2020?", "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"},
    {"input": "Quantos membros de tripulação foram envolvidos em incidentes?", "query": "SELECT COUNT(*) FROM Aviation.Crew"},
    {"input": "Liste todos os incidentes com informações detalhadas da aeronave para os incidentes que ocorreram no ano 2012.", "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"},
    {"input": "Ache todos os incidentes onde houve mais de 5 feridos e inclua a fabricante da aeronave e modelo.", "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"},
    {"input": "Liste todos os membros da tripulação envolvidos em incidentes com machucados sérios, junto da da data de incidente e localização.", "query": "SELECT c.CrewNumber AS 'Crew Number', c.Age, c.Sex AS Gender, e.EventDate AS 'Event Date', e.LocationCity AS 'Location City', e.LocationState AS 'Location State' FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"}
]

[ Nota do tradutor: as frases foram traduzidas para fins de compreensão, mas os as informações de colunas nas frases deveriam idealmente ser deixadas em inglês para o interpretador melhor associar com as colunas na tabela, que estão em inglês. Segue um breve glossário:
Aviation.Aircraft = Aviação.Aeronave
key = chave
month = mês
count = contagem
OpertationType = tipo de operação
OperatorName = nome do operador
LocationCountry = país de localização
State = estado
City = cidade
Crew = Tripulação
Injury = ferimento ]

Selecionando Exemplos Relevantes

Dada a lista de exemplos que vai sempre expandir, não é uma boa ideia fornecer todas elas ao LLM. Ao invés disso, vamos usar a Busca Vetorial IRIS com a classe SemanticSimilarityExampleSelector to identify the most relevant examples based on user prompts.

Defina o seletor de exemplo:

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    IRISVector,
    k=5,
    input_keys=["input"],
    connection_string=iris_conn_str,
    collection_name="sql_samples",
    pre_delete_collection=True
)

Nota: A bandeira pre_delete_collection é usada aqui para propósitos de demonstração para assegura uma coleção nova em cada rodada de teste. Em um ambiente produtivo, essa bandeira deve ser definida como False para evitar deleções desnecessárias.

Consulte o Seletor:

Para buscar os exemplos mais relevantes para uma dada entrada, use o seletor como se segue:

input_query = "Busque todos os eventos em 2010 informando o Event Id e date, location city e state, aircraft manufacturer e model."
relevant_examples = example_selector.select_examples({"input": input_query})

Os resultados devem parecer como algo assim:

[{'input': 'Liste todos os incidents com informações detalhadas sobre a Aircraft para os incidentes que ocorreram no year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'},
 {'input': "Busque todos os incidents para a Aircraft de ID 'N12345'.", 'query': "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
 {'input': 'Ache todos os incidents onde houve mais de 5 injuries e inclua o aircraft manufacturer e model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'Liste todas as aircrafts.', 'query': 'SELECT * FROM Aviation.Aircraft'},
 {'input': 'Ache o número total de aircrafts distintas envolvidas em incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'}]

Se você especificamente precisa de exemplos relacionados a quantidades, você pode procurar o seletor conforme:

input_query = "Qual foi o número de incidentes envolvendo a aircraft Boeing."
quantity_examples = example_selector.select_examples({"input": input_query})

A saída será:

[{'input': 'Quantos incidents estão associados à aircraft de AircraftKey 5?', 'query': 'SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5'},
 {'input': 'Busque o número total de aircrafts distintas envolvidas em incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'},
 {'input': 'Quantos membros de tripulação foram envolvidos em incidents?', 'query': 'SELECT COUNT(*) FROM Aviation.Crew'},
 {'input': 'Ache todos os incidents onde houve mais de 5 ferimentos e inclua a aircraft manufacturer e model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'Liste todos os incidents com informação detalhada sobre as aircrafts para incidents que ocorreram no ano 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'}]

A saída inclui exemplos que especificamente buscam contagens e quantidades.

Considerações futuras

Apesar do SemanticSimilarityExampleSelector ser poderoso, é importante nota que nem todos os exemplos serão perfeitos. Melhorias futuras podem incluir adição de filtros e limites para excluir resultados menos relevantes, assegurando que apenas os exemplos mais apropriados sejam fornecidos ao LLM.

Teste de precisão

Para acessar a performance do prompt e da geração de consultas SQL, precisamos definir e rodar uma série de testes. O objetivo é avaliar se a LLM gera de maneira ótima as consultas SQL baseadas em inputs de usuários e sem o uso dos few shots baseados em exemplos.

Função para gerar consultas SQL

Nós começamos definindo uma função que usa LLM para gerar consultas SQL baseadas no contexto providenciado, prompt, entrada do usuários e outros parâmetros:

def get_sql_from_text(context, prompt, user_input, use_few_shots, tables_vector_store, table_df, example_selector=None, example_prompt=None):
    relevant_tables = get_relevant_tables(user_input, tables_vector_store, table_df)
    context["table_info"] = "\n\n".join(relevant_tables)

    examples = example_selector.select_examples({"input": user_input}) if example_selector else []
    context["examples_value"] = "\n\n".join([
        example_prompt.invoke(x).to_string() for x in examples
    ])

    model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    output_parser = StrOutputParser()
    chain_model = prompt | model | output_parser

    response = chain_model.invoke({
        "top_k": context["top_k"],
        "table_info": context["table_info"],
        "examples_value": context["examples_value"],
        "input": user_input
    })
    return response

Execute o Prompt

Teste o prompt com e sem exemplos:

# Prompt execution **with** few shots
input = "Busque todos os eventos em 2010 informando o Event Id e date, location city e state, aircraft manufacturer e model."
response_with_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=True, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
    example_selector=example_selector, 
    example_prompt=example_prompt,
)
print(response_with_few_shots)
SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.EventId = a.EventId
WHERE Year(e.EventDate) = 2010
# Prompt execution **without** few shots
input = "Busque todos os eventos em 2010 informando o Event Id e date, location city e state, aircraft manufacturer e model."
response_with_no_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=False, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
)
print(response_with_no_few_shots)
SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.ID = a.Event
WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
Utility Functions for Testing

Para testar as queries geradas por SQL, definimos algumas funções úteis:

def execute_sql_query(cnx, query):
    try:
        cursor = cnx.cursor()
        cursor.execute(query)
        rows = cursor.fetchall()
        return rows
    except:
        print('Error running query:')
        print(query)
        print('-'*80)
    return None

def sql_result_equals(cnx, query, expected):
    rows = execute_sql_query(cnx, query)
    result = [set(row._asdict().values()) for row in rows or []]
    if result != expected and rows is not None:
        print('Result not as expected for query:')
        print(query)
        print('-'*80)
    return result == expected
# SQL test for prompt **with** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_few_shots) is None else "SQL is not OK")
    SQL is OK
# SQL test for prompt **without** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_no_few_shots) is None else "SQL is not OK")
    error on running query: 
    SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.ID = a.Event
    WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
    --------------------------------------------------------------------------------
    SQL is not OK

Defina e Execute Testes

Defina um conjunto de casos de teste e rode eles:

tests = [{
    "input": "Quais são os 3 anos com mais eventos gravados?",
    "expected": [{128, 2003}, {122, 2007}, {117, 2005}]
},{
    "input": "Quantos incidentes envolveram a aircraft Boeing.",
    "expected": [{5}]
},{
    "input": "Quantos incidentes resultaram em fatalidades.",
    "expected": [{237}]
},{
    "input": "Liste o event Id e date e, crew number, age e gender para incidents que ocorreram em 2013.",
    "expected": [{1, datetime.datetime(2013, 3, 4, 11, 6), '20130305X71252', 59, 'M'},
                 {1, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 32, 'M'},
                 {2, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 35, 'M'},
                 {1, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 25, 'M'},
                 {2, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 34, 'M'},
                 {1, datetime.datetime(2013, 2, 1, 15, 0), '20130203X53401', 29, 'M'},
                 {1, datetime.datetime(2013, 2, 15, 15, 0), '20130218X70747', 27, 'M'},
                 {1, datetime.datetime(2013, 3, 2, 15, 0), '20130303X21011', 49, 'M'},
                 {1, datetime.datetime(2013, 3, 23, 13, 52), '20130326X85150', 'M', None}]
},{
    "input": "Ache o total de incidents ocorridos em United States.",
    "expected": [{1178}]
},{
    "input": "Liste todos as coordenadas de latitude e longitude de incidentes que resultaram em mais de 5 feridos em 2010.",
    "expected": [{-78.76833333333333, 43.25277777777778}]
},{
    "input": "Busque todos os incidentes em 2010 informando o Event Id e date, location city e state, aircraft manufacturer e model.",
    "expected": [
        {datetime.datetime(2010, 5, 20, 13, 43), '20100520X60222', 'CIRRUS DESIGN CORP', 'Farmingdale', 'New York', 'SR22'},
        {datetime.datetime(2010, 4, 11, 15, 0), '20100411X73253', 'CZECH AIRCRAFT WORKS SPOL SRO', 'Millbrook', 'New York', 'SPORTCRUISER'},
        {'108', datetime.datetime(2010, 1, 9, 12, 55), '20100111X41106', 'Bayport', 'New York', 'STINSON'},
        {datetime.datetime(2010, 8, 1, 14, 20), '20100801X85218', 'A185F', 'CESSNA', 'New York', 'Newfane'}
    ]
}]

Avaliação de precisão

Rode os testes e calcule a precisão:

def execute_tests(cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt):
    tests_generated_sql = [(x, get_sql_from_text(
            context, 
            prompt, 
            user_input=x['input'], 
            use_few_shots=use_few_shots, 
            tables_vector_store=tables_vector_store, 
            table_df=table_df,
            example_selector=example_selector if use_few_shots else None, 
            example_prompt=example_prompt if use_few_shots else None,
        )) for x in deepcopy(tests)]

    tests_sql_executions = [(x[0], sql_result_equals(cnx, x[1], x[0]['expected'])) 
                            for x in tests_generated_sql]

    accuracy = sum(1 for i in tests_sql_executions if i[1] == True) / len(tests_sql_executions)
    print(f'Accuracy: {accuracy}')
    print('-'*80)

Resultados

# Testes de precisão para prompts executados **sem** few shots
use_few_shots = False
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT "EventDate", COUNT("EventId") as "TotalEvents"
    FROM Aviation.Event
    GROUP BY "EventDate"
    ORDER BY "TotalEvents" DESC
    TOP 3;
    --------------------------------------------------------------------------------
    error on running query: 
    SELECT "EventId", "EventDate", "C"."CrewNumber", "C"."Age", "C"."Sex"
    FROM "Aviation.Event" AS "E"
    JOIN "Aviation.Crew" AS "C" ON "E"."ID" = "C"."EventId"
    WHERE "E"."EventDate" >= '2013-01-01' AND "E"."EventDate" < '2014-01-01'
    --------------------------------------------------------------------------------
    result not expected for query: 
    SELECT TOP 3 "e"."EventId", "e"."EventDate", "e"."LocationCity", "e"."LocationState", "a"."AircraftManufacturer", "a"."AircraftModel"
    FROM "Aviation"."Event" AS "e"
    JOIN "Aviation"."Aircraft" AS "a" ON "e"."ID" = "a"."Event"
    WHERE "e"."EventDate" >= '2010-01-01' AND "e"."EventDate" < '2011-01-01'
    --------------------------------------------------------------------------------
    accuracy: 0.5714285714285714
    --------------------------------------------------------------------------------
# Testes de precisão para prompts executados **com** few shots
use_few_shots = True
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.EventId = a.EventId
    WHERE Year(e.EventDate) = 2010 TOP 3
    --------------------------------------------------------------------------------
    accuracy: 0.8571428571428571
    --------------------------------------------------------------------------------

Conclusão

A acurácia das consultas SQL geradas com exemplos (few shots) é aproximadamente 49% maior se comparada àquelas geradas sem exemplos (85% vs. 57%).

References

Discussão (0)1
Entre ou crie uma conta para continuar