Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
|
|
| 1 |
# Importing required libraries
|
| 2 |
import warnings
|
| 3 |
-
|
| 4 |
warnings.filterwarnings("ignore")
|
| 5 |
|
| 6 |
import os
|
|
@@ -16,7 +16,6 @@ import spaces
|
|
| 16 |
# Download gguf model files (Simplified for the specified models)
|
| 17 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set
|
| 18 |
|
| 19 |
-
|
| 20 |
def download_model(repo_id, filename):
|
| 21 |
try:
|
| 22 |
hf_hub_download(
|
|
@@ -37,17 +36,16 @@ if not os.path.exists("./models/google.gemma-3-1b-pt.Q4_K_M.gguf"):
|
|
| 37 |
if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"):
|
| 38 |
download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf")
|
| 39 |
if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
|
| 40 |
-
|
| 41 |
|
| 42 |
|
| 43 |
# Set the title and description
|
| 44 |
-
title = "Gemma
|
| 45 |
-
description = """Gemma models for
|
| 46 |
|
| 47 |
llm = None
|
| 48 |
llm_model = None
|
| 49 |
|
| 50 |
-
|
| 51 |
@spaces.GPU
|
| 52 |
def generate_text(
|
| 53 |
prompt: str,
|
|
@@ -60,6 +58,18 @@ def generate_text(
|
|
| 60 |
) -> Generator[str, None, None]:
|
| 61 |
"""
|
| 62 |
Generates text based on a prompt, using the specified Gemma model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
try:
|
| 65 |
global llm
|
|
@@ -68,7 +78,7 @@ def generate_text(
|
|
| 68 |
model_path = os.path.join("models", model)
|
| 69 |
if not os.path.exists(model_path):
|
| 70 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 71 |
-
|
| 72 |
# Load the model (only if it's a new model)
|
| 73 |
if llm is None or llm_model != model:
|
| 74 |
logging.info(f"Loading model: {model}")
|
|
@@ -76,10 +86,10 @@ def generate_text(
|
|
| 76 |
model_path=model_path,
|
| 77 |
flash_attn=True,
|
| 78 |
n_gpu_layers=999, # Adjust based on your GPU availability
|
| 79 |
-
n_ctx=4096,
|
| 80 |
-
n_threads=4,
|
| 81 |
n_threads_batch=4,
|
| 82 |
-
verbose=False
|
| 83 |
)
|
| 84 |
llm_model = model
|
| 85 |
|
|
@@ -91,8 +101,8 @@ def generate_text(
|
|
| 91 |
top_p=top_p,
|
| 92 |
top_k=top_k,
|
| 93 |
repeat_penalty=repeat_penalty,
|
| 94 |
-
stream=True,
|
| 95 |
-
stop=["<|im_end|>",
|
| 96 |
):
|
| 97 |
text_chunk = token["choices"][0]["text"]
|
| 98 |
yield text_chunk
|
|
@@ -102,8 +112,8 @@ def generate_text(
|
|
| 102 |
|
| 103 |
|
| 104 |
def clear_history():
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
| 108 |
|
| 109 |
with gr.Blocks(theme="Ocean", title=title) as demo:
|
|
@@ -118,8 +128,13 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
|
|
| 118 |
lines=10,
|
| 119 |
)
|
| 120 |
clear_button = gr.Button("Clear Input")
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
with gr.Column(scale=1):
|
| 125 |
submit_button = gr.Button("Generate", variant="primary")
|
|
@@ -177,15 +192,16 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
|
|
| 177 |
info="Penalize repeated words (higher = less repetition)",
|
| 178 |
)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
submit_button.click(
|
| 188 |
-
|
| 189 |
[
|
| 190 |
input_textbox,
|
| 191 |
model_dropdown,
|
|
@@ -195,10 +211,12 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
|
|
| 195 |
top_k_slider,
|
| 196 |
repeat_penalty_slider,
|
| 197 |
],
|
| 198 |
-
|
| 199 |
)
|
| 200 |
-
|
| 201 |
clear_button.click(clear_history, [], input_textbox)
|
| 202 |
|
|
|
|
|
|
|
| 203 |
if __name__ == "__main__":
|
| 204 |
-
demo.launch(debug=False, share=False)
|
|
|
|
| 1 |
+
|
| 2 |
# Importing required libraries
|
| 3 |
import warnings
|
|
|
|
| 4 |
warnings.filterwarnings("ignore")
|
| 5 |
|
| 6 |
import os
|
|
|
|
| 16 |
# Download gguf model files (Simplified for the specified models)
|
| 17 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set
|
| 18 |
|
|
|
|
| 19 |
def download_model(repo_id, filename):
|
| 20 |
try:
|
| 21 |
hf_hub_download(
|
|
|
|
| 36 |
if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"):
|
| 37 |
download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf")
|
| 38 |
if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
|
| 39 |
+
download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf")
|
| 40 |
|
| 41 |
|
| 42 |
# Set the title and description
|
| 43 |
+
title = "Gemma Text Generation"
|
| 44 |
+
description = """Gemma models for text generation and notebook continuation. This interface is designed for generating text continuations, not for interactive chat."""
|
| 45 |
|
| 46 |
llm = None
|
| 47 |
llm_model = None
|
| 48 |
|
|
|
|
| 49 |
@spaces.GPU
|
| 50 |
def generate_text(
|
| 51 |
prompt: str,
|
|
|
|
| 58 |
) -> Generator[str, None, None]:
|
| 59 |
"""
|
| 60 |
Generates text based on a prompt, using the specified Gemma model.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
prompt (str): The initial text to continue.
|
| 64 |
+
model (str): The model file to use (without path).
|
| 65 |
+
max_tokens (int): Maximum number of tokens to generate.
|
| 66 |
+
temperature (float): Controls randomness.
|
| 67 |
+
top_p (float): Nucleus sampling parameter.
|
| 68 |
+
top_k (int): Top-k sampling parameter.
|
| 69 |
+
repeat_penalty (float): Penalty for repeating tokens.
|
| 70 |
+
|
| 71 |
+
Yields:
|
| 72 |
+
str: Generated text chunks, streamed as they become available.
|
| 73 |
"""
|
| 74 |
try:
|
| 75 |
global llm
|
|
|
|
| 78 |
model_path = os.path.join("models", model)
|
| 79 |
if not os.path.exists(model_path):
|
| 80 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 81 |
+
|
| 82 |
# Load the model (only if it's a new model)
|
| 83 |
if llm is None or llm_model != model:
|
| 84 |
logging.info(f"Loading model: {model}")
|
|
|
|
| 86 |
model_path=model_path,
|
| 87 |
flash_attn=True,
|
| 88 |
n_gpu_layers=999, # Adjust based on your GPU availability
|
| 89 |
+
n_ctx=4096, # Context window size. Can increase.
|
| 90 |
+
n_threads=4, # Adjust as needed for performance.
|
| 91 |
n_threads_batch=4,
|
| 92 |
+
verbose=False #Reduce unnecessary verbosity
|
| 93 |
)
|
| 94 |
llm_model = model
|
| 95 |
|
|
|
|
| 101 |
top_p=top_p,
|
| 102 |
top_k=top_k,
|
| 103 |
repeat_penalty=repeat_penalty,
|
| 104 |
+
stream=True, # Ensure streaming is on
|
| 105 |
+
stop=["<|im_end|>","<|endoftext|>","<|file_separator|>"], # Add appropriate stop tokens
|
| 106 |
):
|
| 107 |
text_chunk = token["choices"][0]["text"]
|
| 108 |
yield text_chunk
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def clear_history():
|
| 115 |
+
"""Clears the text input."""
|
| 116 |
+
return ""
|
| 117 |
|
| 118 |
|
| 119 |
with gr.Blocks(theme="Ocean", title=title) as demo:
|
|
|
|
| 128 |
lines=10,
|
| 129 |
)
|
| 130 |
clear_button = gr.Button("Clear Input")
|
| 131 |
+
|
| 132 |
+
output_textbox = gr.Textbox( # Changed to Textbox for streaming
|
| 133 |
+
label="Generated Text",
|
| 134 |
+
lines=10, # Added lines for better display of longer outputs
|
| 135 |
+
interactive=False # Output shouldn't be editable
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
|
| 139 |
with gr.Column(scale=1):
|
| 140 |
submit_button = gr.Button("Generate", variant="primary")
|
|
|
|
| 192 |
info="Penalize repeated words (higher = less repetition)",
|
| 193 |
)
|
| 194 |
|
| 195 |
+
|
| 196 |
+
def streaming_output(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
|
| 197 |
+
"""Wraps the generator for Gradio."""
|
| 198 |
+
generated_text = ""
|
| 199 |
+
for text_chunk in generate_text(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
|
| 200 |
+
generated_text += text_chunk
|
| 201 |
+
yield generated_text
|
| 202 |
|
| 203 |
submit_button.click(
|
| 204 |
+
streaming_output,
|
| 205 |
[
|
| 206 |
input_textbox,
|
| 207 |
model_dropdown,
|
|
|
|
| 211 |
top_k_slider,
|
| 212 |
repeat_penalty_slider,
|
| 213 |
],
|
| 214 |
+
output_textbox,
|
| 215 |
)
|
| 216 |
+
|
| 217 |
clear_button.click(clear_history, [], input_textbox)
|
| 218 |
|
| 219 |
+
|
| 220 |
+
|
| 221 |
if __name__ == "__main__":
|
| 222 |
+
demo.launch(debug=False, share=False) # Added share=False for clearer local-only run.
|