Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from bokeh.events import Tap | |
| from bokeh.io import curdoc | |
| from bokeh.layouts import column | |
| from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs | |
| from bokeh.models.tools import CrosshairTool | |
| from demo_utils import ( | |
| get_data, | |
| prompt_boolq, | |
| pvp_colors, | |
| ctl_colors, | |
| clf_colors, | |
| reduct, | |
| task_best_pattern, | |
| plot_polygons_bokeh, | |
| advantage_text, | |
| data_difference, | |
| calculate_overlap, | |
| circ_easing, | |
| average_advantage_text, | |
| plot_three_polygons_bokeh, | |
| tasks, | |
| metric_tap, | |
| neutral_tasks, pattern_graph, | |
| ) | |
| from text import text1, text2, text3, text4, initial_passage, initial_question, text5 | |
| ######################################################################################################################## | |
| # Basic dimensions | |
| ######################################################################################################################## | |
| plot_width = 1200 | |
| plot_height = 400 | |
| sidebar_width = 400 | |
| in_text_plot_height = 300 | |
| text_width = 800 | |
| widget_size = 400 | |
| ######################################################################################################################## | |
| # Patternification widget | |
| ######################################################################################################################## | |
| passage = TextAreaInput(title="Passage", rows=3, value=initial_passage, max_width=text_width) | |
| passage.align = "center" | |
| question = TextInput(title="Question", value=initial_question, max_width=text_width) | |
| question.align = "center" | |
| radio_button_group = RadioButtonGroup(labels=["Pattern 1", "Pattern 2", "Pattern 3"], active=0, max_width=text_width) | |
| radio_button_group.align = "center" | |
| box_style = { | |
| "display": "block", | |
| "margin": "0 auto", | |
| "width": f"{text_width}px", | |
| "text-align": "center", | |
| "white-space": "pre-wrap", | |
| "background": "#f4f4f4", | |
| "border": "1px solid #ddd", | |
| # "border-left": "3px solid #4d4945", | |
| "color": "#666", | |
| "page-break-inside": "avoid", | |
| # "font-family": "monospace", | |
| "font-size": "15px", | |
| "line-height": "1.6", | |
| "max-width": "100%", | |
| "overflow": "hidden", | |
| "min-height": "30px", | |
| "word-wrap": "break-word", | |
| } | |
| prompt_box = Div( | |
| text=prompt_boolq(passage.value, question.value, radio_button_group.active), | |
| width=text_width, | |
| style=box_style, | |
| sizing_mode="scale_width", | |
| ) | |
| prompt_box.align = "center" | |
| def update_prompt(attrname, old, new): | |
| prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active) | |
| passage.on_change("value", update_prompt) | |
| question.on_change("value", update_prompt) | |
| radio_button_group.on_change("active", update_prompt) | |
| patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width") | |
| patternification.align = "center" | |
| ######################################################################################################################## | |
| # Advantage diagram | |
| ######################################################################################################################## | |
| advantage_plots_per_task = [] | |
| overlapping_range_per_task = [] | |
| training_points_per_task = [] | |
| clf_results_per_task = [] | |
| pvp_results_per_task = [] | |
| advantage_tabs = [] | |
| advantage_all_figures = Tabs(tabs=advantage_tabs) | |
| advantage_box = Div( | |
| text="Click within the comparison region to compute the data advantage for a performance level", | |
| width=text_width, | |
| style=box_style, | |
| sizing_mode="scale_width", | |
| ) | |
| advantage_box.align = "center" | |
| for task in tasks: | |
| training_points, classifier_performances, pattern_performances = get_data(task) | |
| training_points_per_task.append(list(training_points)) | |
| clf_results_per_task.append(reduct(classifier_performances, "accmax")) | |
| pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) | |
| advantage_plots_per_task.append(plot_polygons_bokeh( | |
| task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, | |
| pvp_colors | |
| )) | |
| advantage_plots_per_task[-1].align = "center" | |
| advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) | |
| advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task)) | |
| advantage_plots_per_task[-1].on_event( | |
| Tap, | |
| lambda event: metric_tap( | |
| event, | |
| overlapping_range_per_task[advantage_all_figures.active], | |
| training_points_per_task[advantage_all_figures.active], | |
| clf_results_per_task[advantage_all_figures.active], | |
| pvp_results_per_task[advantage_all_figures.active], | |
| advantage_box, | |
| advantage_plots_per_task[advantage_all_figures.active], | |
| ), | |
| ) | |
| if task == "MNLI": | |
| training_points_per_task.append(list(training_points)) | |
| clf_results_per_task.append(reduct(classifier_performances, "accmax")) | |
| pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")) | |
| advantage_plots_per_task.append(plot_polygons_bokeh( | |
| task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors, | |
| pvp_colors, x_log_scale=True | |
| )) | |
| advantage_plots_per_task[-1].align = "center" | |
| advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1])) | |
| advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)")) | |
| advantage_plots_per_task[-1].on_event( | |
| Tap, | |
| lambda event: metric_tap( | |
| event, | |
| overlapping_range_per_task[advantage_all_figures.active], | |
| training_points_per_task[advantage_all_figures.active], | |
| clf_results_per_task[advantage_all_figures.active], | |
| pvp_results_per_task[advantage_all_figures.active], | |
| advantage_box, | |
| advantage_plots_per_task[advantage_all_figures.active], | |
| ), | |
| ) | |
| advantage_all_figures = Tabs(tabs=advantage_tabs) | |
| advantage_all_figures.align = "center" | |
| def on_integrate_click(): | |
| frames = 200 | |
| initial_placement = overlapping_range_per_task[advantage_all_figures.active][0] | |
| if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span): | |
| metric_line = Span( | |
| location=initial_placement, | |
| line_alpha=0.7, | |
| dimension="width", | |
| line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0], | |
| line_dash="dashed", | |
| line_width=1, | |
| ) | |
| advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line]) | |
| else: | |
| advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement | |
| advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ | |
| 0] if initial_placement < 0 else pvp_colors[0] | |
| average_advantage = 0 | |
| for i in range(1, frames): | |
| metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + ( | |
| overlapping_range_per_task[advantage_all_figures.active][1] - | |
| overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames) | |
| advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active], | |
| training_points_per_task[advantage_all_figures.active], | |
| clf_results_per_task[advantage_all_figures.active], | |
| pvp_results_per_task[advantage_all_figures.active]) | |
| average_advantage = ((i - 1) * average_advantage + advantage_value) / i | |
| advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value | |
| advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[ | |
| 0] if advantage_value < 0 else pvp_colors[0] | |
| advantage_box.text = average_advantage_text(average_advantage) | |
| integrate = Button(width=175, max_width=175, label="Integrate over the whole region!") | |
| integrate.align = "center" | |
| integrate.on_click(on_integrate_click) | |
| def on_tab_change(attr, old, new): | |
| advantage_box.text = "Click within the comparison region to compute the data advantage for a performance level" | |
| advantage_all_figures.on_change('active', on_tab_change) | |
| advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width") | |
| ######################################################################################################################## | |
| # Null verbalizer diagram | |
| ######################################################################################################################## | |
| null_tabs = [] | |
| null_all_figures = Tabs(tabs=null_tabs) | |
| for task in neutral_tasks: | |
| training_points, classifier_performances, pattern_performances = get_data(task) | |
| training_points = list(training_points) | |
| clf_results = reduct(classifier_performances, "accmax") | |
| pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal") | |
| ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral") | |
| null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, | |
| pvp_colors, ctl_colors) | |
| null_plot.align = "center" | |
| null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| null_tabs.append(Panel(child=null_plot, title=task)) | |
| if task == "MNLI": | |
| null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors, | |
| pvp_colors, ctl_colors, x_log_scale=True) | |
| null_plot.align = "center" | |
| null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)")) | |
| null_all_figures = Tabs(tabs=null_tabs) | |
| null_all_figures.align = "center" | |
| ######################################################################################################################## | |
| # Patterns diagram | |
| ######################################################################################################################## | |
| pattern_tabs = [] | |
| pattern_all_figures = Tabs(tabs=pattern_tabs) | |
| for task in tasks: | |
| pattern_plot = pattern_graph(task) | |
| pattern_plot.align = "center" | |
| pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2)) | |
| pattern_tabs.append(Panel(child=pattern_plot, title=task)) | |
| pattern_all_figures = Tabs(tabs=pattern_tabs) | |
| pattern_all_figures.align = "center" | |
| ######################################################################################################################## | |
| # Add write-up text | |
| ######################################################################################################################## | |
| main_text_style = { | |
| "min-height": "100px", | |
| "overflow": "hidden", | |
| "display": "block", | |
| "margin": "auto", | |
| "width": f"{text_width}px", | |
| "font-size": "18px", | |
| } | |
| textbox1 = Div(text=text1, style=main_text_style) | |
| textbox2 = Div(text=text2, style=main_text_style) | |
| textbox3 = Div(text=text3, style=main_text_style) | |
| textbox4 = Div(text=text4, style=main_text_style) | |
| textbox5 = Div(text=text5, style=main_text_style) | |
| textbox1.align = "center" | |
| textbox2.align = "center" | |
| textbox3.align = "center" | |
| textbox4.align = "center" | |
| textbox5.align = "center" | |
| ######################################################################################################################## | |
| # Set up layouts and add to document | |
| ######################################################################################################################## | |
| main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures, | |
| textbox5, sizing_mode="scale_width") | |
| main_body.align = "center" | |
| curdoc().add_root(main_body) | |
| curdoc().title = "How many data points is a prompt worth ?" | |