Bradarr commited on
Commit
37e38a4
·
verified ·
1 Parent(s): 2542cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -1,6 +1,6 @@
 
1
  # Importing required libraries
2
  import warnings
3
-
4
  warnings.filterwarnings("ignore")
5
 
6
  import os
@@ -16,7 +16,6 @@ import spaces
16
  # Download gguf model files (Simplified for the specified models)
17
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set
18
 
19
-
20
  def download_model(repo_id, filename):
21
  try:
22
  hf_hub_download(
@@ -37,17 +36,16 @@ if not os.path.exists("./models/google.gemma-3-1b-pt.Q4_K_M.gguf"):
37
  if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"):
38
  download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf")
39
  if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
40
- download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf")
41
 
42
 
43
  # Set the title and description
44
- title = "Gemma Markdown Generation"
45
- description = """Gemma models for markdown text generation and notebook continuation. This interface is designed for generating markdown text continuations, not for interactive chat."""
46
 
47
  llm = None
48
  llm_model = None
49
 
50
-
51
  @spaces.GPU
52
  def generate_text(
53
  prompt: str,
@@ -60,6 +58,18 @@ def generate_text(
60
  ) -> Generator[str, None, None]:
61
  """
62
  Generates text based on a prompt, using the specified Gemma model.
 
 
 
 
 
 
 
 
 
 
 
 
63
  """
64
  try:
65
  global llm
@@ -68,7 +78,7 @@ def generate_text(
68
  model_path = os.path.join("models", model)
69
  if not os.path.exists(model_path):
70
  raise FileNotFoundError(f"Model file not found: {model_path}")
71
-
72
  # Load the model (only if it's a new model)
73
  if llm is None or llm_model != model:
74
  logging.info(f"Loading model: {model}")
@@ -76,10 +86,10 @@ def generate_text(
76
  model_path=model_path,
77
  flash_attn=True,
78
  n_gpu_layers=999, # Adjust based on your GPU availability
79
- n_ctx=4096, # Context window size. Can increase.
80
- n_threads=4, # Adjust as needed for performance.
81
  n_threads_batch=4,
82
- verbose=False, # Reduce unnecessary verbosity
83
  )
84
  llm_model = model
85
 
@@ -91,8 +101,8 @@ def generate_text(
91
  top_p=top_p,
92
  top_k=top_k,
93
  repeat_penalty=repeat_penalty,
94
- stream=True, # Ensure streaming is on
95
- stop=["<|im_end|>", "<|endoftext|>", "<|file_separator|>"], # Add appropriate stop tokens
96
  ):
97
  text_chunk = token["choices"][0]["text"]
98
  yield text_chunk
@@ -102,8 +112,8 @@ def generate_text(
102
 
103
 
104
  def clear_history():
105
- """Clears the text input."""
106
- return ""
107
 
108
 
109
  with gr.Blocks(theme="Ocean", title=title) as demo:
@@ -118,8 +128,13 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
118
  lines=10,
119
  )
120
  clear_button = gr.Button("Clear Input")
121
-
122
- output_markdown = gr.Markdown(label="Generated Markdown")
 
 
 
 
 
123
 
124
  with gr.Column(scale=1):
125
  submit_button = gr.Button("Generate", variant="primary")
@@ -177,15 +192,16 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
177
  info="Penalize repeated words (higher = less repetition)",
178
  )
179
 
180
- def streaming_markdown_output(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
181
- """Wraps the generator for Gradio and renders as Markdown."""
182
- generated_markdown = ""
183
- for text_chunk in generate_text(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
184
- generated_markdown += text_chunk
185
- yield generated_markdown
 
186
 
187
  submit_button.click(
188
- streaming_markdown_output,
189
  [
190
  input_textbox,
191
  model_dropdown,
@@ -195,10 +211,12 @@ with gr.Blocks(theme="Ocean", title=title) as demo:
195
  top_k_slider,
196
  repeat_penalty_slider,
197
  ],
198
- output_markdown,
199
  )
200
-
201
  clear_button.click(clear_history, [], input_textbox)
202
 
 
 
203
  if __name__ == "__main__":
204
- demo.launch(debug=False, share=False)
 
1
+
2
  # Importing required libraries
3
  import warnings
 
4
  warnings.filterwarnings("ignore")
5
 
6
  import os
 
16
  # Download gguf model files (Simplified for the specified models)
17
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set
18
 
 
19
  def download_model(repo_id, filename):
20
  try:
21
  hf_hub_download(
 
36
  if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"):
37
  download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf")
38
  if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
39
+ download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf")
40
 
41
 
42
  # Set the title and description
43
+ title = "Gemma Text Generation"
44
+ description = """Gemma models for text generation and notebook continuation. This interface is designed for generating text continuations, not for interactive chat."""
45
 
46
  llm = None
47
  llm_model = None
48
 
 
49
  @spaces.GPU
50
  def generate_text(
51
  prompt: str,
 
58
  ) -> Generator[str, None, None]:
59
  """
60
  Generates text based on a prompt, using the specified Gemma model.
61
+
62
+ Args:
63
+ prompt (str): The initial text to continue.
64
+ model (str): The model file to use (without path).
65
+ max_tokens (int): Maximum number of tokens to generate.
66
+ temperature (float): Controls randomness.
67
+ top_p (float): Nucleus sampling parameter.
68
+ top_k (int): Top-k sampling parameter.
69
+ repeat_penalty (float): Penalty for repeating tokens.
70
+
71
+ Yields:
72
+ str: Generated text chunks, streamed as they become available.
73
  """
74
  try:
75
  global llm
 
78
  model_path = os.path.join("models", model)
79
  if not os.path.exists(model_path):
80
  raise FileNotFoundError(f"Model file not found: {model_path}")
81
+
82
  # Load the model (only if it's a new model)
83
  if llm is None or llm_model != model:
84
  logging.info(f"Loading model: {model}")
 
86
  model_path=model_path,
87
  flash_attn=True,
88
  n_gpu_layers=999, # Adjust based on your GPU availability
89
+ n_ctx=4096, # Context window size. Can increase.
90
+ n_threads=4, # Adjust as needed for performance.
91
  n_threads_batch=4,
92
+ verbose=False #Reduce unnecessary verbosity
93
  )
94
  llm_model = model
95
 
 
101
  top_p=top_p,
102
  top_k=top_k,
103
  repeat_penalty=repeat_penalty,
104
+ stream=True, # Ensure streaming is on
105
+ stop=["<|im_end|>","<|endoftext|>","<|file_separator|>"], # Add appropriate stop tokens
106
  ):
107
  text_chunk = token["choices"][0]["text"]
108
  yield text_chunk
 
112
 
113
 
114
  def clear_history():
115
+ """Clears the text input."""
116
+ return ""
117
 
118
 
119
  with gr.Blocks(theme="Ocean", title=title) as demo:
 
128
  lines=10,
129
  )
130
  clear_button = gr.Button("Clear Input")
131
+
132
+ output_textbox = gr.Textbox( # Changed to Textbox for streaming
133
+ label="Generated Text",
134
+ lines=10, # Added lines for better display of longer outputs
135
+ interactive=False # Output shouldn't be editable
136
+ )
137
+
138
 
139
  with gr.Column(scale=1):
140
  submit_button = gr.Button("Generate", variant="primary")
 
192
  info="Penalize repeated words (higher = less repetition)",
193
  )
194
 
195
+
196
+ def streaming_output(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
197
+ """Wraps the generator for Gradio."""
198
+ generated_text = ""
199
+ for text_chunk in generate_text(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
200
+ generated_text += text_chunk
201
+ yield generated_text
202
 
203
  submit_button.click(
204
+ streaming_output,
205
  [
206
  input_textbox,
207
  model_dropdown,
 
211
  top_k_slider,
212
  repeat_penalty_slider,
213
  ],
214
+ output_textbox,
215
  )
216
+
217
  clear_button.click(clear_history, [], input_textbox)
218
 
219
+
220
+
221
  if __name__ == "__main__":
222
+ demo.launch(debug=False, share=False) # Added share=False for clearer local-only run.