Spaces:
Build error
Build error
danseith
commited on
Commit
·
4db26d9
1
Parent(s):
b7321ed
search string generator
Browse files
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
|
|
|
| 4 |
from transformers import pipeline
|
| 5 |
from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
| 6 |
from transformers import AutoModelForMaskedLM
|
|
@@ -9,38 +11,45 @@ ex_str1 = "A crustless sandwich made from two slices of baked bread. The sandwic
|
|
| 9 |
"crustless bread pieces. The bread pieces have the same general outer shape defined by an outer periphery " \
|
| 10 |
"with central portions surrounded by an outer peripheral area, the bread pieces being at least partially " \
|
| 11 |
"crimped together at the outer peripheral area."
|
|
|
|
| 12 |
|
| 13 |
ex_str2 = "The present disclosure provides a DNA-targeting RNA that comprises a targeting sequence and, together with" \
|
| 14 |
" a modifying polypeptide, provides for site-specific modification of a target DNA and/or a polypeptide" \
|
| 15 |
" associated with the target DNA. "
|
|
|
|
| 16 |
|
| 17 |
ex_str3 = "The graphite plane is composed of a two-dimensional hexagonal lattice of carbon atoms and the plate has a " \
|
| 18 |
"length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
|
| 19 |
"least one of the length, width, and thickness values being 100 nanometers or smaller. "
|
|
|
|
| 20 |
|
| 21 |
-
tab_two_examples = [[ex_str1,
|
| 22 |
-
[ex_str2,
|
| 23 |
-
[ex_str3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
['The present disclosure provides a DNA-targeting RNA that comprises a targeting _.'],
|
| 27 |
-
['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
| 28 |
-
]
|
| 29 |
|
| 30 |
|
| 31 |
-
def add_mask(text):
|
| 32 |
split_text = text.split()
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
# If the user supplies a mask, don't add more
|
| 35 |
if '_' in split_text:
|
| 36 |
u_pos = [i for i, s in enumerate(split_text) if '_' in s][0]
|
| 37 |
split_text[u_pos] = '[MASK]'
|
| 38 |
return ' '.join(split_text), '[MASK]'
|
| 39 |
|
| 40 |
-
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
| 41 |
# Don't mask certain words
|
| 42 |
num_iters = 0
|
| 43 |
-
while split_text[idx].lower() in
|
| 44 |
num_iters += 1
|
| 45 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
| 46 |
if num_iters > 10:
|
|
@@ -148,6 +157,7 @@ PIPELINE_REGISTRY.register_pipeline(
|
|
| 148 |
)
|
| 149 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 150 |
|
|
|
|
| 151 |
|
| 152 |
def sample_output(out, sampling):
|
| 153 |
score_to_str = {out[k]: k for k in out.keys()}
|
|
@@ -167,10 +177,10 @@ def unmask_single(text, temp=1):
|
|
| 167 |
return out
|
| 168 |
|
| 169 |
|
| 170 |
-
def unmask(text, temp, rounds):
|
| 171 |
sampling = 'multi'
|
| 172 |
for _ in range(rounds):
|
| 173 |
-
masked_text, masked = add_mask(text)
|
| 174 |
split_text = masked_text.split()
|
| 175 |
res = scrambler(masked_text, temp=temp, top_k=15)
|
| 176 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
|
@@ -194,51 +204,140 @@ def unmask(text, temp, rounds):
|
|
| 194 |
return ''.join(text)
|
| 195 |
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
description1 = """<p>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
This is a model based on
|
| 208 |
<a href= "https://github.com/google/patents-public-data/blob/master/models/BERT%20for%20Patents.md">Patent BERT</a> created by Google.
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
<strong>Note:</strong> You can only add one '_' per submission.
|
| 212 |
<br/>
|
| 213 |
<p/>"""
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
allow_flagging='never',
|
| 228 |
title=title1,
|
| 229 |
description=description1
|
| 230 |
)
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
gr.TabbedInterface(
|
| 243 |
-
[
|
| 244 |
).launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
+
from nltk.stem import PorterStemmer
|
| 5 |
+
from collections import defaultdict
|
| 6 |
from transformers import pipeline
|
| 7 |
from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
| 8 |
from transformers import AutoModelForMaskedLM
|
|
|
|
| 11 |
"crustless bread pieces. The bread pieces have the same general outer shape defined by an outer periphery " \
|
| 12 |
"with central portions surrounded by an outer peripheral area, the bread pieces being at least partially " \
|
| 13 |
"crimped together at the outer peripheral area."
|
| 14 |
+
ex_key1 = "sandwich bread crimped"
|
| 15 |
|
| 16 |
ex_str2 = "The present disclosure provides a DNA-targeting RNA that comprises a targeting sequence and, together with" \
|
| 17 |
" a modifying polypeptide, provides for site-specific modification of a target DNA and/or a polypeptide" \
|
| 18 |
" associated with the target DNA. "
|
| 19 |
+
ex_key2 = "DNA target modification"
|
| 20 |
|
| 21 |
ex_str3 = "The graphite plane is composed of a two-dimensional hexagonal lattice of carbon atoms and the plate has a " \
|
| 22 |
"length and a width parallel to the graphite plane and a thickness orthogonal to the graphite plane with at " \
|
| 23 |
"least one of the length, width, and thickness values being 100 nanometers or smaller. "
|
| 24 |
+
ex_key3 = "graphite lattice orthogonal "
|
| 25 |
|
| 26 |
+
tab_two_examples = [[ex_str1, ex_key1],
|
| 27 |
+
[ex_str2, ex_key2],
|
| 28 |
+
[ex_str3, ex_key3]]
|
| 29 |
+
#
|
| 30 |
+
# tab_one_examples = [['A crustless _ made from two slices of baked bread.'],
|
| 31 |
+
# ['The present disclosure provides a DNA-targeting RNA that comprises a targeting _.'],
|
| 32 |
+
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
| 33 |
+
# ]
|
| 34 |
|
| 35 |
+
ignore = ['a', 'an', 'the', 'is', 'and', 'or']
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
+
def add_mask(text, lower_bound=0, index=None):
|
| 39 |
split_text = text.split()
|
| 40 |
+
if index is not None:
|
| 41 |
+
split_text[index] = '[MASK]'
|
| 42 |
+
return ' '.join(split_text), None
|
| 43 |
# If the user supplies a mask, don't add more
|
| 44 |
if '_' in split_text:
|
| 45 |
u_pos = [i for i, s in enumerate(split_text) if '_' in s][0]
|
| 46 |
split_text[u_pos] = '[MASK]'
|
| 47 |
return ' '.join(split_text), '[MASK]'
|
| 48 |
|
| 49 |
+
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
| 50 |
# Don't mask certain words
|
| 51 |
num_iters = 0
|
| 52 |
+
while split_text[idx].lower() in ignore:
|
| 53 |
num_iters += 1
|
| 54 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
| 55 |
if num_iters > 10:
|
|
|
|
| 157 |
)
|
| 158 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 159 |
|
| 160 |
+
generator = pipeline('text-generation', model='gpt2')
|
| 161 |
|
| 162 |
def sample_output(out, sampling):
|
| 163 |
score_to_str = {out[k]: k for k in out.keys()}
|
|
|
|
| 177 |
return out
|
| 178 |
|
| 179 |
|
| 180 |
+
def unmask(text, temp, rounds, lower_bound=0):
|
| 181 |
sampling = 'multi'
|
| 182 |
for _ in range(rounds):
|
| 183 |
+
masked_text, masked = add_mask(text, lower_bound)
|
| 184 |
split_text = masked_text.split()
|
| 185 |
res = scrambler(masked_text, temp=temp, top_k=15)
|
| 186 |
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
|
|
|
| 204 |
return ''.join(text)
|
| 205 |
|
| 206 |
|
| 207 |
+
def autocomplete(text, temp):
|
| 208 |
+
output = generator(text, max_length=30, num_return_sequences=1)
|
| 209 |
+
gpt_out = output[0]['generated_text']
|
| 210 |
+
# diff = gpt_out.replace(text, '')
|
| 211 |
+
patent_bert_out = unmask(gpt_out, temp=temp, rounds=5, lower_bound=len(text.split()))
|
| 212 |
+
# Take the output from gpt-2 and randomly mask, if a mask is confident, swap it in. Iterate 5 times
|
| 213 |
+
return patent_bert_out
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def extract_keywords(text, queries):
|
| 217 |
+
q_dict = {}
|
| 218 |
+
temp = 1 # set temperature to 1
|
| 219 |
+
for query in queries.split():
|
| 220 |
+
# Iterate through text and mask each token
|
| 221 |
+
ps = PorterStemmer()
|
| 222 |
+
top_scores = defaultdict(list)
|
| 223 |
+
top_k_range = 10
|
| 224 |
+
indices = [i for i, t in enumerate(text.split()) if t.lower() == query.lower()]
|
| 225 |
+
for i in indices:
|
| 226 |
+
masked_text, masked = add_mask(text, index=i)
|
| 227 |
+
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
| 228 |
+
out = {item["token_str"]: item["score"] for item in res}
|
| 229 |
+
sorted_keys = sorted(out, key=out.get)
|
| 230 |
+
# If the key does not appear, floor its rank for that round
|
| 231 |
+
for rank, token_str in enumerate(sorted_keys):
|
| 232 |
+
stemmed = ps.stem(token_str)
|
| 233 |
+
if token_str not in top_scores.keys():
|
| 234 |
+
top_scores[stemmed].append(0)
|
| 235 |
+
norm_rank = rank / top_k_range
|
| 236 |
+
top_scores[stemmed].append(norm_rank)
|
| 237 |
+
|
| 238 |
+
# Calc mean
|
| 239 |
+
for key in top_scores.keys():
|
| 240 |
+
top_scores[key] = np.mean(top_scores[key])
|
| 241 |
+
# Normalize
|
| 242 |
+
for key in top_scores.keys():
|
| 243 |
+
top_scores[key] = top_scores[key] / np.sum(list(top_scores.values()))
|
| 244 |
+
# Get top_k
|
| 245 |
+
top_n = sorted(list(top_scores.values()))[-3]
|
| 246 |
+
for key in list(top_scores.keys()):
|
| 247 |
+
if top_scores[key] < top_n:
|
| 248 |
+
del top_scores[key]
|
| 249 |
+
q_dict[query] = top_scores
|
| 250 |
+
|
| 251 |
+
keywords = ''
|
| 252 |
+
for i, q in enumerate(q_dict.keys()):
|
| 253 |
+
keywords += '['
|
| 254 |
+
for ii, k in enumerate(q_dict[q].keys()):
|
| 255 |
+
keywords += k
|
| 256 |
+
if ii < len(q_dict[q].keys()) - 1:
|
| 257 |
+
keywords += ' OR '
|
| 258 |
+
else:
|
| 259 |
+
keywords += ']'
|
| 260 |
+
if i < len(q_dict.keys()) - 1:
|
| 261 |
+
keywords += ' AND '
|
| 262 |
+
# keywords = set([k for q in q_dict.keys() for k in q_dict[q].keys()])
|
| 263 |
+
# search_str = ' OR '.join(keywords)
|
| 264 |
+
output = [q_dict[q] for q in q_dict]
|
| 265 |
+
output.append(keywords)
|
| 266 |
+
return output
|
| 267 |
+
# fig, ax = plt.subplots(nrows=1, ncols=3)
|
| 268 |
+
# for q in q_dict:
|
| 269 |
+
# ax.bar(q_dict[q])
|
| 270 |
+
# return fig
|
| 271 |
+
|
| 272 |
+
label0 = gr.Label(label='keyword 1', num_top_classes=3)
|
| 273 |
+
label01 = gr.Label(label='keyword 2', num_top_classes=3)
|
| 274 |
+
label02 = gr.Label(label='keyword 3', num_top_classes=3)
|
| 275 |
+
textbox02 = gr.Textbox(label="Input Keywords", lines=3)
|
| 276 |
+
textbox01 = gr.Textbox(label="Input Keywords", placeholder="Type keywords here", lines=1)
|
| 277 |
+
textbox0 = gr.Textbox(label="Input Sentences", placeholder="Type sentences here", lines=5)
|
| 278 |
+
|
| 279 |
+
output_textbox0 = gr.Textbox(label='Search String of Keywords', placeholder="Output will appear here", lines=4)
|
| 280 |
+
# temp_slider0 = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
|
| 281 |
|
| 282 |
+
textbox1 = gr.Textbox(label="Input Sentence", lines=5)
|
| 283 |
+
# output_textbox1 = gr.Textbox(placeholder="Output will appear here", lines=4)
|
| 284 |
+
title1 = "Patent-BERT: Context-Dependent Synonym Generator"
|
| 285 |
description1 = """<p>
|
| 286 |
+
Try inserting a few sentences from a patent, and pick keywords for the model to analyze. The model will analyze the
|
| 287 |
+
context of the keywords in the sentences and generate the top five most likely candidates for each word.
|
| 288 |
+
Can be used for more creative patent drafting or patent searches using the generated search string.
|
| 289 |
+
|
| 290 |
This is a model based on
|
| 291 |
<a href= "https://github.com/google/patents-public-data/blob/master/models/BERT%20for%20Patents.md">Patent BERT</a> created by Google.
|
| 292 |
+
|
| 293 |
+
<strong>Note:</strong> Current pipeline only allows for three keyword submission.
|
|
|
|
| 294 |
<br/>
|
| 295 |
<p/>"""
|
| 296 |
+
|
| 297 |
+
# textbox2 = gr.Textbox(label="Input Sentences", lines=5)
|
| 298 |
+
# output_textbox2 = gr.Textbox(placeholder="Output will appear here", lines=4)
|
| 299 |
+
# temp_slider2 = gr.Slider(1.0, 3.0, value=1.0, label='Creativity')
|
| 300 |
+
# edit_slider2 = gr.Slider(1, 20, step=1, value=1.0, label='Number of edits')
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# title2 = "Patent-BERT Sentence Remix-er: Multiple Edits"
|
| 304 |
+
# description2 = """<p>
|
| 305 |
+
#
|
| 306 |
+
# Try typing in a sentence for the model to remix. Adjust the 'creativity' scale bar to change the
|
| 307 |
+
# the model's confidence in its likely substitutions and the 'number of edits' for the number of edits you want
|
| 308 |
+
# the model to attempt to make. The words substituted in the output sentence will be enclosed in asterisks (e.g., *word*).
|
| 309 |
+
# <br/> <p/> """
|
| 310 |
+
|
| 311 |
+
demo0 = gr.Interface(
|
| 312 |
+
fn=extract_keywords,
|
| 313 |
+
inputs=[textbox0, textbox01],
|
| 314 |
+
outputs=[label0, label01, label02, output_textbox0],
|
| 315 |
+
examples=tab_two_examples,
|
| 316 |
allow_flagging='never',
|
| 317 |
title=title1,
|
| 318 |
description=description1
|
| 319 |
)
|
| 320 |
|
| 321 |
+
# demo1 = gr.Interface(
|
| 322 |
+
# fn=unmask_single,
|
| 323 |
+
# inputs=[textbox1],
|
| 324 |
+
# outputs='label',
|
| 325 |
+
# examples=tab_one_examples,
|
| 326 |
+
# allow_flagging='never',
|
| 327 |
+
# title=title1,
|
| 328 |
+
# description=description1
|
| 329 |
+
# )
|
| 330 |
+
|
| 331 |
+
# demo2 = gr.Interface(
|
| 332 |
+
# fn=unmask,
|
| 333 |
+
# inputs=[textbox2, temp_slider2, edit_slider2],
|
| 334 |
+
# outputs=[output_textbox2],
|
| 335 |
+
# examples=tab_two_examples,
|
| 336 |
+
# allow_flagging='never',
|
| 337 |
+
# title=title2,
|
| 338 |
+
# description=description2
|
| 339 |
+
# )
|
| 340 |
|
| 341 |
gr.TabbedInterface(
|
| 342 |
+
[demo0], ["Keyword generator"]
|
| 343 |
).launch()
|