| | |
| | """ |
| | Created on Fri May 26 14:07:22 2023 |
| | |
| | @author: vibin |
| | """ |
| |
|
| | import streamlit as st |
| | from pandasql import sqldf |
| | import pandas as pd |
| | import re |
| | from typing import List |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
| | import re |
| |
|
| |
|
| | @st.cache_resource() |
| | def tapas_model(): |
| | return(pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")) |
| |
|
| | @st.cache_resource() |
| | def prepare_input(question: str, table: List[str]): |
| | table_prefix = "table:" |
| | question_prefix = "question:" |
| | join_table = ",".join(table) |
| | inputs = f"{question_prefix} {question} {table_prefix} {join_table}" |
| | input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids |
| | return input_ids |
| |
|
| | @st.cache_resource() |
| | def inference(question: str, table: List[str]) -> str: |
| | input_data = prepare_input(question=question, table=table) |
| | input_data = input_data.to(model.device) |
| | outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) |
| | result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) |
| | return result |
| |
|
| | @st.cache_resource() |
| | def tokmod(tok_md): |
| | tkn = AutoTokenizer.from_pretrained(tok_md) |
| | mdl = AutoModelForSeq2SeqLM.from_pretrained(tok_md) |
| | return(tkn,mdl) |
| |
|
| |
|
| | |
| |
|
| | nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) |
| | if nav == "TAPAS": |
| | |
| | col1 , col2, col3 = st.columns(3) |
| | col2.title("TAPAS") |
| | |
| | col3 , col4 = st.columns([3,12]) |
| | col4.text("Tabular Data Text Extraction using text") |
| | |
| | table = pd.read_csv("data.csv") |
| | table = table.astype(str) |
| | st.text("DataSet - ") |
| | st.dataframe(table,width=3000,height= 400) |
| | |
| | st.title("") |
| | |
| | lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] |
| | |
| | v2 = st.selectbox("Choose your text",lst_q,index = 0) |
| |
|
| | st.title("") |
| | |
| | sql_txt = st.text_area("TAPAS Input",v2) |
| | |
| | if st.button("Predict"): |
| | tqa = tapas_model() |
| | txt_sql = tqa(table=table, query=sql_txt)["answer"] |
| | st.text("Output - ") |
| | st.success(f"{txt_sql}") |
| | |
| | |
| | |
| | |
| | elif nav == "Text2SQL": |
| | |
| | |
| | col1 , col2, col3 = st.columns(3) |
| | col2.title("Text2SQL") |
| | |
| | col3 , col4 = st.columns([1,20]) |
| | col4.text("Text will be converted to SQL Query and can extract the data from DataSet") |
| | |
| | |
| | |
| | df_qna = pd.read_csv("qnacsv.csv", encoding= 'unicode_escape') |
| | |
| | st.title("") |
| | |
| | st.text("DataSet - ") |
| | st.dataframe(df_qna,width=3000,height= 500) |
| | |
| | st.title("") |
| | |
| | lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] |
| | v2 = st.selectbox("Choose your text",lst_q,index = 0) |
| |
|
| | st.title("") |
| | |
| | |
| | sql_txt = st.text_area("Text for SQL Conversion",v2) |
| | |
| | |
| | if st.button("Predict"): |
| | |
| | tok_model = "juierror/flan-t5-text2sql-with-schema" |
| | tokenizer,model = tokmod(tok_model) |
| | |
| | |
| | table_name = "df_qna" |
| | table_col = ["Type","Class_Code", "Version","Measure_Indicator_Code","Measure_Indicator_Name","Description_Definition", "Source", "Interfaces"] |
| | |
| | txt_sql = inference(question=sql_txt, table=table_col) |
| | |
| | |
| | |
| | sql_avg = ["AVG","COUNT","DISTINCT","MAX","MIN","SUM"] |
| | txt_sql = txt_sql.replace("table",table_name) |
| | sql_quotes = [] |
| | for match in re.finditer("=",txt_sql): |
| | new_txt = txt_sql[match.span()[1]+1:] |
| | try: |
| | match2 = re.search("AND",new_txt) |
| | sql_quotes.append((new_txt[:match2.span()[0]]).strip()) |
| | except: |
| | sql_quotes.append(new_txt.strip()) |
| | |
| | for i in sql_quotes: |
| | qts = "'" + i + "'" |
| | txt_sql = txt_sql.replace(i, qts) |
| | |
| | for r in sql_avg: |
| | if r in txt_sql: |
| | rr = re.search(rf"{r} (\w+)", txt_sql) |
| | init = " " + rr[1] |
| | qts = "(" + rr[1] + ")" |
| | txt_sql = txt_sql.replace(init,qts) |
| | else: |
| | pass |
| | |
| | |
| | st.success(f"{txt_sql}") |
| | all_students = sqldf(txt_sql) |
| | |
| | st.text("Output - ") |
| | st.write(all_students) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |