mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 21:05:19 +00:00
refac: get_sorted_pipelines()
This commit is contained in:
parent
7ffd75b991
commit
144581a7df
1 changed files with 12 additions and 29 deletions
|
|
@ -764,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
|
||||||
##################################
|
##################################
|
||||||
|
|
||||||
|
|
||||||
def filter_pipeline(payload, user):
|
def get_sorted_filters(model_id):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
||||||
model_id = payload["model"]
|
|
||||||
filters = [
|
filters = [
|
||||||
model
|
model
|
||||||
for model in app.state.MODELS.values()
|
for model in app.state.MODELS.values()
|
||||||
|
|
@ -782,6 +780,13 @@ def filter_pipeline(payload, user):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||||
|
return sorted_filters
|
||||||
|
|
||||||
|
|
||||||
|
def filter_pipeline(payload, user):
|
||||||
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
|
model_id = payload["model"]
|
||||||
|
sorted_filters = get_sorted_filters(model_id)
|
||||||
|
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
|
@ -814,18 +819,11 @@ def filter_pipeline(payload, user):
|
||||||
print(f"Connection error: {e}")
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
if r is not None:
|
if r is not None:
|
||||||
try:
|
|
||||||
res = r.json()
|
res = r.json()
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if "detail" in res:
|
if "detail" in res:
|
||||||
raise Exception(r.status_code, res["detail"])
|
raise Exception(r.status_code, res["detail"])
|
||||||
|
|
||||||
else:
|
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
|
||||||
pass
|
|
||||||
|
|
||||||
if "pipeline" not in app.state.MODELS[model_id]:
|
|
||||||
if "task" in payload:
|
|
||||||
del payload["task"]
|
del payload["task"]
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
@ -1061,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
model = app.state.MODELS[model_id]
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
filters = [
|
sorted_filters = get_sorted_filters(model_id)
|
||||||
model
|
|
||||||
for model in app.state.MODELS.values()
|
|
||||||
if "pipeline" in model
|
|
||||||
and "type" in model["pipeline"]
|
|
||||||
and model["pipeline"]["type"] == "filter"
|
|
||||||
and (
|
|
||||||
model["pipeline"]["pipelines"] == ["*"]
|
|
||||||
or any(
|
|
||||||
model_id == target_model_id
|
|
||||||
for target_model_id in model["pipeline"]["pipelines"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
||||||
if "pipeline" in model:
|
if "pipeline" in model:
|
||||||
sorted_filters = [model] + sorted_filters
|
sorted_filters = [model] + sorted_filters
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue