Spaces:
Runtime error
Runtime error
| ## LIBRARIES ### | |
| from cProfile import label | |
| from tkinter import font | |
| from turtle import width | |
| import streamlit as st | |
| import pandas as pd | |
| from datetime import datetime | |
| import plotly.express as px | |
| def select_plot_data(df, quantile_low, qunatile_high): | |
| df.fillna(0, inplace=True) | |
| df_plot = df.set_index('Model').T | |
| df_plot.index = date_range(df_plot) | |
| df_stats = df_plot.describe() | |
| quantile_lvalue = df_stats.quantile(quantile_low, axis=1)['mean'] | |
| quantile_hvalue = df_stats.quantile(qunatile_high, axis=1)['mean'] | |
| df_plot_data = df_plot.loc[:,[(df_plot[col].mean() > quantile_lvalue and df_plot[col].mean() < quantile_hvalue) for col in df_plot.columns]] | |
| return df_plot_data | |
| def read_file_to_df(file): | |
| return pd.read_csv(file) | |
| def date_range(df): | |
| time = df.index.to_list() | |
| time_range = [] | |
| for t in time: | |
| time_range.append(str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().month) +'/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().day) + '/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().year)[-2:]) | |
| return time_range | |
| if __name__ == "__main__": | |
| ### STREAMLIT APP CONGFIG ### | |
| st.set_page_config(layout="wide", page_title="HF Hub Model Usage Visualization") | |
| st.header("Model Usage Visualization") | |
| with st.expander("How to read and interact with the plot:"): | |
| st.markdown("The plots below visualize weekly usage for HF models categorized by the model creation time.") | |
| st.markdown("Select the model creation time range you want to visualize using the dropdown menu below.") | |
| st.markdown("Choose the quantile range to filter out models with high or low usage.") | |
| st.markdown("The plots are interactive. Hover over the points to see the model name and the number of weekly mean usage. Click on the legend to hide/show the models.") | |
| model_init_year = st.multiselect("Model creation year", ["before_2021", "2021", "2022"], key = "model_init_year", default = "2022") | |
| popularity_low = st.slider("Model popularity quantile (lower limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.90, key = "popularity_low") | |
| popularity_high = st.slider("Model popularity quantile (upper limit) ", min_value=0.0, max_value=1.0, step=0.01, value=0.99, key = "popularity_high") | |
| if 'model_init_year' not in st.session_state: | |
| st.session_state['model_init_year'] = model_init_year | |
| if 'popularity_low' not in st.session_state: | |
| st.session_state['popularity_low'] = popularity_low | |
| if 'popularity_high' not in st.session_state: | |
| st.session_state['popularity_high'] = popularity_high | |
| with st.container(): | |
| for year in st.session_state['model_init_year']: | |
| plotly_spot = st.empty() | |
| df = read_file_to_df("./assets/"+year+"/model_usage.csv") | |
| df_plot_data = select_plot_data(df, st.session_state['popularity_low'], st.session_state['popularity_high']) | |
| fig = px.line(df_plot_data, title="Models created in "+year, labels={"index": "Weeks", "value": "Usage", "variable": "Model"}) | |
| with plotly_spot: | |
| st.plotly_chart(fig, use_container_width=True) | |