Skip to content
This repository was archived by the owner on Nov 1, 2022. It is now read-only.

Modify the way the results a shown #4

Merged
merged 3 commits into from
Jan 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from __future__ import annotations

import functools
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Callable
from typing import Mapping
Expand All @@ -13,7 +11,6 @@
import environ
import fasttext # not working with python3.9
import gradio as gr
from tokenizers.pre_tokenizers import Whitespace
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
from transformers.pipelines.token_classification import AggregationStrategy
Expand Down Expand Up @@ -127,8 +124,8 @@ def predict(
supported_languages: tuple[str, ...] = ("fr", "de"),
) -> tuple[
Mapping[str, float],
str,
Mapping[str, float],
str,
Sequence[tuple[str, str | None]],
Sequence[tuple[str, str | None]],
]:
Expand Down Expand Up @@ -189,27 +186,23 @@ def extract_entities(
predict_fn: Callable,
query: str,
) -> Sequence[tuple[str, str | None]]:
def get_entity(pred: Mapping[str, str]):
return pred.get("entity", pred.get("entity_group", None))

mapping = defaultdict(lambda: None)
mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
predictions = predict_fn(query)

query_processed = Whitespace().pre_tokenize_str(query)
res = tuple(
chain.from_iterable(
((word, mapping[word]), (" ", None)) for word, _ in query_processed
),
)
print(res)
return res
if len(predictions) == 0:
return [(query, None)]
else:
return [
(pred["word"], pred.get("entity_group", pred.get("entity", None)))
for pred in predictions
]

languages = predict_lang(query)
translation = translate_query(query, languages)
classifications = classify_query(translation, categories)
general_entities = extract_entities(models.ner, query)
recipe_entities = extract_entities(models.recipe, translation)
return languages, translation, classifications, general_entities, recipe_entities
return languages, classifications, translation, general_entities, recipe_entities


def main():
Expand Down Expand Up @@ -254,7 +247,7 @@ def extract_commas_separated_values(value: str) -> Sequence[str]:
load_fn=lambda: pipeline(
"ner",
model=cfg.ner.general,
aggregation_strategy=AggregationStrategy.MAX,
aggregation_strategy=AggregationStrategy.SIMPLE,
),
),
recipe=Predictor(
Expand Down Expand Up @@ -282,15 +275,15 @@ def extract_commas_separated_values(value: str) -> Sequence[str]:
type="auto",
label="Language identification",
),
gr.outputs.Textbox(
label="English query",
type="auto",
),
gr.outputs.Label(
num_top_classes=cfg.classification.max_results,
type="auto",
label="Predicted categories",
),
gr.outputs.Textbox(
label="English query",
type="auto",
),
gr.outputs.HighlightedText(label="NER generic"),
gr.outputs.HighlightedText(label="NER Recipes"),
],
Expand Down