Spaces:
Sleeping
Sleeping
| from io import BytesIO | |
| import streamlit as st | |
| import base64 | |
| from transformers import AutoModel, AutoTokenizer | |
| from graphviz import Digraph | |
| import json | |
| def display_tree(output): | |
| size = str(int(len(output))) + ',5' | |
| dpi = '300' | |
| format = 'svg' | |
| print(size, dpi) | |
| # Initialize Digraph object | |
| dot = Digraph(engine='dot', format=format) | |
| dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi) | |
| # Add nodes and edges | |
| for i,word_info in enumerate(output): | |
| word = word_info['word'] # Prepare word for RTL display | |
| head_idx = word_info['dep_head_idx'] | |
| dep_func = word_info['dep_func'] | |
| dot.node(str(i), word) | |
| # Create an invisible edge from the previous word to this one to enforce order | |
| if i > 0: | |
| dot.edge(str(i), str(i - 1), style='invis') | |
| if head_idx != -1: | |
| dot.edge(str(head_idx), str(i), label=dep_func, constraint='False') | |
| # Render the Digraph object | |
| dot.render('syntax_tree', format=format, cleanup=True) | |
| # Display the image in a scrollable container | |
| st.markdown( | |
| f""" | |
| <div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw"> | |
| <img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}" | |
| style="display: block; margin: auto; max-height: 240px;"> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| #st.image('syntax_tree.' + format, use_column_width=True) | |
| def display_download(disp_string): | |
| to_download = BytesIO(disp_string.encode()) | |
| st.download_button(label="⬇️ Download text file", | |
| data=to_download, | |
| file_name="parsed_output.txt", | |
| mime="text/plain") | |
| # Streamlit app title | |
| st.title('DictaBERT-Joint Visualizer') | |
| # Load Hugging Face token | |
| hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets | |
| # Authenticate and load model | |
| tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token) | |
| model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True) | |
| model.eval() | |
| # Checkbox for the compute_mst parameter | |
| compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True) | |
| output_style = st.selectbox( | |
| 'Output Style: ', | |
| ('JSON', 'UD', 'IAHLT_UD'), index=1).lower() | |
| # User input | |
| sentence = st.text_input('Enter a sentence to analyze:') | |
| if sentence: | |
| # Display the input sentence | |
| st.text(sentence) | |
| # Model prediction | |
| output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0] | |
| if output_style == 'ud' or output_style == 'iahlt_ud': | |
| ud_output = output | |
| # convert to tree format of [dict(word, dep_head_idx, dep_func)] | |
| tree = [] | |
| for l in ud_output[2:]: | |
| parts = l.split('\t') | |
| if '-' in parts[0]: continue | |
| tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7])) | |
| display_tree(tree) | |
| display_download('\n'.join(ud_output)) | |
| # Construct the table as a Markdown string | |
| table_md = "<div dir='rtl' style='text-align: right;'>\n\n" # Start with RTL div | |
| st.markdown("""<style> | |
| .google-translate-place { | |
| width: 256px; | |
| height: 128px; | |
| } | |
| .google-translate-crop { | |
| width: 256px; | |
| height: 128px; | |
| overflow: scroll; | |
| position: absolute; | |
| } | |
| .google-translate { | |
| transform: scale(0.75); | |
| transform-origin: 180px 200px; | |
| position: relative; | |
| left: -200px; top: -180px; | |
| width: 2560px; height: 5120px; | |
| position: absolute; | |
| } | |
| </style>""", unsafe_allow_html=True) | |
| # Add the UD header lines | |
| table_md += "##" + ud_output[0] + "\n" | |
| table_md += "##" + ud_output[1] + "\n" | |
| # Table header | |
| table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n" | |
| # Table alignment | |
| table_md += "| " + " | ".join(["---"]*10) + " |\n" | |
| for line in ud_output[2:]: | |
| # Each UD line as a table row | |
| cells = line.replace('_', '\\_').replace('|', '|').replace(':', ':').split('\t') | |
| wrd = cells[2] | |
| if wrd != "\_": | |
| cells[2] = "<div class='google-translate-place'><div class='google-translate-crop'><iframe class='google-translate' src='https://www.google.com/search?igu=1&q=" + wrd + "+in+English+google+translate&authuser=0&hl=en-US' width='256' height='128'></iframe></div></div><br/>" | |
| cells[2] += "<iframe src='https://books.google.com/ngrams/interactive_chart?content=" + wrd + "_*&year_start=1800&year_end=2022&corpus=iw&smoothing=50' width='256' height='128'></iframe><br/>" | |
| cells[2] += "<iframe src='https://freeali.se/freealise/translate/loader.htm?q=" + wrd + "&a=conj' width='256' height='128'></iframe><br/>" | |
| cells[2] += "<iframe src='https://freeali.se/freealise/translate/loader.htm?q=" + wrd + "&a=def' width='256' height='128'></iframe><br/>" | |
| cells[2] += "<a href='https://dict.com/hebrew-english/" + wrd + "' target='_blank'>" + wrd + "</a>" | |
| table_md += "| " + " | ".join(cells) + " |\n" | |
| table_md += "</div>" # Close the RTL div | |
| print(table_md) | |
| # Display the table using a single markdown call | |
| st.markdown(table_md, unsafe_allow_html=True) | |
| else: | |
| # display the tree | |
| tree = [w['syntax'] for w in output['tokens']] | |
| display_tree(tree) | |
| json_output = json.dumps(output, ensure_ascii=False, indent=2) | |
| display_download(json_output) | |
| # and the full json | |
| st.markdown("```json\n" + json_output + "\n```") | |