mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
Compare commits
880 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f1486ffd0 | ||
|
|
73f7e91dec | ||
|
|
8361f73ca6 | ||
|
|
11efb982c1 | ||
|
|
9d87688ecc | ||
|
|
4f9677ffcf | ||
|
|
a49e1d87ad | ||
|
|
9a65ed2260 | ||
|
|
864d54095f | ||
|
|
b29fdc2a0c | ||
|
|
12f237ff80 | ||
|
|
192c2af7ba | ||
|
|
17bfd38696 | ||
|
|
a7e614ca4c | ||
|
|
e5c6b739c2 | ||
|
|
34169b3581 | ||
|
|
01868e856a | ||
|
|
e301d1962e | ||
|
|
9f6c91987f | ||
|
|
d19023288e | ||
|
|
29236aefe8 | ||
|
|
6ce9afd95d | ||
|
|
39f7575b64 | ||
|
|
954aaa6bdc | ||
|
|
aa589fcbd9 | ||
|
|
9f42b9369f | ||
|
|
143d3fbce2 | ||
|
|
6e531679f4 | ||
|
|
562f22960c | ||
|
|
5388cc1bc6 | ||
|
|
0a14196afb | ||
|
|
7b16637043 | ||
|
|
734c04ebf0 | ||
|
|
4f50571b53 | ||
|
|
52ccab8fc0 | ||
|
|
f5e8d4d5a0 | ||
|
|
51621ba91a | ||
|
|
dba86bc980 | ||
|
|
21f3411692 | ||
|
|
91473c788c | ||
|
|
25f0c26b25 | ||
|
|
9791c9bd8b | ||
|
|
c62609faba | ||
|
|
88decab9be | ||
|
|
d499c3aed8 | ||
|
|
277f3a91f1 | ||
|
|
1818f2b3d9 | ||
|
|
a0826ec9fe | ||
|
|
3c846617cd | ||
|
|
39645102d1 | ||
|
|
3f1d9ccbf8 | ||
|
|
781aeebd2a | ||
|
|
f589b7c189 | ||
|
|
696f356881 | ||
|
|
515f85fe1c | ||
|
|
4d74e6cefa | ||
|
|
3ebb3e2143 | ||
|
|
69b82edd63 | ||
|
|
9d39b9b42c | ||
|
|
e65d92fc6f | ||
|
|
f3c8c7045d | ||
|
|
c9185aaf44 | ||
|
|
05e79bdd0c | ||
|
|
fb6b18faef | ||
|
|
b56adf01e3 | ||
|
|
356e982d30 | ||
|
|
bb4b547574 | ||
|
|
20340c3e4e | ||
|
|
6c53bf7175 | ||
|
|
ff121413da | ||
|
|
c1d760692f | ||
|
|
a7c7993bbf | ||
|
|
0f3156651c | ||
|
|
c8071a3180 | ||
|
|
25994dd3da | ||
|
|
b9e849f17d | ||
|
|
80fbb29ccc | ||
|
|
7b1895ec8a | ||
|
|
aae2fce173 | ||
|
|
451907cc92 | ||
|
|
1b095d12ff | ||
|
|
0518749d51 | ||
|
|
fc06c16dd4 | ||
|
|
33b59adf27 | ||
|
|
70948f8803 | ||
|
|
c2634d45ad | ||
|
|
8ef482a52a | ||
|
|
dcf50c4758 | ||
|
|
742832a850 | ||
|
|
0a4358c3d1 | ||
|
|
369298a83e | ||
|
|
b99c9b277a | ||
|
|
4b6773885c | ||
|
|
d232e433e8 | ||
|
|
848f3fd4d8 | ||
|
|
453ea9b9a1 | ||
|
|
6ee50770cd | ||
|
|
15dc607779 | ||
|
|
32c888c280 | ||
|
|
99a7823e01 | ||
|
|
022f9ff3a5 | ||
|
|
ad86707605 | ||
|
|
289801b608 | ||
|
|
6bb204eb80 | ||
|
|
560702a8f7 | ||
|
|
6752772c1d | ||
|
|
d645cdbaf3 | ||
|
|
3b4d7d568b | ||
|
|
d5d0e72590 | ||
|
|
f1a7de94ba | ||
|
|
acccb9afdd | ||
|
|
f2c56fc839 | ||
|
|
dd6b808e69 | ||
|
|
7a374ca2a5 | ||
|
|
421aba7cd7 | ||
|
|
09b6ea38c5 | ||
|
|
28659f6af5 | ||
|
|
64b4d5d9c2 | ||
|
|
c7a48c50a3 | ||
|
|
b5e5617a41 | ||
|
|
ff4b1b9824 | ||
|
|
86cdcda29a | ||
|
|
5a32ea9b49 | ||
|
|
457af65df6 | ||
|
|
04b337323a | ||
|
|
384753c6ca | ||
|
|
3fe5a47050 | ||
|
|
d1bbf6ba92 | ||
|
|
9f89cc5adc | ||
|
|
fa0efae4d5 | ||
|
|
f2d6a425de | ||
|
|
d071cdf7d4 | ||
|
|
4b21704498 | ||
|
|
9fca4969db | ||
|
|
4370dee79e | ||
|
|
c631659327 | ||
|
|
4df5b7eb2e | ||
|
|
8b2015a97b | ||
|
|
477097c2e4 | ||
|
|
c5b73d7184 | ||
|
|
c7eb713689 | ||
|
|
1bfe2c92ba | ||
|
|
69722ba973 | ||
|
|
140605e660 | ||
|
|
f3547568e4 | ||
|
|
15c6860a49 | ||
|
|
363ef194d8 | ||
|
|
33a52628e6 | ||
|
|
35ab6b7667 | ||
|
|
97ba5b8436 | ||
|
|
9899293f05 | ||
|
|
3fa484f290 | ||
|
|
03dc4d7182 | ||
|
|
82a5f11b72 | ||
|
|
4847bdcc9b | ||
|
|
0fa97bde00 | ||
|
|
d5c3e9ea42 | ||
|
|
6235243b62 | ||
|
|
63ca0a3519 | ||
|
|
6a095099d5 | ||
|
|
f22d92e102 | ||
|
|
84ca2258be | ||
|
|
e6d8f89850 | ||
|
|
c0e1203538 | ||
|
|
baa1e07aec | ||
|
|
f2ee70cbfc | ||
|
|
3b5710d0cd | ||
|
|
a7ee36266a | ||
|
|
f0c7bd3f79 | ||
|
|
743199f2d0 | ||
|
|
488631db98 | ||
|
|
2328dc284e | ||
|
|
b1c1e68e56 | ||
|
|
38c6b0bff6 | ||
|
|
9c19d0abd4 | ||
|
|
b875a438f0 | ||
|
|
f0d75e3a48 | ||
|
|
0a687980ee | ||
|
|
a7b611c0e5 | ||
|
|
e567f42020 | ||
|
|
3b23b96a27 | ||
|
|
e3faec62c5 | ||
|
|
0f8729dea2 | ||
|
|
7ad549b4fb | ||
|
|
5183eaab4d | ||
|
|
b004a4a2c1 | ||
|
|
b2a6597617 | ||
|
|
286a5ad0db | ||
|
|
06f0bfd9f5 | ||
|
|
662a1fac47 | ||
|
|
cea99175ca | ||
|
|
a470f6149f | ||
|
|
fc05e0a6c5 | ||
|
|
819668b42d | ||
|
|
fe6783c166 | ||
|
|
a91663c504 | ||
|
|
1e01836c08 | ||
|
|
97f5c21485 | ||
|
|
9d14bc2a8d | ||
|
|
4857c69bd3 | ||
|
|
de889f5ec7 | ||
|
|
0259312626 | ||
|
|
5af5e1d3e4 | ||
|
|
2e1ddf823b | ||
|
|
5b1fa9dd0d | ||
|
|
fbbfa5aa79 | ||
|
|
42fa92ee64 | ||
|
|
7d3f45eff9 | ||
|
|
0eb1246c87 | ||
|
|
48d1e67e79 | ||
|
|
cd008eeb50 | ||
|
|
e29c262394 | ||
|
|
5148970ef5 | ||
|
|
a5c359aede | ||
|
|
4ad7a9bb9c | ||
|
|
66ffd77f2c | ||
|
|
46bbf760e8 | ||
|
|
103f92c8dd | ||
|
|
9b7c3ff999 | ||
|
|
ec45d77ce9 | ||
|
|
b0491886bc | ||
|
|
a4c3fa70c1 | ||
|
|
50813fcce4 | ||
|
|
99f0fe7f32 | ||
|
|
14baf6955d | ||
|
|
288947a648 | ||
|
|
682013cee3 | ||
|
|
f5809165d7 | ||
|
|
e2ef3d9647 | ||
|
|
60dbde7e19 | ||
|
|
7cf07b7e97 | ||
|
|
e6951e804a | ||
|
|
b2034861ae | ||
|
|
f9c96d03ad | ||
|
|
9bfc414d26 | ||
|
|
902c6cfbea | ||
|
|
bb3e222e09 | ||
|
|
4af7cc818e | ||
|
|
19ffa9fc19 | ||
|
|
d48f34cc5b | ||
|
|
b32f7815b8 | ||
|
|
8f2812d394 | ||
|
|
833e3c2690 | ||
|
|
598650f70b | ||
|
|
de971d7aa2 | ||
|
|
eec697e00d | ||
|
|
b4c4d9baf5 | ||
|
|
4d4c572bba | ||
|
|
b88f829dbb | ||
|
|
2b58191b82 | ||
|
|
0c18cd67d5 | ||
|
|
0dd9ad7ffc | ||
|
|
a51579a84b | ||
|
|
3ef7367f01 | ||
|
|
89b1ad649b | ||
|
|
23cadaa41a | ||
|
|
bed201e46e | ||
|
|
27b8775032 | ||
|
|
41701697ec | ||
|
|
2d8e321add | ||
|
|
6442871947 | ||
|
|
cd30152c83 | ||
|
|
557170c0b6 | ||
|
|
0c430629e5 | ||
|
|
b8728064d8 | ||
|
|
7be750bcbb | ||
|
|
d1e7957e69 | ||
|
|
f69e37a850 | ||
|
|
849278ca4f | ||
|
|
b65c728208 | ||
|
|
680cde8f9b | ||
|
|
e6c7495c1a | ||
|
|
6083960655 | ||
|
|
485896753d | ||
|
|
22e85df448 | ||
|
|
ff7a54653a | ||
|
|
4c28f19bdd | ||
|
|
17ac79920f | ||
|
|
66c5b7380d | ||
|
|
c2f889cf9a | ||
|
|
c49bcc65c4 | ||
|
|
aad23e2e53 | ||
|
|
17389e1b66 | ||
|
|
c4ecad0605 | ||
|
|
e486490451 | ||
|
|
4bb15aa425 | ||
|
|
93d0b8241c | ||
|
|
ee26c0537e | ||
|
|
c2afd0d5aa | ||
|
|
76dbbf57d2 | ||
|
|
88416161cc | ||
|
|
a1d09eae95 | ||
|
|
3d50d5ff77 | ||
|
|
5bec4a8005 | ||
|
|
af1db82c7d | ||
|
|
0c47cbd16a | ||
|
|
76acdabdc3 | ||
|
|
90f76d24ec | ||
|
|
31dc97b68b | ||
|
|
7031bb9067 | ||
|
|
f89c170566 | ||
|
|
7e03637446 | ||
|
|
bbd48b3638 | ||
|
|
7762fa5ddf | ||
|
|
720af637e6 | ||
|
|
4386e5abb8 | ||
|
|
07ef295a77 | ||
|
|
42071cb8e8 | ||
|
|
d98b945d73 | ||
|
|
3f97a6993f | ||
|
|
ccd80b9dba | ||
|
|
baffa89f35 | ||
|
|
2af4c4b3c7 | ||
|
|
02238d3113 | ||
|
|
6cdb13d5cb | ||
|
|
63ebc295ce | ||
|
|
34684e7e58 | ||
|
|
ed6aabfbfd | ||
|
|
73734b186b | ||
|
|
b4bc71d1bd | ||
|
|
0ed174f6a1 | ||
|
|
58cff5e482 | ||
|
|
0a72d047ef | ||
|
|
bc576782d7 | ||
|
|
f05e945a45 | ||
|
|
e76f72576e | ||
|
|
f138be9d8a | ||
|
|
3c381fad13 | ||
|
|
03cc6ce8eb | ||
|
|
7c2bed2c73 | ||
|
|
b1565e6913 | ||
|
|
ee10f372a0 | ||
|
|
21c0dd93e2 | ||
|
|
a79a39bb64 | ||
|
|
31fb34918f | ||
|
|
80388855f4 | ||
|
|
8f48e96f5e | ||
|
|
9370b263f5 | ||
|
|
1b0bce529f | ||
|
|
bc2192e8bf | ||
|
|
4d41013804 | ||
|
|
8f3bd2ecbe | ||
|
|
757caeab55 | ||
|
|
9df8d5b204 | ||
|
|
b160eef7eb | ||
|
|
e2ff2ae252 | ||
|
|
a4b2dc22c4 | ||
|
|
5d4134ba77 | ||
|
|
64747f7f79 | ||
|
|
117a33b030 | ||
|
|
7e05cf4e21 | ||
|
|
2bec5c5a5c | ||
|
|
c43f95f4b8 | ||
|
|
6d9a562edd | ||
|
|
6b638db114 | ||
|
|
2de854fa02 | ||
|
|
38f45a38cb | ||
|
|
84912904fd | ||
|
|
da42850eff | ||
|
|
0bf686396d | ||
|
|
f1c317349e | ||
|
|
3207998114 | ||
|
|
ad17d35ac4 | ||
|
|
8dde493e8e | ||
|
|
6eea0d40ab | ||
|
|
ce19b7120b | ||
|
|
4eecbf8ee4 | ||
|
|
cd99df870d | ||
|
|
ed3744b672 | ||
|
|
4673e120c4 | ||
|
|
8e41fea2d9 | ||
|
|
8bff76f745 | ||
|
|
5aa2d01c17 | ||
|
|
dc6e1fe6bd | ||
|
|
ce23843506 | ||
|
|
b2667470cd | ||
|
|
62073d3b7f | ||
|
|
bc739de024 | ||
|
|
e76f77bcb7 | ||
|
|
6cb41a59da | ||
|
|
413fa27b18 | ||
|
|
27df461abd | ||
|
|
a65cc196a5 | ||
|
|
284764e178 | ||
|
|
908f504885 | ||
|
|
e69c2cf3f6 | ||
|
|
25c7f101f2 | ||
|
|
9578bac099 | ||
|
|
79fd61cd5f | ||
|
|
c13a444326 | ||
|
|
3621e2d56c | ||
|
|
dec59e87a3 | ||
|
|
bebba7424e | ||
|
|
2e5668e25c | ||
|
|
e72a7e4eca | ||
|
|
c260efa2f4 | ||
|
|
7afa83e880 | ||
|
|
8114ddc5f9 | ||
|
|
684324ae9e | ||
|
|
3d43797361 | ||
|
|
ce1079d358 | ||
|
|
89c0e150c8 | ||
|
|
1097838b35 | ||
|
|
9a8646157e | ||
|
|
f052b2801a | ||
|
|
e0d5de1697 | ||
|
|
49d57ae82b | ||
|
|
918e720f97 | ||
|
|
261a55b275 | ||
|
|
0d0a37c884 | ||
|
|
a32a3dfee4 | ||
|
|
9b3ecb703a | ||
|
|
e239e17050 | ||
|
|
c2c02846a8 | ||
|
|
4754108253 | ||
|
|
96b98cd13c | ||
|
|
639d26252e | ||
|
|
7faf19dad9 | ||
|
|
c38f878e1e | ||
|
|
67c4ea1e57 | ||
|
|
e85c7f7931 | ||
|
|
d5fe0f6067 | ||
|
|
9c0bd0c0ed | ||
|
|
651f855289 | ||
|
|
dc3d704800 | ||
|
|
8da4e5bb19 | ||
|
|
cabbdd719b | ||
|
|
4e18c8a689 | ||
|
|
224e4c3a61 | ||
|
|
d3f40c5a56 | ||
|
|
7f4b45e7e8 | ||
|
|
63e8ab7a05 | ||
|
|
74db2b9f36 | ||
|
|
e4e2f8352c | ||
|
|
5f3f5170b7 | ||
|
|
aed2d4a8ee | ||
|
|
1cc3493dc8 | ||
|
|
3b944072e3 | ||
|
|
e2b9942648 | ||
|
|
f5c7152a6b | ||
|
|
1aa285edb9 | ||
|
|
4daf81fba2 | ||
|
|
bdd198e946 | ||
|
|
415b93c7c3 | ||
|
|
00520a9602 | ||
|
|
6c583ef9d3 | ||
|
|
cdf90222c7 | ||
|
|
3dc20a25b1 | ||
|
|
5c1ba23026 | ||
|
|
72f8539fd2 | ||
|
|
8d34fcb586 | ||
|
|
72900cd686 | ||
|
|
314cac0113 | ||
|
|
cde4b93fa6 | ||
|
|
89e59d0103 | ||
|
|
2a98ba0ff8 | ||
|
|
75efa4f931 | ||
|
|
3561c7eedd | ||
|
|
a4fd26b478 | ||
|
|
03f207b803 | ||
|
|
414ab51d5d | ||
|
|
8e5690aab4 | ||
|
|
939521b75d | ||
|
|
a8fe979cf8 | ||
|
|
bafeb76c41 | ||
|
|
ec21577f1a | ||
|
|
244809bab7 | ||
|
|
71b86c08ee | ||
|
|
989f192c92 | ||
|
|
08bc00ea77 | ||
|
|
1447102331 | ||
|
|
67aa1b028d | ||
|
|
024f78d3e0 | ||
|
|
d14329b285 | ||
|
|
6681ff5cbd | ||
|
|
478163eb3b | ||
|
|
9e8e004929 | ||
|
|
a9b4774bde | ||
|
|
d98c539d89 | ||
|
|
20af60be42 | ||
|
|
00eacfcacc | ||
|
|
cf6a476998 | ||
|
|
fdf7ca15ea | ||
|
|
ee61970fb0 | ||
|
|
43f817a75e | ||
|
|
fdc0f04a36 | ||
|
|
7fc4c56ea3 | ||
|
|
da282ce5c7 | ||
|
|
c787070dc9 | ||
|
|
24aeec9120 | ||
|
|
15cc9b6cee | ||
|
|
9f837267b6 | ||
|
|
f18f1db704 | ||
|
|
1d98a45b33 | ||
|
|
16af088f4e | ||
|
|
8feed02d40 | ||
|
|
3bcf9a442a | ||
|
|
4a0359789f | ||
|
|
292be82754 | ||
|
|
a0068c4a17 | ||
|
|
61a2909a88 | ||
|
|
48b538f312 | ||
|
|
d1c9555a0b | ||
|
|
f524a6a8e7 | ||
|
|
e986488ab5 | ||
|
|
81c530a9c4 | ||
|
|
76bde402fe | ||
|
|
a776dbd01d | ||
|
|
b9bbf22581 | ||
|
|
cbcab062eb | ||
|
|
c8b2313362 | ||
|
|
92aafd6c06 | ||
|
|
bfadbc9934 | ||
|
|
f0834e397c | ||
|
|
182e4138bf | ||
|
|
b72443004d | ||
|
|
5e17882488 | ||
|
|
b2ed5be457 | ||
|
|
f47214314b | ||
|
|
46a3f7de5e | ||
|
|
a996497bf1 | ||
|
|
5338edd644 | ||
|
|
c573321305 | ||
|
|
20cd9e9461 | ||
|
|
69e1ba8234 | ||
|
|
a70bc52c34 | ||
|
|
54c05ac6e0 | ||
|
|
ed6449d35f | ||
|
|
82c08a3b5d | ||
|
|
8197844ff7 | ||
|
|
0c4a1ac54d | ||
|
|
c9465da8f2 | ||
|
|
2f6a050325 | ||
|
|
235ed8956c | ||
|
|
7a3f4d85f6 | ||
|
|
d11d49a08a | ||
|
|
a4d0bd1073 | ||
|
|
f83dc60666 | ||
|
|
82c45b721e | ||
|
|
f06e2c1a4a | ||
|
|
46a8e4acad | ||
|
|
06e280d831 | ||
|
|
1f287be27f | ||
|
|
5444210163 | ||
|
|
d28559c49c | ||
|
|
c7df4c5082 | ||
|
|
c0fc37d112 | ||
|
|
79b0c3af47 | ||
|
|
6471945076 | ||
|
|
17e80ecd81 | ||
|
|
7222a9aef6 | ||
|
|
006a2d6bb3 | ||
|
|
c165a6b6c2 | ||
|
|
35504e8486 | ||
|
|
267794638c | ||
|
|
e0e50f7380 | ||
|
|
e171c7915a | ||
|
|
8f060ee2fa | ||
|
|
02a2683eb0 | ||
|
|
6eecade06e | ||
|
|
546a334328 | ||
|
|
8f5eb03a40 | ||
|
|
710f6eec12 | ||
|
|
92f359fb9b | ||
|
|
fe192eb738 | ||
|
|
23ea754061 | ||
|
|
3c7e739b3c | ||
|
|
6593b7ccc8 | ||
|
|
eb9c4c0e35 | ||
|
|
9942de8011 | ||
|
|
0a78ceef6d | ||
|
|
b68a5f330d | ||
|
|
b9f0d239b0 | ||
|
|
a708cdf55e | ||
|
|
bc9067d5aa | ||
|
|
9c9085adfa | ||
|
|
8cbc472f7f | ||
|
|
1be9187236 | ||
|
|
098e64b35d | ||
|
|
c469369c35 | ||
|
|
da89e36abe | ||
|
|
b754aad987 | ||
|
|
e225435c8e | ||
|
|
e361606c61 | ||
|
|
e24fec0de4 | ||
|
|
1f89eacb69 | ||
|
|
892ddf9eac | ||
|
|
2b0b87c0f9 | ||
|
|
f5899e875c | ||
|
|
43eac35b5b | ||
|
|
9aea08ccb2 | ||
|
|
3e003a5f17 | ||
|
|
acac6d5973 | ||
|
|
2a8b0b2581 | ||
|
|
a2b3abfa14 | ||
|
|
32996a16cc | ||
|
|
3af6192495 | ||
|
|
7a83e7dfa3 | ||
|
|
3984184a82 | ||
|
|
ca5bafcd2f | ||
|
|
ca332db7eb | ||
|
|
680361a88c | ||
|
|
375d29bab4 | ||
|
|
9784eb4c0b | ||
|
|
0031fb8274 | ||
|
|
8af6a4cf21 | ||
|
|
4b74034967 | ||
|
|
9762ef3ef6 | ||
|
|
691012782a | ||
|
|
f503ba499c | ||
|
|
d6915c0f40 | ||
|
|
f39f29c38f | ||
|
|
6a0300fdb7 | ||
|
|
7d29991fa5 | ||
|
|
ecbf74dbea | ||
|
|
d49fb9c010 | ||
|
|
c107a3799f | ||
|
|
40c450e6e5 | ||
|
|
6cb58af3db | ||
|
|
25087e09e6 | ||
|
|
d347497609 | ||
|
|
ab07bab140 | ||
|
|
f97e73962a | ||
|
|
c0a91b566b | ||
|
|
850ca01ca2 | ||
|
|
a483d41de2 | ||
|
|
185e01eecc | ||
|
|
7c393bc166 | ||
|
|
f170f37ba4 | ||
|
|
27e85e448b | ||
|
|
eff40229fe | ||
|
|
9ae06a3cac | ||
|
|
2783e0eb89 | ||
|
|
efed0e3f63 | ||
|
|
fd0e9652a8 | ||
|
|
d0da1d722c | ||
|
|
b48c790a76 | ||
|
|
417c9d923a | ||
|
|
e8c1dbb2da | ||
|
|
2bd972305a | ||
|
|
9b79486316 | ||
|
|
fbeff475cf | ||
|
|
0dc04fd8b9 | ||
|
|
288b323df8 | ||
|
|
29d4148fe9 | ||
|
|
62e57fe32c | ||
|
|
9dc7889493 | ||
|
|
aa5de4c05b | ||
|
|
c46ea40f70 | ||
|
|
ffe127e0d6 | ||
|
|
bcb7b65553 | ||
|
|
dd51808a6f | ||
|
|
94806555bf | ||
|
|
7a68f065fa | ||
|
|
ad61460faf | ||
|
|
91a43848a1 | ||
|
|
2a34a0ae10 | ||
|
|
fcc2bb5a05 | ||
|
|
7031708c87 | ||
|
|
8039cc40f7 | ||
|
|
515e136502 | ||
|
|
5fe56a862b | ||
|
|
9971919ca1 | ||
|
|
f0689b260a | ||
|
|
7ba752e2aa | ||
|
|
2b3b2e4aa6 | ||
|
|
4b160d88a2 | ||
|
|
a730a277b9 | ||
|
|
40e9d9c330 | ||
|
|
f373784aaf | ||
|
|
43d64b3bf8 | ||
|
|
f0dfdb30c3 | ||
|
|
cf76922a8f | ||
|
|
e41836f8bd | ||
|
|
0417a456c3 | ||
|
|
2f8685d0b7 | ||
|
|
3fc29b292c | ||
|
|
e7476350e3 | ||
|
|
31f6520ba9 | ||
|
|
1e6426ed4f | ||
|
|
6e78f855a3 | ||
|
|
ffad1f1dd1 | ||
|
|
5064506de4 | ||
|
|
e14430daa8 | ||
|
|
8c84501235 | ||
|
|
74576b88f6 | ||
|
|
7b3bd58735 | ||
|
|
95e3717f96 | ||
|
|
a9c4e4b422 | ||
|
|
d13a35ab96 | ||
|
|
43a2881074 | ||
|
|
a65b78fc61 | ||
|
|
4e763e8aa8 | ||
|
|
f9ff780422 | ||
|
|
b93e9b1698 | ||
|
|
de722390fc | ||
|
|
b37eb3dd5f | ||
|
|
f1113d2758 | ||
|
|
8ac30955a3 | ||
|
|
c4832fdb70 | ||
|
|
cce20f8bfc | ||
|
|
ccfb0f54e1 | ||
|
|
f1b50fb83a | ||
|
|
8ca4596918 | ||
|
|
5043e7fc8c | ||
|
|
fbfbc29789 | ||
|
|
86cb1058c3 | ||
|
|
40f7e17b3e | ||
|
|
bc29e75116 | ||
|
|
17ad6f8aad | ||
|
|
3f030ef537 | ||
|
|
62456a606e | ||
|
|
3512904e00 | ||
|
|
38a39b5892 | ||
|
|
0127ac39eb | ||
|
|
bc0da24e9c | ||
|
|
46ae3f4f5d | ||
|
|
19ba6c06dd | ||
|
|
8d7d79d54b | ||
|
|
c1e86ad0a6 | ||
|
|
dbbdad3ebd | ||
|
|
73f38dce88 | ||
|
|
d11fba2dfd | ||
|
|
911a114ad4 | ||
|
|
216fb5c3db | ||
|
|
18d74f2b10 | ||
|
|
f5e1a42f51 | ||
|
|
299e0e2674 | ||
|
|
3c47e49cf0 | ||
|
|
53e98620bf | ||
|
|
7d205b1711 | ||
|
|
e000494e48 | ||
|
|
2d71193bc0 | ||
|
|
696876d393 | ||
|
|
db1780678e | ||
|
|
861953fd2d | ||
|
|
7448cf6c65 | ||
|
|
4bb5b39410 | ||
|
|
f69426fd77 | ||
|
|
82a16f1305 | ||
|
|
b98d8aa8ec | ||
|
|
b6108527c2 | ||
|
|
c58a4a0166 | ||
|
|
ca3563b3b4 | ||
|
|
dec3748509 | ||
|
|
d62f21549f | ||
|
|
3f71fa641f | ||
|
|
2f398895c6 | ||
|
|
b4536a691a | ||
|
|
270ca2ddbe | ||
|
|
9b4f032660 | ||
|
|
53de48d2b3 | ||
|
|
08f8713ee1 | ||
|
|
a97f1457f6 | ||
|
|
fa2534a529 | ||
|
|
8538d1b2cb | ||
|
|
079b9ec86e | ||
|
|
52998e6aaa | ||
|
|
5448618dd5 | ||
|
|
eaf786c1ef | ||
|
|
c453d919b3 | ||
|
|
062264c7f6 | ||
|
|
713f0ef9ce | ||
|
|
d2cb78179d | ||
|
|
fff0e55f41 | ||
|
|
a743b16728 | ||
|
|
23c8f6d507 | ||
|
|
2250d102b2 | ||
|
|
96ecb47bc7 | ||
|
|
6050c86ab6 | ||
|
|
e6cc7db3c1 | ||
|
|
0ba1cfc612 | ||
|
|
2e08bda19d | ||
|
|
2c59a28860 | ||
|
|
f20723ca54 | ||
|
|
a2a2bafdf6 | ||
|
|
ce83276fa4 | ||
|
|
d40c710354 | ||
|
|
f65231becb | ||
|
|
039358e049 | ||
|
|
59929a8d06 | ||
|
|
9ff6baf7d9 | ||
|
|
0a4922b40f | ||
|
|
c5c6c32f24 | ||
|
|
351ba167f5 | ||
|
|
8334149cb2 | ||
|
|
5b05f13ce5 | ||
|
|
f25b7b73b4 | ||
|
|
7e70f8d2c1 | ||
|
|
227139aa33 | ||
|
|
339e95e9d7 | ||
|
|
2494de8f12 | ||
|
|
d12afc6039 | ||
|
|
abe70d1793 | ||
|
|
893d4fba5f | ||
|
|
d9fdbb627f | ||
|
|
79afb7afdb | ||
|
|
5fe588b480 | ||
|
|
8f41835352 | ||
|
|
b8649fc364 | ||
|
|
981306fa2b | ||
|
|
c8c6a48b94 | ||
|
|
5d5b42d3f5 | ||
|
|
a1fc99c66f | ||
|
|
6c4deed37a | ||
|
|
b6538b2cdd | ||
|
|
a57bb6a7d4 | ||
|
|
6ff392edc0 | ||
|
|
e9bced867d | ||
|
|
39675434f6 | ||
|
|
2e75c6dbdf | ||
|
|
c0d3e70296 | ||
|
|
028f29556f | ||
|
|
3a601e0fc3 | ||
|
|
af34e414e1 | ||
|
|
9677871ce7 | ||
|
|
d87a2315ce | ||
|
|
97faeccebf | ||
|
|
6afbcac9d0 | ||
|
|
e2ca7b8632 | ||
|
|
557367cf48 | ||
|
|
6c8c3257fd | ||
|
|
6f88db1be2 | ||
|
|
dc3911be14 | ||
|
|
03498bd2fd | ||
|
|
6ddb44985e | ||
|
|
2f923e47ec | ||
|
|
ebce0578e6 | ||
|
|
7563a62dfe | ||
|
|
f5a4d27e57 | ||
|
|
6eb479d5af | ||
|
|
8c9c2bec6c | ||
|
|
86727ae271 | ||
|
|
c6e498e837 | ||
|
|
be01b87cd4 | ||
|
|
0330dc3159 | ||
|
|
e493562735 | ||
|
|
b1c196ed83 | ||
|
|
a8a0a2655d | ||
|
|
b0e3afc0be | ||
|
|
242bcc8643 | ||
|
|
17af98126a | ||
|
|
d9a723b0ab | ||
|
|
6cc7a8b052 | ||
|
|
2341c2e729 | ||
|
|
80cbdbb535 | ||
|
|
39b6f4bbf5 | ||
|
|
c850439f92 | ||
|
|
44b71de41e | ||
|
|
5c059e604b | ||
|
|
07cc807bc9 | ||
|
|
e96fb67367 | ||
|
|
7259905114 | ||
|
|
0d0d28641d | ||
|
|
01a5b97415 | ||
|
|
8c662c65a9 | ||
|
|
88a6fe3c7a | ||
|
|
fcc3d9ed2b | ||
|
|
58efa18f96 | ||
|
|
887772db22 | ||
|
|
73f32a88e3 | ||
|
|
ddb109a456 | ||
|
|
ccf53afad4 | ||
|
|
170ae9f3be | ||
|
|
7d513ff1fe | ||
|
|
b387880194 | ||
|
|
b15db795e1 | ||
|
|
2a3f57bc61 | ||
|
|
0a928d6e9d | ||
|
|
4d7fddaf7e | ||
|
|
5438eb0057 | ||
|
|
88f7852928 | ||
|
|
a5a098943c | ||
|
|
f267e479ba | ||
|
|
c9c0dd367f | ||
|
|
a395af31b1 | ||
|
|
3389f0eece | ||
|
|
3b1197a1bb | ||
|
|
8df74dde80 | ||
|
|
c6d80496ab |
397 changed files with 38640 additions and 19617 deletions
8
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
8
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
|
|
@ -11,7 +11,9 @@ body:
|
||||||
|
|
||||||
## Important Notes
|
## Important Notes
|
||||||
|
|
||||||
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) and [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. Duplicates may be closed without notice. **Please search for existing issues and discussions.**
|
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) and [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. Duplicates may be closed without notice. **Please search for existing issues AND discussions. No matter open or closed.**
|
||||||
|
|
||||||
|
- Check for opened, **but also for (recently) CLOSED issues** as the issue you are trying to report **might already have been fixed on the dev branch!**
|
||||||
|
|
||||||
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
|
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
|
||||||
|
|
||||||
|
|
@ -19,6 +21,8 @@ body:
|
||||||
|
|
||||||
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
|
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
|
||||||
|
|
||||||
|
- **Scope**: If you want to report a SECURITY VULNERABILITY, then do so through our [GitHub security page](https://github.com/open-webui/open-webui/security).
|
||||||
|
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: issue-check
|
id: issue-check
|
||||||
attributes:
|
attributes:
|
||||||
|
|
@ -29,6 +33,8 @@ body:
|
||||||
required: true
|
required: true
|
||||||
- label: I have searched for any existing and/or related discussions.
|
- label: I have searched for any existing and/or related discussions.
|
||||||
required: true
|
required: true
|
||||||
|
- label: I have also searched in the CLOSED issues AND CLOSED discussions and found no related items (your issue might already be addressed on the development branch!).
|
||||||
|
required: true
|
||||||
- label: I am using the latest version of Open WebUI.
|
- label: I am using the latest version of Open WebUI.
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
|
|
|
||||||
26
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
26
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
|
|
@ -8,8 +8,19 @@ body:
|
||||||
value: |
|
value: |
|
||||||
## Important Notes
|
## Important Notes
|
||||||
### Before submitting
|
### Before submitting
|
||||||
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
|
|
||||||
|
Please check the **open AND closed** [Issues](https://github.com/open-webui/open-webui/issues) AND [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
|
||||||
It's likely we're already tracking it! If you’re unsure, start a discussion post first.
|
It's likely we're already tracking it! If you’re unsure, start a discussion post first.
|
||||||
|
|
||||||
|
#### Scope
|
||||||
|
|
||||||
|
If your feature request is likely to take more than a quick coding session to implement, test and verify, then open it in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions) instead.
|
||||||
|
**We will close and force move your feature request to the Ideas section, if we believe your feature request is not trivial/quick to implement.**
|
||||||
|
This is to ensure the issues tab is used only for issues, quickly addressable feature requests and tracking tickets by the maintainers.
|
||||||
|
Other feature requests belong in the **Ideas** section of the [Discussions](https://github.com/open-webui/open-webui/discussions).
|
||||||
|
|
||||||
|
If your feature request might impact others in the community, definitely open a discussion instead and evaluate whether and how to implement it.
|
||||||
|
|
||||||
This will help us efficiently focus on improving the project.
|
This will help us efficiently focus on improving the project.
|
||||||
|
|
||||||
### Collaborate respectfully
|
### Collaborate respectfully
|
||||||
|
|
@ -22,7 +33,6 @@ body:
|
||||||
|
|
||||||
We appreciate your time and ask that you **respect ours**.
|
We appreciate your time and ask that you **respect ours**.
|
||||||
|
|
||||||
|
|
||||||
### Contributing
|
### Contributing
|
||||||
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
|
||||||
|
|
||||||
|
|
@ -35,14 +45,22 @@ body:
|
||||||
label: Check Existing Issues
|
label: Check Existing Issues
|
||||||
description: Please confirm that you've checked for existing similar requests
|
description: Please confirm that you've checked for existing similar requests
|
||||||
options:
|
options:
|
||||||
- label: I have searched the existing issues and discussions.
|
- label: I have searched for all existing **open AND closed** issues and discussions for similar requests. I have found none that is comparable to my request.
|
||||||
|
required: true
|
||||||
|
- type: checkboxes
|
||||||
|
id: feature-scope
|
||||||
|
attributes:
|
||||||
|
label: Verify Feature Scope
|
||||||
|
description: Please confirm the feature's scope is within the described scope
|
||||||
|
options:
|
||||||
|
- label: I have read through and understood the scope definition for feature requests in the Issues section. I believe my feature request meets the definition and belongs in the Issues section instead of the Discussions.
|
||||||
required: true
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: problem-description
|
id: problem-description
|
||||||
attributes:
|
attributes:
|
||||||
label: Problem Description
|
label: Problem Description
|
||||||
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
|
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
|
||||||
placeholder: "Ex. I'm always frustrated when..."
|
placeholder: "Ex. I'm always frustrated when... / Not related to a problem"
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|
|
||||||
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
|
|
@ -1,17 +1,20 @@
|
||||||
# Pull Request Checklist
|
# Pull Request Checklist
|
||||||
|
|
||||||
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
|
### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) to discuss your idea/fix with the community before creating a pull request, and describe your changes before submitting a pull request.
|
||||||
|
|
||||||
|
This is to ensure large feature PRs are discussed with the community first, before starting work on it. If the community does not want this feature or it is not relevant for Open WebUI as a project, it can be identified in the discussion before working on the feature and submitting the PR.
|
||||||
|
|
||||||
**Before submitting, make sure you've checked the following:**
|
**Before submitting, make sure you've checked the following:**
|
||||||
|
|
||||||
- [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
|
- [ ] **Target branch:** Verify that the pull request targets the `dev` branch. **Not targeting the `dev` branch will lead to immediate closure of the PR.**
|
||||||
- [ ] **Description:** Provide a concise description of the changes made in this pull request.
|
- [ ] **Description:** Provide a concise description of the changes made in this pull request down below.
|
||||||
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
|
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
|
||||||
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
|
- [ ] **Documentation:** If necessary, update relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs) like environment variables, the tutorials, or other documentation sources.
|
||||||
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
|
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
|
||||||
- [ ] **Testing:** Have you written and run sufficient tests to validate the changes?
|
- [ ] **Testing:** Perform manual tests to **verify the implemented fix/feature works as intended AND does not break any other functionality**. Take this as an opportunity to **make screenshots of the feature/fix and include it in the PR description**.
|
||||||
|
- [ ] **Agentic AI Code:** Confirm this Pull Request is **not written by any AI Agent** or has at least **gone through additional human review AND manual testing**. If any AI Agent is the co-author of this PR, it may lead to immediate closure of the PR.
|
||||||
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
|
- [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
|
||||||
- [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
|
- [ ] **Title Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following:
|
||||||
- **BREAKING CHANGE**: Significant changes that may affect compatibility
|
- **BREAKING CHANGE**: Significant changes that may affect compatibility
|
||||||
- **build**: Changes that affect the build system or external dependencies
|
- **build**: Changes that affect the build system or external dependencies
|
||||||
- **ci**: Changes to our continuous integration processes or workflows
|
- **ci**: Changes to our continuous integration processes or workflows
|
||||||
|
|
@ -74,3 +77,6 @@
|
||||||
### Contributor License Agreement
|
### Contributor License Agreement
|
||||||
|
|
||||||
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](https://github.com/open-webui/open-webui/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.
|
By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](https://github.com/open-webui/open-webui/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Deleting the CLA section will lead to immediate closure of your PR and it will not be merged in.
|
||||||
|
|
|
||||||
6
.github/workflows/docker-build.yaml
vendored
6
.github/workflows/docker-build.yaml
vendored
|
|
@ -141,6 +141,9 @@ jobs:
|
||||||
platform=${{ matrix.platform }}
|
platform=${{ matrix.platform }}
|
||||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Delete huge unnecessary tools folder
|
||||||
|
run: rm -rf /opt/hostedtoolcache
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
|
@ -243,6 +246,9 @@ jobs:
|
||||||
platform=${{ matrix.platform }}
|
platform=${{ matrix.platform }}
|
||||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Delete huge unnecessary tools folder
|
||||||
|
run: rm -rf /opt/hostedtoolcache
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
|
|
||||||
438
CHANGELOG.md
438
CHANGELOG.md
|
|
@ -5,25 +5,439 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.6.41] - 2025-12-02
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🚦 Sign-in rate limiting was implemented to protect against brute force attacks, limiting login attempts to 15 per 3-minute window per email address using Redis with automatic fallback to in-memory storage when Redis is unavailable. [Commit](https://github.com/open-webui/open-webui/commit/7b166370432414ce8f186747fb098e0c70fb2d6b)
|
||||||
|
- 📂 Administrators can now globally disable the folders feature and control user-level folder permissions through the admin panel, enabling minimalist interface configurations for deployments that don't require workspace organization features. [#19529](https://github.com/open-webui/open-webui/pull/19529), [#19210](https://github.com/open-webui/open-webui/discussions/19210), [#18459](https://github.com/open-webui/open-webui/discussions/18459), [#18299](https://github.com/open-webui/open-webui/discussions/18299)
|
||||||
|
- 👥 Group channels were introduced as a new channel type enabling membership-based collaboration spaces where users explicitly join as members rather than accessing through permissions, with support for public or private visibility, automatic member inclusion from specified user groups, member role tracking with invitation metadata, and post-creation member management allowing channel managers to add or remove members through the channel info modal. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4), [Commit](https://github.com/open-webui/open-webui/commit/3f1d9ccbf8443a2fa5278f36202bad930a216680)
|
||||||
|
- 💬 Direct Message channels were introduced with a dedicated channel type selector and multi-user member selection interface, enabling private conversations between specific users without requiring full channel visibility. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386)
|
||||||
|
- 📨 Direct Message channels now support a complete user-to-user messaging system with member-based access control, automatic deduplication for one-on-one conversations, optional channel naming, and distinct visual presentation using participant avatars instead of channel icons. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
||||||
|
- 🙈 Users can now hide Direct Message channels from their sidebar while preserving message history, with automatic reactivation when new messages arrive from other participants, providing a cleaner interface for managing active conversations. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
||||||
|
- ☑️ A comprehensive user selection component was added to the channel creation modal, featuring search functionality, sortable user lists, pagination support, and multi-select checkboxes for building Direct Message participant lists. [Commit](https://github.com/open-webui/open-webui/commit/acccb9afdd557274d6296c70258bb897bbb6652f)
|
||||||
|
- 🔴 Channel unread message count tracking was implemented with visual badge indicators in the sidebar, automatically updating counts in real-time and marking messages as read when users view channels, with join/leave functionality to manage membership status. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386)
|
||||||
|
- 📌 Message pinning functionality was added to channels, allowing users to pin important messages for easy reference with visual highlighting, a dedicated pinned messages modal accessible from the navbar, and complete backend support for tracking pinned status, pin timestamp, and the user who pinned each message. [Commit](https://github.com/open-webui/open-webui/commit/64b4d5d9c280b926746584aaf92b447d09deb386), [Commit](https://github.com/open-webui/open-webui/commit/aae2fce17355419d9c29f8100409108037895201)
|
||||||
|
- 🟢 Direct Message channels now display an active status indicator for one-on-one conversations, showing a green dot when the other participant is currently online or a gray dot when offline. [Commit](https://github.com/open-webui/open-webui/commit/4b6773885cd7527c5a56b963781dac5e95105eec), [Commit](https://github.com/open-webui/open-webui/commit/39645102d14f34e71b34e5ddce0625790be33f6f)
|
||||||
|
- 🆔 Users can now start Direct Message conversations directly from user profile previews by clicking the "Message" button, enabling quick access to private messaging without navigating away from the current channel. [Commit](https://github.com/open-webui/open-webui/commit/a0826ec9fedb56320532616d568fa59dda831d4e)
|
||||||
|
- ⚡ Channel messages now appear instantly when sent using optimistic UI rendering, displaying with a pending state while the server confirms delivery, providing a more responsive messaging experience. [Commit](https://github.com/open-webui/open-webui/commit/25994dd3da90600401f53596d4e4fb067c1b8eaa)
|
||||||
|
- 👍 Channel message reactions now display the names of users who reacted when hovering over the emoji, showing up to three names with a count for additional reactors. [Commit](https://github.com/open-webui/open-webui/commit/05e79bdd0c7af70b631e958924e3656db1013b80)
|
||||||
|
- 🛠️ Channel creators can now edit and delete their own group and DM channels without requiring administrator privileges, enabling users to manage the channels they create independently. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
||||||
|
- 🔌 A new API endpoint was added to directly get or create a Direct Message channel with a specific user by their ID, streamlining programmatic DM channel creation for integrations and frontend workflows. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
||||||
|
- 💭 Users can now set a custom status with an emoji and message that displays in profile previews, the sidebar user menu, and Direct Message channel items in the sidebar, with the ability to clear status at any time, providing visibility into availability or current focus similar to team communication platforms. [Commit](https://github.com/open-webui/open-webui/commit/51621ba91a982e52da168ce823abffd11ad3e4fa), [Commit](https://github.com/open-webui/open-webui/commit/f5e8d4d5a004115489c35725408b057e24dfe318)
|
||||||
|
- 📤 A group export API endpoint was added, enabling administrators to export complete group data including member lists for backup and migration purposes. [Commit](https://github.com/open-webui/open-webui/commit/09b6ea38c579659f8ca43ae5ea3746df3ac561ad)
|
||||||
|
- 📡 A new API endpoint was added to retrieve all users belonging to a specific group, enabling programmatic access to group membership information for administrative workflows. [Commit](https://github.com/open-webui/open-webui/commit/01868e856a10f474f74fbd1b4425dafdf949222f)
|
||||||
|
- 👁️ The admin user list now displays an active status indicator next to each user, showing a visual green dot for users who have been active within the last three minutes. [Commit](https://github.com/open-webui/open-webui/commit/1b095d12ff2465b83afa94af89ded9593f8a8655)
|
||||||
|
- 🔑 The admin user edit modal now displays OAuth identity information with a per-provider breakdown, showing each linked identity provider and its associated subject identifier separately. [#19573](https://github.com/open-webui/open-webui/pull/19573)
|
||||||
|
- 🧩 OAuth role claim parsing now respects the "OAUTH_ROLES_SEPARATOR" configuration, enabling proper parsing of roles returned as comma-separated strings and providing consistent behavior with group claim handling. [#19514](https://github.com/open-webui/open-webui/pull/19514)
|
||||||
|
- 🎛️ Channel feature access can now be controlled through both the "USER_PERMISSIONS_FEATURES_CHANNELS" environment variable and group permission toggles in the admin panel, allowing administrators to restrict channel functionality for specific users or groups while defaulting to enabled for all users. [Commit](https://github.com/open-webui/open-webui/commit/f589b7c1895a6a77166c047891acfa21bc0936c4)
|
||||||
|
- 🎨 The model editor interface was refined with access control settings moved to a dedicated modal, group member counts now displayed when configuring permissions, reorganized layout with improved visual hierarchy, and redesigned prompt suggestions cards with tooltips for field guidance. [Commit](https://github.com/open-webui/open-webui/commit/e65d92fc6f49da5ca059e1c65a729e7973354b99), [Commit](https://github.com/open-webui/open-webui/commit/9d39b9b42c653ee2acf2674b2df343ecbceb4954)
|
||||||
|
- 🏗️ Knowledge base file management was rebuilt with a dedicated database table replacing the previous JSON array storage, enabling pagination support for large knowledge bases, significantly faster file listing performance, and more reliable file-knowledge base relationship tracking. [Commit](https://github.com/open-webui/open-webui/commit/d19023288e2ca40f86e2dc3fd9f230540f3e70d7)
|
||||||
|
- ☁️ Azure Document Intelligence model selection was added, allowing administrators to specify which model to use for document processing via the "DOCUMENT_INTELLIGENCE_MODEL" environment variable or admin UI setting, with "prebuilt-layout" as the default. [#19692](https://github.com/open-webui/open-webui/pull/19692), [Docs:#872](https://github.com/open-webui/docs/pull/872)
|
||||||
|
- 🚀 Milvus multitenancy vector database performance was improved by removing manual flush calls after upsert operations, eliminating rate limit errors and reducing load on etcd and MinIO/S3 storage by allowing Milvus to manage segment persistence automatically via its WAL and auto-flush policies. [#19680](https://github.com/open-webui/open-webui/pull/19680)
|
||||||
|
- ✨ Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌍 Translations for German, French, Portuguese (Brazil), Catalan, Simplified Chinese, and Traditional Chinese were enhanced and expanded.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🔄 Tool call response token duplication was fixed by removing redundant message history additions in non-native function calling mode, resolving an issue where tool results were included twice in the context and causing 2x token consumption. [#19656](https://github.com/open-webui/open-webui/issues/19656), [Commit](https://github.com/open-webui/open-webui/commit/52ccab8)
|
||||||
|
- 🛡️ Web search domain filtering was corrected to properly block results when any resolved hostname or IP address matches a blocked domain, preventing blocked sites from appearing in search results due to permissive hostname resolution logic that previously allowed results through if any single resolved address passed the filter. [#19670](https://github.com/open-webui/open-webui/pull/19670), [#19669](https://github.com/open-webui/open-webui/issues/19669)
|
||||||
|
- 🧠 Custom models based on Ollama or OpenAI now properly inherit the connection type from their base model, ensuring they appear correctly in the "Local" or "External" model selection tabs instead of only appearing under "All". [#19183](https://github.com/open-webui/open-webui/issues/19183), [Commit](https://github.com/open-webui/open-webui/commit/39f7575)
|
||||||
|
- 🐍 SentenceTransformers embedding initialization was fixed by updating the transformers dependency to version 4.57.3, resolving a regression in v0.6.40 where document ingestion failed with "'NoneType' object has no attribute 'encode'" errors due to a bug in transformers 4.57.2. [#19512](https://github.com/open-webui/open-webui/issues/19512), [#19513](https://github.com/open-webui/open-webui/pull/19513)
|
||||||
|
- 📈 Active user count accuracy was significantly improved by replacing the socket-based USER_POOL tracking with a database-backed heartbeat mechanism, resolving long-standing issues where Redis deployments displayed inflated user counts due to stale sessions never being cleaned up on disconnect. [#16074](https://github.com/open-webui/open-webui/discussions/16074), [Commit](https://github.com/open-webui/open-webui/commit/70948f8803e417459d5203839f8077fdbfbbb213)
|
||||||
|
- 👥 Default group assignment now applies consistently across all user registration methods including OAuth/SSO, LDAP, and admin-created users, fixing an issue where the "DEFAULT_GROUP_ID" setting was only being applied to users who signed up via the email/password signup form. [#19685](https://github.com/open-webui/open-webui/pull/19685)
|
||||||
|
- 🔦 Model list filtering in workspaces was corrected to properly include models shared with user groups, ensuring members can view models they have write access to through group permissions. [#19461](https://github.com/open-webui/open-webui/issues/19461), [Commit](https://github.com/open-webui/open-webui/commit/69722ba973768a5f689f2e2351bf583a8db9bba8)
|
||||||
|
- 🖼️ User profile image display in preview contexts was fixed by resolving a Pydantic validation error that prevented proper rendering. [Commit](https://github.com/open-webui/open-webui/commit/c7eb7136893b0ddfdc5d55ffc7a05bd84a00f5d6)
|
||||||
|
- 🔒 Redis TLS connection failures were resolved by updating the python-socketio dependency to version 5.15.0, restoring support for the "rediss://" URL schema. [#19480](https://github.com/open-webui/open-webui/issues/19480), [#19488](https://github.com/open-webui/open-webui/pull/19488)
|
||||||
|
- 📝 MCP tool server configuration was corrected to properly handle the "Function Name Filter List" as both string and list types, preventing AttributeError when the field is empty and ensuring backward compatibility. [#19486](https://github.com/open-webui/open-webui/issues/19486), [Commit](https://github.com/open-webui/open-webui/commit/c5b73d71843edc024325d4a6e625ec939a747279), [Commit](https://github.com/open-webui/open-webui/commit/477097c2e42985c14892301d0127314629d07df1)
|
||||||
|
- 📎 Web page attachment failures causing TypeError on metadata checks were resolved by correcting async threadpool parameter passing in vector database operations. [#19493](https://github.com/open-webui/open-webui/issues/19493), [Commit](https://github.com/open-webui/open-webui/commit/4370dee79e19d77062c03fba81780cb3b779fca3)
|
||||||
|
- 💾 Model allowlist persistence in multi-worker deployments was fixed by implementing Redis-based shared state for the internal models dictionary, ensuring configuration changes are consistently visible across all worker processes. [#19395](https://github.com/open-webui/open-webui/issues/19395), [Commit](https://github.com/open-webui/open-webui/commit/b5e5617d7f7ad3e4eec9f15f4cc7f07cb5afc2fa)
|
||||||
|
- ⏳ Chat history infinite loading was prevented by enhancing message data structure to properly track parent message relationships, resolving issues where missing parentId fields caused perpetual loading states. [#19225](https://github.com/open-webui/open-webui/issues/19225), [Commit](https://github.com/open-webui/open-webui/commit/ff4b1b9862d15adfa15eac17d2ce066c3d8ae38f)
|
||||||
|
- 🩹 Database migration robustness was improved by automatically detecting and correcting missing primary key constraints on the user table, ensuring successful schema upgrades for databases with non-standard configurations. [#19487](https://github.com/open-webui/open-webui/discussions/19487), [Commit](https://github.com/open-webui/open-webui/commit/453ea9b9a167c0b03d86c46e6efd086bf10056ce)
|
||||||
|
- 🏷️ OAuth group assignment now updates correctly on first login when users transition from admin to user role, ensuring group memberships reflect immediately when group management is enabled. [#19475](https://github.com/open-webui/open-webui/issues/19475), [#19476](https://github.com/open-webui/open-webui/pull/19476)
|
||||||
|
- 💡 Knowledge base file tooltips now properly display the parent collection name when referencing files with the hash symbol, preventing confusion between identically-named files in different collections. [#19491](https://github.com/open-webui/open-webui/issues/19491), [Commit](https://github.com/open-webui/open-webui/commit/3fe5a47b0ff84ac97f8e4ff56a19fa2ec065bf66)
|
||||||
|
- 🔐 Knowledge base file access inconsistencies were resolved where authorized non-admin users received "Not found" or permission errors for certain files due to race conditions during upload causing mismatched collection_name values, with file access validation now properly checking against knowledge base file associations. [#18689](https://github.com/open-webui/open-webui/issues/18689), [#19523](https://github.com/open-webui/open-webui/pull/19523), [Commit](https://github.com/open-webui/open-webui/commit/e301d1962e45900ababd3eabb7e9a2ad275a5761)
|
||||||
|
- 📦 Knowledge API batch file addition endpoint was corrected to properly handle async operations, resolving 500 Internal Server Error responses when adding multiple files simultaneously. [#19538](https://github.com/open-webui/open-webui/issues/19538), [Commit](https://github.com/open-webui/open-webui/commit/28659f60d94feb4f6a99bb1a5b54d7f45e5ea10f)
|
||||||
|
- 🤖 Embedding model auto-update functionality was fixed to properly respect the "RAG_EMBEDDING_MODEL_AUTO_UPDATE" setting by correctly passing the flag to the model path resolver, ensuring models update as expected when the auto-update option is enabled. [#19687](https://github.com/open-webui/open-webui/pull/19687)
|
||||||
|
- 📉 API response payload sizes were dramatically reduced by removing base64-encoded profile images from most endpoints, eliminating multi-megabyte responses caused by high-resolution avatars and enabling better browser caching. [#19519](https://github.com/open-webui/open-webui/issues/19519), [Commit](https://github.com/open-webui/open-webui/commit/384753c4c17f62a68d38af4bbcf55a21ee08e0f2)
|
||||||
|
- 📞 Redundant API calls on the admin user overview page were eliminated by consolidating reactive statements, reducing four duplicate requests to a single efficient call and significantly improving page load performance. [#19509](https://github.com/open-webui/open-webui/issues/19509), [Commit](https://github.com/open-webui/open-webui/commit/9f89cc5e9f7e1c6c9e2bc91177e08df7c79f66f9)
|
||||||
|
- 🧹 Duplicate API calls on the workspace models page were eliminated by removing redundant model list fetching, reducing two identical requests to a single call and improving page responsiveness. [#19517](https://github.com/open-webui/open-webui/issues/19517), [Commit](https://github.com/open-webui/open-webui/commit/d1bbf6be7a4d1d53fa8ad46ca4f62fc4b2e6a8cb)
|
||||||
|
- 🔘 The model valves button was corrected to prevent unintended form submission by adding explicit button type attribute, ensuring it no longer triggers message sending when the input area contains text. [#19534](https://github.com/open-webui/open-webui/pull/19534)
|
||||||
|
- 🗑️ Ollama model deletion was fixed by correcting the request payload format and ensuring the model selector properly displays the placeholder option. [Commit](https://github.com/open-webui/open-webui/commit/0f3156651c64bc5af188a65fc2908bdcecf30c74)
|
||||||
|
- 🎨 Image generation in temporary chats was fixed by correctly handling local chat sessions that are not persisted to the database. [Commit](https://github.com/open-webui/open-webui/commit/a7c7993bbf3a21cb7ba416525b89233cf2ad877f)
|
||||||
|
- 🕵️♂️ Audit logging was fixed by correctly awaiting the async user authentication call, resolving failures where coroutine objects were passed instead of user data. [#19658](https://github.com/open-webui/open-webui/pull/19658), [Commit](https://github.com/open-webui/open-webui/commit/dba86bc)
|
||||||
|
- 🌙 Dark mode select dropdown styling was corrected to use proper background colors, fixing an issue where dropdown borders and hover states appeared white instead of matching the dark theme. [#19693](https://github.com/open-webui/open-webui/pull/19693), [#19442](https://github.com/open-webui/open-webui/issues/19442)
|
||||||
|
- 🔍 Milvus vector database query filtering was fixed by correcting string quote handling in filter expressions and using the proper parameter name for queries, resolving false "duplicate content detected" errors that prevented uploading multiple files to knowledge bases. [#19602](https://github.com/open-webui/open-webui/pull/19602), [#18119](https://github.com/open-webui/open-webui/issues/18119), [#16345](https://github.com/open-webui/open-webui/issues/16345), [#17088](https://github.com/open-webui/open-webui/issues/17088), [#18485](https://github.com/open-webui/open-webui/issues/18485)
|
||||||
|
- 🆙 Milvus multitenancy vector database was updated to use query_iterator() for improved robustness and consistency with the standard Milvus implementation, fixing the same false duplicate detection errors and improving handling of large result sets in multi-tenant deployments. [#19695](https://github.com/open-webui/open-webui/pull/19695)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- ⚠️ **IMPORTANT for Multi-Instance Deployments** — This release includes database schema changes; multi-worker, multi-server, or load-balanced deployments must update all instances simultaneously rather than performing rolling updates, as running mixed versions will cause application failures due to schema incompatibility between old and new instances.
|
||||||
|
- 👮 Channel creation is now restricted to administrators only, with the channel add button hidden for regular users to maintain organizational control over communication channels. [Commit](https://github.com/open-webui/open-webui/commit/421aba7cd7cd708168b1f2565026c74525a67905)
|
||||||
|
- ➖ The active user count indicator was removed from the bottom-left user menu in the sidebar to streamline the interface. [Commit](https://github.com/open-webui/open-webui/commit/848f3fd4d86ca66656e0ff0335773945af8d7d8d)
|
||||||
|
- 🗂️ The user table was restructured with API keys migrated to a dedicated table supporting future multi-key functionality, OAuth data storage converted to a JSON structure enabling multiple identity providers per user account, and internal column types optimized from TEXT to JSON for the "info" and "settings" fields, with automatic migration preserving all existing data and associations. [#19573](https://github.com/open-webui/open-webui/pull/19573)
|
||||||
|
- 🔄 The knowledge base API was restructured to support the new file relationship model.
|
||||||
|
|
||||||
|
## [0.6.40] - 2025-11-25
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🗄️ A critical PostgreSQL user listing performance issue was resolved by removing a redundant count operation that caused severe database slowdowns and potential timeouts when viewing user lists in admin panels.
|
||||||
|
|
||||||
|
## [0.6.39] - 2025-11-25
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 💬 A user list modal was added to channels, displaying all users with access and featuring search, sorting, and pagination capabilities. [Commit](https://github.com/open-webui/open-webui/commit/c0e120353824be00a2ef63cbde8be5d625bd6fd0)
|
||||||
|
- 💬 Channel navigation now displays the total number of users with access to the channel. [Commit](https://github.com/open-webui/open-webui/commit/3b5710d0cd445cf86423187f5ee7c40472a0df0b)
|
||||||
|
- 🔌 Tool servers and MCP connections now support function name filtering, allowing administrators to selectively enable or block specific functions using allow/block lists. [Commit](https://github.com/open-webui/open-webui/commit/743199f2d097ae1458381bce450d9025a0ab3f3d)
|
||||||
|
- ⚡ A toggle to disable parallel embedding processing was added via "ENABLE_ASYNC_EMBEDDING", allowing sequential processing for rate-limited or resource-constrained local embedding setups. [#19444](https://github.com/open-webui/open-webui/pull/19444)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Localization improvements were made for German (de-DE) and Portuguese (Brazil) translations.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 📝 Inline citations now render correctly within markdown lists and nested elements instead of displaying as "undefined" values. [#19452](https://github.com/open-webui/open-webui/issues/19452)
|
||||||
|
- 👥 Group member selection now works correctly without randomly selecting other users or causing the user list to jump around. [#19426](https://github.com/open-webui/open-webui/issues/19426)
|
||||||
|
- 👥 Admin panel user list now displays the correct total user count and properly paginates 30 items per page after fixing database query issues with group member joins. [#19429](https://github.com/open-webui/open-webui/issues/19429)
|
||||||
|
- 🔍 Knowledge base reindexing now works correctly after resolving async execution chain issues by implementing threadpool workers for embedding operations. [#19434](https://github.com/open-webui/open-webui/pull/19434)
|
||||||
|
- 🖼️ OpenAI image generation now works correctly after fixing a connection adapter error caused by incorrect URL formatting. [#19435](https://github.com/open-webui/open-webui/pull/19435)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- 🔧 BREAKING: Docling configuration has been consolidated from individual environment variables into a single "DOCLING_PARAMS" JSON configuration and now supports API key authentication via "DOCLING_API_KEY", requiring users to migrate existing Docling settings to the new format. [#16841](https://github.com/open-webui/open-webui/issues/16841), [#19427](https://github.com/open-webui/open-webui/pull/19427)
|
||||||
|
- 🔧 The environment variable "REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE" has been renamed to "ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION" for naming consistency.
|
||||||
|
|
||||||
|
## [0.6.38] - 2025-11-24
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🔍 Hybrid search now works reliably after recent changes.
|
||||||
|
- 🛠️ Tool server saving now handles errors gracefully, preventing failed saves from impacting the UI.
|
||||||
|
- 🔐 SSO/OIDC code fixed to improve login reliability and better handle edge cases.
|
||||||
|
|
||||||
|
## [0.6.37] - 2025-11-24
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🔐 Granular sharing permissions are now available with two-tiered control separating group sharing from public sharing, allowing administrators to independently configure whether users can share workspace items with groups or make them publicly accessible, with separate permission toggles for models, knowledge bases, prompts, tools, and notes, configurable via "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_SHARING", "USER_PERMISSIONS_WORKSPACE_MODELS_ALLOW_PUBLIC_SHARING", and corresponding environment variables for other workspace item types, while groups can now be configured to opt-out of sharing via the "Allow Group Sharing" setting. [Commit](https://github.com/open-webui/open-webui/commit/7be750bcbb40da91912a0a66b7ab791effdcc3b6), [Commit](https://github.com/open-webui/open-webui/commit/f69e37a8507d6d57382d6670641b367f3127f90a)
|
||||||
|
- 🔐 Password policy enforcement is now available with configurable validation rules, allowing administrators to require specific password complexity requirements via "ENABLE_PASSWORD_VALIDATION" and "PASSWORD_VALIDATION_REGEX_PATTERN" environment variables, with default pattern requiring minimum 8 characters including uppercase, lowercase, digit, and special character. [#17794](https://github.com/open-webui/open-webui/pull/17794)
|
||||||
|
- 🔐 Granular import and export permissions are now available for workspace items, introducing six separate permission toggles for models, prompts, and tools that are disabled by default for enhanced security. [#19242](https://github.com/open-webui/open-webui/pull/19242)
|
||||||
|
- 👥 Default group assignment is now available for new users, allowing administrators to automatically assign newly registered users to a specified group for streamlined access control to models, prompts, and tools, particularly useful for organizations with group-based model access policies. [#19325](https://github.com/open-webui/open-webui/pull/19325), [#17842](https://github.com/open-webui/open-webui/issues/17842)
|
||||||
|
- 🔒 Password-based authentication can now be fully disabled via "ENABLE_PASSWORD_AUTH" environment variable, enforcing SSO-only authentication and preventing password login fallback when SSO is configured. [#19113](https://github.com/open-webui/open-webui/pull/19113)
|
||||||
|
- 🖼️ Large stream chunk handling was implemented to support models that generate images directly in their output responses, with configurable buffer size via "CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE" environment variable, resolving compatibility issues with models like Gemini 2.5 Flash Image. [#18884](https://github.com/open-webui/open-webui/pull/18884), [#17626](https://github.com/open-webui/open-webui/issues/17626)
|
||||||
|
- 🖼️ Streaming response middleware now handles images in delta updates with automatic base64 conversion, enabling proper display of images from models using the "choices[0].delta.images.image_url" format such as Gemini 2.5 Flash Image Preview on OpenRouter. [#19073](https://github.com/open-webui/open-webui/pull/19073), [#19019](https://github.com/open-webui/open-webui/issues/19019)
|
||||||
|
- 📈 Model list API performance was optimized by pre-fetching user group memberships and removing profile image URLs from response payloads, significantly reducing both database queries and payload size for instances with large model lists, with profile images now served dynamically via dedicated endpoints. [#19097](https://github.com/open-webui/open-webui/pull/19097), [#18950](https://github.com/open-webui/open-webui/issues/18950)
|
||||||
|
- ⏩ Batch file processing performance was improved by reducing database queries by 67% while ensuring data consistency between vector and relational databases. [#18953](https://github.com/open-webui/open-webui/pull/18953)
|
||||||
|
- 🚀 Chat import performance was dramatically improved by replacing individual per-chat API requests with a bulk import endpoint, reducing import time by up to 95% for large chat collections and providing user feedback via toast notifications displaying the number of successfully imported chats. [#17861](https://github.com/open-webui/open-webui/pull/17861)
|
||||||
|
- ⚡ Socket event broadcasting performance was optimized by implementing user-specific rooms, significantly reducing server overhead particularly for users with multiple concurrent sessions. [#18996](https://github.com/open-webui/open-webui/pull/18996)
|
||||||
|
- 🗄️ Weaviate is now supported as a vector database option, providing an additional choice for RAG document storage alongside existing ChromaDB, Milvus, Qdrant, and OpenSearch integrations. [#14747](https://github.com/open-webui/open-webui/pull/14747)
|
||||||
|
- 🗄️ PostgreSQL pgvector now supports HNSW index types and large dimensional embeddings exceeding 2000 dimensions through automatic halfvec type selection, with configurable index methods via "PGVECTOR_INDEX_METHOD", "PGVECTOR_HNSW_M", "PGVECTOR_HNSW_EF_CONSTRUCTION", and "PGVECTOR_IVFFLAT_LISTS" environment variables. [#19158](https://github.com/open-webui/open-webui/pull/19158), [#16890](https://github.com/open-webui/open-webui/issues/16890)
|
||||||
|
- 🔍 Azure AI Search is now supported as a web search provider, enabling integration with Azure's cognitive search services via "AZURE_AI_SEARCH_API_KEY", "AZURE_AI_SEARCH_ENDPOINT", and "AZURE_AI_SEARCH_INDEX_NAME" configuration. [#19104](https://github.com/open-webui/open-webui/pull/19104)
|
||||||
|
- ⚡ External embedding generation now processes API requests in parallel instead of sequential batches, reducing document processing time by 10-50x when using OpenAI, Azure OpenAI, or Ollama embedding providers, with large PDFs now processing in seconds instead of minutes. [#19296](https://github.com/open-webui/open-webui/pull/19296)
|
||||||
|
- 💨 Base64 image conversion is now available for markdown content in chat responses, automatically uploading embedded images exceeding 1KB and replacing them with file URLs to reduce payload size and resource consumption, configurable via "REPLACE_IMAGE_URLS_IN_CHAT_RESPONSE" environment variable. [#19076](https://github.com/open-webui/open-webui/pull/19076)
|
||||||
|
- 🎨 OpenAI image generation now supports additional API parameters including quality settings for GPT Image 1, configurable via "IMAGES_OPENAI_API_PARAMS" environment variable or through the admin interface, enabling cost-effective image generation with low, medium, or high quality options. [#19228](https://github.com/open-webui/open-webui/issues/19228)
|
||||||
|
- 🖼️ Image editing can now be independently enabled or disabled via admin settings, allowing administrators to control whether sequential image prompts trigger image editing or new image generation, configurable via "ENABLE_IMAGE_EDIT" environment variable. [#19284](https://github.com/open-webui/open-webui/issues/19284)
|
||||||
|
- 🔐 SSRF protection was implemented with a configurable URL blocklist that prevents access to cloud metadata endpoints and private networks, with default protections for AWS, Google Cloud, Azure, and Alibaba Cloud metadata services, customizable via "WEB_FETCH_FILTER_LIST" environment variable. [#19201](https://github.com/open-webui/open-webui/pull/19201)
|
||||||
|
- ⚡ Workspace models page now supports server-side pagination dramatically improving load times and usability for instances with large numbers of workspace models.
|
||||||
|
- 🔍 Hybrid search now indexes file metadata including filenames, titles, headings, sources, and snippets alongside document content, enabling keyword queries to surface documents where search terms appear only in metadata, configurable via "ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS" environment variable. [#19095](https://github.com/open-webui/open-webui/pull/19095)
|
||||||
|
- 📂 Knowledge base upload page now supports folder drag-and-drop with recursive directory handling, enabling batch uploads of entire directory structures instead of requiring individual file selection. [#19320](https://github.com/open-webui/open-webui/pull/19320)
|
||||||
|
- 🤖 Model cloning is now available in admin settings, allowing administrators to quickly create workspace models based on existing base models through a "Clone" option in the model dropdown menu. [#17937](https://github.com/open-webui/open-webui/pull/17937)
|
||||||
|
- 🎨 UI scale adjustment is now available in interface settings, allowing users to increase the size of the entire interface from 1.0x to 1.5x for improved accessibility and readability, particularly beneficial for users with visual impairments. [#19186](https://github.com/open-webui/open-webui/pull/19186)
|
||||||
|
- 📌 Default pinned models can now be configured by administrators for all new users, mirroring the behavior of default models where admin-configured defaults apply only to users who haven't customized their pinned models, configurable via "DEFAULT_PINNED_MODELS" environment variable. [#19273](https://github.com/open-webui/open-webui/pull/19273)
|
||||||
|
- 🎙️ Text-to-Speech and Speech-to-Text services now receive user information headers when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled, allowing external TTS and STT providers to implement user-specific personalization, rate limiting, and usage tracking. [#19323](https://github.com/open-webui/open-webui/pull/19323), [#19312](https://github.com/open-webui/open-webui/issues/19312)
|
||||||
|
- 🎙️ Voice mode now supports custom system prompts via "VOICE_MODE_PROMPT_TEMPLATE" configuration, allowing administrators to control response style and behavior for voice interactions. [#18607](https://github.com/open-webui/open-webui/pull/18607)
|
||||||
|
- 🔧 WebSocket and Redis configuration options are now available including debug logging controls, custom ping timeout and interval settings, and arbitrary Redis connection options via "WEBSOCKET_SERVER_LOGGING", "WEBSOCKET_SERVER_ENGINEIO_LOGGING", "WEBSOCKET_SERVER_PING_TIMEOUT", "WEBSOCKET_SERVER_PING_INTERVAL", and "WEBSOCKET_REDIS_OPTIONS" environment variables. [#19091](https://github.com/open-webui/open-webui/pull/19091)
|
||||||
|
- 🔧 MCP OAuth dynamic client registration now automatically detects and uses the appropriate token endpoint authentication method from server-supported options, enabling compatibility with OAuth servers that only support "client_secret_basic" instead of "client_secret_post". [#19193](https://github.com/open-webui/open-webui/issues/19193)
|
||||||
|
- 🔧 Custom headers can now be configured for remote MCP and OpenAPI tool server connections, enabling integration with services that require additional authentication headers. [#18918](https://github.com/open-webui/open-webui/issues/18918)
|
||||||
|
- 🔍 Perplexity Search now supports custom API endpoints via "PERPLEXITY_SEARCH_API_URL" configuration and automatically forwards user information headers to enable personalized search experiences. [#19147](https://github.com/open-webui/open-webui/pull/19147)
|
||||||
|
- 🔍 User information headers can now be optionally forwarded to external web search engines when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled. [#19043](https://github.com/open-webui/open-webui/pull/19043)
|
||||||
|
- 📊 Daily active user metric is now available for monitoring, tracking unique users active since midnight UTC via the "webui.users.active.today" Prometheus gauge. [#19236](https://github.com/open-webui/open-webui/pull/19236), [#19234](https://github.com/open-webui/open-webui/issues/19234)
|
||||||
|
- 📊 Audit log file path is now configurable via "AUDIT_LOGS_FILE_PATH" environment variable, enabling storage in separate volumes or custom locations. [#19173](https://github.com/open-webui/open-webui/pull/19173)
|
||||||
|
- 🎨 Sidebar collapse states for model lists and group information are now persistent across page refreshes, remembering user preferences through browser-based storage. [#19159](https://github.com/open-webui/open-webui/issues/19159)
|
||||||
|
- 🎨 Background image display was enhanced with semi-transparent overlays for navbar and sidebar, creating a seamless and visually cohesive design across the entire interface. [#19157](https://github.com/open-webui/open-webui/issues/19157)
|
||||||
|
- 📋 Tables in chat messages now include a copy button that appears on hover, enabling quick copying of table content alongside the existing CSV export functionality. [#19162](https://github.com/open-webui/open-webui/issues/19162)
|
||||||
|
- 📝 Notes can now be created directly via the "/notes/new" URL endpoint with optional title and content query parameters, enabling faster note creation through bookmarks and shortcuts. [#19195](https://github.com/open-webui/open-webui/issues/19195)
|
||||||
|
- 🏷️ Tag suggestions are now context-aware, displaying only relevant tags when creating or editing models versus chat conversations, preventing confusion between model and chat tags. [#19135](https://github.com/open-webui/open-webui/issues/19135)
|
||||||
|
- ✍️ Prompt autocompletion is now available independently of the rich text input setting, improving accessibility to the feature. [#19150](https://github.com/open-webui/open-webui/issues/19150)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Translations for Simplified Chinese, Traditional Chinese, Portuguese (Brazil), Catalan, Spanish (Spain), Finnish, Irish, Farsi, Swedish, Danish, German, Korean, and Thai were improved and expanded.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🤖 Model update functionality now works correctly, resolving a database parameter binding error that prevented saving changes to model configurations via the Save & Update button. [#19335](https://github.com/open-webui/open-webui/issues/19335)
|
||||||
|
- 🖼️ Multiple input images for image editing and generation are now correctly passed as an array using the "image[]" parameter syntax, enabling proper multi-image reference functionality with models like GPT Image 1. [#19339](https://github.com/open-webui/open-webui/issues/19339)
|
||||||
|
- 📱 PWA installations on iOS now properly refresh after server container restarts, resolving freezing issues by automatically unregistering service workers when version or deployment changes are detected. [#19316](https://github.com/open-webui/open-webui/pull/19316)
|
||||||
|
- 🗄️ S3 Vectors collection detection now correctly handles buckets with more than 2000 indexes by using direct index lookup instead of paginated list scanning, improving performance by approximately 8x and enabling RAG queries to work reliably at scale. [#19238](https://github.com/open-webui/open-webui/pull/19238), [#19233](https://github.com/open-webui/open-webui/issues/19233)
|
||||||
|
- 📈 Feedback retrieval performance was optimized by eliminating N+1 query patterns through database joins, adding server-side pagination and sorting, significantly reducing database load for instances with large feedback datasets. [#17976](https://github.com/open-webui/open-webui/pull/17976)
|
||||||
|
- 🔍 Chat search now works correctly with PostgreSQL when chat data contains null bytes, with comprehensive sanitization preventing null bytes during data writes, cleaning existing data on read, and stripping null bytes during search queries to ensure reliable search functionality. [#15616](https://github.com/open-webui/open-webui/issues/15616)
|
||||||
|
- 🔍 Hybrid search with reranking now correctly handles attribute validation, preventing errors when collection results lack expected structure. [#19025](https://github.com/open-webui/open-webui/pull/19025), [#17046](https://github.com/open-webui/open-webui/issues/17046)
|
||||||
|
- 🔎 Reranking functionality now works correctly after recent refactoring, resolving crashes caused by incorrect function argument handling. [#19270](https://github.com/open-webui/open-webui/pull/19270)
|
||||||
|
- 🤖 Azure OpenAI models now support the "reasoning_effort" parameter, enabling proper configuration of reasoning capabilities for models like GPT-5.1 which default to no reasoning without this setting. [#19290](https://github.com/open-webui/open-webui/issues/19290)
|
||||||
|
- 🤖 Models with very long IDs can now be deleted correctly, resolving URL length limitations that previously prevented management operations on such models. [#18230](https://github.com/open-webui/open-webui/pull/18230)
|
||||||
|
- 🤖 Model-level streaming settings now correctly apply to API requests, ensuring "Stream Chat Response" toggle properly controls the streaming parameter. [#19154](https://github.com/open-webui/open-webui/issues/19154)
|
||||||
|
- 🖼️ Image editing configuration now correctly preserves independent OpenAI API endpoints and keys, preventing them from being overwritten by image generation settings. [#19003](https://github.com/open-webui/open-webui/issues/19003)
|
||||||
|
- 🎨 Gemini image edit settings now display correctly in the admin panel, fixing an incorrect configuration key reference that prevented proper rendering of edit options. [#19200](https://github.com/open-webui/open-webui/pull/19200)
|
||||||
|
- 🖌️ Image generation settings menu now loads correctly, resolving validation errors with AUTOMATIC1111 API authentication parameters. [#19187](https://github.com/open-webui/open-webui/issues/19187), [#19246](https://github.com/open-webui/open-webui/issues/19246)
|
||||||
|
- 📅 Date formatting in chat search and admin user chat search now correctly respects the "DEFAULT_LOCALE" environment variable, displaying dates according to the configured locale instead of always using MM/DD/YYYY format. [#19305](https://github.com/open-webui/open-webui/pull/19305), [#19020](https://github.com/open-webui/open-webui/issues/19020)
|
||||||
|
- 📝 RAG template query placeholder escaping logic was corrected to prevent unintended replacements of context values when query placeholders appear in retrieved content. [#19102](https://github.com/open-webui/open-webui/pull/19102), [#19101](https://github.com/open-webui/open-webui/issues/19101)
|
||||||
|
- 📄 RAG template prompt duplication was eliminated by removing redundant user query section from the default template. [#19099](https://github.com/open-webui/open-webui/pull/19099), [#19098](https://github.com/open-webui/open-webui/issues/19098)
|
||||||
|
- 📋 MinerU local mode configuration no longer incorrectly requires an API key, allowing proper use of local content extraction without external API credentials. [#19258](https://github.com/open-webui/open-webui/issues/19258)
|
||||||
|
- 📊 Excel file uploads now work correctly with the addition of the missing msoffcrypto-tool dependency, resolving import errors introduced by the unstructured package upgrade. [#19153](https://github.com/open-webui/open-webui/issues/19153)
|
||||||
|
- 📑 Docling parameters now properly handle JSON serialization, preventing exceptions and ensuring configuration changes are saved correctly. [#19072](https://github.com/open-webui/open-webui/pull/19072)
|
||||||
|
- 🛠️ UserValves configuration now correctly isolates settings per tool, preventing configuration contamination when multiple tools with UserValves are used simultaneously. [#19185](https://github.com/open-webui/open-webui/pull/19185), [#15569](https://github.com/open-webui/open-webui/issues/15569)
|
||||||
|
- 🔧 Tool selection prompt now correctly handles user messages without duplication, removing redundant query prefixes and improving prompt clarity. [#19122](https://github.com/open-webui/open-webui/pull/19122), [#19121](https://github.com/open-webui/open-webui/issues/19121)
|
||||||
|
- 📝 Notes chat feature now correctly submits messages to the completions endpoint, resolving errors that prevented AI model interactions. [#19079](https://github.com/open-webui/open-webui/pull/19079)
|
||||||
|
- 📝 Note PDF downloads now sanitize HTML content using DOMPurify before rendering, preventing potential DOM-based XSS attacks from malicious content in notes. [Commit](https://github.com/open-webui/open-webui/commit/03cc6ce8eb5c055115406e2304fbf7e3338b8dce)
|
||||||
|
- 📁 Archived chats now have their folder associations automatically removed to prevent unintended deletion when their previous folder is deleted. [#14578](https://github.com/open-webui/open-webui/issues/14578)
|
||||||
|
- 🔐 ElevenLabs API key is now properly obfuscated in the admin settings page, preventing plain text exposure of sensitive credentials. [#19262](https://github.com/open-webui/open-webui/pull/19262), [#19260](https://github.com/open-webui/open-webui/issues/19260)
|
||||||
|
- 🔧 MCP OAuth server metadata discovery now follows the correct specification order, ensuring proper authentication flow compliance. [#19244](https://github.com/open-webui/open-webui/pull/19244)
|
||||||
|
- 🔒 API key endpoint restrictions now properly enforce access controls for all endpoints including SCIM, preventing unintended access when "API_KEY_ALLOWED_ENDPOINTS" is configured. [#19168](https://github.com/open-webui/open-webui/issues/19168)
|
||||||
|
- 🔓 OAuth role claim parsing now supports both flat and nested claim structures, enabling compatibility with OAuth providers that deliver claims as direct properties on the user object rather than nested structures. [#19286](https://github.com/open-webui/open-webui/pull/19286)
|
||||||
|
- 🔑 OAuth MCP server verification now correctly extracts the access token value for authorization headers instead of sending the entire token dictionary. [#19149](https://github.com/open-webui/open-webui/pull/19149), [#19148](https://github.com/open-webui/open-webui/issues/19148)
|
||||||
|
- ⚙️ OAuth dynamic client registration now correctly converts empty strings to None for optional fields, preventing validation failures in MCP package integration. [#19144](https://github.com/open-webui/open-webui/pull/19144), [#19129](https://github.com/open-webui/open-webui/issues/19129)
|
||||||
|
- 🔐 OIDC authentication now correctly passes client credentials in access token requests, ensuring compatibility with providers that require these parameters per RFC 6749. [#19132](https://github.com/open-webui/open-webui/pull/19132), [#19131](https://github.com/open-webui/open-webui/issues/19131)
|
||||||
|
- 🔗 OAuth client creation now respects configured token endpoint authentication methods instead of defaulting to basic authentication, preventing failures with servers that don't support basic auth. [#19165](https://github.com/open-webui/open-webui/pull/19165)
|
||||||
|
- 📋 Text copied from chat responses in Chrome now pastes without background formatting, improving readability when pasting into word processors. [#19083](https://github.com/open-webui/open-webui/issues/19083)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- 🗄️ Group membership data storage was refactored from JSON arrays to a dedicated relational database table, significantly improving query performance and scalability for instances with large numbers of users and groups, while API responses now return member counts instead of full user ID arrays. [#19239](https://github.com/open-webui/open-webui/pull/19239)
|
||||||
|
- 📄 MinerU parameter handling was refactored to pass parameters directly to the API, improving flexibility and fixing VLM backend configuration. [#19105](https://github.com/open-webui/open-webui/pull/19105), [#18446](https://github.com/open-webui/open-webui/discussions/18446)
|
||||||
|
- 🔐 API key creation is now controlled by granular user and group permissions, with the "ENABLE_API_KEY" environment variable renamed to "ENABLE_API_KEYS" and disabled by default, requiring explicit configuration at both the global and user permission levels, while related environment variables "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS" and "API_KEY_ALLOWED_ENDPOINTS" were renamed to "ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS" and "API_KEYS_ALLOWED_ENDPOINTS" respectively. [#18336](https://github.com/open-webui/open-webui/pull/18336)
|
||||||
|
|
||||||
|
## [0.6.36] - 2025-11-07
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🔐 OAuth group parsing now supports configurable separators via the "OAUTH_GROUPS_SEPARATOR" environment variable, enabling proper handling of semicolon-separated group claims from providers like CILogon. [#18987](https://github.com/open-webui/open-webui/pull/18987), [#18979](https://github.com/open-webui/open-webui/issues/18979)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🛠️ Tool calling functionality is restored by correcting asynchronous function handling in tool parameter updates. [#18981](https://github.com/open-webui/open-webui/issues/18981)
|
||||||
|
- 🖼️ The ComfyUI image edit workflow editor modal now opens correctly when clicking the Edit button. [#18978](https://github.com/open-webui/open-webui/issues/18978)
|
||||||
|
- 🔥 Firecrawl import errors are resolved by implementing lazy loading and using the correct class name. [#18973](https://github.com/open-webui/open-webui/issues/18973)
|
||||||
|
- 🔌 Socket.IO CORS warning is resolved by properly configuring CORS origins for Socket.IO connections. [Commit](https://github.com/open-webui/open-webui/commit/639d26252e528c9c37a5f553b11eb94376d8792d)
|
||||||
|
|
||||||
|
## [0.6.35] - 2025-11-06
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🖼️ Image generation system received a comprehensive overhaul with major new capabilities including full image editing support allowing users to modify existing images using text prompts with OpenAI, Gemini, or ComfyUI engines, adding Gemini 2.5 Flash Image (Nano Banana) support, Qwen Image Edit integration, resolution of base64-encoded image display issues, streamlined AUTOMATIC1111 configuration by consolidating parameters into a flexible JSON parameters field, and enhanced UI with a code editor modal for ComfyUI workflow management. [#17434](https://github.com/open-webui/open-webui/pull/17434), [#16976](https://github.com/open-webui/open-webui/issues/16976), [Commit](https://github.com/open-webui/open-webui/commit/8e5690aab4f632a57027e2acf880b8f89a8717c0), [Commit](https://github.com/open-webui/open-webui/commit/72f8539fd2e679fec0762945f22f4b8a6920afa0), [Commit](https://github.com/open-webui/open-webui/commit/8d34fcb586eeee1fac6da2f991518b8a68b00b72), [Commit](https://github.com/open-webui/open-webui/commit/72900cd686de1fa6be84b5a8a2fc857cff7b91b8)
|
||||||
|
- 🔒 CORS origin validation was added to WebSocket connections as a defense-in-depth security measure against cross-site WebSocket hijacking attacks. [#18411](https://github.com/open-webui/open-webui/pull/18411), [#18410](https://github.com/open-webui/open-webui/issues/18410)
|
||||||
|
- 🔄 Automatic page refresh now occurs when a version update is detected via WebSocket connection, ensuring users always run the latest version without cache issues. [Commit](https://github.com/open-webui/open-webui/commit/989f192c92d2fe55daa31336e7971e21798b96ae)
|
||||||
|
- 🐍 Experimental initial preparations for Python 3.13 compatibility by updating dependencies with security enhancements and cryptographic improvements. [#18430](https://github.com/open-webui/open-webui/pull/18430), [#18424](https://github.com/open-webui/open-webui/pull/18424)
|
||||||
|
- ⚡ Image compression now preserves the original image format instead of converting to PNG, significantly reducing file sizes and improving chat loading performance. [#18506](https://github.com/open-webui/open-webui/pull/18506)
|
||||||
|
- 🎤 Mistral Voxtral model support was added for text-to-speech, including voxtral-small and voxtral-mini models with both transcription and chat completion API support. [#18934](https://github.com/open-webui/open-webui/pull/18934)
|
||||||
|
- 🔊 Text-to-speech now uses a global audio queue system to prevent overlapping playback, ensuring only one TTS instance plays at a time with proper stop/start controls and automatic cleanup when switching between messages. [#16152](https://github.com/open-webui/open-webui/pull/16152), [#18744](https://github.com/open-webui/open-webui/pull/18744), [#16150](https://github.com/open-webui/open-webui/issues/16150)
|
||||||
|
- 🔊 ELEVENLABS_API_BASE_URL environment variable now allows configuration of custom ElevenLabs API endpoints, enabling support for EU residency API requirements. [#18402](https://github.com/open-webui/open-webui/issues/18402)
|
||||||
|
- 🔐 OAUTH_ROLES_SEPARATOR environment variable now allows custom role separators for OAuth roles that contain commas, useful for roles specified in LDAP syntax. [#18572](https://github.com/open-webui/open-webui/pull/18572)
|
||||||
|
- 📄 External document loaders can now optionally forward user information headers when ENABLE_FORWARD_USER_INFO_HEADERS is enabled, enabling cost tracking, audit logs, and usage analytics for external services. [#18731](https://github.com/open-webui/open-webui/pull/18731)
|
||||||
|
- 📄 MISTRAL_OCR_API_BASE_URL environment variable now allows configuration of custom Mistral OCR API endpoints for flexible deployment options. [Commit](https://github.com/open-webui/open-webui/commit/415b93c7c35c2e2db4425e6da1b88b3750f496b0)
|
||||||
|
- ⌨️ Keyboard shortcut hints are now displayed on sidebar buttons with a refactored shortcuts modal that accurately reflects all available hotkeys across different keyboard layouts. [#18473](https://github.com/open-webui/open-webui/pull/18473)
|
||||||
|
- 🛠️ Tooltips now display tool descriptions when hovering over tool names on the model edit page, improving usability and providing immediate context. [#18707](https://github.com/open-webui/open-webui/pull/18707)
|
||||||
|
- 📝 "Create a new note" from the search modal now immediately creates a new private note and opens it in the editor instead of navigating to the generic notes page. [#18255](https://github.com/open-webui/open-webui/pull/18255)
|
||||||
|
- 🖨️ Code block output now preserves whitespace formatting with monospace font to accurately reflect terminal behavior. [#18352](https://github.com/open-webui/open-webui/pull/18352)
|
||||||
|
- ✏️ Edit button is now available in the three-dot menu of models in the workspace section for quick access to model editing, with the menu reorganized for better user experience and Edit, Clone, Copy Link, and Share options logically grouped. [#18574](https://github.com/open-webui/open-webui/pull/18574)
|
||||||
|
- 📌 Sidebar models section is now collapsible, allowing users to expand and collapse the pinned models list for better sidebar organization. [Commit](https://github.com/open-webui/open-webui/commit/82c08a3b5d189f81c96b6548cc872198771015b0)
|
||||||
|
- 🌙 Dark mode styles for select elements were added using Tailwind CSS classes, improving consistency across the interface. [#18636](https://github.com/open-webui/open-webui/pull/18636)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Translations for Portuguese (Brazil), Greek, German, Traditional Chinese, Simplified Chinese, Spanish, Georgian, Danish, and Estonian were enhanced and expanded.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 🔒 Server-Sent Event (SSE) code injection vulnerability in Direct Connections is resolved by blocking event emission from untrusted external model servers; event emitters from direct connected model servers are no longer supported, preventing arbitrary JavaScript execution in user browsers. [Commit](https://github.com/open-webui/open-webui/commit/8af6a4cf21b756a66cd58378a01c60f74c39b7ca)
|
||||||
|
- 🛡️ DOM XSS vulnerability in "Insert Prompt as Rich Text" is resolved by sanitizing HTML content with DOMPurify before rendering. [Commit](https://github.com/open-webui/open-webui/commit/eb9c4c0e358c274aea35f21c2856c0a20051e5f1)
|
||||||
|
- ⚙️ MCP server cancellation scope corruption is prevented by reversing disconnection order to follow LIFO and properly handling exceptions, resolving 100% CPU usage when resuming chats with expired tokens or using multiple streamable MCP servers. [#18537](https://github.com/open-webui/open-webui/pull/18537)
|
||||||
|
- 🔧 UI freeze when querying models with knowledge bases containing inconsistent distance metrics is resolved by properly initializing the distances array in citations. [#18585](https://github.com/open-webui/open-webui/pull/18585)
|
||||||
|
- 🤖 Duplicate model IDs from multiple OpenAI endpoints are now automatically deduplicated server-side, preventing frontend crashes for users with unified gateway proxies that aggregate multiple providers. [Commit](https://github.com/open-webui/open-webui/commit/fdf7ca11d4f3cc8fe63e81c98dc0d1e48e52ba36)
|
||||||
|
- 🔐 Login failures with passwords longer than 72 bytes are resolved by safely truncating oversized passwords for bcrypt compatibility. [#18157](https://github.com/open-webui/open-webui/issues/18157)
|
||||||
|
- 🔐 OAuth 2.1 MCP tool connections now automatically re-register clients when stored client IDs become stale, preventing unauthorized_client errors after editing tool endpoints and providing detailed error messages for callback failures. [#18415](https://github.com/open-webui/open-webui/pull/18415), [#18309](https://github.com/open-webui/open-webui/issues/18309)
|
||||||
|
- 🔓 OAuth 2.1 discovery, metadata fetching, and dynamic client registration now correctly use HTTP proxy environment variables when trust_env is enabled. [Commit](https://github.com/open-webui/open-webui/commit/bafeb76c411483bd6b135f0edbcdce048120f264)
|
||||||
|
- 🔌 MCP server connection failures now display clear error messages in the chat interface instead of silently failing. [#18892](https://github.com/open-webui/open-webui/pull/18892), [#18889](https://github.com/open-webui/open-webui/issues/18889)
|
||||||
|
- 💬 Chat titles are now properly generated even when title auto-generation is disabled in interface settings, fixing an issue where chats would remain labeled as "New chat". [#18761](https://github.com/open-webui/open-webui/pull/18761), [#18717](https://github.com/open-webui/open-webui/issues/18717), [#6478](https://github.com/open-webui/open-webui/issues/6478)
|
||||||
|
- 🔍 Chat query errors are prevented by properly validating and handling the "order_by" parameter to ensure requested columns exist. [#18400](https://github.com/open-webui/open-webui/pull/18400), [#18452](https://github.com/open-webui/open-webui/pull/18452)
|
||||||
|
- 🔧 Root-level max_tokens parameter is no longer dropped when proxying to Ollama, properly converting to num_predict to limit output token length as intended. [#18618](https://github.com/open-webui/open-webui/issues/18618)
|
||||||
|
- 🔑 Self-hosted Marker instances can now be used without requiring an API key, while keeping it optional for datalab Marker service users. [#18617](https://github.com/open-webui/open-webui/issues/18617)
|
||||||
|
- 🔧 OpenAPI specification endpoint conflict between "/api/v1/models" and "/api/v1/models/" is resolved by changing the models router endpoint to "/list", preventing duplicate operationId errors when generating TypeScript API clients. [#18758](https://github.com/open-webui/open-webui/issues/18758)
|
||||||
|
- 🏷️ Model tags are now de-duplicated case-insensitively in both the model selector and workspace models page, preventing duplicate entries with different capitalization from appearing in filter dropdowns. [#18716](https://github.com/open-webui/open-webui/pull/18716), [#18711](https://github.com/open-webui/open-webui/issues/18711)
|
||||||
|
- 📄 Docling RAG parameter configuration is now correctly saved in the admin UI by fixing the typo in the "DOCLING_PARAMS" parameter name. [#18390](https://github.com/open-webui/open-webui/pull/18390)
|
||||||
|
- 📃 Tika document processing now automatically detects content types instead of relying on potentially incorrect browser-provided mime-types, improving file handling accuracy for formats like RTF. [#18765](https://github.com/open-webui/open-webui/pull/18765), [#18683](https://github.com/open-webui/open-webui/issues/18683)
|
||||||
|
- 🖼️ Image and video uploads to knowledge bases now display proper error messages instead of showing an infinite spinner when the content extraction engine does not support these file types. [#18514](https://github.com/open-webui/open-webui/issues/18514)
|
||||||
|
- 📝 Notes PDF export now properly detects and applies dark mode styling consistently across both the notes list and individual note pages, with a shared utility function to eliminate code duplication. [#18526](https://github.com/open-webui/open-webui/issues/18526)
|
||||||
|
- 💭 Details tags for reasoning content are now correctly identified and rendered even when the same tag is present in user messages. [#18840](https://github.com/open-webui/open-webui/pull/18840), [#18294](https://github.com/open-webui/open-webui/issues/18294)
|
||||||
|
- 📊 Mermaid and Vega rendering errors now display inline with the code instead of showing repetitive toast notifications, improving user experience when models generate invalid diagram syntax. [Commit](https://github.com/open-webui/open-webui/commit/fdc0f04a8b7dd0bc9f9dc0e7e30854f7a0eea3e9)
|
||||||
|
- 📈 Mermaid diagram rendering errors no longer cause UI unavailability or display error messages below the input box. [#18493](https://github.com/open-webui/open-webui/pull/18493), [#18340](https://github.com/open-webui/open-webui/issues/18340)
|
||||||
|
- 🔗 Web search SSL verification is now asynchronous, preventing the website from hanging during web search operations. [#18714](https://github.com/open-webui/open-webui/pull/18714), [#18699](https://github.com/open-webui/open-webui/issues/18699)
|
||||||
|
- 🌍 Web search results now correctly use HTTP proxy environment variables when WEB_SEARCH_TRUST_ENV is enabled. [#18667](https://github.com/open-webui/open-webui/pull/18667), [#7008](https://github.com/open-webui/open-webui/discussions/7008)
|
||||||
|
- 🔍 Google Programmable Search Engine now properly includes referer headers, enabling API keys with HTTP referrer restrictions configured in Google Cloud Console. [#18871](https://github.com/open-webui/open-webui/pull/18871), [#18870](https://github.com/open-webui/open-webui/issues/18870)
|
||||||
|
- ⚡ YouTube video transcript fetching now works correctly when using a proxy connection. [#18419](https://github.com/open-webui/open-webui/pull/18419)
|
||||||
|
- 🎙️ Speech-to-text transcription no longer deletes or replaces existing text in the prompt input field, properly preserving any previously entered content. [#18540](https://github.com/open-webui/open-webui/issues/18540)
|
||||||
|
- 🎙️ The "Instant Auto-Send After Voice Transcription" setting now functions correctly and automatically sends transcribed text when enabled. [#18466](https://github.com/open-webui/open-webui/issues/18466)
|
||||||
|
- ⚙️ Chat settings now load properly when reopening a tab or starting a new session by initializing defaults when sessionStorage is empty. [#18438](https://github.com/open-webui/open-webui/pull/18438)
|
||||||
|
- 🔎 Folder tag search in the sidebar now correctly handles folder names with multiple spaces by replacing all spaces with underscores. [Commit](https://github.com/open-webui/open-webui/commit/a8fe979af68e47e4e4bb3eb76e48d93d60cd2a45)
|
||||||
|
- 🛠️ Functions page now updates immediately after deleting a function, removing the need for a manual page reload. [#18912](https://github.com/open-webui/open-webui/pull/18912), [#18908](https://github.com/open-webui/open-webui/issues/18908)
|
||||||
|
- 🛠️ Native tool calling now properly supports sequential tool calls with shared context, allowing tools to access images and data from previous tool executions in the same conversation. [#18664](https://github.com/open-webui/open-webui/pull/18664)
|
||||||
|
- 🎯 Globally enabled actions in the model editor now correctly apply as global instead of being treated as disabled. [#18577](https://github.com/open-webui/open-webui/pull/18577)
|
||||||
|
- 📋 Clipboard images pasted via the "{{CLIPBOARD}}" prompt variable are now correctly converted to base64 format before being sent to the backend, resolving base64 encoding errors. [#18432](https://github.com/open-webui/open-webui/pull/18432), [#18425](https://github.com/open-webui/open-webui/issues/18425)
|
||||||
|
- 📋 File list is now cleared when switching to models that do not support file uploads, preventing files from being sent to incompatible models. [#18496](https://github.com/open-webui/open-webui/pull/18496)
|
||||||
|
- 📂 Move menu no longer displays when folders are empty. [#18484](https://github.com/open-webui/open-webui/pull/18484)
|
||||||
|
- 📁 Folder and channel creation now validates that names are not empty, preventing creation of folders or channels with no name and showing an error toast if attempted. [#18564](https://github.com/open-webui/open-webui/pull/18564)
|
||||||
|
- 🖊️ Rich text input no longer removes text between equals signs when pasting code with comparison operators. [#18551](https://github.com/open-webui/open-webui/issues/18551)
|
||||||
|
- ⌨️ Keyboard shortcuts now display the correct keys for international and non-QWERTY keyboard layouts by detecting the user's layout using the Keyboard API. [#18533](https://github.com/open-webui/open-webui/pull/18533)
|
||||||
|
- 🌐 "Attach Webpage" button now displays with correct disabled styling when a model does not support file uploads. [#18483](https://github.com/open-webui/open-webui/pull/18483)
|
||||||
|
- 🎚️ Divider no longer displays in the integrations menu when no integrations are enabled. [#18487](https://github.com/open-webui/open-webui/pull/18487)
|
||||||
|
- 📱 Chat controls button is now properly hidden on mobile for users without admin or explicit chat control permissions. [#18641](https://github.com/open-webui/open-webui/pull/18641)
|
||||||
|
- 📍 User menu, download submenu, and move submenu are now repositioned to prevent overlap with the Chat Controls sidebar when it is open. [Commit](https://github.com/open-webui/open-webui/commit/414ab51cb6df1ab0d6c85ac6c1f2c5c9a5f8e2aa)
|
||||||
|
- 🎯 Artifacts button no longer appears in the chat menu when there are no artifacts to display. [Commit](https://github.com/open-webui/open-webui/commit/ed6449d35f84f68dc75ee5c6b3f4748a3fda0096)
|
||||||
|
- 🎨 Artifacts view now automatically displays when opening an existing conversation containing artifacts, improving user experience. [#18215](https://github.com/open-webui/open-webui/pull/18215)
|
||||||
|
- 🖌️ Formatting toolbar is no longer hidden under images or code blocks in chat and now displays correctly above all message content.
|
||||||
|
- 🎨 Layout shift near system instructions is prevented by properly rendering the chat component when system prompts are empty. [#18594](https://github.com/open-webui/open-webui/pull/18594)
|
||||||
|
- 📐 Modal layout shift caused by scrollbar appearance is prevented by adding a stable scrollbar gutter. [#18591](https://github.com/open-webui/open-webui/pull/18591)
|
||||||
|
- ✨ Spacing between icon and label in the user menu dropdown items is now consistent. [#18595](https://github.com/open-webui/open-webui/pull/18595)
|
||||||
|
- 💬 Duplicate prompt suggestions no longer cause the webpage to freeze or throw JavaScript errors by implementing proper key management with composite keys. [#18841](https://github.com/open-webui/open-webui/pull/18841), [#18566](https://github.com/open-webui/open-webui/issues/18566)
|
||||||
|
- 🔍 Chat preview loading in the search modal now works correctly for all search results by fixing an index boundary check that previously caused out-of-bounds errors. [#18911](https://github.com/open-webui/open-webui/pull/18911)
|
||||||
|
- ♿ Screen reader support was enhanced by wrapping messages in semantic elements with descriptive aria-labels, adding "Assistant is typing" and "Response complete" announcements for improved accessibility. [#18735](https://github.com/open-webui/open-webui/pull/18735)
|
||||||
|
- 🔒 Incorrect await call in the OAuth 2.1 flow is removed, eliminating a logged exception during authentication. [#18236](https://github.com/open-webui/open-webui/pull/18236)
|
||||||
|
- 🛡️ Duplicate crossorigin attribute in the manifest file was removed. [#18413](https://github.com/open-webui/open-webui/pull/18413)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- 🔄 Firecrawl integration was refactored to use the official Firecrawl SDK instead of direct HTTP requests and langchain_community FireCrawlLoader, improving reliability and performance with batch scraping support and enhanced error handling. [#18635](https://github.com/open-webui/open-webui/pull/18635)
|
||||||
|
- 📄 MinerU content extraction engine now only supports PDF files following the upstream removal of LibreOffice document conversion in version 2.0.0; users needing to process office documents should convert them to PDF format first. [#18448](https://github.com/open-webui/open-webui/issues/18448)
|
||||||
|
|
||||||
|
## [0.6.34] - 2025-10-16
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 📄 MinerU is now supported as a document parser backend, with support for both local and managed API deployments. [#18306](https://github.com/open-webui/open-webui/pull/18306)
|
||||||
|
- 🔒 JWT token expiration default is now set to 4 weeks instead of never expiring, with security warnings displayed in backend logs and admin UI when set to unlimited. [#18261](https://github.com/open-webui/open-webui/pull/18261), [#18262](https://github.com/open-webui/open-webui/pull/18262)
|
||||||
|
- ⚡ Page loading performance is improved by preventing unnecessary API requests when sidebar folders are not expanded. [#18179](https://github.com/open-webui/open-webui/pull/18179), [#17476](https://github.com/open-webui/open-webui/issues/17476)
|
||||||
|
- 📁 File hash values are now included in the knowledge endpoint response, enabling efficient file synchronization through hash comparison. [#18284](https://github.com/open-webui/open-webui/pull/18284), [#18283](https://github.com/open-webui/open-webui/issues/18283)
|
||||||
|
- 🎨 Chat dialog scrollbar visibility is improved by increasing its width, making it easier to use for navigation. [#18369](https://github.com/open-webui/open-webui/pull/18369), [#11782](https://github.com/open-webui/open-webui/issues/11782)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Translations for Catalan, Chinese, Czech, Finnish, German, Kabyle, Korean, Portuguese (Brazil), Spanish, Thai, and Turkish were enhanced and expanded.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 📚 Focused retrieval mode now works correctly, preventing the system from forcing full context mode and loading all documents in a knowledge base regardless of settings. [#18133](https://github.com/open-webui/open-webui/issues/18133)
|
||||||
|
- 🔧 Filter inlet functions now correctly execute on tool call continuations, ensuring parameter persistence throughout tool interactions. [#18222](https://github.com/open-webui/open-webui/issues/18222)
|
||||||
|
- 🛠️ External tool servers now properly support DELETE requests with body data. [#18289](https://github.com/open-webui/open-webui/pull/18289), [#18287](https://github.com/open-webui/open-webui/issues/18287)
|
||||||
|
- 🗄️ Oracle23ai vector database client now correctly handles variable initialization, resolving UnboundLocalError when retrieving items from collections. [#18356](https://github.com/open-webui/open-webui/issues/18356)
|
||||||
|
- 🔧 Model auto-pull functionality now works correctly even when user settings remain unmodified. [#18324](https://github.com/open-webui/open-webui/pull/18324)
|
||||||
|
- 🎨 Duplicate HTML content in artifacts is now prevented by improving code block detection logic. [#18195](https://github.com/open-webui/open-webui/pull/18195), [#6154](https://github.com/open-webui/open-webui/issues/6154)
|
||||||
|
- 💬 Pinned chats now appear in the Reference Chats list and can be referenced in conversations. [#18288](https://github.com/open-webui/open-webui/issues/18288)
|
||||||
|
- 📝 Misleading knowledge base warning text in documents settings is clarified to correctly instruct users about reindexing vectors. [#18263](https://github.com/open-webui/open-webui/pull/18263)
|
||||||
|
- 🔔 Toast notifications can now be dismissed even when a modal is open. [#18260](https://github.com/open-webui/open-webui/pull/18260)
|
||||||
|
- 🔘 The "Chats" button in the sidebar now correctly toggles chat list visibility without navigating away from the current page. [#18232](https://github.com/open-webui/open-webui/pull/18232)
|
||||||
|
- 🎯 The Integrations menu no longer closes prematurely when clicking outside the Valves modal. [#18310](https://github.com/open-webui/open-webui/pull/18310)
|
||||||
|
- 🛠️ Tool ID display issues where "undefined" was incorrectly shown in the interface are now resolved. [#18178](https://github.com/open-webui/open-webui/pull/18178)
|
||||||
|
- 🛠️ Model management issues caused by excessively long model IDs are now prevented through validation that limits model IDs to 256 characters. [#18125](https://github.com/open-webui/open-webui/issues/18125)
|
||||||
|
|
||||||
|
## [0.6.33] - 2025-10-08
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 🎨 Workspace interface received a comprehensive redesign across Models, Knowledge, Prompts, and Tools sections, featuring reorganized controls, view filters for created vs shared items, tag selectors, improved visual hierarchy, and streamlined import/export functionality. [Commit](https://github.com/open-webui/open-webui/commit/2c59a288603d8c5f004f223ee00fef37cc763a8e), [Commit](https://github.com/open-webui/open-webui/commit/6050c86ab6ef6b8c96dd3f99c62a6867011b67a4), [Commit](https://github.com/open-webui/open-webui/commit/96ecb47bc71c072aa34ef2be10781b042bef4e8c), [Commit](https://github.com/open-webui/open-webui/commit/2250d102b28075a9611696e911536547abb8b38a), [Commit](https://github.com/open-webui/open-webui/commit/23c8f6d507bfee75ab0015a3e2972d5c26f7e9bf), [Commit](https://github.com/open-webui/open-webui/commit/a743b16728c6ae24b8befbc2d7f24eb9e20c4ad5)
|
||||||
|
- 🛠️ Functions admin interface received a comprehensive redesign with creator attribution display, ownership filters for created vs shared items, improved organization, and refined styling. [Commit](https://github.com/open-webui/open-webui/commit/f5e1a42f51acc0b9d5b63a33c1ca2e42470239c1)
|
||||||
|
- ⚡ Page initialization performance is significantly improved through parallel data loading and optimized folder API calls, reducing initial page load time. [#17559](https://github.com/open-webui/open-webui/pull/17559), [#17889](https://github.com/open-webui/open-webui/pull/17889)
|
||||||
|
- ⚡ Chat overview component is now dynamically loaded on demand, reducing initial page bundle size by approximately 470KB and improving first-screen loading speed. [#17595](https://github.com/open-webui/open-webui/pull/17595)
|
||||||
|
- 📁 Folders can now be attached to chats using the "#" command, automatically expanding to include all files within the folder for streamlined knowledge base integration. [Commit](https://github.com/open-webui/open-webui/commit/d2cb78179d66dc85188172a08622d4c97a2ea1ee)
|
||||||
|
- 📱 Progressive Web App now supports Android share target functionality, allowing users to share web pages, YouTube videos, and text directly to Open WebUI from the system share menu. [#17633](https://github.com/open-webui/open-webui/pull/17633), [#17125](https://github.com/open-webui/open-webui/issues/17125)
|
||||||
|
- 🗄️ Redis session storage is now available as an experimental option for OAuth authentication flows via the ENABLE_STAR_SESSIONS_MIDDLEWARE environment variable, providing shared session state across multi-replica deployments to address CSRF errors, though currently only basic Redis setups are supported. [#17223](https://github.com/open-webui/open-webui/pull/17223), [#15373](https://github.com/open-webui/open-webui/issues/15373), [Docs:Commit](https://github.com/open-webui/docs/commit/14052347f165d1b597615370373d7289ce44c7f9)
|
||||||
|
- 📊 Vega and Vega-Lite chart visualization renderers are now supported in code blocks, enabling inline rendering of data visualizations with automatic compilation of Vega-Lite specifications. [#18033](https://github.com/open-webui/open-webui/pull/18033), [#18040](https://github.com/open-webui/open-webui/pull/18040), [#18022](https://github.com/open-webui/open-webui/issues/18022)
|
||||||
|
- 🔗 OpenAI connections now support custom HTTP headers, enabling users to configure authentication and routing headers for specific deployment requirements. [#18021](https://github.com/open-webui/open-webui/pull/18021), [#9732](https://github.com/open-webui/open-webui/discussions/9732)
|
||||||
|
- 🔐 OpenID Connect authentication now supports OIDC providers without email scope via the ENABLE_OAUTH_WITHOUT_EMAIL environment variable, enabling compatibility with identity providers that don't expose email addresses. [#18047](https://github.com/open-webui/open-webui/pull/18047), [#18045](https://github.com/open-webui/open-webui/issues/18045)
|
||||||
|
- 🤖 Ollama model management modal now features individual model update cancellation, comprehensive tooltips for all buttons, and streamlined notification behavior to reduce toast spam. [#16863](https://github.com/open-webui/open-webui/pull/16863)
|
||||||
|
- ☁️ OneDrive file picker now includes search functionality and "My Organization" pivot for business accounts, enabling easier file discovery across organizational content. [#17930](https://github.com/open-webui/open-webui/pull/17930), [#17929](https://github.com/open-webui/open-webui/issues/17929)
|
||||||
|
- 📊 Chat overview flow diagram now supports toggling between vertical and horizontal layout orientations for improved visualization flexibility. [#17941](https://github.com/open-webui/open-webui/pull/17941)
|
||||||
|
- 🔊 OpenAI Text-to-Speech engine now supports additional parameters, allowing users to customize TTS behavior with provider-specific options via JSON configuration. [#17985](https://github.com/open-webui/open-webui/issues/17985), [#17188](https://github.com/open-webui/open-webui/pull/17188)
|
||||||
|
- 🛠️ Tool server list now displays server name, URL, and type (OpenAPI or MCP) for easier identification and management. [#18062](https://github.com/open-webui/open-webui/issues/18062)
|
||||||
|
- 📁 Folders now remember the last selected model, automatically applying it when starting new chats within that folder. [#17836](https://github.com/open-webui/open-webui/issues/17836)
|
||||||
|
- 🔢 Ollama embedding endpoint now supports the optional dimensions parameter for controlling embedding output size, compatible with Ollama v0.11.11 and later. [#17942](https://github.com/open-webui/open-webui/pull/17942)
|
||||||
|
- ⚡ Workspace knowledge page load time is improved by removing redundant API calls, enhancing overall responsiveness. [#18057](https://github.com/open-webui/open-webui/pull/18057)
|
||||||
|
- ⚡ File metadata query performance is enhanced by selecting only relevant columns instead of retrieving entire records, reducing database overhead. [#18013](https://github.com/open-webui/open-webui/pull/18013)
|
||||||
|
- 📄 Note PDF exports now include titles and properly render in dark mode with appropriate background colors. [Commit](https://github.com/open-webui/open-webui/commit/216fb5c3db1a223ffe6e72d97aa9551fe0e2d028)
|
||||||
|
- 📄 Docling document extraction now supports additional parameters for VLM pipeline configuration, enabling customized vision model settings. [#17363](https://github.com/open-webui/open-webui/pull/17363)
|
||||||
|
- ⚙️ Server startup script now supports passing arbitrary arguments to uvicorn, enabling custom server configuration options. [#17919](https://github.com/open-webui/open-webui/pull/17919), [#17918](https://github.com/open-webui/open-webui/issues/17918)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Translations for German, Danish, Spanish, Korean, Portuguese (Brazil), Simplified Chinese, and Traditional Chinese were enhanced and expanded.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- 💬 System prompts are no longer duplicated in chat requests, eliminating confusion and excessive token usage caused by repeated instructions being sent to models. [#17198](https://github.com/open-webui/open-webui/issues/17198), [#16855](https://github.com/open-webui/open-webui/issues/16855)
|
||||||
|
- 🔐 MCP OAuth 2.1 authentication now complies with the standard by implementing PKCE with S256 code challenge method and explicitly passing client credentials during token authorization, resolving "code_challenge: Field required" and "client_id: Field required" errors when connecting to OAuth-secured MCP servers. [Commit](https://github.com/open-webui/open-webui/commit/911a114ad459f5deebd97543c13c2b90196efb54), [#18010](https://github.com/open-webui/open-webui/issues/18010), [#18087](https://github.com/open-webui/open-webui/pull/18087)
|
||||||
|
- 🔐 OAuth signup flow now handles password hashing correctly by migrating from passlib to native bcrypt, preventing failures when passwords exceed 72 bytes. [#17917](https://github.com/open-webui/open-webui/issues/17917)
|
||||||
|
- 🔐 OAuth token refresh errors are resolved by properly registering and storing OAuth clients, fixing "Constructor parameter should be str" exceptions for Google, Microsoft, and OIDC providers. [#17829](https://github.com/open-webui/open-webui/issues/17829)
|
||||||
|
- 🔐 OAuth server metadata URL is now correctly accessed via the proper attribute, fixing automatic token refresh and logout functionality for Microsoft OAuth provider when OPENID_PROVIDER_URL is not set. [#18065](https://github.com/open-webui/open-webui/pull/18065)
|
||||||
|
- 🔐 OAuth credential decryption failures now allow the application to start gracefully with clear error messages instead of crashing, preventing complete service outages when WEBUI_SECRET_KEY mismatches occur during database migrations or environment changes. [#18094](https://github.com/open-webui/open-webui/pull/18094), [#18092](https://github.com/open-webui/open-webui/issues/18092)
|
||||||
|
- 🔐 OAuth 2.1 server discovery now correctly attempts all configured discovery URLs in sequence instead of only trying the first URL. [#17906](https://github.com/open-webui/open-webui/pull/17906), [#17904](https://github.com/open-webui/open-webui/issues/17904), [#18026](https://github.com/open-webui/open-webui/pull/18026)
|
||||||
|
- 🔐 Login redirect now correctly honors the redirect query parameter after authentication, ensuring users are returned to their intended destination with query parameters intact instead of defaulting to the homepage. [#18071](https://github.com/open-webui/open-webui/issues/18071)
|
||||||
|
- ☁️ OneDrive Business integration authentication regression is resolved, ensuring the popup now properly triggers when connecting to OneDrive accounts. [#17902](https://github.com/open-webui/open-webui/pull/17902), [#17825](https://github.com/open-webui/open-webui/discussions/17825), [#17816](https://github.com/open-webui/open-webui/issues/17816)
|
||||||
|
- 👥 Default group settings now persist correctly after page navigation, ensuring configuration changes are properly saved and retained. [#17899](https://github.com/open-webui/open-webui/issues/17899), [#18003](https://github.com/open-webui/open-webui/issues/18003)
|
||||||
|
- 📁 Folder data integrity is now verified on retrieval, automatically fixing orphaned folders with invalid parent references and ensuring proper cascading deletion of nested folder structures. [Commit](https://github.com/open-webui/open-webui/commit/5448618dd5ea181b9635b77040cef60926a902ff)
|
||||||
|
- 🗄️ Redis Sentinel and Redis Cluster configurations with the experimental ENABLE_STAR_SESSIONS_MIDDLEWARE feature are now properly isolated by making the feature opt-in only, preventing ReadOnlyError failures when connecting to read replicas in multi-node Redis deployments. [#18073](https://github.com/open-webui/open-webui/issues/18073)
|
||||||
|
- 📊 Mermaid and Vega diagram rendering now displays error toast notifications when syntax errors are detected, helping users identify and fix diagram issues instead of silently failing. [#18068](https://github.com/open-webui/open-webui/pull/18068)
|
||||||
|
- 🤖 Reasoning models that return reasoning_content instead of content no longer cause NoneType errors during chat title generation, follow-up suggestions, and tag generation. [#18080](https://github.com/open-webui/open-webui/pull/18080)
|
||||||
|
- 📚 Citation rendering now correctly handles multiple source references in a single bracket, parsing formats like [1,2] and [1, 2] into separate clickable citation links. [#18120](https://github.com/open-webui/open-webui/pull/18120)
|
||||||
|
- 🔍 Web search now handles individual source failures gracefully, continuing to process remaining sources instead of failing entirely when a single URL is unreachable or returns an error. [Commit](https://github.com/open-webui/open-webui/commit/e000494e488090c5f66989a2b3f89d3eaeb7946b), [Commit](https://github.com/open-webui/open-webui/commit/53e98620bff38ab9280aee5165af0a704bdd99b9)
|
||||||
|
- 🔍 Hybrid search with reranking now handles empty result sets gracefully instead of crashing with ValueError when all results are filtered out due to relevance thresholds. [#18096](https://github.com/open-webui/open-webui/issues/18096)
|
||||||
|
- 🔍 Reranking models without defined padding tokens now work correctly by automatically falling back to eos_token_id as pad_token_id, fixing "Cannot handle batch sizes > 1" errors for models like Qwen3-Reranker. [#18108](https://github.com/open-webui/open-webui/pull/18108), [#16027](https://github.com/open-webui/open-webui/discussions/16027)
|
||||||
|
- 🔍 Model selector search now correctly returns results for non-admin users by dynamically updating the search index when the model list changes, fixing a race condition that caused empty search results. [#17996](https://github.com/open-webui/open-webui/pull/17996), [#17960](https://github.com/open-webui/open-webui/pull/17960)
|
||||||
|
- ⚡ Task model function calling performance is improved by excluding base64 image data from payloads, significantly reducing token count and memory usage when images are present in conversations. [#17897](https://github.com/open-webui/open-webui/pull/17897)
|
||||||
|
- 🤖 Text selection "Ask" action now correctly recognizes and uses local models configured via direct connections instead of only showing external provider models. [#17896](https://github.com/open-webui/open-webui/issues/17896)
|
||||||
|
- 🛑 Task cancellation API now returns accurate response status, correctly reporting successful cancellations instead of incorrectly indicating failures. [#17920](https://github.com/open-webui/open-webui/issues/17920)
|
||||||
|
- 💬 Follow-up query suggestions are now generated and displayed in temporary chats, matching the behavior of saved chats. [#14987](https://github.com/open-webui/open-webui/issues/14987)
|
||||||
|
- 🔊 Azure Text-to-Speech now properly escapes special characters like ampersands in SSML, preventing HTTP 400 errors and ensuring audio generation succeeds for all text content. [#17962](https://github.com/open-webui/open-webui/issues/17962)
|
||||||
|
- 🛠️ OpenAPI tool server calls with optional parameters now execute successfully even when no arguments are provided, removing the incorrect requirement for a request body. [#18036](https://github.com/open-webui/open-webui/issues/18036)
|
||||||
|
- 🛠️ MCP mode tool server connections no longer incorrectly validate the OpenAPI path field, allowing seamless switching between OpenAPI and MCP connection types. [#17989](https://github.com/open-webui/open-webui/pull/17989), [#17988](https://github.com/open-webui/open-webui/issues/17988)
|
||||||
|
- 🛠️ Third-party tool responses containing non-UTF8 or invalid byte sequences are now handled gracefully without causing request failures. [#17882](https://github.com/open-webui/open-webui/pull/17882)
|
||||||
|
- 🎨 Workspace filter dropdown now correctly renders model tags as strings instead of displaying individual characters, fixing broken filtering interface when models have multiple tags. [#18034](https://github.com/open-webui/open-webui/issues/18034)
|
||||||
|
- ⌨️ Ctrl+Enter keyboard shortcut now correctly sends messages in mobile and narrow browser views on Chrome instead of inserting newlines. [#17975](https://github.com/open-webui/open-webui/issues/17975)
|
||||||
|
- ⌨️ Tab characters are now preserved when pasting code or formatted text into the chat input box in plain text mode. [#17958](https://github.com/open-webui/open-webui/issues/17958)
|
||||||
|
- 📋 Text selection copying from the chat input box now correctly copies only the selected text instead of the entire textbox content. [#17911](https://github.com/open-webui/open-webui/issues/17911)
|
||||||
|
- 🔍 Web search query logging now uses debug level instead of info level, preventing user search queries from appearing in production logs. [#17888](https://github.com/open-webui/open-webui/pull/17888)
|
||||||
|
- 📝 Debug print statements in middleware were removed to prevent excessive log pollution and respect configured logging levels. [#17943](https://github.com/open-webui/open-webui/issues/17943)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- 🗄️ Milvus vector database dependency is updated from pymilvus 2.5.0 to 2.6.2, ensuring compatibility with newer Milvus versions but requiring users on older Milvus instances to either upgrade their database or manually downgrade the pymilvus package. [#18066](https://github.com/open-webui/open-webui/pull/18066)
|
||||||
|
|
||||||
## [0.6.32] - 2025-09-29
|
## [0.6.32] - 2025-09-29
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- 🗝️ Permission toggle for public sharing of notes was added, allowing note owners to quickly enable or disable public access from the note settings interface.
|
- ⚡ JSON model import moved to backend processing for significant performance improvements when importing large model files. [#17871](https://github.com/open-webui/open-webui/pull/17871)
|
||||||
- ⚠️ A warning is now displayed in the user edit modal if conflicting group permissions are detected, helping administrators resolve access control ambiguities before saving changes.
|
- ⚠️ Visual warnings for group permissions that display when a permission is disabled in a group but remains enabled in the default user role, clarifying inheritance behavior for administrators. [#17848](https://github.com/open-webui/open-webui/pull/17848)
|
||||||
|
- 🗄️ Milvus multi-tenancy mode using shared collections with resource ID filtering for improved scalability, mirroring the existing Qdrant implementation and configurable via ENABLE_MILVUS_MULTITENANCY_MODE environment variable. [#17837](https://github.com/open-webui/open-webui/pull/17837)
|
||||||
|
- 🛠️ Enhanced tool result processing with improved error handling, better MCP tool result handling, and performance improvements for embedded UI components. [Commit](https://github.com/open-webui/open-webui/commit/4f06f29348b2c9d71c87d1bbe5b748a368f5101f)
|
||||||
|
- 👥 New user groups now automatically inherit default group permissions, streamlining the admin setup process by eliminating manual permission configuration. [#17843](https://github.com/open-webui/open-webui/pull/17843)
|
||||||
|
- 🗂️ Bulk unarchive functionality for all chats, providing a single backend endpoint to efficiently restore all archived chats at once. [#17857](https://github.com/open-webui/open-webui/pull/17857)
|
||||||
|
- 🏷️ Browser tab title toggle setting allows users to control whether chat titles appear in the browser tab or display only "Open WebUI". [#17851](https://github.com/open-webui/open-webui/pull/17851)
|
||||||
|
- 💬 Reply-to-message functionality in channels, allowing users to reply directly to specific messages with visual threading and context display. [Commit](https://github.com/open-webui/open-webui/commit/1a18928c94903ad1f1f0391b8ade042c3e60205b)
|
||||||
|
- 🔧 Tool server import and export functionality, allowing direct upload of openapi.json and openapi.yaml files as an alternative to URL-based configuration. [#14446](https://github.com/open-webui/open-webui/issues/14446)
|
||||||
|
- 🔧 User valve configuration for Functions is now available in the integration menu, providing consistent management alongside Tools. [#17784](https://github.com/open-webui/open-webui/issues/17784)
|
||||||
|
- 🔐 Admin permission toggle for controlling public sharing of notes, configurable via USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING environment variable. [#17801](https://github.com/open-webui/open-webui/pull/17801), [Docs:#715](https://github.com/open-webui/docs/pull/715)
|
||||||
|
- 🗄️ DISKANN index type support for Milvus vector database with configurable maximum degree and search list size parameters. [#17770](https://github.com/open-webui/open-webui/pull/17770), [Docs:Commit](https://github.com/open-webui/docs/commit/cec50ab4d4b659558ca1ccd4b5e6fc024f05fb83)
|
||||||
|
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||||
|
- 🌐 Translations for Chinese (Simplified & Traditional) and Bosnian (Latin) were enhanced and expanded.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- 🧰 Fixed regression where External Tool servers (OpenAPI) were nonfunctional after the 0.6.31 update; external tools integration is now restored and reliable.
|
- 🛠️ MCP tool calls are now correctly routed to the appropriate server when multiple streamable-http MCP servers are enabled, preventing "Tool not found" errors. [#17817](https://github.com/open-webui/open-webui/issues/17817)
|
||||||
- 🚑 Resolved a critical bug causing Streamable HTTP OAuth 2.1 (MCP server) integrations to throw a 500 error on first invocation due to missing 'SessionMiddleware'. OAuth 2.1 registration now succeeds and works on subsequent requests as expected.
|
- 🛠️ External tool servers (OpenAPI/MCP) now properly process and return tool results to the model, restoring functionality that was broken in v0.6.31. [#17764](https://github.com/open-webui/open-webui/issues/17764)
|
||||||
- 🐛 The "Set as default" option is now reliably clickable in model and filter selection menus, fixing cases where the interface appeared unresponsive.
|
- 🔧 User valve detection now correctly identifies valves in imported tool code, ensuring gear icons appear in the integrations menu for all tools with user valves. [#17765](https://github.com/open-webui/open-webui/issues/17765)
|
||||||
- 🛠️ Embed UI now works seamlessly with both default and native function calling flows, ensuring the tool embedding experience is consistent regardless of invocation method.
|
- 🔐 MCP OAuth discovery now correctly handles multi-tenant configurations by including subpaths in metadata URL discovery. [#17768](https://github.com/open-webui/open-webui/issues/17768)
|
||||||
- 🧹 Addressed various minor UI bugs and inconsistencies for a cleaner user experience.
|
- 🗄️ Milvus query operations now correctly use -1 instead of None for unlimited queries, preventing TypeError exceptions. [#17769](https://github.com/open-webui/open-webui/pull/17769), [#17088](https://github.com/open-webui/open-webui/issues/17088)
|
||||||
|
- 📁 File upload error messages are now displayed when files are modified during upload, preventing user confusion on Android and Windows devices. [#17777](https://github.com/open-webui/open-webui/pull/17777)
|
||||||
### Changed
|
- 🎨 MessageInput Integrations button hover effect now displays correctly with proper visual feedback. [#17767](https://github.com/open-webui/open-webui/pull/17767)
|
||||||
|
- 🎯 "Set as default" label positioning is fixed to ensure it remains clickable in all scenarios, including multi-model configurations. [#17779](https://github.com/open-webui/open-webui/pull/17779)
|
||||||
- 🧬 MCP tool result handling code was refactored for improved parsing and robustness of tool outputs.
|
- 🎛️ Floating buttons now correctly retrieve message context by using the proper messageId parameter in createMessagesList calls. [#17823](https://github.com/open-webui/open-webui/pull/17823)
|
||||||
- 🧩 The user edit modal was overhauled for clarity and usability, improving the organization of group, permission, and public sharing controls.
|
- 📌 Pinned chats are now properly cleared from the sidebar after archiving all chats, ensuring UI consistency without requiring a page refresh. [#17832](https://github.com/open-webui/open-webui/pull/17832)
|
||||||
|
- 🗑️ Delete confirmation modals now properly truncate long names for Notes, Prompts, Tools, and Functions to prevent modal overflow. [#17812](https://github.com/open-webui/open-webui/pull/17812)
|
||||||
|
- 🌐 Internationalization function calls now use proper Svelte store subscription syntax, preventing "i18n.t is not a function" errors on the model creation page. [#17819](https://github.com/open-webui/open-webui/pull/17819)
|
||||||
|
- 🎨 Playground chat interface button layout is corrected to prevent vertical text rendering for Assistant/User role buttons. [#17819](https://github.com/open-webui/open-webui/pull/17819)
|
||||||
|
- 🏷️ UI text truncation is improved across multiple components including usernames in admin panels, arena model names, model tags, and filter tags to prevent layout overflow issues. [#17805](https://github.com/open-webui/open-webui/pull/17805), [#17803](https://github.com/open-webui/open-webui/pull/17803), [#17791](https://github.com/open-webui/open-webui/pull/17791), [#17796](https://github.com/open-webui/open-webui/pull/17796)
|
||||||
|
|
||||||
## [0.6.31] - 2025-09-25
|
## [0.6.31] - 2025-09-25
|
||||||
|
|
||||||
|
|
|
||||||
63
README.md
63
README.md
|
|
@ -17,7 +17,7 @@ Passionate about open-source AI? [Join our team →](https://careers.openwebui.c
|
||||||

|

|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
|
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** – **[Speak with Our Sales Team Today!](https://docs.openwebui.com/enterprise)**
|
||||||
>
|
>
|
||||||
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
|
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
|
||||||
|
|
||||||
|
|
@ -31,32 +31,44 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
||||||
|
|
||||||
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
- 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users.
|
||||||
|
|
||||||
- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
|
||||||
|
|
||||||
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
|
||||||
|
|
||||||
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
|
||||||
|
|
||||||
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
|
||||||
|
|
||||||
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
|
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features using multiple Speech-to-Text providers (Local Whisper, OpenAI, Deepgram, Azure) and Text-to-Speech engines (Azure, ElevenLabs, OpenAI, Transformers, WebAPI), allowing for dynamic and interactive chat environments.
|
||||||
|
|
||||||
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
|
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
|
||||||
|
|
||||||
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
|
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
|
||||||
|
|
||||||
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
- 💾 **Persistent Artifact Storage**: Built-in key-value storage API for artifacts, enabling features like journals, trackers, leaderboards, and collaborative tools with both personal and shared data scopes across sessions.
|
||||||
|
|
||||||
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `TavilySearch`, `SearchApi` and `Bing` and inject the results directly into your chat experience.
|
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support using your choice of 9 vector databases and multiple content extraction engines (Tika, Docling, Document Intelligence, Mistral OCR, External loaders). Load documents directly into chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
|
||||||
|
|
||||||
|
- 🔍 **Web Search for RAG**: Perform web searches using 15+ providers including `SearXNG`, `Google PSE`, `Brave Search`, `Kagi`, `Mojeek`, `Tavily`, `Perplexity`, `serpstack`, `serper`, `Serply`, `DuckDuckGo`, `SearchApi`, `SerpApi`, `Bing`, `Jina`, `Exa`, `Sougou`, `Azure AI Search`, and `Ollama Cloud`, injecting results directly into your chat experience.
|
||||||
|
|
||||||
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
|
||||||
|
|
||||||
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
|
- 🎨 **Image Generation & Editing Integration**: Create and edit images using multiple engines including OpenAI's DALL-E, Gemini, ComfyUI (local), and AUTOMATIC1111 (local), with support for both generation and prompt-based editing workflows.
|
||||||
|
|
||||||
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
|
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
|
||||||
|
|
||||||
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
|
||||||
|
|
||||||
|
- 🗄️ **Flexible Database & Storage Options**: Choose from SQLite (with optional encryption), PostgreSQL, or configure cloud storage backends (S3, Google Cloud Storage, Azure Blob Storage) for scalable deployments.
|
||||||
|
|
||||||
|
- 🔍 **Advanced Vector Database Support**: Select from 9 vector database options including ChromaDB, PGVector, Qdrant, Milvus, Elasticsearch, OpenSearch, Pinecone, S3Vector, and Oracle 23ai for optimal RAG performance.
|
||||||
|
|
||||||
|
- 🔐 **Enterprise Authentication**: Full support for LDAP/Active Directory integration, SCIM 2.0 automated provisioning, and SSO via trusted headers alongside OAuth providers. Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management.
|
||||||
|
|
||||||
|
- ☁️ **Cloud-Native Integration**: Native support for Google Drive and OneDrive/SharePoint file picking, enabling seamless document import from enterprise cloud storage.
|
||||||
|
|
||||||
|
- 📊 **Production Observability**: Built-in OpenTelemetry support for traces, metrics, and logs, enabling comprehensive monitoring with your existing observability stack.
|
||||||
|
|
||||||
|
- ⚖️ **Horizontal Scalability**: Redis-backed session management and WebSocket support for multi-worker and multi-node deployments behind load balancers.
|
||||||
|
|
||||||
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
|
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
|
||||||
|
|
||||||
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
|
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/tree/main/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
|
||||||
|
|
@ -65,43 +77,6 @@ For more information, be sure to check out our [Open WebUI Documentation](https:
|
||||||
|
|
||||||
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
|
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
|
||||||
|
|
||||||
## Sponsors 🙌
|
|
||||||
|
|
||||||
#### Emerald
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<!-- <tr>
|
|
||||||
<td>
|
|
||||||
<a href="https://n8n.io/" target="_blank">
|
|
||||||
<img src="https://docs.openwebui.com/sponsors/logos/n8n.png" alt="n8n" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
|
||||||
</a>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<a href="https://n8n.io/">n8n</a> • Does your interface have a backend yet?<br>Try <a href="https://n8n.io/">n8n</a>
|
|
||||||
</td>
|
|
||||||
</tr> -->
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<a href="https://tailscale.com/blog/self-host-a-local-ai-stack/?utm_source=OpenWebUI&utm_medium=paid-ad-placement&utm_campaign=OpenWebUI-Docs" target="_blank">
|
|
||||||
<img src="https://docs.openwebui.com/sponsors/logos/tailscale.png" alt="Tailscale" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
|
||||||
</a>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<a href="https://tailscale.com/blog/self-host-a-local-ai-stack/?utm_source=OpenWebUI&utm_medium=paid-ad-placement&utm_campaign=OpenWebUI-Docs">Tailscale</a> • Connect self-hosted AI to any device with Tailscale
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<a href="https://warp.dev/open-webui" target="_blank">
|
|
||||||
<img src="https://docs.openwebui.com/sponsors/logos/warp.png" alt="Warp" style="width: 8rem; height: 8rem; border-radius: .75rem;" />
|
|
||||||
</a>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<a href="https://warp.dev/open-webui">Warp</a> • The intelligent terminal for developers
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
|
We are incredibly grateful for the generous support of our sponsors. Their contributions help us to maintain and improve our project, ensuring we can continue to deliver quality work to our community. Thank you!
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -38,13 +38,14 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
||||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||||
|
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
|
||||||
|
|
||||||
INVALID_TOKEN = (
|
INVALID_TOKEN = (
|
||||||
"Your session has expired or the token is invalid. Please sign in again."
|
"Your session has expired or the token is invalid. Please sign in again."
|
||||||
)
|
)
|
||||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||||
INVALID_PASSWORD = (
|
INCORRECT_PASSWORD = (
|
||||||
"The password provided is incorrect. Please check for typos and try again."
|
"The password provided is incorrect. Please check for typos and try again."
|
||||||
)
|
)
|
||||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
||||||
|
|
@ -104,6 +105,10 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||||
|
|
||||||
|
INVALID_PASSWORD = lambda err="": (
|
||||||
|
err if err else "The password does not meet the required validation criteria."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TASKS(str, Enum):
|
class TASKS(str, Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ import shutil
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
@ -135,6 +137,9 @@ else:
|
||||||
PACKAGE_DATA = {"version": "0.0.0"}
|
PACKAGE_DATA = {"version": "0.0.0"}
|
||||||
|
|
||||||
VERSION = PACKAGE_DATA["version"]
|
VERSION = PACKAGE_DATA["version"]
|
||||||
|
|
||||||
|
|
||||||
|
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "")
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -212,6 +217,11 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
|
||||||
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Experimental feature, may be removed in future
|
||||||
|
ENABLE_STAR_SESSIONS_MIDDLEWARE = (
|
||||||
|
os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_BUILD_HASH
|
# WEBUI_BUILD_HASH
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -421,6 +431,17 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ENABLE_PASSWORD_VALIDATION = (
|
||||||
|
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
|
||||||
|
"PASSWORD_VALIDATION_REGEX_PATTERN",
|
||||||
|
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
|
||||||
|
)
|
||||||
|
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
|
||||||
|
|
||||||
|
|
||||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -468,7 +489,9 @@ ENABLE_COMPRESSION_MIDDLEWARE = (
|
||||||
####################################
|
####################################
|
||||||
# OAUTH Configuration
|
# OAUTH Configuration
|
||||||
####################################
|
####################################
|
||||||
|
ENABLE_OAUTH_EMAIL_FALLBACK = (
|
||||||
|
os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
|
ENABLE_OAUTH_ID_TOKEN_COOKIE = (
|
||||||
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
|
os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true"
|
||||||
|
|
@ -482,12 +505,14 @@ OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get(
|
||||||
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
|
"OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# SCIM Configuration
|
# SCIM Configuration
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true"
|
ENABLE_SCIM = (
|
||||||
|
os.environ.get("ENABLE_SCIM", os.environ.get("SCIM_ENABLED", "False")).lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -535,6 +560,11 @@ else:
|
||||||
# CHAT
|
# CHAT
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
|
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION = (
|
||||||
|
os.environ.get("ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION", "False").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
|
||||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get(
|
||||||
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
"CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1"
|
||||||
)
|
)
|
||||||
|
|
@ -563,6 +593,21 @@ else:
|
||||||
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
|
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30
|
||||||
|
|
||||||
|
|
||||||
|
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = os.environ.get(
|
||||||
|
"CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE == "":
|
||||||
|
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = int(
|
||||||
|
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE = None
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBSOCKET SUPPORT
|
# WEBSOCKET SUPPORT
|
||||||
####################################
|
####################################
|
||||||
|
|
@ -574,6 +619,17 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
||||||
|
|
||||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||||
|
|
||||||
|
WEBSOCKET_REDIS_OPTIONS = os.environ.get("WEBSOCKET_REDIS_OPTIONS", "")
|
||||||
|
if WEBSOCKET_REDIS_OPTIONS == "":
|
||||||
|
log.debug("No WEBSOCKET_REDIS_OPTIONS provided, defaulting to None")
|
||||||
|
WEBSOCKET_REDIS_OPTIONS = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
WEBSOCKET_REDIS_OPTIONS = json.loads(WEBSOCKET_REDIS_OPTIONS)
|
||||||
|
except Exception:
|
||||||
|
log.warning("Invalid WEBSOCKET_REDIS_OPTIONS, defaulting to None")
|
||||||
|
WEBSOCKET_REDIS_OPTIONS = None
|
||||||
|
|
||||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||||
WEBSOCKET_REDIS_CLUSTER = (
|
WEBSOCKET_REDIS_CLUSTER = (
|
||||||
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
||||||
|
|
@ -588,6 +644,23 @@ except ValueError:
|
||||||
|
|
||||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||||
|
WEBSOCKET_SERVER_LOGGING = (
|
||||||
|
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
WEBSOCKET_SERVER_ENGINEIO_LOGGING = (
|
||||||
|
os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20")
|
||||||
|
try:
|
||||||
|
WEBSOCKET_SERVER_PING_TIMEOUT = int(WEBSOCKET_SERVER_PING_TIMEOUT)
|
||||||
|
except ValueError:
|
||||||
|
WEBSOCKET_SERVER_PING_TIMEOUT = 20
|
||||||
|
|
||||||
|
WEBSOCKET_SERVER_PING_INTERVAL = os.environ.get("WEBSOCKET_SERVER_PING_INTERVAL", "25")
|
||||||
|
try:
|
||||||
|
WEBSOCKET_SERVER_PING_INTERVAL = int(WEBSOCKET_SERVER_PING_INTERVAL)
|
||||||
|
except ValueError:
|
||||||
|
WEBSOCKET_SERVER_PING_INTERVAL = 25
|
||||||
|
|
||||||
|
|
||||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||||
|
|
@ -700,7 +773,9 @@ if OFFLINE_MODE:
|
||||||
# AUDIT LOGGING
|
# AUDIT LOGGING
|
||||||
####################################
|
####################################
|
||||||
# Where to store log file
|
# Where to store log file
|
||||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
# Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to
|
||||||
|
# provide the whole path, like: /app/audit.log
|
||||||
|
AUDIT_LOGS_FILE_PATH = os.getenv("AUDIT_LOGS_FILE_PATH", f"{DATA_DIR}/audit.log")
|
||||||
# Maximum size of a file before rotating into a new log file
|
# Maximum size of a file before rotating into a new log file
|
||||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,11 +61,11 @@ from open_webui.utils import logger
|
||||||
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
|
||||||
from open_webui.utils.logger import start_logger
|
from open_webui.utils.logger import start_logger
|
||||||
from open_webui.socket.main import (
|
from open_webui.socket.main import (
|
||||||
|
MODELS,
|
||||||
app as socket_app,
|
app as socket_app,
|
||||||
periodic_usage_pool_cleanup,
|
periodic_usage_pool_cleanup,
|
||||||
get_event_emitter,
|
get_event_emitter,
|
||||||
get_models_in_use,
|
get_models_in_use,
|
||||||
get_active_user_ids,
|
|
||||||
)
|
)
|
||||||
from open_webui.routers import (
|
from open_webui.routers import (
|
||||||
audio,
|
audio,
|
||||||
|
|
@ -145,9 +146,7 @@ from open_webui.config import (
|
||||||
# Image
|
# Image
|
||||||
AUTOMATIC1111_API_AUTH,
|
AUTOMATIC1111_API_AUTH,
|
||||||
AUTOMATIC1111_BASE_URL,
|
AUTOMATIC1111_BASE_URL,
|
||||||
AUTOMATIC1111_CFG_SCALE,
|
AUTOMATIC1111_PARAMS,
|
||||||
AUTOMATIC1111_SAMPLER,
|
|
||||||
AUTOMATIC1111_SCHEDULER,
|
|
||||||
COMFYUI_BASE_URL,
|
COMFYUI_BASE_URL,
|
||||||
COMFYUI_API_KEY,
|
COMFYUI_API_KEY,
|
||||||
COMFYUI_WORKFLOW,
|
COMFYUI_WORKFLOW,
|
||||||
|
|
@ -161,8 +160,23 @@ from open_webui.config import (
|
||||||
IMAGES_OPENAI_API_BASE_URL,
|
IMAGES_OPENAI_API_BASE_URL,
|
||||||
IMAGES_OPENAI_API_VERSION,
|
IMAGES_OPENAI_API_VERSION,
|
||||||
IMAGES_OPENAI_API_KEY,
|
IMAGES_OPENAI_API_KEY,
|
||||||
|
IMAGES_OPENAI_API_PARAMS,
|
||||||
IMAGES_GEMINI_API_BASE_URL,
|
IMAGES_GEMINI_API_BASE_URL,
|
||||||
IMAGES_GEMINI_API_KEY,
|
IMAGES_GEMINI_API_KEY,
|
||||||
|
IMAGES_GEMINI_ENDPOINT_METHOD,
|
||||||
|
ENABLE_IMAGE_EDIT,
|
||||||
|
IMAGE_EDIT_ENGINE,
|
||||||
|
IMAGE_EDIT_MODEL,
|
||||||
|
IMAGE_EDIT_SIZE,
|
||||||
|
IMAGES_EDIT_OPENAI_API_BASE_URL,
|
||||||
|
IMAGES_EDIT_OPENAI_API_KEY,
|
||||||
|
IMAGES_EDIT_OPENAI_API_VERSION,
|
||||||
|
IMAGES_EDIT_GEMINI_API_BASE_URL,
|
||||||
|
IMAGES_EDIT_GEMINI_API_KEY,
|
||||||
|
IMAGES_EDIT_COMFYUI_BASE_URL,
|
||||||
|
IMAGES_EDIT_COMFYUI_API_KEY,
|
||||||
|
IMAGES_EDIT_COMFYUI_WORKFLOW,
|
||||||
|
IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
|
||||||
# Audio
|
# Audio
|
||||||
AUDIO_STT_ENGINE,
|
AUDIO_STT_ENGINE,
|
||||||
AUDIO_STT_MODEL,
|
AUDIO_STT_MODEL,
|
||||||
|
|
@ -174,13 +188,17 @@ from open_webui.config import (
|
||||||
AUDIO_STT_AZURE_LOCALES,
|
AUDIO_STT_AZURE_LOCALES,
|
||||||
AUDIO_STT_AZURE_BASE_URL,
|
AUDIO_STT_AZURE_BASE_URL,
|
||||||
AUDIO_STT_AZURE_MAX_SPEAKERS,
|
AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||||
AUDIO_TTS_API_KEY,
|
AUDIO_STT_MISTRAL_API_KEY,
|
||||||
|
AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||||
|
AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||||
AUDIO_TTS_ENGINE,
|
AUDIO_TTS_ENGINE,
|
||||||
AUDIO_TTS_MODEL,
|
AUDIO_TTS_MODEL,
|
||||||
|
AUDIO_TTS_VOICE,
|
||||||
AUDIO_TTS_OPENAI_API_BASE_URL,
|
AUDIO_TTS_OPENAI_API_BASE_URL,
|
||||||
AUDIO_TTS_OPENAI_API_KEY,
|
AUDIO_TTS_OPENAI_API_KEY,
|
||||||
|
AUDIO_TTS_OPENAI_PARAMS,
|
||||||
|
AUDIO_TTS_API_KEY,
|
||||||
AUDIO_TTS_SPLIT_ON,
|
AUDIO_TTS_SPLIT_ON,
|
||||||
AUDIO_TTS_VOICE,
|
|
||||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||||
AUDIO_TTS_AZURE_SPEECH_BASE_URL,
|
AUDIO_TTS_AZURE_SPEECH_BASE_URL,
|
||||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||||
|
|
@ -212,6 +230,7 @@ from open_webui.config import (
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_EMBEDDING_ENGINE,
|
RAG_EMBEDDING_ENGINE,
|
||||||
RAG_EMBEDDING_BATCH_SIZE,
|
RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
ENABLE_ASYNC_EMBEDDING,
|
||||||
RAG_TOP_K,
|
RAG_TOP_K,
|
||||||
RAG_TOP_K_RERANKER,
|
RAG_TOP_K_RERANKER,
|
||||||
RAG_RELEVANCE_THRESHOLD,
|
RAG_RELEVANCE_THRESHOLD,
|
||||||
|
|
@ -241,24 +260,21 @@ from open_webui.config import (
|
||||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||||
DATALAB_MARKER_FORMAT_LINES,
|
DATALAB_MARKER_FORMAT_LINES,
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT,
|
DATALAB_MARKER_OUTPUT_FORMAT,
|
||||||
|
MINERU_API_MODE,
|
||||||
|
MINERU_API_URL,
|
||||||
|
MINERU_API_KEY,
|
||||||
|
MINERU_PARAMS,
|
||||||
DATALAB_MARKER_USE_LLM,
|
DATALAB_MARKER_USE_LLM,
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||||
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||||
TIKA_SERVER_URL,
|
TIKA_SERVER_URL,
|
||||||
DOCLING_SERVER_URL,
|
DOCLING_SERVER_URL,
|
||||||
DOCLING_DO_OCR,
|
DOCLING_API_KEY,
|
||||||
DOCLING_FORCE_OCR,
|
DOCLING_PARAMS,
|
||||||
DOCLING_OCR_ENGINE,
|
|
||||||
DOCLING_OCR_LANG,
|
|
||||||
DOCLING_PDF_BACKEND,
|
|
||||||
DOCLING_TABLE_MODE,
|
|
||||||
DOCLING_PIPELINE,
|
|
||||||
DOCLING_DO_PICTURE_DESCRIPTION,
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_MODE,
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_API,
|
|
||||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
DOCUMENT_INTELLIGENCE_KEY,
|
DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
DOCUMENT_INTELLIGENCE_MODEL,
|
||||||
|
MISTRAL_OCR_API_BASE_URL,
|
||||||
MISTRAL_OCR_API_KEY,
|
MISTRAL_OCR_API_KEY,
|
||||||
RAG_TEXT_SPLITTER,
|
RAG_TEXT_SPLITTER,
|
||||||
TIKTOKEN_ENCODING_NAME,
|
TIKTOKEN_ENCODING_NAME,
|
||||||
|
|
@ -297,6 +313,7 @@ from open_webui.config import (
|
||||||
PERPLEXITY_API_KEY,
|
PERPLEXITY_API_KEY,
|
||||||
PERPLEXITY_MODEL,
|
PERPLEXITY_MODEL,
|
||||||
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||||
|
PERPLEXITY_SEARCH_API_URL,
|
||||||
SOUGOU_API_SID,
|
SOUGOU_API_SID,
|
||||||
SOUGOU_API_SK,
|
SOUGOU_API_SK,
|
||||||
KAGI_SEARCH_API_KEY,
|
KAGI_SEARCH_API_KEY,
|
||||||
|
|
@ -314,6 +331,7 @@ from open_webui.config import (
|
||||||
ENABLE_ONEDRIVE_PERSONAL,
|
ENABLE_ONEDRIVE_PERSONAL,
|
||||||
ENABLE_ONEDRIVE_BUSINESS,
|
ENABLE_ONEDRIVE_BUSINESS,
|
||||||
ENABLE_RAG_HYBRID_SEARCH,
|
ENABLE_RAG_HYBRID_SEARCH,
|
||||||
|
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
|
||||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||||
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||||
|
|
@ -332,9 +350,10 @@ from open_webui.config import (
|
||||||
JWT_EXPIRES_IN,
|
JWT_EXPIRES_IN,
|
||||||
ENABLE_SIGNUP,
|
ENABLE_SIGNUP,
|
||||||
ENABLE_LOGIN_FORM,
|
ENABLE_LOGIN_FORM,
|
||||||
ENABLE_API_KEY,
|
ENABLE_API_KEYS,
|
||||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
||||||
API_KEY_ALLOWED_ENDPOINTS,
|
API_KEYS_ALLOWED_ENDPOINTS,
|
||||||
|
ENABLE_FOLDERS,
|
||||||
ENABLE_CHANNELS,
|
ENABLE_CHANNELS,
|
||||||
ENABLE_NOTES,
|
ENABLE_NOTES,
|
||||||
ENABLE_COMMUNITY_SHARING,
|
ENABLE_COMMUNITY_SHARING,
|
||||||
|
|
@ -344,10 +363,12 @@ from open_webui.config import (
|
||||||
BYPASS_ADMIN_ACCESS_CONTROL,
|
BYPASS_ADMIN_ACCESS_CONTROL,
|
||||||
USER_PERMISSIONS,
|
USER_PERMISSIONS,
|
||||||
DEFAULT_USER_ROLE,
|
DEFAULT_USER_ROLE,
|
||||||
|
DEFAULT_GROUP_ID,
|
||||||
PENDING_USER_OVERLAY_CONTENT,
|
PENDING_USER_OVERLAY_CONTENT,
|
||||||
PENDING_USER_OVERLAY_TITLE,
|
PENDING_USER_OVERLAY_TITLE,
|
||||||
DEFAULT_PROMPT_SUGGESTIONS,
|
DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
DEFAULT_MODELS,
|
DEFAULT_MODELS,
|
||||||
|
DEFAULT_PINNED_MODELS,
|
||||||
DEFAULT_ARENA_MODEL,
|
DEFAULT_ARENA_MODEL,
|
||||||
MODEL_ORDER_LIST,
|
MODEL_ORDER_LIST,
|
||||||
EVALUATION_ARENA_MODELS,
|
EVALUATION_ARENA_MODELS,
|
||||||
|
|
@ -406,6 +427,7 @@ from open_webui.config import (
|
||||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
VOICE_MODE_PROMPT_TEMPLATE,
|
||||||
QUERY_GENERATION_PROMPT_TEMPLATE,
|
QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||||
|
|
@ -427,6 +449,7 @@ from open_webui.env import (
|
||||||
SAFE_MODE,
|
SAFE_MODE,
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
VERSION,
|
VERSION,
|
||||||
|
DEPLOYMENT_ID,
|
||||||
INSTANCE_ID,
|
INSTANCE_ID,
|
||||||
WEBUI_BUILD_HASH,
|
WEBUI_BUILD_HASH,
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
|
|
@ -437,7 +460,7 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||||
# SCIM
|
# SCIM
|
||||||
SCIM_ENABLED,
|
ENABLE_SCIM,
|
||||||
SCIM_TOKEN,
|
SCIM_TOKEN,
|
||||||
ENABLE_COMPRESSION_MIDDLEWARE,
|
ENABLE_COMPRESSION_MIDDLEWARE,
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
|
|
@ -447,6 +470,7 @@ from open_webui.env import (
|
||||||
ENABLE_OTEL,
|
ENABLE_OTEL,
|
||||||
EXTERNAL_PWA_MANIFEST_URL,
|
EXTERNAL_PWA_MANIFEST_URL,
|
||||||
AIOHTTP_CLIENT_SESSION_SSL,
|
AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
ENABLE_STAR_SESSIONS_MIDDLEWARE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -474,9 +498,11 @@ from open_webui.utils.auth import (
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||||
from open_webui.utils.oauth import (
|
from open_webui.utils.oauth import (
|
||||||
|
get_oauth_client_info_with_dynamic_client_registration,
|
||||||
|
encrypt_data,
|
||||||
|
decrypt_data,
|
||||||
OAuthManager,
|
OAuthManager,
|
||||||
OAuthClientManager,
|
OAuthClientManager,
|
||||||
decrypt_data,
|
|
||||||
OAuthClientInformationFull,
|
OAuthClientInformationFull,
|
||||||
)
|
)
|
||||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||||
|
|
@ -690,7 +716,7 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||||
#
|
#
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
app.state.SCIM_ENABLED = SCIM_ENABLED
|
app.state.ENABLE_SCIM = ENABLE_SCIM
|
||||||
app.state.SCIM_TOKEN = SCIM_TOKEN
|
app.state.SCIM_TOKEN = SCIM_TOKEN
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
|
@ -712,11 +738,11 @@ app.state.config.WEBUI_URL = WEBUI_URL
|
||||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||||
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
||||||
|
|
||||||
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
|
app.state.config.ENABLE_API_KEYS = ENABLE_API_KEYS
|
||||||
app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
|
||||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
|
||||||
)
|
)
|
||||||
app.state.config.API_KEY_ALLOWED_ENDPOINTS = API_KEY_ALLOWED_ENDPOINTS
|
app.state.config.API_KEYS_ALLOWED_ENDPOINTS = API_KEYS_ALLOWED_ENDPOINTS
|
||||||
|
|
||||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||||
|
|
||||||
|
|
@ -725,8 +751,13 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
||||||
|
|
||||||
|
|
||||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||||
|
app.state.config.DEFAULT_PINNED_MODELS = DEFAULT_PINNED_MODELS
|
||||||
|
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
||||||
|
|
||||||
|
|
||||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||||
|
app.state.config.DEFAULT_GROUP_ID = DEFAULT_GROUP_ID
|
||||||
|
|
||||||
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
|
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
|
||||||
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
|
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
|
||||||
|
|
@ -736,9 +767,9 @@ app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
|
||||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
app.state.config.BANNERS = WEBUI_BANNERS
|
app.state.config.BANNERS = WEBUI_BANNERS
|
||||||
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
|
||||||
|
|
||||||
|
|
||||||
|
app.state.config.ENABLE_FOLDERS = ENABLE_FOLDERS
|
||||||
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
||||||
app.state.config.ENABLE_NOTES = ENABLE_NOTES
|
app.state.config.ENABLE_NOTES = ENABLE_NOTES
|
||||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||||
|
|
@ -814,6 +845,9 @@ app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT
|
||||||
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
|
||||||
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||||
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
|
||||||
|
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
|
||||||
|
)
|
||||||
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||||||
|
|
||||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||||
|
|
@ -834,20 +868,17 @@ app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||||
app.state.config.DOCLING_DO_OCR = DOCLING_DO_OCR
|
app.state.config.DOCLING_API_KEY = DOCLING_API_KEY
|
||||||
app.state.config.DOCLING_FORCE_OCR = DOCLING_FORCE_OCR
|
app.state.config.DOCLING_PARAMS = DOCLING_PARAMS
|
||||||
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
|
|
||||||
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
|
|
||||||
app.state.config.DOCLING_PDF_BACKEND = DOCLING_PDF_BACKEND
|
|
||||||
app.state.config.DOCLING_TABLE_MODE = DOCLING_TABLE_MODE
|
|
||||||
app.state.config.DOCLING_PIPELINE = DOCLING_PIPELINE
|
|
||||||
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
|
|
||||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
|
|
||||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
||||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
|
|
||||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||||
|
app.state.config.DOCUMENT_INTELLIGENCE_MODEL = DOCUMENT_INTELLIGENCE_MODEL
|
||||||
|
app.state.config.MISTRAL_OCR_API_BASE_URL = MISTRAL_OCR_API_BASE_URL
|
||||||
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
||||||
|
app.state.config.MINERU_API_MODE = MINERU_API_MODE
|
||||||
|
app.state.config.MINERU_API_URL = MINERU_API_URL
|
||||||
|
app.state.config.MINERU_API_KEY = MINERU_API_KEY
|
||||||
|
app.state.config.MINERU_PARAMS = MINERU_PARAMS
|
||||||
|
|
||||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||||
|
|
@ -858,6 +889,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||||
|
app.state.config.ENABLE_ASYNC_EMBEDDING = ENABLE_ASYNC_EMBEDDING
|
||||||
|
|
||||||
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
|
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
|
||||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||||
|
|
@ -927,6 +959,7 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||||
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
||||||
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||||
|
app.state.config.PERPLEXITY_SEARCH_API_URL = PERPLEXITY_SEARCH_API_URL
|
||||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||||
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
||||||
|
|
@ -951,9 +984,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.state.ef = get_ef(
|
app.state.ef = get_ef(
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||||
|
|
@ -964,7 +995,6 @@ try:
|
||||||
app.state.config.RAG_RERANKING_MODEL,
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app.state.rf = None
|
app.state.rf = None
|
||||||
|
|
@ -1049,27 +1079,42 @@ app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
|
||||||
app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
|
app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
|
||||||
app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
|
app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
|
||||||
|
|
||||||
|
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
||||||
|
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
||||||
|
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||||
|
|
||||||
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||||
app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION
|
app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION
|
||||||
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||||
|
app.state.config.IMAGES_OPENAI_API_PARAMS = IMAGES_OPENAI_API_PARAMS
|
||||||
|
|
||||||
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
|
||||||
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
|
||||||
|
app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
|
||||||
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
|
|
||||||
|
|
||||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||||
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
||||||
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
|
app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
|
||||||
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
|
|
||||||
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
|
|
||||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||||
app.state.config.COMFYUI_API_KEY = COMFYUI_API_KEY
|
app.state.config.COMFYUI_API_KEY = COMFYUI_API_KEY
|
||||||
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
||||||
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
||||||
|
|
||||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
|
||||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
app.state.config.ENABLE_IMAGE_EDIT = ENABLE_IMAGE_EDIT
|
||||||
|
app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
|
||||||
|
app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
|
||||||
|
app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
|
||||||
|
app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL
|
||||||
|
app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY
|
||||||
|
app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION
|
||||||
|
app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL
|
||||||
|
app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY
|
||||||
|
app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL = IMAGES_EDIT_COMFYUI_BASE_URL
|
||||||
|
app.state.config.IMAGES_EDIT_COMFYUI_API_KEY = IMAGES_EDIT_COMFYUI_API_KEY
|
||||||
|
app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW = IMAGES_EDIT_COMFYUI_WORKFLOW
|
||||||
|
app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = IMAGES_EDIT_COMFYUI_WORKFLOW_NODES
|
||||||
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
|
@ -1095,11 +1140,21 @@ app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
|
||||||
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
|
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
|
||||||
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
|
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
|
||||||
|
|
||||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
app.state.config.AUDIO_STT_MISTRAL_API_KEY = AUDIO_STT_MISTRAL_API_KEY
|
||||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = AUDIO_STT_MISTRAL_API_BASE_URL
|
||||||
|
app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = (
|
||||||
|
AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
|
||||||
|
)
|
||||||
|
|
||||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||||
|
|
||||||
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
||||||
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
||||||
|
|
||||||
|
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||||
|
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||||
|
app.state.config.TTS_OPENAI_PARAMS = AUDIO_TTS_OPENAI_PARAMS
|
||||||
|
|
||||||
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
||||||
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
||||||
|
|
||||||
|
|
@ -1152,6 +1207,7 @@ app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
|
||||||
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
|
||||||
)
|
)
|
||||||
|
app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
|
|
@ -1160,7 +1216,11 @@ app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
|
||||||
#
|
#
|
||||||
########################################
|
########################################
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = MODELS
|
||||||
|
|
||||||
|
# Add the middleware to the app
|
||||||
|
if ENABLE_COMPRESSION_MIDDLEWARE:
|
||||||
|
app.add_middleware(CompressMiddleware)
|
||||||
|
|
||||||
|
|
||||||
class RedirectMiddleware(BaseHTTPMiddleware):
|
class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
|
|
@ -1170,12 +1230,32 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
query_params = dict(parse_qs(urlparse(str(request.url)).query))
|
query_params = dict(parse_qs(urlparse(str(request.url)).query))
|
||||||
|
|
||||||
|
redirect_params = {}
|
||||||
|
|
||||||
# Check for the specific watch path and the presence of 'v' parameter
|
# Check for the specific watch path and the presence of 'v' parameter
|
||||||
if path.endswith("/watch") and "v" in query_params:
|
if path.endswith("/watch") and "v" in query_params:
|
||||||
# Extract the first 'v' parameter
|
# Extract the first 'v' parameter
|
||||||
video_id = query_params["v"][0]
|
youtube_video_id = query_params["v"][0]
|
||||||
encoded_video_id = urlencode({"youtube": video_id})
|
redirect_params["youtube"] = youtube_video_id
|
||||||
redirect_url = f"/?{encoded_video_id}"
|
|
||||||
|
if "shared" in query_params and len(query_params["shared"]) > 0:
|
||||||
|
# PWA share_target support
|
||||||
|
|
||||||
|
text = query_params["shared"][0]
|
||||||
|
if text:
|
||||||
|
urls = re.match(r"https://\S+", text)
|
||||||
|
if urls:
|
||||||
|
from open_webui.retrieval.loaders.youtube import _parse_video_id
|
||||||
|
|
||||||
|
if youtube_video_id := _parse_video_id(urls[0]):
|
||||||
|
redirect_params["youtube"] = youtube_video_id
|
||||||
|
else:
|
||||||
|
redirect_params["load-url"] = urls[0]
|
||||||
|
else:
|
||||||
|
redirect_params["q"] = text
|
||||||
|
|
||||||
|
if redirect_params:
|
||||||
|
redirect_url = f"/?{urlencode(redirect_params)}"
|
||||||
return RedirectResponse(url=redirect_url)
|
return RedirectResponse(url=redirect_url)
|
||||||
|
|
||||||
# Proceed with the normal flow of other requests
|
# Proceed with the normal flow of other requests
|
||||||
|
|
@ -1183,14 +1263,53 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Add the middleware to the app
|
|
||||||
if ENABLE_COMPRESSION_MIDDLEWARE:
|
|
||||||
app.add_middleware(CompressMiddleware)
|
|
||||||
|
|
||||||
app.add_middleware(RedirectMiddleware)
|
app.add_middleware(RedirectMiddleware)
|
||||||
app.add_middleware(SecurityHeadersMiddleware)
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
token = None
|
||||||
|
|
||||||
|
if auth_header:
|
||||||
|
scheme, token = auth_header.split(" ")
|
||||||
|
|
||||||
|
# Only apply restrictions if an sk- API key is used
|
||||||
|
if token and token.startswith("sk-"):
|
||||||
|
# Check if restrictions are enabled
|
||||||
|
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
|
||||||
|
allowed_paths = [
|
||||||
|
path.strip()
|
||||||
|
for path in str(
|
||||||
|
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS
|
||||||
|
).split(",")
|
||||||
|
if path.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
request_path = request.url.path
|
||||||
|
|
||||||
|
# Match exact path or prefix path
|
||||||
|
is_allowed = any(
|
||||||
|
request_path == allowed or request_path.startswith(allowed + "/")
|
||||||
|
for allowed in allowed_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_allowed:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
content={
|
||||||
|
"detail": "API key not allowed to access this endpoint."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(APIKeyRestrictionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def commit_session_after_request(request: Request, call_next):
|
async def commit_session_after_request(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
@ -1206,7 +1325,7 @@ async def check_url(request: Request, call_next):
|
||||||
request.headers.get("Authorization")
|
request.headers.get("Authorization")
|
||||||
)
|
)
|
||||||
|
|
||||||
request.state.enable_api_key = app.state.config.ENABLE_API_KEY
|
request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
process_time = int(time.time()) - start_time
|
process_time = int(time.time()) - start_time
|
||||||
response.headers["X-Process-Time"] = str(process_time)
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
|
@ -1281,7 +1400,7 @@ app.include_router(
|
||||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||||
|
|
||||||
# SCIM 2.0 API for identity management
|
# SCIM 2.0 API for identity management
|
||||||
if SCIM_ENABLED:
|
if ENABLE_SCIM:
|
||||||
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1318,6 +1437,10 @@ async def get_models(
|
||||||
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
if "pipeline" in model and model["pipeline"].get("type", None) == "filter":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Remove profile image URL to reduce payload size
|
||||||
|
if model.get("info", {}).get("meta", {}).get("profile_image_url"):
|
||||||
|
model["info"]["meta"].pop("profile_image_url", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_tags = [
|
model_tags = [
|
||||||
tag.get("name")
|
tag.get("name")
|
||||||
|
|
@ -1440,6 +1563,9 @@ async def chat_completion(
|
||||||
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
reasoning_tags = form_data.get("params", {}).get("reasoning_tags")
|
||||||
|
|
||||||
# Model Params
|
# Model Params
|
||||||
|
if model_info_params.get("stream_response") is not None:
|
||||||
|
form_data["stream"] = model_info_params.get("stream_response")
|
||||||
|
|
||||||
if model_info_params.get("stream_delta_chunk_size"):
|
if model_info_params.get("stream_delta_chunk_size"):
|
||||||
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size")
|
||||||
|
|
||||||
|
|
@ -1450,6 +1576,7 @@ async def chat_completion(
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"chat_id": form_data.pop("chat_id", None),
|
"chat_id": form_data.pop("chat_id", None),
|
||||||
"message_id": form_data.pop("id", None),
|
"message_id": form_data.pop("id", None),
|
||||||
|
"parent_message_id": form_data.pop("parent_id", None),
|
||||||
"session_id": form_data.pop("session_id", None),
|
"session_id": form_data.pop("session_id", None),
|
||||||
"filter_ids": form_data.pop("filter_ids", []),
|
"filter_ids": form_data.pop("filter_ids", []),
|
||||||
"tool_ids": form_data.get("tool_ids", None),
|
"tool_ids": form_data.get("tool_ids", None),
|
||||||
|
|
@ -1474,7 +1601,7 @@ async def chat_completion(
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.get("chat_id") and (user and user.role != "admin"):
|
if metadata.get("chat_id") and (user and user.role != "admin"):
|
||||||
if metadata["chat_id"] != "local":
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id)
|
||||||
if chat is None:
|
if chat is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -1501,10 +1628,12 @@ async def chat_completion(
|
||||||
response = await chat_completion_handler(request, form_data, user)
|
response = await chat_completion_handler(request, form_data, user)
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
try:
|
try:
|
||||||
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
metadata["chat_id"],
|
metadata["chat_id"],
|
||||||
metadata["message_id"],
|
metadata["message_id"],
|
||||||
{
|
{
|
||||||
|
"parentId": metadata.get("parent_message_id", None),
|
||||||
"model": model_id,
|
"model": model_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -1518,20 +1647,26 @@ async def chat_completion(
|
||||||
log.info("Chat processing was cancelled")
|
log.info("Chat processing was cancelled")
|
||||||
try:
|
try:
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
await event_emitter(
|
await asyncio.shield(
|
||||||
|
event_emitter(
|
||||||
{"type": "chat:tasks:cancel"},
|
{"type": "chat:tasks:cancel"},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
raise # re-raise to ensure proper task cancellation handling
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error processing chat payload: {e}")
|
log.debug(f"Error processing chat payload: {e}")
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
# Update the chat message with the error
|
# Update the chat message with the error
|
||||||
try:
|
try:
|
||||||
|
if not metadata["chat_id"].startswith("local:"):
|
||||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
metadata["chat_id"],
|
metadata["chat_id"],
|
||||||
metadata["message_id"],
|
metadata["message_id"],
|
||||||
{
|
{
|
||||||
|
"parentId": metadata.get("parent_message_id", None),
|
||||||
"error": {"content": str(e)},
|
"error": {"content": str(e)},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -1552,7 +1687,7 @@ async def chat_completion(
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
if mcp_clients := metadata.get("mcp_clients"):
|
if mcp_clients := metadata.get("mcp_clients"):
|
||||||
for client in mcp_clients.values():
|
for client in reversed(mcp_clients.values()):
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error cleaning up: {e}")
|
log.debug(f"Error cleaning up: {e}")
|
||||||
|
|
@ -1703,7 +1838,7 @@ async def get_app_config(request: Request):
|
||||||
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||||
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
|
||||||
"enable_ldap": app.state.config.ENABLE_LDAP,
|
"enable_ldap": app.state.config.ENABLE_LDAP,
|
||||||
"enable_api_key": app.state.config.ENABLE_API_KEY,
|
"enable_api_keys": app.state.config.ENABLE_API_KEYS,
|
||||||
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
"enable_signup": app.state.config.ENABLE_SIGNUP,
|
||||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||||
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
|
||||||
|
|
@ -1711,6 +1846,7 @@ async def get_app_config(request: Request):
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||||
|
"enable_folders": app.state.config.ENABLE_FOLDERS,
|
||||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||||
"enable_notes": app.state.config.ENABLE_NOTES,
|
"enable_notes": app.state.config.ENABLE_NOTES,
|
||||||
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
||||||
|
|
@ -1741,6 +1877,7 @@ async def get_app_config(request: Request):
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"default_models": app.state.config.DEFAULT_MODELS,
|
"default_models": app.state.config.DEFAULT_MODELS,
|
||||||
|
"default_pinned_models": app.state.config.DEFAULT_PINNED_MODELS,
|
||||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
"user_count": user_count,
|
"user_count": user_count,
|
||||||
"code": {
|
"code": {
|
||||||
|
|
@ -1842,6 +1979,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||||
async def get_app_version():
|
async def get_app_version():
|
||||||
return {
|
return {
|
||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
|
"deployment_id": DEPLOYMENT_ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1881,7 +2019,10 @@ async def get_current_usage(user=Depends(get_verified_user)):
|
||||||
This is an experimental endpoint and subject to change.
|
This is an experimental endpoint and subject to change.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()}
|
return {
|
||||||
|
"model_ids": get_models_in_use(),
|
||||||
|
"user_count": Users.get_active_user_count(),
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting usage statistics: {e}")
|
log.error(f"Error getting usage statistics: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Internal Server Error")
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
||||||
|
|
@ -1898,18 +2039,26 @@ if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
|
||||||
if tool_server_connection.get("type", "openapi") == "mcp":
|
if tool_server_connection.get("type", "openapi") == "mcp":
|
||||||
server_id = tool_server_connection.get("info", {}).get("id")
|
server_id = tool_server_connection.get("info", {}).get("id")
|
||||||
auth_type = tool_server_connection.get("auth_type", "none")
|
auth_type = tool_server_connection.get("auth_type", "none")
|
||||||
|
|
||||||
if server_id and auth_type == "oauth_2.1":
|
if server_id and auth_type == "oauth_2.1":
|
||||||
oauth_client_info = tool_server_connection.get("info", {}).get(
|
oauth_client_info = tool_server_connection.get("info", {}).get(
|
||||||
"oauth_client_info", ""
|
"oauth_client_info", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
oauth_client_info = decrypt_data(oauth_client_info)
|
oauth_client_info = decrypt_data(oauth_client_info)
|
||||||
app.state.oauth_client_manager.add_client(
|
app.state.oauth_client_manager.add_client(
|
||||||
f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info)
|
f"mcp:{server_id}",
|
||||||
|
OAuthClientInformationFull(**oauth_client_info),
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Error adding OAuth client for MCP tool server {server_id}: {e}"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if REDIS_URL:
|
if ENABLE_STAR_SESSIONS_MIDDLEWARE:
|
||||||
redis_session_store = RedisStore(
|
redis_session_store = RedisStore(
|
||||||
url=REDIS_URL,
|
url=REDIS_URL,
|
||||||
prefix=(f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"),
|
prefix=(f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"),
|
||||||
|
|
@ -1936,6 +2085,64 @@ except Exception as e:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_client(request, client_id: str) -> bool:
|
||||||
|
server_type, server_id = client_id.split(":", 1)
|
||||||
|
|
||||||
|
connection = None
|
||||||
|
connection_idx = None
|
||||||
|
|
||||||
|
for idx, conn in enumerate(request.app.state.config.TOOL_SERVER_CONNECTIONS or []):
|
||||||
|
if conn.get("type", "openapi") == server_type:
|
||||||
|
info = conn.get("info", {})
|
||||||
|
if info.get("id") == server_id:
|
||||||
|
connection = conn
|
||||||
|
connection_idx = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if connection is None or connection_idx is None:
|
||||||
|
log.warning(
|
||||||
|
f"Unable to locate MCP tool server configuration for client {client_id} during re-registration"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
server_url = connection.get("url")
|
||||||
|
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
||||||
|
|
||||||
|
try:
|
||||||
|
oauth_client_info = (
|
||||||
|
await get_oauth_client_info_with_dynamic_client_registration(
|
||||||
|
request,
|
||||||
|
client_id,
|
||||||
|
server_url,
|
||||||
|
oauth_server_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Dynamic client re-registration failed for {client_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS[connection_idx] = {
|
||||||
|
**connection,
|
||||||
|
"info": {
|
||||||
|
**connection.get("info", {}),
|
||||||
|
"oauth_client_info": encrypt_data(
|
||||||
|
oauth_client_info.model_dump(mode="json")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to persist updated OAuth client info for tool server {client_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
oauth_client_manager.remove_client(client_id)
|
||||||
|
oauth_client_manager.add_client(client_id, oauth_client_info)
|
||||||
|
log.info(f"Re-registered OAuth client {client_id} for tool server")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@app.get("/oauth/clients/{client_id}/authorize")
|
@app.get("/oauth/clients/{client_id}/authorize")
|
||||||
async def oauth_client_authorize(
|
async def oauth_client_authorize(
|
||||||
client_id: str,
|
client_id: str,
|
||||||
|
|
@ -1943,6 +2150,41 @@ async def oauth_client_authorize(
|
||||||
response: Response,
|
response: Response,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
# ensure_valid_client_registration
|
||||||
|
client = oauth_client_manager.get_client(client_id)
|
||||||
|
client_info = oauth_client_manager.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
if not await oauth_client_manager._preflight_authorization_url(client, client_info):
|
||||||
|
log.info(
|
||||||
|
"Detected invalid OAuth client %s; attempting re-registration",
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
registered = await register_client(request, client_id)
|
||||||
|
if not registered:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to re-register OAuth client",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = oauth_client_manager.get_client(client_id)
|
||||||
|
client_info = oauth_client_manager.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client unavailable after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await oauth_client_manager._preflight_authorization_url(
|
||||||
|
client, client_info
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client registration is still invalid after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
|
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2004,6 +2246,11 @@ async def get_manifest_json():
|
||||||
"purpose": "maskable",
|
"purpose": "maskable",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"share_target": {
|
||||||
|
"action": "/",
|
||||||
|
"method": "GET",
|
||||||
|
"params": {"text": "shared"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""Update messages and channel member table
|
||||||
|
|
||||||
|
Revision ID: 2f1211949ecc
|
||||||
|
Revises: 37f288994c47
|
||||||
|
Create Date: 2025-11-27 03:07:56.200231
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import open_webui.internal.db
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "2f1211949ecc"
|
||||||
|
down_revision: Union[str, None] = "37f288994c47"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# New columns to be added to channel_member table
|
||||||
|
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
|
||||||
|
op.add_column(
|
||||||
|
"channel_member",
|
||||||
|
sa.Column(
|
||||||
|
"is_active",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
default=True,
|
||||||
|
server_default=sa.sql.expression.true(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"channel_member",
|
||||||
|
sa.Column(
|
||||||
|
"is_channel_muted",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
default=False,
|
||||||
|
server_default=sa.sql.expression.false(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"channel_member",
|
||||||
|
sa.Column(
|
||||||
|
"is_channel_pinned",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
default=False,
|
||||||
|
server_default=sa.sql.expression.false(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
|
||||||
|
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# New columns to be added to message table
|
||||||
|
op.add_column(
|
||||||
|
"message",
|
||||||
|
sa.Column(
|
||||||
|
"is_pinned",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
default=False,
|
||||||
|
server_default=sa.sql.expression.false(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
|
||||||
|
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("channel_member", "updated_at")
|
||||||
|
op.drop_column("channel_member", "last_read_at")
|
||||||
|
|
||||||
|
op.drop_column("channel_member", "meta")
|
||||||
|
op.drop_column("channel_member", "data")
|
||||||
|
|
||||||
|
op.drop_column("channel_member", "is_channel_pinned")
|
||||||
|
op.drop_column("channel_member", "is_channel_muted")
|
||||||
|
|
||||||
|
op.drop_column("message", "pinned_by")
|
||||||
|
op.drop_column("message", "pinned_at")
|
||||||
|
op.drop_column("message", "is_pinned")
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""add_group_member_table
|
||||||
|
|
||||||
|
Revision ID: 37f288994c47
|
||||||
|
Revises: a5c220713937
|
||||||
|
Create Date: 2025-11-17 03:45:25.123939
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "37f288994c47"
|
||||||
|
down_revision: Union[str, None] = "a5c220713937"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# 1. Create new table
|
||||||
|
op.create_table(
|
||||||
|
"group_member",
|
||||||
|
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"group_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("group.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||||
|
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# 2. Read existing group with user_ids JSON column
|
||||||
|
group_table = sa.Table(
|
||||||
|
"group",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
|
||||||
|
)
|
||||||
|
|
||||||
|
results = connection.execute(
|
||||||
|
sa.select(group_table.c.id, group_table.c.user_ids)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
# 3. Insert members into group_member table
|
||||||
|
gm_table = sa.Table(
|
||||||
|
"group_member",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("group_id", sa.Text()),
|
||||||
|
sa.Column("user_id", sa.Text()),
|
||||||
|
sa.Column("created_at", sa.BigInteger()),
|
||||||
|
sa.Column("updated_at", sa.BigInteger()),
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
for group_id, user_ids in results:
|
||||||
|
if not user_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(user_ids, str):
|
||||||
|
try:
|
||||||
|
user_ids = json.loads(user_ids)
|
||||||
|
except Exception:
|
||||||
|
continue # skip invalid JSON
|
||||||
|
|
||||||
|
if not isinstance(user_ids, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"group_id": group_id,
|
||||||
|
"user_id": uid,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
for uid in user_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
connection.execute(gm_table.insert(), rows)
|
||||||
|
|
||||||
|
# 4. Optionally drop the old column
|
||||||
|
with op.batch_alter_table("group") as batch:
|
||||||
|
batch.drop_column("user_ids")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Reverse: restore user_ids column
|
||||||
|
with op.batch_alter_table("group") as batch:
|
||||||
|
batch.add_column(sa.Column("user_ids", sa.JSON()))
|
||||||
|
|
||||||
|
connection = op.get_bind()
|
||||||
|
gm_table = sa.Table(
|
||||||
|
"group_member",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("group_id", sa.Text()),
|
||||||
|
sa.Column("user_id", sa.Text()),
|
||||||
|
sa.Column("created_at", sa.BigInteger()),
|
||||||
|
sa.Column("updated_at", sa.BigInteger()),
|
||||||
|
)
|
||||||
|
|
||||||
|
group_table = sa.Table(
|
||||||
|
"group",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("user_ids", sa.JSON()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build JSON arrays again
|
||||||
|
results = connection.execute(sa.select(group_table.c.id)).fetchall()
|
||||||
|
|
||||||
|
for (group_id,) in results:
|
||||||
|
members = connection.execute(
|
||||||
|
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
member_ids = [m[0] for m in members]
|
||||||
|
|
||||||
|
connection.execute(
|
||||||
|
group_table.update()
|
||||||
|
.where(group_table.c.id == group_id)
|
||||||
|
.values(user_ids=member_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop the new table
|
||||||
|
op.drop_table("group_member")
|
||||||
|
|
@ -20,18 +20,46 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
|
||||||
|
inspector = sa.inspect(op.get_bind())
|
||||||
|
columns = inspector.get_columns("user")
|
||||||
|
|
||||||
|
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
|
||||||
|
id_column = next((col for col in columns if col["name"] == "id"), None)
|
||||||
|
|
||||||
|
if id_column and not id_column.get("unique", False):
|
||||||
|
unique_constraints = inspector.get_unique_constraints("user")
|
||||||
|
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
|
||||||
|
|
||||||
|
with op.batch_alter_table("user") as batch_op:
|
||||||
|
# If primary key is wrong, drop it
|
||||||
|
if pk_columns and pk_columns != ["id"]:
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
inspector.get_pk_constraint("user")["name"], type_="primary"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add unique constraint if missing
|
||||||
|
if ("id",) not in unique_columns:
|
||||||
|
batch_op.create_unique_constraint("uq_user_id", ["id"])
|
||||||
|
|
||||||
|
# Re-create correct primary key
|
||||||
|
batch_op.create_primary_key("pk_user_id", ["id"])
|
||||||
|
|
||||||
# Create oauth_session table
|
# Create oauth_session table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"oauth_session",
|
"oauth_session",
|
||||||
sa.Column("id", sa.Text(), nullable=False),
|
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
sa.Column("provider", sa.Text(), nullable=False),
|
sa.Column("provider", sa.Text(), nullable=False),
|
||||||
sa.Column("token", sa.Text(), nullable=False),
|
sa.Column("token", sa.Text(), nullable=False),
|
||||||
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create indexes for better performance
|
# Create indexes for better performance
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
"""Add knowledge_file table
|
||||||
|
|
||||||
|
Revision ID: 3e0e00844bb0
|
||||||
|
Revises: 90ef40d4714e
|
||||||
|
Create Date: 2025-12-02 06:54:19.401334
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
import open_webui.internal.db
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "3e0e00844bb0"
|
||||||
|
down_revision: Union[str, None] = "90ef40d4714e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"knowledge_file",
|
||||||
|
sa.Column("id", sa.Text(), primary_key=True),
|
||||||
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"knowledge_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"file_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
|
# indexes
|
||||||
|
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
|
||||||
|
sa.Index("ix_knowledge_file_file_id", "file_id"),
|
||||||
|
sa.Index("ix_knowledge_file_user_id", "user_id"),
|
||||||
|
# unique constraints
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
||||||
|
), # prevent duplicate entries
|
||||||
|
)
|
||||||
|
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# 2. Read existing group with user_ids JSON column
|
||||||
|
knowledge_table = sa.Table(
|
||||||
|
"knowledge",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("user_id", sa.Text()),
|
||||||
|
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
|
||||||
|
)
|
||||||
|
|
||||||
|
results = connection.execute(
|
||||||
|
sa.select(
|
||||||
|
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
|
||||||
|
)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
# 3. Insert members into group_member table
|
||||||
|
kf_table = sa.Table(
|
||||||
|
"knowledge_file",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("user_id", sa.Text()),
|
||||||
|
sa.Column("knowledge_id", sa.Text()),
|
||||||
|
sa.Column("file_id", sa.Text()),
|
||||||
|
sa.Column("created_at", sa.BigInteger()),
|
||||||
|
sa.Column("updated_at", sa.BigInteger()),
|
||||||
|
)
|
||||||
|
|
||||||
|
file_table = sa.Table(
|
||||||
|
"file",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
for knowledge_id, user_id, data in results:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
continue # skip invalid JSON
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_ids = data.get("file_ids", [])
|
||||||
|
|
||||||
|
for file_id in file_ids:
|
||||||
|
file_exists = connection.execute(
|
||||||
|
sa.select(file_table.c.id).where(file_table.c.id == file_id)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not file_exists:
|
||||||
|
continue # skip non-existing files
|
||||||
|
|
||||||
|
row = {
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"user_id": user_id,
|
||||||
|
"knowledge_id": knowledge_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
connection.execute(kf_table.insert().values(**row))
|
||||||
|
|
||||||
|
with op.batch_alter_table("knowledge") as batch:
|
||||||
|
batch.drop_column("data")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# 1. Add back the old data column
|
||||||
|
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# 2. Read knowledge_file entries and reconstruct data JSON
|
||||||
|
knowledge_table = sa.Table(
|
||||||
|
"knowledge",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("data", sa.JSON()),
|
||||||
|
)
|
||||||
|
|
||||||
|
kf_table = sa.Table(
|
||||||
|
"knowledge_file",
|
||||||
|
sa.MetaData(),
|
||||||
|
sa.Column("id", sa.Text()),
|
||||||
|
sa.Column("knowledge_id", sa.Text()),
|
||||||
|
sa.Column("file_id", sa.Text()),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
|
||||||
|
|
||||||
|
for (knowledge_id,) in results:
|
||||||
|
file_ids = connection.execute(
|
||||||
|
sa.select(kf_table.c.file_id).where(kf_table.c.knowledge_id == knowledge_id)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
file_ids_list = [fid for (fid,) in file_ids]
|
||||||
|
|
||||||
|
data_json = {"file_ids": file_ids_list}
|
||||||
|
|
||||||
|
connection.execute(
|
||||||
|
knowledge_table.update()
|
||||||
|
.where(knowledge_table.c.id == knowledge_id)
|
||||||
|
.values(data=data_json)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Drop the knowledge_file table
|
||||||
|
op.drop_table("knowledge_file")
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""Update channel and channel members table
|
||||||
|
|
||||||
|
Revision ID: 90ef40d4714e
|
||||||
|
Revises: b10670c03dd5
|
||||||
|
Create Date: 2025-11-30 06:33:38.790341
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import open_webui.internal.db
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "90ef40d4714e"
|
||||||
|
down_revision: Union[str, None] = "b10670c03dd5"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Update 'channel' table
|
||||||
|
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
|
||||||
|
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
|
||||||
|
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# Update 'channel_member' table
|
||||||
|
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
|
||||||
|
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
|
||||||
|
op.add_column(
|
||||||
|
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create 'channel_webhook' table
|
||||||
|
op.create_table(
|
||||||
|
"channel_webhook",
|
||||||
|
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"channel_id",
|
||||||
|
sa.Text(),
|
||||||
|
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
||||||
|
sa.Column("token", sa.Text(), nullable=False),
|
||||||
|
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Downgrade 'channel' table
|
||||||
|
op.drop_column("channel", "is_private")
|
||||||
|
op.drop_column("channel", "archived_at")
|
||||||
|
op.drop_column("channel", "archived_by")
|
||||||
|
op.drop_column("channel", "deleted_at")
|
||||||
|
op.drop_column("channel", "deleted_by")
|
||||||
|
op.drop_column("channel", "updated_by")
|
||||||
|
|
||||||
|
# Downgrade 'channel_member' table
|
||||||
|
op.drop_column("channel_member", "role")
|
||||||
|
op.drop_column("channel_member", "invited_by")
|
||||||
|
op.drop_column("channel_member", "invited_at")
|
||||||
|
|
||||||
|
# Drop 'channel_webhook' table
|
||||||
|
op.drop_table("channel_webhook")
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,251 @@
|
||||||
|
"""Update user table
|
||||||
|
|
||||||
|
Revision ID: b10670c03dd5
|
||||||
|
Revises: 2f1211949ecc
|
||||||
|
Create Date: 2025-11-28 04:55:31.737538
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
import open_webui.internal.db
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b10670c03dd5"
|
||||||
|
down_revision: Union[str, None] = "2f1211949ecc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
|
||||||
|
"""
|
||||||
|
SQLite requires manual removal of any indexes referencing a column
|
||||||
|
before ALTER TABLE ... DROP COLUMN can succeed.
|
||||||
|
"""
|
||||||
|
indexes = conn.execute(sa.text(f"PRAGMA index_list('{table_name}')")).fetchall()
|
||||||
|
|
||||||
|
for idx in indexes:
|
||||||
|
index_name = idx[1] # index name
|
||||||
|
# Get indexed columns
|
||||||
|
idx_info = conn.execute(
|
||||||
|
sa.text(f"PRAGMA index_info('{index_name}')")
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
indexed_cols = [row[2] for row in idx_info] # col names
|
||||||
|
if column_name in indexed_cols:
|
||||||
|
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_column_to_json(table: str, column: str):
|
||||||
|
conn = op.get_bind()
|
||||||
|
dialect = conn.dialect.name
|
||||||
|
|
||||||
|
# SQLite cannot ALTER COLUMN → must recreate column
|
||||||
|
if dialect == "sqlite":
|
||||||
|
# 1. Add temporary column
|
||||||
|
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
# 2. Load old data
|
||||||
|
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
uid, raw = row
|
||||||
|
if raw is None:
|
||||||
|
parsed = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
parsed = None # fallback safe behavior
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
||||||
|
{"val": json.dumps(parsed) if parsed else None, "id": uid},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Drop old TEXT column
|
||||||
|
op.drop_column(table, column)
|
||||||
|
|
||||||
|
# 4. Rename new JSON column → original name
|
||||||
|
op.alter_column(table, f"{column}_json", new_column_name=column)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# PostgreSQL supports direct CAST
|
||||||
|
op.alter_column(
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
type_=sa.JSON(),
|
||||||
|
postgresql_using=f"{column}::json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_column_to_text(table: str, column: str):
|
||||||
|
conn = op.get_bind()
|
||||||
|
dialect = conn.dialect.name
|
||||||
|
|
||||||
|
if dialect == "sqlite":
|
||||||
|
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||||
|
|
||||||
|
for uid, raw in rows:
|
||||||
|
conn.execute(
|
||||||
|
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
||||||
|
{"val": json.dumps(raw) if raw else None, "id": uid},
|
||||||
|
)
|
||||||
|
|
||||||
|
op.drop_column(table, column)
|
||||||
|
op.alter_column(table, f"{column}_text", new_column_name=column)
|
||||||
|
|
||||||
|
else:
|
||||||
|
op.alter_column(
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
type_=sa.Text(),
|
||||||
|
postgresql_using=f"to_json({column})::text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
|
||||||
|
)
|
||||||
|
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
|
||||||
|
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
|
||||||
|
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
|
||||||
|
op.add_column(
|
||||||
|
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
# Convert info (TEXT/JSONField) → JSON
|
||||||
|
_convert_column_to_json("user", "info")
|
||||||
|
# Convert settings (TEXT/JSONField) → JSON
|
||||||
|
_convert_column_to_json("user", "settings")
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"api_key",
|
||||||
|
sa.Column("id", sa.Text(), primary_key=True, unique=True),
|
||||||
|
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
|
||||||
|
sa.Column("key", sa.Text(), unique=True, nullable=False),
|
||||||
|
sa.Column("data", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = op.get_bind()
|
||||||
|
users = conn.execute(
|
||||||
|
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for uid, oauth_sub in users:
|
||||||
|
if oauth_sub:
|
||||||
|
# Example formats supported:
|
||||||
|
# provider@sub
|
||||||
|
# plain sub (stored as {"oidc": {"sub": sub}})
|
||||||
|
if "@" in oauth_sub:
|
||||||
|
provider, sub = oauth_sub.split("@", 1)
|
||||||
|
else:
|
||||||
|
provider, sub = "oidc", oauth_sub
|
||||||
|
|
||||||
|
oauth_json = json.dumps({provider: {"sub": sub}})
|
||||||
|
conn.execute(
|
||||||
|
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
||||||
|
{"oauth": oauth_json, "id": uid},
|
||||||
|
)
|
||||||
|
|
||||||
|
users_with_keys = conn.execute(
|
||||||
|
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
|
||||||
|
).fetchall()
|
||||||
|
now = int(time.time())
|
||||||
|
|
||||||
|
for uid, api_key in users_with_keys:
|
||||||
|
if api_key:
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
INSERT INTO api_key (id, user_id, key, created_at, updated_at)
|
||||||
|
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": f"key_{uid}",
|
||||||
|
"user_id": uid,
|
||||||
|
"key": api_key,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if conn.dialect.name == "sqlite":
|
||||||
|
_drop_sqlite_indexes_for_column("user", "api_key", conn)
|
||||||
|
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
|
||||||
|
|
||||||
|
with op.batch_alter_table("user") as batch_op:
|
||||||
|
batch_op.drop_column("api_key")
|
||||||
|
batch_op.drop_column("oauth_sub")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# --- 1. Restore old oauth_sub column ---
|
||||||
|
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
conn = op.get_bind()
|
||||||
|
users = conn.execute(
|
||||||
|
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for uid, oauth in users:
|
||||||
|
try:
|
||||||
|
data = json.loads(oauth)
|
||||||
|
provider = list(data.keys())[0]
|
||||||
|
sub = data[provider].get("sub")
|
||||||
|
oauth_sub = f"{provider}@{sub}"
|
||||||
|
except Exception:
|
||||||
|
oauth_sub = None
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
||||||
|
{"oauth_sub": oauth_sub, "id": uid},
|
||||||
|
)
|
||||||
|
|
||||||
|
op.drop_column("user", "oauth")
|
||||||
|
|
||||||
|
# --- 2. Restore api_key field ---
|
||||||
|
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Restore values from api_key
|
||||||
|
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
|
||||||
|
for uid, key in keys:
|
||||||
|
conn.execute(
|
||||||
|
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
||||||
|
{"key": key, "id": uid},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop new table
|
||||||
|
op.drop_table("api_key")
|
||||||
|
|
||||||
|
with op.batch_alter_table("user") as batch_op:
|
||||||
|
batch_op.drop_column("profile_banner_image_url")
|
||||||
|
batch_op.drop_column("timezone")
|
||||||
|
|
||||||
|
batch_op.drop_column("presence_state")
|
||||||
|
batch_op.drop_column("status_emoji")
|
||||||
|
batch_op.drop_column("status_message")
|
||||||
|
batch_op.drop_column("status_expires_at")
|
||||||
|
|
||||||
|
# Convert info (JSON) → TEXT
|
||||||
|
_convert_column_to_text("user", "info")
|
||||||
|
# Convert settings (JSON) → TEXT
|
||||||
|
_convert_column_to_text("user", "settings")
|
||||||
|
|
@ -3,11 +3,10 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.users import UserModel, Users
|
from open_webui.models.users import UserModel, UserProfileImageResponse, Users
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Boolean, Column, String, Text
|
from sqlalchemy import Boolean, Column, String, Text
|
||||||
from open_webui.utils.auth import verify_password
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
@ -20,7 +19,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
class Auth(Base):
|
class Auth(Base):
|
||||||
__tablename__ = "auth"
|
__tablename__ = "auth"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
email = Column(String)
|
email = Column(String)
|
||||||
password = Column(Text)
|
password = Column(Text)
|
||||||
active = Column(Boolean)
|
active = Column(Boolean)
|
||||||
|
|
@ -47,15 +46,7 @@ class ApiKey(BaseModel):
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class SigninResponse(Token, UserProfileImageResponse):
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
name: str
|
|
||||||
role: str
|
|
||||||
profile_image_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class SigninResponse(Token, UserResponse):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -97,7 +88,7 @@ class AuthsTable:
|
||||||
name: str,
|
name: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth_sub: Optional[str] = None,
|
oauth: Optional[dict] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
log.info("insert_new_auth")
|
log.info("insert_new_auth")
|
||||||
|
|
@ -111,7 +102,7 @@ class AuthsTable:
|
||||||
db.add(result)
|
db.add(result)
|
||||||
|
|
||||||
user = Users.insert_new_user(
|
user = Users.insert_new_user(
|
||||||
id, name, email, profile_image_url, role, oauth_sub
|
id, name, email, profile_image_url, role, oauth=oauth
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
@ -122,7 +113,9 @@ class AuthsTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
def authenticate_user(
|
||||||
|
self, email: str, verify_password: callable
|
||||||
|
) -> Optional[UserModel]:
|
||||||
log.info(f"authenticate_user: {email}")
|
log.info(f"authenticate_user: {email}")
|
||||||
|
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
|
|
@ -133,7 +126,7 @@ class AuthsTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||||
if auth:
|
if auth:
|
||||||
if verify_password(password, auth.password):
|
if verify_password(auth.password):
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,13 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast
|
||||||
from sqlalchemy import or_, func, select, and_, text
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
|
|
||||||
|
|
@ -19,19 +22,30 @@ from sqlalchemy.sql import exists
|
||||||
class Channel(Base):
|
class Channel(Base):
|
||||||
__tablename__ = "channel"
|
__tablename__ = "channel"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
type = Column(Text, nullable=True)
|
type = Column(Text, nullable=True)
|
||||||
|
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Used to indicate if the channel is private (for 'group' type channels)
|
||||||
|
is_private = Column(Boolean, nullable=True)
|
||||||
|
|
||||||
data = Column(JSON, nullable=True)
|
data = Column(JSON, nullable=True)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
access_control = Column(JSON, nullable=True)
|
access_control = Column(JSON, nullable=True)
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
|
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
updated_by = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
archived_at = Column(BigInteger, nullable=True)
|
||||||
|
archived_by = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
deleted_at = Column(BigInteger, nullable=True)
|
||||||
|
deleted_by = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class ChannelModel(BaseModel):
|
class ChannelModel(BaseModel):
|
||||||
|
|
@ -39,17 +53,122 @@ class ChannelModel(BaseModel):
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
is_private: Optional[bool] = None
|
||||||
|
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch (time_ns)
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
|
updated_at: int # timestamp in epoch (time_ns)
|
||||||
|
updated_by: Optional[str] = None
|
||||||
|
|
||||||
|
archived_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
archived_by: Optional[str] = None
|
||||||
|
|
||||||
|
deleted_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
deleted_by: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelMember(Base):
|
||||||
|
__tablename__ = "channel_member"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
|
channel_id = Column(Text, nullable=False)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
role = Column(Text, nullable=True)
|
||||||
|
status = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
is_active = Column(Boolean, nullable=False, default=True)
|
||||||
|
|
||||||
|
is_channel_muted = Column(Boolean, nullable=False, default=False)
|
||||||
|
is_channel_pinned = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
data = Column(JSON, nullable=True)
|
||||||
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
invited_at = Column(BigInteger, nullable=True)
|
||||||
|
invited_by = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
joined_at = Column(BigInteger)
|
||||||
|
left_at = Column(BigInteger, nullable=True)
|
||||||
|
|
||||||
|
last_read_at = Column(BigInteger, nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(BigInteger)
|
||||||
|
updated_at = Column(BigInteger)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelMemberModel(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
channel_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
role: Optional[str] = None
|
||||||
|
status: Optional[str] = None
|
||||||
|
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
is_channel_muted: bool = False
|
||||||
|
is_channel_pinned: bool = False
|
||||||
|
|
||||||
|
data: Optional[dict] = None
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
|
invited_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
invited_by: Optional[str] = None
|
||||||
|
|
||||||
|
joined_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
left_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
created_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
updated_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelWebhook(Base):
|
||||||
|
__tablename__ = "channel_webhook"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
|
channel_id = Column(Text, nullable=False)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
name = Column(Text, nullable=False)
|
||||||
|
profile_image_url = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
token = Column(Text, nullable=False)
|
||||||
|
last_used_at = Column(BigInteger, nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(BigInteger, nullable=False)
|
||||||
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelWebhookModel(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
channel_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
name: str
|
||||||
|
profile_image_url: Optional[str] = None
|
||||||
|
|
||||||
|
token: str
|
||||||
|
last_used_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
created_at: int # timestamp in epoch (time_ns)
|
||||||
|
updated_at: int # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -58,26 +177,94 @@ class ChannelModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChannelResponse(ChannelModel):
|
class ChannelResponse(ChannelModel):
|
||||||
|
is_manager: bool = False
|
||||||
write_access: bool = False
|
write_access: bool = False
|
||||||
|
|
||||||
|
user_count: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class ChannelForm(BaseModel):
|
class ChannelForm(BaseModel):
|
||||||
name: str
|
name: str = ""
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
is_private: Optional[bool] = None
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
group_ids: Optional[list[str]] = None
|
||||||
|
user_ids: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChannelForm(ChannelForm):
|
||||||
|
type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChannelTable:
|
class ChannelTable:
|
||||||
|
|
||||||
|
def _collect_unique_user_ids(
|
||||||
|
self,
|
||||||
|
invited_by: str,
|
||||||
|
user_ids: Optional[list[str]] = None,
|
||||||
|
group_ids: Optional[list[str]] = None,
|
||||||
|
) -> set[str]:
|
||||||
|
"""
|
||||||
|
Collect unique user ids from:
|
||||||
|
- invited_by
|
||||||
|
- user_ids
|
||||||
|
- each group in group_ids
|
||||||
|
Returns a set for efficient SQL diffing.
|
||||||
|
"""
|
||||||
|
users = set(user_ids or [])
|
||||||
|
users.add(invited_by)
|
||||||
|
|
||||||
|
for group_id in group_ids or []:
|
||||||
|
users.update(Groups.get_group_user_ids_by_id(group_id))
|
||||||
|
|
||||||
|
return users
|
||||||
|
|
||||||
|
def _create_membership_models(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
invited_by: str,
|
||||||
|
user_ids: set[str],
|
||||||
|
) -> list[ChannelMember]:
|
||||||
|
"""
|
||||||
|
Takes a set of NEW user IDs (already filtered to exclude existing members).
|
||||||
|
Returns ORM ChannelMember objects to be added.
|
||||||
|
"""
|
||||||
|
now = int(time.time_ns())
|
||||||
|
memberships = []
|
||||||
|
|
||||||
|
for uid in user_ids:
|
||||||
|
model = ChannelMemberModel(
|
||||||
|
**{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"user_id": uid,
|
||||||
|
"status": "joined",
|
||||||
|
"is_active": True,
|
||||||
|
"is_channel_muted": False,
|
||||||
|
"is_channel_pinned": False,
|
||||||
|
"invited_at": now,
|
||||||
|
"invited_by": invited_by,
|
||||||
|
"joined_at": now,
|
||||||
|
"left_at": None,
|
||||||
|
"last_read_at": now,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
memberships.append(ChannelMember(**model.model_dump()))
|
||||||
|
|
||||||
|
return memberships
|
||||||
|
|
||||||
def insert_new_channel(
|
def insert_new_channel(
|
||||||
self, type: Optional[str], form_data: ChannelForm, user_id: str
|
self, form_data: CreateChannelForm, user_id: str
|
||||||
) -> Optional[ChannelModel]:
|
) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
channel = ChannelModel(
|
channel = ChannelModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
"type": type,
|
"type": form_data.type if form_data.type else None,
|
||||||
"name": form_data.name.lower(),
|
"name": form_data.name.lower(),
|
||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
@ -85,9 +272,21 @@ class ChannelTable:
|
||||||
"updated_at": int(time.time_ns()),
|
"updated_at": int(time.time_ns()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
new_channel = Channel(**channel.model_dump())
|
new_channel = Channel(**channel.model_dump())
|
||||||
|
|
||||||
|
if form_data.type in ["group", "dm"]:
|
||||||
|
users = self._collect_unique_user_ids(
|
||||||
|
invited_by=user_id,
|
||||||
|
user_ids=form_data.user_ids,
|
||||||
|
group_ids=form_data.group_ids,
|
||||||
|
)
|
||||||
|
memberships = self._create_membership_models(
|
||||||
|
channel_id=new_channel.id,
|
||||||
|
invited_by=user_id,
|
||||||
|
user_ids=users,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add_all(memberships)
|
||||||
db.add(new_channel)
|
db.add(new_channel)
|
||||||
db.commit()
|
db.commit()
|
||||||
return channel
|
return channel
|
||||||
|
|
@ -97,16 +296,346 @@ class ChannelTable:
|
||||||
channels = db.query(Channel).all()
|
channels = db.query(Channel).all()
|
||||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||||
|
|
||||||
def get_channels_by_user_id(
|
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||||
self, user_id: str, permission: str = "read"
|
group_ids = filter.get("group_ids", [])
|
||||||
) -> list[ChannelModel]:
|
user_id = filter.get("user_id")
|
||||||
channels = self.get_channels()
|
|
||||||
return [
|
dialect_name = db.bind.dialect.name
|
||||||
channel
|
|
||||||
for channel in channels
|
# Public access
|
||||||
if channel.user_id == user_id
|
conditions = []
|
||||||
or has_access(user_id, permission, channel.access_control)
|
if group_ids or user_id:
|
||||||
|
conditions.extend(
|
||||||
|
[
|
||||||
|
Channel.access_control.is_(None),
|
||||||
|
cast(Channel.access_control, String) == "null",
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User-level permission
|
||||||
|
if user_id:
|
||||||
|
conditions.append(Channel.user_id == user_id)
|
||||||
|
|
||||||
|
# Group-level permission
|
||||||
|
if group_ids:
|
||||||
|
group_conditions = []
|
||||||
|
for gid in group_ids:
|
||||||
|
if dialect_name == "sqlite":
|
||||||
|
group_conditions.append(
|
||||||
|
Channel.access_control[permission]["group_ids"].contains([gid])
|
||||||
|
)
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
group_conditions.append(
|
||||||
|
cast(
|
||||||
|
Channel.access_control[permission]["group_ids"],
|
||||||
|
JSONB,
|
||||||
|
).contains([gid])
|
||||||
|
)
|
||||||
|
conditions.append(or_(*group_conditions))
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query = query.filter(or_(*conditions))
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
user_group_ids = [
|
||||||
|
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||||
|
]
|
||||||
|
|
||||||
|
membership_channels = (
|
||||||
|
db.query(Channel)
|
||||||
|
.join(ChannelMember, Channel.id == ChannelMember.channel_id)
|
||||||
|
.filter(
|
||||||
|
Channel.deleted_at.is_(None),
|
||||||
|
Channel.archived_at.is_(None),
|
||||||
|
Channel.type.in_(["group", "dm"]),
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
ChannelMember.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
query = db.query(Channel).filter(
|
||||||
|
Channel.deleted_at.is_(None),
|
||||||
|
Channel.archived_at.is_(None),
|
||||||
|
or_(
|
||||||
|
Channel.type.is_(None), # True NULL/None
|
||||||
|
Channel.type == "", # Empty string
|
||||||
|
and_(Channel.type != "group", Channel.type != "dm"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
query = self._has_permission(
|
||||||
|
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
||||||
|
)
|
||||||
|
|
||||||
|
standard_channels = query.all()
|
||||||
|
|
||||||
|
all_channels = membership_channels + standard_channels
|
||||||
|
return [ChannelModel.model_validate(c) for c in all_channels]
|
||||||
|
|
||||||
|
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
# Ensure uniqueness in case a list with duplicates is passed
|
||||||
|
unique_user_ids = list(set(user_ids))
|
||||||
|
|
||||||
|
match_count = func.sum(
|
||||||
|
case(
|
||||||
|
(ChannelMember.user_id.in_(unique_user_ids), 1),
|
||||||
|
else_=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
subquery = (
|
||||||
|
db.query(ChannelMember.channel_id)
|
||||||
|
.group_by(ChannelMember.channel_id)
|
||||||
|
# 1. Channel must have exactly len(user_ids) members
|
||||||
|
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
|
||||||
|
# 2. All those members must be in unique_user_ids
|
||||||
|
.having(match_count == len(unique_user_ids))
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = (
|
||||||
|
db.query(Channel)
|
||||||
|
.filter(
|
||||||
|
Channel.id.in_(subquery),
|
||||||
|
Channel.type == "dm",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChannelModel.model_validate(channel) if channel else None
|
||||||
|
|
||||||
|
def add_members_to_channel(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
invited_by: str,
|
||||||
|
user_ids: Optional[list[str]] = None,
|
||||||
|
group_ids: Optional[list[str]] = None,
|
||||||
|
) -> list[ChannelMemberModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
# 1. Collect all user_ids including groups + inviter
|
||||||
|
requested_users = self._collect_unique_user_ids(
|
||||||
|
invited_by, user_ids, group_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_users = {
|
||||||
|
row.user_id
|
||||||
|
for row in db.query(ChannelMember.user_id)
|
||||||
|
.filter(ChannelMember.channel_id == channel_id)
|
||||||
|
.all()
|
||||||
|
}
|
||||||
|
|
||||||
|
new_user_ids = requested_users - existing_users
|
||||||
|
if not new_user_ids:
|
||||||
|
return [] # Nothing to add
|
||||||
|
|
||||||
|
new_memberships = self._create_membership_models(
|
||||||
|
channel_id, invited_by, new_user_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add_all(new_memberships)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ChannelMemberModel.model_validate(membership)
|
||||||
|
for membership in new_memberships
|
||||||
|
]
|
||||||
|
|
||||||
|
def remove_members_from_channel(
|
||||||
|
self,
|
||||||
|
channel_id: str,
|
||||||
|
user_ids: list[str],
|
||||||
|
) -> int:
|
||||||
|
with get_db() as db:
|
||||||
|
result = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id.in_(user_ids),
|
||||||
|
)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
return result # number of rows deleted
|
||||||
|
|
||||||
|
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
# Check if the user is the creator of the channel
|
||||||
|
# or has a 'manager' role in ChannelMember
|
||||||
|
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
||||||
|
if channel and channel.user_id == user_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
ChannelMember.role == "manager",
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return membership is not None
|
||||||
|
|
||||||
|
def join_channel(
|
||||||
|
self, channel_id: str, user_id: str
|
||||||
|
) -> Optional[ChannelMemberModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
# Check if the membership already exists
|
||||||
|
existing_membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing_membership:
|
||||||
|
return ChannelMemberModel.model_validate(existing_membership)
|
||||||
|
|
||||||
|
# Create new membership
|
||||||
|
channel_member = ChannelMemberModel(
|
||||||
|
**{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"status": "joined",
|
||||||
|
"is_active": True,
|
||||||
|
"is_channel_muted": False,
|
||||||
|
"is_channel_pinned": False,
|
||||||
|
"joined_at": int(time.time_ns()),
|
||||||
|
"left_at": None,
|
||||||
|
"last_read_at": int(time.time_ns()),
|
||||||
|
"created_at": int(time.time_ns()),
|
||||||
|
"updated_at": int(time.time_ns()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
new_membership = ChannelMember(**channel_member.model_dump())
|
||||||
|
|
||||||
|
db.add(new_membership)
|
||||||
|
db.commit()
|
||||||
|
return channel_member
|
||||||
|
|
||||||
|
def leave_channel(self, channel_id: str, user_id: str) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not membership:
|
||||||
|
return False
|
||||||
|
|
||||||
|
membership.status = "left"
|
||||||
|
membership.is_active = False
|
||||||
|
membership.left_at = int(time.time_ns())
|
||||||
|
membership.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_member_by_channel_and_user_id(
|
||||||
|
self, channel_id: str, user_id: str
|
||||||
|
) -> Optional[ChannelMemberModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||||
|
|
||||||
|
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
memberships = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(ChannelMember.channel_id == channel_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ChannelMemberModel.model_validate(membership)
|
||||||
|
for membership in memberships
|
||||||
|
]
|
||||||
|
|
||||||
|
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not membership:
|
||||||
|
return False
|
||||||
|
|
||||||
|
membership.is_channel_pinned = is_pinned
|
||||||
|
membership.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not membership:
|
||||||
|
return False
|
||||||
|
|
||||||
|
membership.last_read_at = int(time.time_ns())
|
||||||
|
membership.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_member_active_status(
|
||||||
|
self, channel_id: str, user_id: str, is_active: bool
|
||||||
|
) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not membership:
|
||||||
|
return False
|
||||||
|
|
||||||
|
membership.is_active = is_active
|
||||||
|
membership.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
membership = (
|
||||||
|
db.query(ChannelMember)
|
||||||
|
.filter(
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
ChannelMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return membership is not None
|
||||||
|
|
||||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -122,8 +651,12 @@ class ChannelTable:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
channel.name = form_data.name
|
channel.name = form_data.name
|
||||||
|
channel.description = form_data.description
|
||||||
|
channel.is_private = form_data.is_private
|
||||||
|
|
||||||
channel.data = form_data.data
|
channel.data = form_data.data
|
||||||
channel.meta = form_data.meta
|
channel.meta = form_data.meta
|
||||||
|
|
||||||
channel.access_control = form_data.access_control
|
channel.access_control = form_data.access_control
|
||||||
channel.updated_at = int(time.time_ns())
|
channel.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
class Chat(Base):
|
class Chat(Base):
|
||||||
__tablename__ = "chat"
|
__tablename__ = "chat"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
title = Column(Text)
|
title = Column(Text)
|
||||||
chat = Column(JSON)
|
chat = Column(JSON)
|
||||||
|
|
@ -92,6 +92,10 @@ class ChatImportForm(ChatForm):
|
||||||
updated_at: Optional[int] = None
|
updated_at: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatsImportForm(BaseModel):
|
||||||
|
chats: list[ChatImportForm]
|
||||||
|
|
||||||
|
|
||||||
class ChatTitleMessagesForm(BaseModel):
|
class ChatTitleMessagesForm(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
|
@ -123,6 +127,43 @@ class ChatTitleIdResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatTable:
|
class ChatTable:
|
||||||
|
def _clean_null_bytes(self, obj):
|
||||||
|
"""
|
||||||
|
Recursively remove actual null bytes (\x00) and unicode escape \\u0000
|
||||||
|
from strings inside dict/list structures.
|
||||||
|
Safe for JSON objects.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return obj.replace("\x00", "").replace("\u0000", "")
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: self._clean_null_bytes(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self._clean_null_bytes(v) for v in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def _sanitize_chat_row(self, chat_item):
|
||||||
|
"""
|
||||||
|
Clean a Chat SQLAlchemy model's title + chat JSON,
|
||||||
|
and return True if anything changed.
|
||||||
|
"""
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
# Clean title
|
||||||
|
if chat_item.title:
|
||||||
|
cleaned = self._clean_null_bytes(chat_item.title)
|
||||||
|
if cleaned != chat_item.title:
|
||||||
|
chat_item.title = cleaned
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
# Clean JSON
|
||||||
|
if chat_item.chat:
|
||||||
|
cleaned = self._clean_null_bytes(chat_item.chat)
|
||||||
|
if cleaned != chat_item.chat:
|
||||||
|
chat_item.chat = cleaned
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return changed
|
||||||
|
|
||||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
@ -130,68 +171,76 @@ class ChatTable:
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"title": (
|
"title": self._clean_null_bytes(
|
||||||
form_data.chat["title"]
|
form_data.chat["title"]
|
||||||
if "title" in form_data.chat
|
if "title" in form_data.chat
|
||||||
else "New Chat"
|
else "New Chat"
|
||||||
),
|
),
|
||||||
"chat": form_data.chat,
|
"chat": self._clean_null_bytes(form_data.chat),
|
||||||
"folder_id": form_data.folder_id,
|
"folder_id": form_data.folder_id,
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Chat(**chat.model_dump())
|
chat_item = Chat(**chat.model_dump())
|
||||||
db.add(result)
|
db.add(chat_item)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(chat_item)
|
||||||
return ChatModel.model_validate(result) if result else None
|
return ChatModel.model_validate(chat_item) if chat_item else None
|
||||||
|
|
||||||
def import_chat(
|
def _chat_import_form_to_chat_model(
|
||||||
self, user_id: str, form_data: ChatImportForm
|
self, user_id: str, form_data: ChatImportForm
|
||||||
) -> Optional[ChatModel]:
|
) -> ChatModel:
|
||||||
with get_db() as db:
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
chat = ChatModel(
|
chat = ChatModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"title": (
|
"title": self._clean_null_bytes(
|
||||||
form_data.chat["title"]
|
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
||||||
if "title" in form_data.chat
|
|
||||||
else "New Chat"
|
|
||||||
),
|
),
|
||||||
"chat": form_data.chat,
|
"chat": self._clean_null_bytes(form_data.chat),
|
||||||
"meta": form_data.meta,
|
"meta": form_data.meta,
|
||||||
"pinned": form_data.pinned,
|
"pinned": form_data.pinned,
|
||||||
"folder_id": form_data.folder_id,
|
"folder_id": form_data.folder_id,
|
||||||
"created_at": (
|
"created_at": (
|
||||||
form_data.created_at
|
form_data.created_at if form_data.created_at else int(time.time())
|
||||||
if form_data.created_at
|
|
||||||
else int(time.time())
|
|
||||||
),
|
),
|
||||||
"updated_at": (
|
"updated_at": (
|
||||||
form_data.updated_at
|
form_data.updated_at if form_data.updated_at else int(time.time())
|
||||||
if form_data.updated_at
|
|
||||||
else int(time.time())
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
return chat
|
||||||
|
|
||||||
result = Chat(**chat.model_dump())
|
def import_chats(
|
||||||
db.add(result)
|
self, user_id: str, chat_import_forms: list[ChatImportForm]
|
||||||
|
) -> list[ChatModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
chats = []
|
||||||
|
|
||||||
|
for form_data in chat_import_forms:
|
||||||
|
chat = self._chat_import_form_to_chat_model(user_id, form_data)
|
||||||
|
chats.append(Chat(**chat.model_dump()))
|
||||||
|
|
||||||
|
db.add_all(chats)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
return [ChatModel.model_validate(chat) for chat in chats]
|
||||||
return ChatModel.model_validate(result) if result else None
|
|
||||||
|
|
||||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat_item = db.get(Chat, id)
|
chat_item = db.get(Chat, id)
|
||||||
chat_item.chat = chat
|
chat_item.chat = self._clean_null_bytes(chat)
|
||||||
chat_item.title = chat["title"] if "title" in chat else "New Chat"
|
chat_item.title = (
|
||||||
|
self._clean_null_bytes(chat["title"])
|
||||||
|
if "title" in chat
|
||||||
|
else "New Chat"
|
||||||
|
)
|
||||||
|
|
||||||
chat_item.updated_at = int(time.time())
|
chat_item.updated_at = int(time.time())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat_item)
|
db.refresh(chat_item)
|
||||||
|
|
||||||
|
|
@ -297,6 +346,27 @@ class ChatTable:
|
||||||
chat["history"] = history
|
chat["history"] = history
|
||||||
return self.update_chat_by_id(id, chat)
|
return self.update_chat_by_id(id, chat)
|
||||||
|
|
||||||
|
def add_message_files_by_id_and_message_id(
|
||||||
|
self, id: str, message_id: str, files: list[dict]
|
||||||
|
) -> list[dict]:
|
||||||
|
chat = self.get_chat_by_id(id)
|
||||||
|
if chat is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chat = chat.chat
|
||||||
|
history = chat.get("history", {})
|
||||||
|
|
||||||
|
message_files = []
|
||||||
|
|
||||||
|
if message_id in history.get("messages", {}):
|
||||||
|
message_files = history["messages"][message_id].get("files", [])
|
||||||
|
message_files = message_files + files
|
||||||
|
history["messages"][message_id]["files"] = message_files
|
||||||
|
|
||||||
|
chat["history"] = history
|
||||||
|
self.update_chat_by_id(id, chat)
|
||||||
|
return message_files
|
||||||
|
|
||||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# Get the existing chat to share
|
# Get the existing chat to share
|
||||||
|
|
@ -405,6 +475,7 @@ class ChatTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat = db.get(Chat, id)
|
chat = db.get(Chat, id)
|
||||||
chat.archived = not chat.archived
|
chat.archived = not chat.archived
|
||||||
|
chat.folder_id = None
|
||||||
chat.updated_at = int(time.time())
|
chat.updated_at = int(time.time())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(chat)
|
db.refresh(chat)
|
||||||
|
|
@ -440,7 +511,10 @@ class ChatTable:
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get("order_by")
|
||||||
direction = filter.get("direction")
|
direction = filter.get("direction")
|
||||||
|
|
||||||
if order_by and direction and getattr(Chat, order_by):
|
if order_by and direction:
|
||||||
|
if not getattr(Chat, order_by, None):
|
||||||
|
raise ValueError("Invalid order_by field")
|
||||||
|
|
||||||
if direction.lower() == "asc":
|
if direction.lower() == "asc":
|
||||||
query = query.order_by(getattr(Chat, order_by).asc())
|
query = query.order_by(getattr(Chat, order_by).asc())
|
||||||
elif direction.lower() == "desc":
|
elif direction.lower() == "desc":
|
||||||
|
|
@ -502,6 +576,7 @@ class ChatTable:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
include_archived: bool = False,
|
include_archived: bool = False,
|
||||||
include_folders: bool = False,
|
include_folders: bool = False,
|
||||||
|
include_pinned: bool = False,
|
||||||
skip: Optional[int] = None,
|
skip: Optional[int] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> list[ChatTitleIdResponse]:
|
) -> list[ChatTitleIdResponse]:
|
||||||
|
|
@ -511,6 +586,7 @@ class ChatTable:
|
||||||
if not include_folders:
|
if not include_folders:
|
||||||
query = query.filter_by(folder_id=None)
|
query = query.filter_by(folder_id=None)
|
||||||
|
|
||||||
|
if not include_pinned:
|
||||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||||
|
|
||||||
if not include_archived:
|
if not include_archived:
|
||||||
|
|
@ -556,8 +632,15 @@ class ChatTable:
|
||||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
chat = db.get(Chat, id)
|
chat_item = db.get(Chat, id)
|
||||||
return ChatModel.model_validate(chat)
|
if chat_item is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._sanitize_chat_row(chat_item):
|
||||||
|
db.commit()
|
||||||
|
db.refresh(chat_item)
|
||||||
|
|
||||||
|
return ChatModel.model_validate(chat_item)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -760,21 +843,32 @@ class ChatTable:
|
||||||
)
|
)
|
||||||
|
|
||||||
elif dialect_name == "postgresql":
|
elif dialect_name == "postgresql":
|
||||||
# PostgreSQL relies on proper JSON query for search
|
# PostgreSQL doesn't allow null bytes in text. We filter those out by checking
|
||||||
postgres_content_sql = (
|
# the JSON representation for \u0000 before attempting text extraction
|
||||||
"EXISTS ("
|
|
||||||
" SELECT 1 "
|
# Safety filter: JSON field must not contain \u0000
|
||||||
" FROM json_array_elements(Chat.chat->'messages') AS message "
|
query = query.filter(text("Chat.chat::text NOT LIKE '%\\\\u0000%'"))
|
||||||
" WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'"
|
|
||||||
")"
|
# Safety filter: title must not contain actual null bytes
|
||||||
|
query = query.filter(text("Chat.title::text NOT LIKE '%\\x00%'"))
|
||||||
|
|
||||||
|
postgres_content_sql = """
|
||||||
|
EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM json_array_elements(Chat.chat->'messages') AS message
|
||||||
|
WHERE json_typeof(message->'content') = 'string'
|
||||||
|
AND LOWER(message->>'content') LIKE '%' || :content_key || '%'
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
postgres_content_clause = text(postgres_content_sql)
|
postgres_content_clause = text(postgres_content_sql)
|
||||||
|
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Chat.title.ilike(bindparam("title_key")),
|
Chat.title.ilike(bindparam("title_key")),
|
||||||
postgres_content_clause,
|
postgres_content_clause,
|
||||||
).params(title_key=f"%{search_text}%", content_key=search_text)
|
|
||||||
)
|
)
|
||||||
|
).params(title_key=f"%{search_text}%", content_key=search_text.lower())
|
||||||
|
|
||||||
# Check if there are any tags to filter, it should have all the tags
|
# Check if there are any tags to filter, it should have all the tags
|
||||||
if "none" in tag_ids:
|
if "none" in tag_ids:
|
||||||
|
|
@ -1049,6 +1143,20 @@ class ChatTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def move_chats_by_user_id_and_folder_id(
|
||||||
|
self, user_id: str, folder_id: str, new_folder_id: Optional[str]
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update(
|
||||||
|
{"folder_id": new_folder_id}
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.users import User
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
@ -21,7 +21,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
class Feedback(Base):
|
class Feedback(Base):
|
||||||
__tablename__ = "feedback"
|
__tablename__ = "feedback"
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
version = Column(BigInteger, default=0)
|
version = Column(BigInteger, default=0)
|
||||||
type = Column(Text)
|
type = Column(Text)
|
||||||
|
|
@ -92,6 +92,28 @@ class FeedbackForm(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
email: str
|
||||||
|
role: str = "pending"
|
||||||
|
|
||||||
|
last_active_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackUserResponse(FeedbackResponse):
|
||||||
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackListResponse(BaseModel):
|
||||||
|
items: list[FeedbackUserResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class FeedbackTable:
|
class FeedbackTable:
|
||||||
def insert_new_feedback(
|
def insert_new_feedback(
|
||||||
self, user_id: str, form_data: FeedbackForm
|
self, user_id: str, form_data: FeedbackForm
|
||||||
|
|
@ -143,6 +165,70 @@ class FeedbackTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_feedback_items(
|
||||||
|
self, filter: dict = {}, skip: int = 0, limit: int = 30
|
||||||
|
) -> FeedbackListResponse:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
order_by = filter.get("order_by")
|
||||||
|
direction = filter.get("direction")
|
||||||
|
|
||||||
|
if order_by == "username":
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(User.name.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(User.name.desc())
|
||||||
|
elif order_by == "model_id":
|
||||||
|
# it's stored in feedback.data['model_id']
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(
|
||||||
|
Feedback.data["model_id"].as_string().asc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.order_by(
|
||||||
|
Feedback.data["model_id"].as_string().desc()
|
||||||
|
)
|
||||||
|
elif order_by == "rating":
|
||||||
|
# it's stored in feedback.data['rating']
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(
|
||||||
|
Feedback.data["rating"].as_string().asc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.order_by(
|
||||||
|
Feedback.data["rating"].as_string().desc()
|
||||||
|
)
|
||||||
|
elif order_by == "updated_at":
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(Feedback.updated_at.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(Feedback.updated_at.desc())
|
||||||
|
|
||||||
|
else:
|
||||||
|
query = query.order_by(Feedback.created_at.desc())
|
||||||
|
|
||||||
|
# Count BEFORE pagination
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
if skip:
|
||||||
|
query = query.offset(skip)
|
||||||
|
if limit:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
items = query.all()
|
||||||
|
|
||||||
|
feedbacks = []
|
||||||
|
for feedback, user in items:
|
||||||
|
feedback_model = FeedbackModel.model_validate(feedback)
|
||||||
|
user_model = UserResponse.model_validate(user)
|
||||||
|
feedbacks.append(
|
||||||
|
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
return FeedbackListResponse(items=feedbacks, total=total)
|
||||||
|
|
||||||
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
class File(Base):
|
class File(Base):
|
||||||
__tablename__ = "file"
|
__tablename__ = "file"
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
hash = Column(Text, nullable=True)
|
hash = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
|
@ -82,6 +82,7 @@ class FileModelResponse(BaseModel):
|
||||||
|
|
||||||
class FileMetadataResponse(BaseModel):
|
class FileMetadataResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
hash: Optional[str] = None
|
||||||
meta: dict
|
meta: dict
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
@ -97,6 +98,12 @@ class FileForm(BaseModel):
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileUpdateForm(BaseModel):
|
||||||
|
hash: Optional[str] = None
|
||||||
|
data: Optional[dict] = None
|
||||||
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class FilesTable:
|
class FilesTable:
|
||||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -147,6 +154,7 @@ class FilesTable:
|
||||||
file = db.get(File, id)
|
file = db.get(File, id)
|
||||||
return FileMetadataResponse(
|
return FileMetadataResponse(
|
||||||
id=file.id,
|
id=file.id,
|
||||||
|
hash=file.hash,
|
||||||
meta=file.meta,
|
meta=file.meta,
|
||||||
created_at=file.created_at,
|
created_at=file.created_at,
|
||||||
updated_at=file.updated_at,
|
updated_at=file.updated_at,
|
||||||
|
|
@ -182,11 +190,14 @@ class FilesTable:
|
||||||
return [
|
return [
|
||||||
FileMetadataResponse(
|
FileMetadataResponse(
|
||||||
id=file.id,
|
id=file.id,
|
||||||
|
hash=file.hash,
|
||||||
meta=file.meta,
|
meta=file.meta,
|
||||||
created_at=file.created_at,
|
created_at=file.created_at,
|
||||||
updated_at=file.updated_at,
|
updated_at=file.updated_at,
|
||||||
)
|
)
|
||||||
for file in db.query(File)
|
for file in db.query(
|
||||||
|
File.id, File.hash, File.meta, File.created_at, File.updated_at
|
||||||
|
)
|
||||||
.filter(File.id.in_(ids))
|
.filter(File.id.in_(ids))
|
||||||
.order_by(File.updated_at.desc())
|
.order_by(File.updated_at.desc())
|
||||||
.all()
|
.all()
|
||||||
|
|
@ -199,6 +210,29 @@ class FilesTable:
|
||||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def update_file_by_id(
|
||||||
|
self, id: str, form_data: FileUpdateForm
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
try:
|
||||||
|
file = db.query(File).filter_by(id=id).first()
|
||||||
|
|
||||||
|
if form_data.hash is not None:
|
||||||
|
file.hash = form_data.hash
|
||||||
|
|
||||||
|
if form_data.data is not None:
|
||||||
|
file.data = {**(file.data if file.data else {}), **form_data.data}
|
||||||
|
|
||||||
|
if form_data.meta is not None:
|
||||||
|
file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
|
||||||
|
|
||||||
|
file.updated_at = int(time.time())
|
||||||
|
db.commit()
|
||||||
|
return FileModel.model_validate(file)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error updating file completely by id: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
||||||
class Folder(Base):
|
class Folder(Base):
|
||||||
__tablename__ = "folder"
|
__tablename__ = "folder"
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
parent_id = Column(Text, nullable=True)
|
parent_id = Column(Text, nullable=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, JSONField, get_db
|
from open_webui.internal.db import Base, JSONField, get_db
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users, UserModel
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
||||||
|
|
@ -19,7 +19,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
class Function(Base):
|
class Function(Base):
|
||||||
__tablename__ = "function"
|
__tablename__ = "function"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
type = Column(Text)
|
type = Column(Text)
|
||||||
|
|
@ -76,6 +76,10 @@ class FunctionWithValvesModel(BaseModel):
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionUserResponse(FunctionModel):
|
||||||
|
user: Optional[UserModel] = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionResponse(BaseModel):
|
class FunctionResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
@ -203,6 +207,28 @@ class FunctionsTable:
|
||||||
FunctionModel.model_validate(function) for function in functions
|
FunctionModel.model_validate(function) for function in functions
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_function_list(self) -> list[FunctionUserResponse]:
|
||||||
|
with get_db() as db:
|
||||||
|
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
|
||||||
|
user_ids = list(set(func.user_id for func in functions))
|
||||||
|
|
||||||
|
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||||
|
users_dict = {user.id: user for user in users}
|
||||||
|
|
||||||
|
return [
|
||||||
|
FunctionUserResponse.model_validate(
|
||||||
|
{
|
||||||
|
**FunctionModel.model_validate(func).model_dump(),
|
||||||
|
"user": (
|
||||||
|
users_dict.get(func.user_id).model_dump()
|
||||||
|
if func.user_id in users_dict
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for func in functions
|
||||||
|
]
|
||||||
|
|
||||||
def get_functions_by_type(
|
def get_functions_by_type(
|
||||||
self, type: str, active_only=False
|
self, type: str, active_only=False
|
||||||
) -> list[FunctionModel]:
|
) -> list[FunctionModel]:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
JSON,
|
||||||
|
and_,
|
||||||
|
func,
|
||||||
|
ForeignKey,
|
||||||
|
cast,
|
||||||
|
or_,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -35,14 +46,12 @@ class Group(Base):
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
permissions = Column(JSON, nullable=True)
|
permissions = Column(JSON, nullable=True)
|
||||||
user_ids = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
|
||||||
|
|
||||||
class GroupModel(BaseModel):
|
class GroupModel(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
@ -53,44 +62,64 @@ class GroupModel(BaseModel):
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
user_ids: list[str] = []
|
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupMember(Base):
|
||||||
|
__tablename__ = "group_member"
|
||||||
|
|
||||||
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
|
group_id = Column(
|
||||||
|
Text,
|
||||||
|
ForeignKey("group.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
created_at = Column(BigInteger, nullable=True)
|
||||||
|
updated_at = Column(BigInteger, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupMemberModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
group_id: str
|
||||||
|
user_id: str
|
||||||
|
created_at: Optional[int] = None # timestamp in epoch
|
||||||
|
updated_at: Optional[int] = None # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class GroupResponse(BaseModel):
|
class GroupResponse(GroupModel):
|
||||||
id: str
|
member_count: Optional[int] = None
|
||||||
user_id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
permissions: Optional[dict] = None
|
|
||||||
data: Optional[dict] = None
|
|
||||||
meta: Optional[dict] = None
|
|
||||||
user_ids: list[str] = []
|
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
|
|
||||||
class GroupForm(BaseModel):
|
class GroupForm(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
|
data: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class UserIdsForm(BaseModel):
|
class UserIdsForm(BaseModel):
|
||||||
user_ids: Optional[list[str]] = None
|
user_ids: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class GroupUpdateForm(GroupForm, UserIdsForm):
|
class GroupUpdateForm(GroupForm):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GroupListResponse(BaseModel):
|
||||||
|
items: list[GroupResponse] = []
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
class GroupTable:
|
class GroupTable:
|
||||||
def insert_new_group(
|
def insert_new_group(
|
||||||
self, user_id: str, form_data: GroupForm
|
self, user_id: str, form_data: GroupForm
|
||||||
|
|
@ -119,24 +148,94 @@ class GroupTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_groups(self) -> list[GroupModel]:
|
def get_all_groups(self) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||||
|
return [GroupModel.model_validate(group) for group in groups]
|
||||||
|
|
||||||
|
def get_groups(self, filter) -> list[GroupResponse]:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Group)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
if "query" in filter:
|
||||||
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||||
|
if "member_id" in filter:
|
||||||
|
query = query.join(
|
||||||
|
GroupMember, GroupMember.group_id == Group.id
|
||||||
|
).filter(GroupMember.user_id == filter["member_id"])
|
||||||
|
|
||||||
|
if "share" in filter:
|
||||||
|
share_value = filter["share"]
|
||||||
|
json_share = Group.data["config"]["share"].as_boolean()
|
||||||
|
|
||||||
|
if share_value:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
Group.data.is_(None),
|
||||||
|
json_share.is_(None),
|
||||||
|
json_share == True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.filter(
|
||||||
|
and_(Group.data.isnot(None), json_share == False)
|
||||||
|
)
|
||||||
|
groups = query.order_by(Group.updated_at.desc()).all()
|
||||||
return [
|
return [
|
||||||
GroupModel.model_validate(group)
|
GroupResponse.model_validate(
|
||||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
{
|
||||||
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
|
"member_count": self.get_group_member_count_by_id(group.id),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def search_groups(
|
||||||
|
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
||||||
|
) -> GroupListResponse:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Group)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
if "query" in filter:
|
||||||
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||||
|
if "member_id" in filter:
|
||||||
|
query = query.join(
|
||||||
|
GroupMember, GroupMember.group_id == Group.id
|
||||||
|
).filter(GroupMember.user_id == filter["member_id"])
|
||||||
|
|
||||||
|
if "share" in filter:
|
||||||
|
# 'share' is stored in data JSON, support both sqlite and postgres
|
||||||
|
share_value = filter["share"]
|
||||||
|
print("Filtering by share:", share_value)
|
||||||
|
query = query.filter(
|
||||||
|
Group.data.op("->>")("share") == str(share_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
query = query.order_by(Group.updated_at.desc())
|
||||||
|
groups = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"items": [
|
||||||
|
GroupResponse.model_validate(
|
||||||
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
|
member_count=self.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
|
||||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return [
|
return [
|
||||||
GroupModel.model_validate(group)
|
GroupModel.model_validate(group)
|
||||||
for group in db.query(Group)
|
for group in db.query(Group)
|
||||||
.filter(
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||||
func.json_array_length(Group.user_ids) > 0
|
.filter(GroupMember.user_id == user_id)
|
||||||
) # Ensure array exists
|
|
||||||
.filter(
|
|
||||||
Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
|
||||||
) # String-based check
|
|
||||||
.order_by(Group.updated_at.desc())
|
.order_by(Group.updated_at.desc())
|
||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
@ -149,13 +248,64 @@ class GroupTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
||||||
group = self.get_group_by_id(id)
|
with get_db() as db:
|
||||||
if group:
|
members = (
|
||||||
return group.user_ids
|
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||||
else:
|
)
|
||||||
|
|
||||||
|
if not members:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return [m[0] for m in members]
|
||||||
|
|
||||||
|
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
|
||||||
|
with get_db() as db:
|
||||||
|
members = (
|
||||||
|
db.query(GroupMember.group_id, GroupMember.user_id)
|
||||||
|
.filter(GroupMember.group_id.in_(group_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
group_user_ids: dict[str, list[str]] = {
|
||||||
|
group_id: [] for group_id in group_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
for group_id, user_id in members:
|
||||||
|
group_user_ids[group_id].append(user_id)
|
||||||
|
|
||||||
|
return group_user_ids
|
||||||
|
|
||||||
|
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
||||||
|
with get_db() as db:
|
||||||
|
# Delete existing members
|
||||||
|
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
||||||
|
|
||||||
|
# Insert new members
|
||||||
|
now = int(time.time())
|
||||||
|
new_members = [
|
||||||
|
GroupMember(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
for user_id in user_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
db.add_all(new_members)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
def get_group_member_count_by_id(self, id: str) -> int:
|
||||||
|
with get_db() as db:
|
||||||
|
count = (
|
||||||
|
db.query(func.count(GroupMember.user_id))
|
||||||
|
.filter(GroupMember.group_id == id)
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
return count if count else 0
|
||||||
|
|
||||||
def update_group_by_id(
|
def update_group_by_id(
|
||||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||||
) -> Optional[GroupModel]:
|
) -> Optional[GroupModel]:
|
||||||
|
|
@ -195,20 +345,29 @@ class GroupTable:
|
||||||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
groups = self.get_groups_by_member_id(user_id)
|
# Find all groups the user belongs to
|
||||||
|
groups = (
|
||||||
for group in groups:
|
db.query(Group)
|
||||||
group.user_ids.remove(user_id)
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||||
db.query(Group).filter_by(id=group.id).update(
|
.filter(GroupMember.user_id == user_id)
|
||||||
{
|
.all()
|
||||||
"user_ids": group.user_ids,
|
|
||||||
"updated_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
db.commit()
|
|
||||||
|
|
||||||
|
# Remove the user from each group
|
||||||
|
for group in groups:
|
||||||
|
db.query(GroupMember).filter(
|
||||||
|
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
db.query(Group).filter_by(id=group.id).update(
|
||||||
|
{"updated_at": int(time.time())}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def create_groups_by_group_names(
|
def create_groups_by_group_names(
|
||||||
|
|
@ -216,7 +375,7 @@ class GroupTable:
|
||||||
) -> list[GroupModel]:
|
) -> list[GroupModel]:
|
||||||
|
|
||||||
# check for existing groups
|
# check for existing groups
|
||||||
existing_groups = self.get_groups()
|
existing_groups = self.get_all_groups()
|
||||||
existing_group_names = {group.name for group in existing_groups}
|
existing_group_names = {group.name for group in existing_groups}
|
||||||
|
|
||||||
new_groups = []
|
new_groups = []
|
||||||
|
|
@ -246,37 +405,61 @@ class GroupTable:
|
||||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
try:
|
try:
|
||||||
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
now = int(time.time())
|
||||||
group_ids = [group.id for group in groups]
|
|
||||||
|
|
||||||
# Remove user from groups not in the new list
|
# 1. Groups that SHOULD contain the user
|
||||||
existing_groups = self.get_groups_by_member_id(user_id)
|
target_groups = (
|
||||||
|
db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||||
|
)
|
||||||
|
target_group_ids = {g.id for g in target_groups}
|
||||||
|
|
||||||
for group in existing_groups:
|
# 2. Groups the user is CURRENTLY in
|
||||||
if group.id not in group_ids:
|
existing_group_ids = {
|
||||||
group.user_ids.remove(user_id)
|
g.id
|
||||||
db.query(Group).filter_by(id=group.id).update(
|
for g in db.query(Group)
|
||||||
{
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||||
"user_ids": group.user_ids,
|
.filter(GroupMember.user_id == user_id)
|
||||||
"updated_at": int(time.time()),
|
.all()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 3. Determine adds + removals
|
||||||
|
groups_to_add = target_group_ids - existing_group_ids
|
||||||
|
groups_to_remove = existing_group_ids - target_group_ids
|
||||||
|
|
||||||
|
# 4. Remove in one bulk delete
|
||||||
|
if groups_to_remove:
|
||||||
|
db.query(GroupMember).filter(
|
||||||
|
GroupMember.user_id == user_id,
|
||||||
|
GroupMember.group_id.in_(groups_to_remove),
|
||||||
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
|
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
||||||
|
{"updated_at": now}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user to new groups
|
# 5. Bulk insert missing memberships
|
||||||
for group in groups:
|
for group_id in groups_to_add:
|
||||||
if user_id not in group.user_ids:
|
db.add(
|
||||||
group.user_ids.append(user_id)
|
GroupMember(
|
||||||
db.query(Group).filter_by(id=group.id).update(
|
id=str(uuid.uuid4()),
|
||||||
{
|
group_id=group_id,
|
||||||
"user_ids": group.user_ids,
|
user_id=user_id,
|
||||||
"updated_at": int(time.time()),
|
created_at=now,
|
||||||
}
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if groups_to_add:
|
||||||
|
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
||||||
|
{"updated_at": now}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
db.rollback()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_users_to_group(
|
def add_users_to_group(
|
||||||
|
|
@ -288,21 +471,31 @@ class GroupTable:
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
group_user_ids = group.user_ids
|
now = int(time.time())
|
||||||
if not group_user_ids or not isinstance(group_user_ids, list):
|
|
||||||
group_user_ids = []
|
|
||||||
|
|
||||||
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
for user_id in user_ids or []:
|
||||||
|
try:
|
||||||
|
db.add(
|
||||||
|
GroupMember(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
group_id=id,
|
||||||
|
user_id=user_id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db.flush() # Detect unique constraint violation early
|
||||||
|
except Exception:
|
||||||
|
db.rollback() # Clear failed INSERT
|
||||||
|
db.begin() # Start a new transaction
|
||||||
|
continue # Duplicate → ignore
|
||||||
|
|
||||||
for user_id in user_ids:
|
group.updated_at = now
|
||||||
if user_id not in group_user_ids:
|
|
||||||
group_user_ids.append(user_id)
|
|
||||||
|
|
||||||
group.user_ids = group_user_ids
|
|
||||||
group.updated_at = int(time.time())
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(group)
|
db.refresh(group)
|
||||||
|
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
return None
|
return None
|
||||||
|
|
@ -316,23 +509,22 @@ class GroupTable:
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
group_user_ids = group.user_ids
|
if not user_ids:
|
||||||
|
|
||||||
if not group_user_ids or not isinstance(group_user_ids, list):
|
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
# Remove each user from group_member
|
||||||
|
|
||||||
for user_id in user_ids:
|
for user_id in user_ids:
|
||||||
if user_id in group_user_ids:
|
db.query(GroupMember).filter(
|
||||||
group_user_ids.remove(user_id)
|
GroupMember.group_id == id, GroupMember.user_id == user_id
|
||||||
|
).delete()
|
||||||
|
|
||||||
group.user_ids = group_user_ids
|
# Update group timestamp
|
||||||
group.updated_at = int(time.time())
|
group.updated_at = int(time.time())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(group)
|
db.refresh(group)
|
||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,21 @@ import uuid
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.files import FileMetadataResponse
|
from open_webui.models.files import File, FileModel, FileMetadataResponse
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
ForeignKey,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
JSON,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
|
|
@ -34,9 +42,7 @@ class Knowledge(Base):
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
|
|
||||||
data = Column(JSON, nullable=True)
|
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||||
# Defines access control rules for this entry.
|
# Defines access control rules for this entry.
|
||||||
# - `None`: Public access, available to all users with the "user" role.
|
# - `None`: Public access, available to all users with the "user" role.
|
||||||
|
|
@ -67,7 +73,6 @@ class KnowledgeModel(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
data: Optional[dict] = None
|
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
@ -76,11 +81,42 @@ class KnowledgeModel(BaseModel):
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeFile(Base):
|
||||||
|
__tablename__ = "knowledge_file"
|
||||||
|
|
||||||
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
|
|
||||||
|
knowledge_id = Column(
|
||||||
|
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
created_at = Column(BigInteger, nullable=False)
|
||||||
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeFileModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
knowledge_id: str
|
||||||
|
file_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeUserModel(KnowledgeModel):
|
class KnowledgeUserModel(KnowledgeModel):
|
||||||
user: Optional[UserResponse] = None
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
|
|
@ -96,7 +132,6 @@ class KnowledgeUserResponse(KnowledgeUserModel):
|
||||||
class KnowledgeForm(BaseModel):
|
class KnowledgeForm(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
data: Optional[dict] = None
|
|
||||||
access_control: Optional[dict] = None
|
access_control: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -182,6 +217,100 @@ class KnowledgeTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
knowledges = (
|
||||||
|
db.query(Knowledge)
|
||||||
|
.join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id)
|
||||||
|
.filter(KnowledgeFile.file_id == file_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
KnowledgeModel.model_validate(knowledge) for knowledge in knowledges
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_files_by_id(self, knowledge_id: str) -> list[FileModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
files = (
|
||||||
|
db.query(File)
|
||||||
|
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
||||||
|
.filter(KnowledgeFile.knowledge_id == knowledge_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [FileModel.model_validate(file) for file in files]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
files = self.get_files_by_id(knowledge_id)
|
||||||
|
return [FileMetadataResponse(**file.model_dump()) for file in files]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_file_to_knowledge_by_id(
|
||||||
|
self, knowledge_id: str, file_id: str, user_id: str
|
||||||
|
) -> Optional[KnowledgeFileModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
knowledge_file = KnowledgeFileModel(
|
||||||
|
**{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"knowledge_id": knowledge_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = KnowledgeFile(**knowledge_file.model_dump())
|
||||||
|
db.add(result)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(result)
|
||||||
|
if result:
|
||||||
|
return KnowledgeFileModel.model_validate(result)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(KnowledgeFile).filter_by(
|
||||||
|
knowledge_id=knowledge_id, file_id=file_id
|
||||||
|
).delete()
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
# Delete all knowledge_file entries for this knowledge_id
|
||||||
|
db.query(KnowledgeFile).filter_by(knowledge_id=id).delete()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Update the knowledge entry's updated_at timestamp
|
||||||
|
db.query(Knowledge).filter_by(id=id).update(
|
||||||
|
{
|
||||||
|
"updated_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return self.get_knowledge_by_id(id=id)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
return None
|
||||||
|
|
||||||
def update_knowledge_by_id(
|
def update_knowledge_by_id(
|
||||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
||||||
) -> Optional[KnowledgeModel]:
|
) -> Optional[KnowledgeModel]:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from sqlalchemy import BigInteger, Column, String, Text
|
||||||
class Memory(Base):
|
class Memory(Base):
|
||||||
__tablename__ = "memory"
|
__tablename__ = "memory"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ from typing import Optional
|
||||||
|
|
||||||
from open_webui.internal.db import Base, get_db
|
from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.tags import TagModel, Tag, Tags
|
from open_webui.models.tags import TagModel, Tag, Tags
|
||||||
from open_webui.models.users import Users, UserNameResponse
|
from open_webui.models.users import Users, User, UserNameResponse
|
||||||
|
from open_webui.models.channels import Channels, ChannelMember
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
@ -20,7 +21,7 @@ from sqlalchemy.sql import exists
|
||||||
|
|
||||||
class MessageReaction(Base):
|
class MessageReaction(Base):
|
||||||
__tablename__ = "message_reaction"
|
__tablename__ = "message_reaction"
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
message_id = Column(Text)
|
message_id = Column(Text)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
|
|
@ -39,7 +40,7 @@ class MessageReactionModel(BaseModel):
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
__tablename__ = "message"
|
__tablename__ = "message"
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
|
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
channel_id = Column(Text, nullable=True)
|
channel_id = Column(Text, nullable=True)
|
||||||
|
|
@ -47,6 +48,11 @@ class Message(Base):
|
||||||
reply_to_id = Column(Text, nullable=True)
|
reply_to_id = Column(Text, nullable=True)
|
||||||
parent_id = Column(Text, nullable=True)
|
parent_id = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Pins
|
||||||
|
is_pinned = Column(Boolean, nullable=False, default=False)
|
||||||
|
pinned_at = Column(BigInteger, nullable=True)
|
||||||
|
pinned_by = Column(Text, nullable=True)
|
||||||
|
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
data = Column(JSON, nullable=True)
|
data = Column(JSON, nullable=True)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
@ -65,12 +71,17 @@ class MessageModel(BaseModel):
|
||||||
reply_to_id: Optional[str] = None
|
reply_to_id: Optional[str] = None
|
||||||
parent_id: Optional[str] = None
|
parent_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Pins
|
||||||
|
is_pinned: bool = False
|
||||||
|
pinned_by: Optional[str] = None
|
||||||
|
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch (time_ns)
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch (time_ns)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
|
|
@ -79,6 +90,7 @@ class MessageModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MessageForm(BaseModel):
|
class MessageForm(BaseModel):
|
||||||
|
temp_id: Optional[str] = None
|
||||||
content: str
|
content: str
|
||||||
reply_to_id: Optional[str] = None
|
reply_to_id: Optional[str] = None
|
||||||
parent_id: Optional[str] = None
|
parent_id: Optional[str] = None
|
||||||
|
|
@ -88,7 +100,7 @@ class MessageForm(BaseModel):
|
||||||
|
|
||||||
class Reactions(BaseModel):
|
class Reactions(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
user_ids: list[str]
|
users: list[dict]
|
||||||
count: int
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -100,6 +112,10 @@ class MessageReplyToResponse(MessageUserResponse):
|
||||||
reply_to_message: Optional[MessageUserResponse] = None
|
reply_to_message: Optional[MessageUserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MessageWithReactionsResponse(MessageUserResponse):
|
||||||
|
reactions: list[Reactions]
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(MessageReplyToResponse):
|
class MessageResponse(MessageReplyToResponse):
|
||||||
latest_reply_at: Optional[int]
|
latest_reply_at: Optional[int]
|
||||||
reply_count: int
|
reply_count: int
|
||||||
|
|
@ -111,9 +127,11 @@ class MessageTable:
|
||||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||||
) -> Optional[MessageModel]:
|
) -> Optional[MessageModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
id = str(uuid.uuid4())
|
channel_member = Channels.join_channel(channel_id, user_id)
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
ts = int(time.time_ns())
|
ts = int(time.time_ns())
|
||||||
|
|
||||||
message = MessageModel(
|
message = MessageModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
|
|
@ -121,6 +139,9 @@ class MessageTable:
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
"reply_to_id": form_data.reply_to_id,
|
"reply_to_id": form_data.reply_to_id,
|
||||||
"parent_id": form_data.parent_id,
|
"parent_id": form_data.parent_id,
|
||||||
|
"is_pinned": False,
|
||||||
|
"pinned_at": None,
|
||||||
|
"pinned_by": None,
|
||||||
"content": form_data.content,
|
"content": form_data.content,
|
||||||
"data": form_data.data,
|
"data": form_data.data,
|
||||||
"meta": form_data.meta,
|
"meta": form_data.meta,
|
||||||
|
|
@ -128,8 +149,8 @@ class MessageTable:
|
||||||
"updated_at": ts,
|
"updated_at": ts,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Message(**message.model_dump())
|
result = Message(**message.model_dump())
|
||||||
|
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
|
|
@ -280,6 +301,30 @@ class MessageTable:
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
message = (
|
||||||
|
db.query(Message)
|
||||||
|
.filter_by(channel_id=channel_id)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return MessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
|
def get_pinned_messages_by_channel_id(
|
||||||
|
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||||
|
) -> list[MessageModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
all_messages = (
|
||||||
|
db.query(Message)
|
||||||
|
.filter_by(channel_id=channel_id, is_pinned=True)
|
||||||
|
.order_by(Message.pinned_at.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [MessageModel.model_validate(message) for message in all_messages]
|
||||||
|
|
||||||
def update_message_by_id(
|
def update_message_by_id(
|
||||||
self, id: str, form_data: MessageForm
|
self, id: str, form_data: MessageForm
|
||||||
) -> Optional[MessageModel]:
|
) -> Optional[MessageModel]:
|
||||||
|
|
@ -299,10 +344,44 @@ class MessageTable:
|
||||||
db.refresh(message)
|
db.refresh(message)
|
||||||
return MessageModel.model_validate(message) if message else None
|
return MessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
|
def update_is_pinned_by_id(
|
||||||
|
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None
|
||||||
|
) -> Optional[MessageModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
message = db.get(Message, id)
|
||||||
|
message.is_pinned = is_pinned
|
||||||
|
message.pinned_at = int(time.time_ns()) if is_pinned else None
|
||||||
|
message.pinned_by = pinned_by if is_pinned else None
|
||||||
|
db.commit()
|
||||||
|
db.refresh(message)
|
||||||
|
return MessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
|
def get_unread_message_count(
|
||||||
|
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None
|
||||||
|
) -> int:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Message).filter(
|
||||||
|
Message.channel_id == channel_id,
|
||||||
|
Message.parent_id == None, # only count top-level messages
|
||||||
|
Message.created_at > (last_read_at if last_read_at else 0),
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
query = query.filter(Message.user_id != user_id)
|
||||||
|
return query.count()
|
||||||
|
|
||||||
def add_reaction_to_message(
|
def add_reaction_to_message(
|
||||||
self, id: str, user_id: str, name: str
|
self, id: str, user_id: str, name: str
|
||||||
) -> Optional[MessageReactionModel]:
|
) -> Optional[MessageReactionModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
# check for existing reaction
|
||||||
|
existing_reaction = (
|
||||||
|
db.query(MessageReaction)
|
||||||
|
.filter_by(message_id=id, user_id=user_id, name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing_reaction:
|
||||||
|
return MessageReactionModel.model_validate(existing_reaction)
|
||||||
|
|
||||||
reaction_id = str(uuid.uuid4())
|
reaction_id = str(uuid.uuid4())
|
||||||
reaction = MessageReactionModel(
|
reaction = MessageReactionModel(
|
||||||
id=reaction_id,
|
id=reaction_id,
|
||||||
|
|
@ -319,17 +398,30 @@ class MessageTable:
|
||||||
|
|
||||||
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
|
# JOIN User so all user info is fetched in one query
|
||||||
|
results = (
|
||||||
|
db.query(MessageReaction, User)
|
||||||
|
.join(User, MessageReaction.user_id == User.id)
|
||||||
|
.filter(MessageReaction.message_id == id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
reactions = {}
|
reactions = {}
|
||||||
for reaction in all_reactions:
|
|
||||||
|
for reaction, user in results:
|
||||||
if reaction.name not in reactions:
|
if reaction.name not in reactions:
|
||||||
reactions[reaction.name] = {
|
reactions[reaction.name] = {
|
||||||
"name": reaction.name,
|
"name": reaction.name,
|
||||||
"user_ids": [],
|
"users": [],
|
||||||
"count": 0,
|
"count": 0,
|
||||||
}
|
}
|
||||||
reactions[reaction.name]["user_ids"].append(reaction.user_id)
|
|
||||||
|
reactions[reaction.name]["users"].append(
|
||||||
|
{
|
||||||
|
"id": user.id,
|
||||||
|
"name": user.name,
|
||||||
|
}
|
||||||
|
)
|
||||||
reactions[reaction.name]["count"] += 1
|
reactions[reaction.name]["count"] += 1
|
||||||
|
|
||||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,15 @@ from open_webui.internal.db import Base, JSONField, get_db
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.users import Users, UserResponse
|
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from sqlalchemy import or_, and_, func
|
from sqlalchemy import String, cast, or_, and_, func
|
||||||
from sqlalchemy.dialects import postgresql, sqlite
|
from sqlalchemy.dialects import postgresql, sqlite
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,7 +55,7 @@ class ModelMeta(BaseModel):
|
||||||
class Model(Base):
|
class Model(Base):
|
||||||
__tablename__ = "model"
|
__tablename__ = "model"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
"""
|
"""
|
||||||
The model's id as used in the API. If set to an existing model, it will override the model.
|
The model's id as used in the API. If set to an existing model, it will override the model.
|
||||||
"""
|
"""
|
||||||
|
|
@ -133,6 +135,11 @@ class ModelResponse(ModelModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
items: list[ModelUserResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class ModelForm(BaseModel):
|
class ModelForm(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
base_model_id: Optional[str] = None
|
base_model_id: Optional[str] = None
|
||||||
|
|
@ -215,6 +222,135 @@ class ModelsTable:
|
||||||
or has_access(user_id, permission, model.access_control, user_group_ids)
|
or has_access(user_id, permission, model.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||||
|
group_ids = filter.get("group_ids", [])
|
||||||
|
user_id = filter.get("user_id")
|
||||||
|
|
||||||
|
dialect_name = db.bind.dialect.name
|
||||||
|
|
||||||
|
# Public access
|
||||||
|
conditions = []
|
||||||
|
if group_ids or user_id:
|
||||||
|
conditions.extend(
|
||||||
|
[
|
||||||
|
Model.access_control.is_(None),
|
||||||
|
cast(Model.access_control, String) == "null",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User-level permission
|
||||||
|
if user_id:
|
||||||
|
conditions.append(Model.user_id == user_id)
|
||||||
|
|
||||||
|
# Group-level permission
|
||||||
|
if group_ids:
|
||||||
|
group_conditions = []
|
||||||
|
for gid in group_ids:
|
||||||
|
if dialect_name == "sqlite":
|
||||||
|
group_conditions.append(
|
||||||
|
Model.access_control[permission]["group_ids"].contains([gid])
|
||||||
|
)
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
group_conditions.append(
|
||||||
|
cast(
|
||||||
|
Model.access_control[permission]["group_ids"],
|
||||||
|
JSONB,
|
||||||
|
).contains([gid])
|
||||||
|
)
|
||||||
|
conditions.append(or_(*group_conditions))
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query = query.filter(or_(*conditions))
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
def search_models(
|
||||||
|
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
|
||||||
|
) -> ModelListResponse:
|
||||||
|
with get_db() as db:
|
||||||
|
# Join GroupMember so we can order by group_id when requested
|
||||||
|
query = db.query(Model, User).outerjoin(User, User.id == Model.user_id)
|
||||||
|
query = query.filter(Model.base_model_id != None)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
query_key = filter.get("query")
|
||||||
|
if query_key:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
Model.name.ilike(f"%{query_key}%"),
|
||||||
|
Model.base_model_id.ilike(f"%{query_key}%"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
view_option = filter.get("view_option")
|
||||||
|
if view_option == "created":
|
||||||
|
query = query.filter(Model.user_id == user_id)
|
||||||
|
elif view_option == "shared":
|
||||||
|
query = query.filter(Model.user_id != user_id)
|
||||||
|
|
||||||
|
# Apply access control filtering
|
||||||
|
query = self._has_permission(
|
||||||
|
db,
|
||||||
|
query,
|
||||||
|
filter,
|
||||||
|
permission="write",
|
||||||
|
)
|
||||||
|
|
||||||
|
tag = filter.get("tag")
|
||||||
|
if tag:
|
||||||
|
# TODO: This is a simple implementation and should be improved for performance
|
||||||
|
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
|
||||||
|
meta_text = func.lower(cast(Model.meta, String))
|
||||||
|
|
||||||
|
query = query.filter(meta_text.like(like_pattern))
|
||||||
|
|
||||||
|
order_by = filter.get("order_by")
|
||||||
|
direction = filter.get("direction")
|
||||||
|
|
||||||
|
if order_by == "name":
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(Model.name.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(Model.name.desc())
|
||||||
|
elif order_by == "created_at":
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(Model.created_at.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(Model.created_at.desc())
|
||||||
|
elif order_by == "updated_at":
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(Model.updated_at.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(Model.updated_at.desc())
|
||||||
|
|
||||||
|
else:
|
||||||
|
query = query.order_by(Model.created_at.desc())
|
||||||
|
|
||||||
|
# Count BEFORE pagination
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
if skip:
|
||||||
|
query = query.offset(skip)
|
||||||
|
if limit:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
items = query.all()
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for model, user in items:
|
||||||
|
models.append(
|
||||||
|
ModelUserResponse(
|
||||||
|
**ModelModel.model_validate(model).model_dump(),
|
||||||
|
user=(
|
||||||
|
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||||
|
if user
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelListResponse(items=models, total=total)
|
||||||
|
|
||||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -244,11 +380,9 @@ class ModelsTable:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
# update only the fields that are present in the model
|
# update only the fields that are present in the model
|
||||||
result = (
|
data = model.model_dump(exclude={"id"})
|
||||||
db.query(Model)
|
result = db.query(Model).filter_by(id=id).update(data)
|
||||||
.filter_by(id=id)
|
|
||||||
.update(model.model_dump(exclude={"id"}))
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
model = db.get(Model, id)
|
model = db.get(Model, id)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from sqlalchemy.sql import exists
|
||||||
class Note(Base):
|
class Note(Base):
|
||||||
__tablename__ = "note"
|
__tablename__ = "note"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
|
|
||||||
title = Column(Text)
|
title = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
class OAuthSession(Base):
|
class OAuthSession(Base):
|
||||||
__tablename__ = "oauth_session"
|
__tablename__ = "oauth_session"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
provider = Column(Text, nullable=False)
|
provider = Column(Text, nullable=False)
|
||||||
token = Column(
|
token = Column(
|
||||||
|
|
@ -262,5 +262,16 @@ class OAuthSessionTable:
|
||||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def delete_sessions_by_provider(self, provider: str) -> bool:
|
||||||
|
"""Delete all OAuth sessions for a provider"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(OAuthSession).filter_by(provider=provider).delete()
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
OAuthSessions = OAuthSessionTable()
|
OAuthSessions = OAuthSessionTable()
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
class Tool(Base):
|
class Tool(Base):
|
||||||
__tablename__ = "tool"
|
__tablename__ = "tool"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
name = Column(Text)
|
name = Column(Text)
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,28 @@ from open_webui.internal.db import Base, JSONField, get_db
|
||||||
|
|
||||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups, GroupMember
|
||||||
|
from open_webui.models.channels import ChannelMember
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.misc import throttle
|
from open_webui.utils.misc import throttle
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, Date
|
from sqlalchemy import (
|
||||||
from sqlalchemy import or_
|
BigInteger,
|
||||||
|
JSON,
|
||||||
|
Column,
|
||||||
|
String,
|
||||||
|
Boolean,
|
||||||
|
Text,
|
||||||
|
Date,
|
||||||
|
exists,
|
||||||
|
select,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
from sqlalchemy import or_, case
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
@ -21,59 +36,71 @@ import datetime
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
|
||||||
__tablename__ = "user"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
name = Column(String)
|
|
||||||
|
|
||||||
email = Column(String)
|
|
||||||
username = Column(String(50), nullable=True)
|
|
||||||
|
|
||||||
role = Column(String)
|
|
||||||
profile_image_url = Column(Text)
|
|
||||||
|
|
||||||
bio = Column(Text, nullable=True)
|
|
||||||
gender = Column(Text, nullable=True)
|
|
||||||
date_of_birth = Column(Date, nullable=True)
|
|
||||||
|
|
||||||
info = Column(JSONField, nullable=True)
|
|
||||||
settings = Column(JSONField, nullable=True)
|
|
||||||
|
|
||||||
api_key = Column(String, nullable=True, unique=True)
|
|
||||||
oauth_sub = Column(Text, unique=True)
|
|
||||||
|
|
||||||
last_active_at = Column(BigInteger)
|
|
||||||
|
|
||||||
updated_at = Column(BigInteger)
|
|
||||||
created_at = Column(BigInteger)
|
|
||||||
|
|
||||||
|
|
||||||
class UserSettings(BaseModel):
|
class UserSettings(BaseModel):
|
||||||
ui: Optional[dict] = {}
|
ui: Optional[dict] = {}
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "user"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, unique=True)
|
||||||
|
email = Column(String)
|
||||||
|
username = Column(String(50), nullable=True)
|
||||||
|
role = Column(String)
|
||||||
|
|
||||||
|
name = Column(String)
|
||||||
|
|
||||||
|
profile_image_url = Column(Text)
|
||||||
|
profile_banner_image_url = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
bio = Column(Text, nullable=True)
|
||||||
|
gender = Column(Text, nullable=True)
|
||||||
|
date_of_birth = Column(Date, nullable=True)
|
||||||
|
timezone = Column(String, nullable=True)
|
||||||
|
|
||||||
|
presence_state = Column(String, nullable=True)
|
||||||
|
status_emoji = Column(String, nullable=True)
|
||||||
|
status_message = Column(Text, nullable=True)
|
||||||
|
status_expires_at = Column(BigInteger, nullable=True)
|
||||||
|
|
||||||
|
info = Column(JSON, nullable=True)
|
||||||
|
settings = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
oauth = Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
last_active_at = Column(BigInteger)
|
||||||
|
updated_at = Column(BigInteger)
|
||||||
|
created_at = Column(BigInteger)
|
||||||
|
|
||||||
|
|
||||||
class UserModel(BaseModel):
|
class UserModel(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
|
||||||
|
|
||||||
email: str
|
email: str
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
|
|
||||||
role: str = "pending"
|
role: str = "pending"
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
profile_image_url: str
|
profile_image_url: str
|
||||||
|
profile_banner_image_url: Optional[str] = None
|
||||||
|
|
||||||
bio: Optional[str] = None
|
bio: Optional[str] = None
|
||||||
gender: Optional[str] = None
|
gender: Optional[str] = None
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
|
||||||
|
presence_state: Optional[str] = None
|
||||||
|
status_emoji: Optional[str] = None
|
||||||
|
status_message: Optional[str] = None
|
||||||
|
status_expires_at: Optional[int] = None
|
||||||
|
|
||||||
info: Optional[dict] = None
|
info: Optional[dict] = None
|
||||||
settings: Optional[UserSettings] = None
|
settings: Optional[UserSettings] = None
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
oauth: Optional[dict] = None
|
||||||
oauth_sub: Optional[str] = None
|
|
||||||
|
|
||||||
last_active_at: int # timestamp in epoch
|
last_active_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
@ -82,6 +109,38 @@ class UserModel(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatusModel(UserModel):
|
||||||
|
is_active: bool = False
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKey(Base):
|
||||||
|
__tablename__ = "api_key"
|
||||||
|
|
||||||
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
|
user_id = Column(Text, nullable=False)
|
||||||
|
key = Column(Text, unique=True, nullable=False)
|
||||||
|
data = Column(JSON, nullable=True)
|
||||||
|
expires_at = Column(BigInteger, nullable=True)
|
||||||
|
last_used_at = Column(BigInteger, nullable=True)
|
||||||
|
created_at = Column(BigInteger, nullable=False)
|
||||||
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyModel(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
key: str
|
||||||
|
data: Optional[dict] = None
|
||||||
|
expires_at: Optional[int] = None
|
||||||
|
last_used_at: Optional[int] = None
|
||||||
|
created_at: int # timestamp in epoch
|
||||||
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# Forms
|
# Forms
|
||||||
####################
|
####################
|
||||||
|
|
@ -95,12 +154,31 @@ class UpdateProfileForm(BaseModel):
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserGroupIdsModel(UserModel):
|
||||||
|
group_ids: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class UserModelResponse(UserModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class UserListResponse(BaseModel):
|
class UserListResponse(BaseModel):
|
||||||
users: list[UserModel]
|
users: list[UserModelResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class UserInfoResponse(BaseModel):
|
class UserGroupIdsListResponse(BaseModel):
|
||||||
|
users: list[UserGroupIdsModel]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatus(BaseModel):
|
||||||
|
status_emoji: Optional[str] = None
|
||||||
|
status_message: Optional[str] = None
|
||||||
|
status_expires_at: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfoResponse(UserStatus):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
email: str
|
email: str
|
||||||
|
|
@ -112,6 +190,12 @@ class UserIdNameResponse(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserIdNameStatusResponse(UserStatus):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class UserInfoListResponse(BaseModel):
|
class UserInfoListResponse(BaseModel):
|
||||||
users: list[UserInfoResponse]
|
users: list[UserInfoResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
@ -122,18 +206,18 @@ class UserIdNameListResponse(BaseModel):
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
email: str
|
|
||||||
role: str
|
|
||||||
profile_image_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserNameResponse(BaseModel):
|
class UserNameResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(UserNameResponse):
|
||||||
|
email: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfileImageResponse(UserNameResponse):
|
||||||
|
email: str
|
||||||
profile_image_url: str
|
profile_image_url: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -158,20 +242,20 @@ class UsersTable:
|
||||||
email: str,
|
email: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth_sub: Optional[str] = None,
|
oauth: Optional[dict] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = UserModel(
|
user = UserModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"name": name,
|
|
||||||
"email": email,
|
"email": email,
|
||||||
|
"name": name,
|
||||||
"role": role,
|
"role": role,
|
||||||
"profile_image_url": profile_image_url,
|
"profile_image_url": profile_image_url,
|
||||||
"last_active_at": int(time.time()),
|
"last_active_at": int(time.time()),
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
"oauth_sub": oauth_sub,
|
"oauth": oauth,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = User(**user.model_dump())
|
result = User(**user.model_dump())
|
||||||
|
|
@ -194,8 +278,13 @@ class UsersTable:
|
||||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(api_key=api_key).first()
|
user = (
|
||||||
return UserModel.model_validate(user)
|
db.query(User)
|
||||||
|
.join(ApiKey, User.id == ApiKey.user_id)
|
||||||
|
.filter(ApiKey.key == api_key)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return UserModel.model_validate(user) if user else None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -207,12 +296,23 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db: # type: Session
|
||||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
dialect_name = db.bind.dialect.name
|
||||||
return UserModel.model_validate(user)
|
|
||||||
except Exception:
|
query = db.query(User)
|
||||||
|
if dialect_name == "sqlite":
|
||||||
|
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
query = query.filter(
|
||||||
|
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
||||||
|
)
|
||||||
|
|
||||||
|
user = query.first()
|
||||||
|
return UserModel.model_validate(user) if user else None
|
||||||
|
except Exception as e:
|
||||||
|
# You may want to log the exception here
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_users(
|
def get_users(
|
||||||
|
|
@ -222,6 +322,7 @@ class UsersTable:
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
# Join GroupMember so we can order by group_id when requested
|
||||||
query = db.query(User)
|
query = db.query(User)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
|
|
@ -234,14 +335,76 @@ class UsersTable:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
channel_id = filter.get("channel_id")
|
||||||
|
if channel_id:
|
||||||
|
query = query.filter(
|
||||||
|
exists(
|
||||||
|
select(ChannelMember.id).where(
|
||||||
|
ChannelMember.user_id == User.id,
|
||||||
|
ChannelMember.channel_id == channel_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = filter.get("user_ids")
|
||||||
|
group_ids = filter.get("group_ids")
|
||||||
|
|
||||||
|
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
||||||
|
# If both are empty lists, return no users
|
||||||
|
if not user_ids and not group_ids:
|
||||||
|
return {"users": [], "total": 0}
|
||||||
|
|
||||||
|
if user_ids:
|
||||||
|
query = query.filter(User.id.in_(user_ids))
|
||||||
|
|
||||||
|
if group_ids:
|
||||||
|
query = query.filter(
|
||||||
|
exists(
|
||||||
|
select(GroupMember.id).where(
|
||||||
|
GroupMember.user_id == User.id,
|
||||||
|
GroupMember.group_id.in_(group_ids),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
roles = filter.get("roles")
|
||||||
|
if roles:
|
||||||
|
include_roles = [role for role in roles if not role.startswith("!")]
|
||||||
|
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
||||||
|
|
||||||
|
if include_roles:
|
||||||
|
query = query.filter(User.role.in_(include_roles))
|
||||||
|
if exclude_roles:
|
||||||
|
query = query.filter(~User.role.in_(exclude_roles))
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get("order_by")
|
||||||
direction = filter.get("direction")
|
direction = filter.get("direction")
|
||||||
|
|
||||||
if order_by == "name":
|
if order_by and order_by.startswith("group_id:"):
|
||||||
|
group_id = order_by.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Subquery that checks if the user belongs to the group
|
||||||
|
membership_exists = exists(
|
||||||
|
select(GroupMember.id).where(
|
||||||
|
GroupMember.user_id == User.id,
|
||||||
|
GroupMember.group_id == group_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# CASE: user in group → 1, user not in group → 0
|
||||||
|
group_sort = case((membership_exists, 1), else_=0)
|
||||||
|
|
||||||
|
if direction == "asc":
|
||||||
|
query = query.order_by(group_sort.asc(), User.name.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(group_sort.desc(), User.name.asc())
|
||||||
|
|
||||||
|
elif order_by == "name":
|
||||||
if direction == "asc":
|
if direction == "asc":
|
||||||
query = query.order_by(User.name.asc())
|
query = query.order_by(User.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.name.desc())
|
query = query.order_by(User.name.desc())
|
||||||
|
|
||||||
elif order_by == "email":
|
elif order_by == "email":
|
||||||
if direction == "asc":
|
if direction == "asc":
|
||||||
query = query.order_by(User.email.asc())
|
query = query.order_by(User.email.asc())
|
||||||
|
|
@ -274,18 +437,32 @@ class UsersTable:
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.created_at.desc())
|
query = query.order_by(User.created_at.desc())
|
||||||
|
|
||||||
if skip:
|
# Count BEFORE pagination
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# correct pagination logic
|
||||||
|
if skip is not None:
|
||||||
query = query.offset(skip)
|
query = query.offset(skip)
|
||||||
if limit:
|
if limit is not None:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
users = query.all()
|
users = query.all()
|
||||||
return {
|
return {
|
||||||
"users": [UserModel.model_validate(user) for user in users],
|
"users": [UserModel.model_validate(user) for user in users],
|
||||||
"total": db.query(User).count(),
|
"total": total,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
def get_users_by_group_id(self, group_id: str) -> list[UserModel]:
|
||||||
|
with get_db() as db:
|
||||||
|
users = (
|
||||||
|
db.query(User)
|
||||||
|
.join(GroupMember, User.id == GroupMember.user_id)
|
||||||
|
.filter(GroupMember.group_id == group_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [UserModel.model_validate(user) for user in users]
|
||||||
|
|
||||||
|
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||||
return [UserModel.model_validate(user) for user in users]
|
return [UserModel.model_validate(user) for user in users]
|
||||||
|
|
@ -322,6 +499,15 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_num_users_active_today(self) -> Optional[int]:
|
||||||
|
with get_db() as db:
|
||||||
|
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||||
|
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||||
|
query = db.query(User).filter(
|
||||||
|
User.last_active_at > today_midnight_timestamp
|
||||||
|
)
|
||||||
|
return query.count()
|
||||||
|
|
||||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
@ -332,6 +518,21 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def update_user_status_by_id(
|
||||||
|
self, id: str, form_data: UserStatus
|
||||||
|
) -> Optional[UserModel]:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(User).filter_by(id=id).update(
|
||||||
|
{**form_data.model_dump(exclude_none=True)}
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
user = db.query(User).filter_by(id=id).first()
|
||||||
|
return UserModel.model_validate(user)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def update_user_profile_image_url_by_id(
|
def update_user_profile_image_url_by_id(
|
||||||
self, id: str, profile_image_url: str
|
self, id: str, profile_image_url: str
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
|
|
@ -348,7 +549,7 @@ class UsersTable:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
def update_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(User).filter_by(id=id).update(
|
db.query(User).filter_by(id=id).update(
|
||||||
|
|
@ -361,16 +562,35 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_oauth_sub_by_id(
|
def update_user_oauth_by_id(
|
||||||
self, id: str, oauth_sub: str
|
self, id: str, provider: str, sub: str
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
|
"""
|
||||||
|
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||||
|
Example resulting structure:
|
||||||
|
{
|
||||||
|
"google": { "sub": "123" },
|
||||||
|
"github": { "sub": "abc" }
|
||||||
|
}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
|
user = db.query(User).filter_by(id=id).first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load existing oauth JSON or create empty
|
||||||
|
oauth = user.oauth or {}
|
||||||
|
|
||||||
|
# Update or insert provider entry
|
||||||
|
oauth[provider] = {"sub": sub}
|
||||||
|
|
||||||
|
# Persist updated JSON
|
||||||
|
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
user = db.query(User).filter_by(id=id).first()
|
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -424,23 +644,45 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
|
||||||
try:
|
|
||||||
with get_db() as db:
|
|
||||||
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
|
|
||||||
db.commit()
|
|
||||||
return True if result == 1 else False
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||||
return user.api_key
|
return api_key.key if api_key else None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
new_api_key = ApiKey(
|
||||||
|
id=f"key_{id}",
|
||||||
|
user_id=id,
|
||||||
|
key=api_key,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
db.add(new_api_key)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_user_api_key_by_id(self, id: str) -> bool:
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||||
|
|
@ -454,5 +696,23 @@ class UsersTable:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_active_user_count(self) -> int:
|
||||||
|
with get_db() as db:
|
||||||
|
# Consider user active if last_active_at within the last 3 minutes
|
||||||
|
three_minutes_ago = int(time.time()) - 180
|
||||||
|
count = (
|
||||||
|
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||||
|
)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def is_user_active(self, user_id: str) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
user = db.query(User).filter_by(id=user_id).first()
|
||||||
|
if user and user.last_active_at:
|
||||||
|
# Consider user active if last_active_at within the last 3 minutes
|
||||||
|
three_minutes_ago = int(time.time()) - 180
|
||||||
|
return user.last_active_at >= three_minutes_ago
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
Users = UsersTable()
|
Users = UsersTable()
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from urllib.parse import quote
|
||||||
|
|
||||||
from langchain_core.document_loaders import BaseLoader
|
from langchain_core.document_loaders import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -18,6 +19,7 @@ class ExternalDocumentLoader(BaseLoader):
|
||||||
url: str,
|
url: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
mime_type=None,
|
mime_type=None,
|
||||||
|
user=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.url = url
|
self.url = url
|
||||||
|
|
@ -26,6 +28,8 @@ class ExternalDocumentLoader(BaseLoader):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.mime_type = mime_type
|
self.mime_type = mime_type
|
||||||
|
|
||||||
|
self.user = user
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, "rb") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
@ -42,6 +46,9 @@ class ExternalDocumentLoader(BaseLoader):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if self.user is not None:
|
||||||
|
headers = include_user_info_headers(headers, self.user)
|
||||||
|
|
||||||
url = self.url
|
url = self.url
|
||||||
if url.endswith("/"):
|
if url.endswith("/"):
|
||||||
url = url[:-1]
|
url = url[:-1]
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from open_webui.retrieval.loaders.external_document import ExternalDocumentLoade
|
||||||
|
|
||||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||||
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||||
|
from open_webui.retrieval.loaders.mineru import MinerULoader
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
@ -131,8 +132,9 @@ class TikaLoader:
|
||||||
|
|
||||||
|
|
||||||
class DoclingLoader:
|
class DoclingLoader:
|
||||||
def __init__(self, url, file_path=None, mime_type=None, params=None):
|
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
||||||
self.url = url.rstrip("/")
|
self.url = url.rstrip("/")
|
||||||
|
self.api_key = api_key
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.mime_type = mime_type
|
self.mime_type = mime_type
|
||||||
|
|
||||||
|
|
@ -140,6 +142,10 @@ class DoclingLoader:
|
||||||
|
|
||||||
def load(self) -> list[Document]:
|
def load(self) -> list[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, "rb") as f:
|
||||||
|
headers = {}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
files = {
|
files = {
|
||||||
"files": (
|
"files": (
|
||||||
self.file_path,
|
self.file_path,
|
||||||
|
|
@ -148,60 +154,15 @@ class DoclingLoader:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
params = {"image_export_mode": "placeholder"}
|
r = requests.post(
|
||||||
|
f"{self.url}/v1/convert/file",
|
||||||
if self.params:
|
files=files,
|
||||||
if self.params.get("do_picture_description"):
|
data={
|
||||||
params["do_picture_description"] = self.params.get(
|
"image_export_mode": "placeholder",
|
||||||
"do_picture_description"
|
**self.params,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
picture_description_mode = self.params.get(
|
|
||||||
"picture_description_mode", ""
|
|
||||||
).lower()
|
|
||||||
|
|
||||||
if picture_description_mode == "local" and self.params.get(
|
|
||||||
"picture_description_local", {}
|
|
||||||
):
|
|
||||||
params["picture_description_local"] = json.dumps(
|
|
||||||
self.params.get("picture_description_local", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
elif picture_description_mode == "api" and self.params.get(
|
|
||||||
"picture_description_api", {}
|
|
||||||
):
|
|
||||||
params["picture_description_api"] = json.dumps(
|
|
||||||
self.params.get("picture_description_api", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
params["do_ocr"] = self.params.get("do_ocr")
|
|
||||||
|
|
||||||
params["force_ocr"] = self.params.get("force_ocr")
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.params.get("do_ocr")
|
|
||||||
and self.params.get("ocr_engine")
|
|
||||||
and self.params.get("ocr_lang")
|
|
||||||
):
|
|
||||||
params["ocr_engine"] = self.params.get("ocr_engine")
|
|
||||||
params["ocr_lang"] = [
|
|
||||||
lang.strip()
|
|
||||||
for lang in self.params.get("ocr_lang").split(",")
|
|
||||||
if lang.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.params.get("pdf_backend"):
|
|
||||||
params["pdf_backend"] = self.params.get("pdf_backend")
|
|
||||||
|
|
||||||
if self.params.get("table_mode"):
|
|
||||||
params["table_mode"] = self.params.get("table_mode")
|
|
||||||
|
|
||||||
if self.params.get("pipeline"):
|
|
||||||
params["pipeline"] = self.params.get("pipeline")
|
|
||||||
|
|
||||||
endpoint = f"{self.url}/v1/convert/file"
|
|
||||||
r = requests.post(endpoint, files=files, data=params)
|
|
||||||
|
|
||||||
if r.ok:
|
if r.ok:
|
||||||
result = r.json()
|
result = r.json()
|
||||||
document_data = result.get("document", {})
|
document_data = result.get("document", {})
|
||||||
|
|
@ -210,7 +171,6 @@ class DoclingLoader:
|
||||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||||
|
|
||||||
log.debug("Docling extracted text: %s", text)
|
log.debug("Docling extracted text: %s", text)
|
||||||
|
|
||||||
return [Document(page_content=text, metadata=metadata)]
|
return [Document(page_content=text, metadata=metadata)]
|
||||||
else:
|
else:
|
||||||
error_msg = f"Error calling Docling API: {r.reason}"
|
error_msg = f"Error calling Docling API: {r.reason}"
|
||||||
|
|
@ -227,6 +187,7 @@ class DoclingLoader:
|
||||||
class Loader:
|
class Loader:
|
||||||
def __init__(self, engine: str = "", **kwargs):
|
def __init__(self, engine: str = "", **kwargs):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self.user = kwargs.get("user", None)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -263,6 +224,7 @@ class Loader:
|
||||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
|
user=self.user,
|
||||||
)
|
)
|
||||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||||
if self._is_text_file(file_ext, file_content_type):
|
if self._is_text_file(file_ext, file_content_type):
|
||||||
|
|
@ -271,7 +233,6 @@ class Loader:
|
||||||
loader = TikaLoader(
|
loader = TikaLoader(
|
||||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
|
||||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
|
@ -338,6 +299,7 @@ class Loader:
|
||||||
|
|
||||||
loader = DoclingLoader(
|
loader = DoclingLoader(
|
||||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||||
|
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
params=params,
|
params=params,
|
||||||
|
|
@ -346,11 +308,9 @@ class Loader:
|
||||||
self.engine == "document_intelligence"
|
self.engine == "document_intelligence"
|
||||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||||
and (
|
and (
|
||||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
file_ext in ["pdf", "docx", "ppt", "pptx"]
|
||||||
or file_content_type
|
or file_content_type
|
||||||
in [
|
in [
|
||||||
"application/vnd.ms-excel",
|
|
||||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
"application/vnd.ms-powerpoint",
|
"application/vnd.ms-powerpoint",
|
||||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
|
@ -362,12 +322,24 @@ class Loader:
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||||
|
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loader = AzureAIDocumentIntelligenceLoader(
|
loader = AzureAIDocumentIntelligenceLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||||
azure_credential=DefaultAzureCredential(),
|
azure_credential=DefaultAzureCredential(),
|
||||||
|
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||||
|
)
|
||||||
|
elif self.engine == "mineru" and file_ext in [
|
||||||
|
"pdf"
|
||||||
|
]: # MinerU currently only supports PDF
|
||||||
|
loader = MinerULoader(
|
||||||
|
file_path=file_path,
|
||||||
|
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
||||||
|
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
||||||
|
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
||||||
|
params=self.kwargs.get("MINERU_PARAMS", {}),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.engine == "mistral_ocr"
|
self.engine == "mistral_ocr"
|
||||||
|
|
@ -376,16 +348,9 @@ class Loader:
|
||||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||||
):
|
):
|
||||||
loader = MistralLoader(
|
loader = MistralLoader(
|
||||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
|
||||||
)
|
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
|
||||||
elif (
|
file_path=file_path,
|
||||||
self.engine == "external"
|
|
||||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
|
||||||
and file_ext
|
|
||||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
|
||||||
):
|
|
||||||
loader = MistralLoader(
|
|
||||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if file_ext == "pdf":
|
if file_ext == "pdf":
|
||||||
|
|
|
||||||
522
backend/open_webui/retrieval/loaders/mineru.py
Normal file
522
backend/open_webui/retrieval/loaders/mineru.py
Normal file
|
|
@ -0,0 +1,522 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from typing import List, Optional
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MinerULoader:
|
||||||
|
"""
|
||||||
|
MinerU document parser loader supporting both Cloud API and Local API modes.
|
||||||
|
|
||||||
|
Cloud API: Uses MinerU managed service with async task-based processing
|
||||||
|
Local API: Uses self-hosted MinerU API with synchronous processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
api_mode: str = "local",
|
||||||
|
api_url: str = "http://localhost:8000",
|
||||||
|
api_key: str = "",
|
||||||
|
params: dict = None,
|
||||||
|
):
|
||||||
|
self.file_path = file_path
|
||||||
|
self.api_mode = api_mode.lower()
|
||||||
|
self.api_url = api_url.rstrip("/")
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
# Parse params dict with defaults
|
||||||
|
self.params = params or {}
|
||||||
|
self.enable_ocr = params.get("enable_ocr", False)
|
||||||
|
self.enable_formula = params.get("enable_formula", True)
|
||||||
|
self.enable_table = params.get("enable_table", True)
|
||||||
|
self.language = params.get("language", "en")
|
||||||
|
self.model_version = params.get("model_version", "pipeline")
|
||||||
|
|
||||||
|
self.page_ranges = self.params.pop("page_ranges", "")
|
||||||
|
|
||||||
|
# Validate API mode
|
||||||
|
if self.api_mode not in ["local", "cloud"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate Cloud API requirements
|
||||||
|
if self.api_mode == "cloud" and not self.api_key:
|
||||||
|
raise ValueError("API key is required for Cloud API mode")
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Main entry point for loading and parsing the document.
|
||||||
|
Routes to Cloud or Local API based on api_mode.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.api_mode == "cloud":
|
||||||
|
return self._load_cloud_api()
|
||||||
|
else:
|
||||||
|
return self._load_local_api()
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error loading document with MinerU: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_local_api(self) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load document using Local API (synchronous).
|
||||||
|
Posts file to /file_parse endpoint and gets immediate response.
|
||||||
|
"""
|
||||||
|
log.info(f"Using MinerU Local API at {self.api_url}")
|
||||||
|
|
||||||
|
filename = os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
# Build form data for Local API
|
||||||
|
form_data = {
|
||||||
|
**self.params,
|
||||||
|
"return_md": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||||
|
if self.page_ranges:
|
||||||
|
# For simplicity, if page_ranges is specified, log a warning
|
||||||
|
# Full page range parsing would require parsing the string
|
||||||
|
log.warning(
|
||||||
|
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
||||||
|
"Consider using start_page_id/end_page_id parameters if needed."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.file_path, "rb") as f:
|
||||||
|
files = {"files": (filename, f, "application/octet-stream")}
|
||||||
|
|
||||||
|
log.info(f"Sending file to MinerU Local API: {filename}")
|
||||||
|
log.debug(f"Local API parameters: {form_data}")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_url}/file_parse",
|
||||||
|
data=form_data,
|
||||||
|
files=files,
|
||||||
|
timeout=300, # 5 minute timeout for large documents
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||||
|
)
|
||||||
|
except requests.Timeout:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail="MinerU Local API request timed out",
|
||||||
|
)
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
error_detail = f"MinerU Local API request failed: {e}"
|
||||||
|
if e.response is not None:
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
error_detail += f" - {error_data}"
|
||||||
|
except:
|
||||||
|
error_detail += f" - {e.response.text}"
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error calling MinerU Local API: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
try:
|
||||||
|
result = response.json()
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Invalid JSON response from MinerU Local API: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract markdown content from response
|
||||||
|
if "results" not in result:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="MinerU Local API response missing 'results' field",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = result["results"]
|
||||||
|
if not results:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="MinerU returned empty results",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the first (and typically only) result
|
||||||
|
file_result = list(results.values())[0]
|
||||||
|
markdown_content = file_result.get("md_content", "")
|
||||||
|
|
||||||
|
if not markdown_content:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="MinerU returned empty markdown content",
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source": filename,
|
||||||
|
"api_mode": "local",
|
||||||
|
"backend": result.get("backend", "unknown"),
|
||||||
|
"version": result.get("version", "unknown"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||||
|
|
||||||
|
def _load_cloud_api(self) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load document using Cloud API (asynchronous).
|
||||||
|
Uses batch upload endpoint to avoid need for public file URLs.
|
||||||
|
"""
|
||||||
|
log.info(f"Using MinerU Cloud API at {self.api_url}")
|
||||||
|
|
||||||
|
filename = os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
# Step 1: Request presigned upload URL
|
||||||
|
batch_id, upload_url = self._request_upload_url(filename)
|
||||||
|
|
||||||
|
# Step 2: Upload file to presigned URL
|
||||||
|
self._upload_to_presigned_url(upload_url)
|
||||||
|
|
||||||
|
# Step 3: Poll for results
|
||||||
|
result = self._poll_batch_status(batch_id, filename)
|
||||||
|
|
||||||
|
# Step 4: Download and extract markdown from ZIP
|
||||||
|
markdown_content = self._download_and_extract_zip(
|
||||||
|
result["full_zip_url"], filename
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source": filename,
|
||||||
|
"api_mode": "cloud",
|
||||||
|
"batch_id": batch_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||||
|
|
||||||
|
def _request_upload_url(self, filename: str) -> tuple:
|
||||||
|
"""
|
||||||
|
Request presigned upload URL from Cloud API.
|
||||||
|
Returns (batch_id, upload_url).
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build request body
|
||||||
|
request_body = {
|
||||||
|
**self.params,
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"name": filename,
|
||||||
|
"is_ocr": self.enable_ocr,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add page ranges if specified
|
||||||
|
if self.page_ranges:
|
||||||
|
request_body["files"][0]["page_ranges"] = self.page_ranges
|
||||||
|
|
||||||
|
log.info(f"Requesting upload URL for: {filename}")
|
||||||
|
log.debug(f"Cloud API request body: {request_body}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_url}/file-urls/batch",
|
||||||
|
headers=headers,
|
||||||
|
json=request_body,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
error_detail = f"Failed to request upload URL: {e}"
|
||||||
|
if e.response is not None:
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||||
|
except:
|
||||||
|
error_detail += f" - {e.response.text}"
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error requesting upload URL: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = response.json()
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Invalid JSON response: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for API error response
|
||||||
|
if result.get("code") != 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
batch_id = data.get("batch_id")
|
||||||
|
file_urls = data.get("file_urls", [])
|
||||||
|
|
||||||
|
if not batch_id or not file_urls:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail="MinerU Cloud API response missing batch_id or file_urls",
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_url = file_urls[0]
|
||||||
|
log.info(f"Received upload URL for batch: {batch_id}")
|
||||||
|
|
||||||
|
return batch_id, upload_url
|
||||||
|
|
||||||
|
def _upload_to_presigned_url(self, upload_url: str) -> None:
|
||||||
|
"""
|
||||||
|
Upload file to presigned URL (no authentication needed).
|
||||||
|
"""
|
||||||
|
log.info(f"Uploading file to presigned URL")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.file_path, "rb") as f:
|
||||||
|
response = requests.put(
|
||||||
|
upload_url,
|
||||||
|
data=f,
|
||||||
|
timeout=300, # 5 minute timeout for large files
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||||
|
)
|
||||||
|
except requests.Timeout:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail="File upload to presigned URL timed out",
|
||||||
|
)
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Failed to upload file to presigned URL: {e}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error uploading file: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("File uploaded successfully")
|
||||||
|
|
||||||
|
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
||||||
|
"""
|
||||||
|
Poll batch status until completion.
|
||||||
|
Returns the result dict for the file.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
||||||
|
poll_interval = 2 # seconds
|
||||||
|
|
||||||
|
log.info(f"Polling batch status: {batch_id}")
|
||||||
|
|
||||||
|
for iteration in range(max_iterations):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.api_url}/extract-results/batch/{batch_id}",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
error_detail = f"Failed to poll batch status: {e}"
|
||||||
|
if e.response is not None:
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||||
|
except:
|
||||||
|
error_detail += f" - {e.response.text}"
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error polling batch status: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = response.json()
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Invalid JSON response while polling: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for API error response
|
||||||
|
if result.get("code") != 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = result.get("data", {})
|
||||||
|
extract_result = data.get("extract_result", [])
|
||||||
|
|
||||||
|
# Find our file in the batch results
|
||||||
|
file_result = None
|
||||||
|
for item in extract_result:
|
||||||
|
if item.get("file_name") == filename:
|
||||||
|
file_result = item
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_result:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"File {filename} not found in batch results",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = file_result.get("state")
|
||||||
|
|
||||||
|
if state == "done":
|
||||||
|
log.info(f"Processing complete for {filename}")
|
||||||
|
return file_result
|
||||||
|
elif state == "failed":
|
||||||
|
error_msg = file_result.get("err_msg", "Unknown error")
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"MinerU processing failed: {error_msg}",
|
||||||
|
)
|
||||||
|
elif state in ["waiting-file", "pending", "running", "converting"]:
|
||||||
|
# Still processing
|
||||||
|
if iteration % 10 == 0: # Log every 20 seconds
|
||||||
|
log.info(
|
||||||
|
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
|
||||||
|
)
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
else:
|
||||||
|
log.warning(f"Unknown state: {state}")
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
# Timeout
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail="MinerU processing timed out after 10 minutes",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Download ZIP file from CDN and extract markdown content.
|
||||||
|
Returns the markdown content as a string.
|
||||||
|
"""
|
||||||
|
log.info(f"Downloading results from: {zip_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(zip_url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Failed to download results ZIP: {e}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error downloading results: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save ZIP to temporary file and extract
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||||
|
tmp_zip.write(response.content)
|
||||||
|
tmp_zip_path = tmp_zip.name
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Extract ZIP
|
||||||
|
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(tmp_dir)
|
||||||
|
|
||||||
|
# Find markdown file - search recursively for any .md file
|
||||||
|
markdown_content = None
|
||||||
|
found_md_path = None
|
||||||
|
|
||||||
|
# First, list all files in the ZIP for debugging
|
||||||
|
all_files = []
|
||||||
|
for root, dirs, files in os.walk(tmp_dir):
|
||||||
|
for file in files:
|
||||||
|
full_path = os.path.join(root, file)
|
||||||
|
all_files.append(full_path)
|
||||||
|
# Look for any .md file
|
||||||
|
if file.endswith(".md"):
|
||||||
|
found_md_path = full_path
|
||||||
|
log.info(f"Found markdown file at: {full_path}")
|
||||||
|
try:
|
||||||
|
with open(full_path, "r", encoding="utf-8") as f:
|
||||||
|
markdown_content = f.read()
|
||||||
|
if (
|
||||||
|
markdown_content
|
||||||
|
): # Use the first non-empty markdown file
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Failed to read {full_path}: {e}")
|
||||||
|
if markdown_content:
|
||||||
|
break
|
||||||
|
|
||||||
|
if markdown_content is None:
|
||||||
|
log.error(f"Available files in ZIP: {all_files}")
|
||||||
|
# Try to provide more helpful error message
|
||||||
|
md_files = [f for f in all_files if f.endswith(".md")]
|
||||||
|
if md_files:
|
||||||
|
error_msg = (
|
||||||
|
f"Found .md files but couldn't read them: {md_files}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_msg = (
|
||||||
|
f"No .md files found in ZIP. Available files: {all_files}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up temporary ZIP file
|
||||||
|
os.unlink(tmp_zip_path)
|
||||||
|
|
||||||
|
except zipfile.BadZipFile as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Invalid ZIP file received: {e}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Error extracting ZIP: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not markdown_content:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Extracted markdown content is empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
|
||||||
|
)
|
||||||
|
return markdown_content
|
||||||
|
|
@ -30,10 +30,9 @@ class MistralLoader:
|
||||||
- Enhanced error handling with retryable error classification
|
- Enhanced error handling with retryable error classification
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BASE_API_URL = "https://api.mistral.ai/v1"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
base_url: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
timeout: int = 300, # 5 minutes default
|
timeout: int = 300, # 5 minutes default
|
||||||
|
|
@ -55,6 +54,9 @@ class MistralLoader:
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
raise FileNotFoundError(f"File not found at {file_path}")
|
raise FileNotFoundError(f"File not found at {file_path}")
|
||||||
|
|
||||||
|
self.base_url = (
|
||||||
|
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
|
||||||
|
)
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
@ -240,7 +242,7 @@ class MistralLoader:
|
||||||
in a context manager to minimize memory usage duration.
|
in a context manager to minimize memory usage duration.
|
||||||
"""
|
"""
|
||||||
log.info("Uploading file to Mistral API")
|
log.info("Uploading file to Mistral API")
|
||||||
url = f"{self.BASE_API_URL}/files"
|
url = f"{self.base_url}/files"
|
||||||
|
|
||||||
def upload_request():
|
def upload_request():
|
||||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||||
|
|
@ -275,7 +277,7 @@ class MistralLoader:
|
||||||
|
|
||||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||||
"""Async file upload with streaming for better memory efficiency."""
|
"""Async file upload with streaming for better memory efficiency."""
|
||||||
url = f"{self.BASE_API_URL}/files"
|
url = f"{self.base_url}/files"
|
||||||
|
|
||||||
async def upload_request():
|
async def upload_request():
|
||||||
# Create multipart writer for streaming upload
|
# Create multipart writer for streaming upload
|
||||||
|
|
@ -321,7 +323,7 @@ class MistralLoader:
|
||||||
def _get_signed_url(self, file_id: str) -> str:
|
def _get_signed_url(self, file_id: str) -> str:
|
||||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
url = f"{self.base_url}/files/{file_id}/url"
|
||||||
params = {"expiry": 1}
|
params = {"expiry": 1}
|
||||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||||
|
|
||||||
|
|
@ -346,7 +348,7 @@ class MistralLoader:
|
||||||
self, session: aiohttp.ClientSession, file_id: str
|
self, session: aiohttp.ClientSession, file_id: str
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Async signed URL retrieval."""
|
"""Async signed URL retrieval."""
|
||||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
url = f"{self.base_url}/files/{file_id}/url"
|
||||||
params = {"expiry": 1}
|
params = {"expiry": 1}
|
||||||
|
|
||||||
headers = {**self.headers, "Accept": "application/json"}
|
headers = {**self.headers, "Accept": "application/json"}
|
||||||
|
|
@ -373,7 +375,7 @@ class MistralLoader:
|
||||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||||
log.info("Processing OCR via Mistral API")
|
log.info("Processing OCR via Mistral API")
|
||||||
url = f"{self.BASE_API_URL}/ocr"
|
url = f"{self.base_url}/ocr"
|
||||||
ocr_headers = {
|
ocr_headers = {
|
||||||
**self.headers,
|
**self.headers,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
@ -407,7 +409,7 @@ class MistralLoader:
|
||||||
self, session: aiohttp.ClientSession, signed_url: str
|
self, session: aiohttp.ClientSession, signed_url: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Async OCR processing with timing metrics."""
|
"""Async OCR processing with timing metrics."""
|
||||||
url = f"{self.BASE_API_URL}/ocr"
|
url = f"{self.base_url}/ocr"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
**self.headers,
|
**self.headers,
|
||||||
|
|
@ -446,7 +448,7 @@ class MistralLoader:
|
||||||
def _delete_file(self, file_id: str) -> None:
|
def _delete_file(self, file_id: str) -> None:
|
||||||
"""Deletes the file from Mistral storage (sync version)."""
|
"""Deletes the file from Mistral storage (sync version)."""
|
||||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||||
url = f"{self.BASE_API_URL}/files/{file_id}"
|
url = f"{self.base_url}/files/{file_id}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.delete(
|
response = requests.delete(
|
||||||
|
|
@ -467,7 +469,7 @@ class MistralLoader:
|
||||||
async def delete_request():
|
async def delete_request():
|
||||||
self._debug_log(f"Deleting file ID: {file_id}")
|
self._debug_log(f"Deleting file ID: {file_id}")
|
||||||
async with session.delete(
|
async with session.delete(
|
||||||
url=f"{self.BASE_API_URL}/files/{file_id}",
|
url=f"{self.base_url}/files/{file_id}",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
total=self.cleanup_timeout
|
total=self.cleanup_timeout
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class YoutubeLoader:
|
||||||
TranscriptsDisabled,
|
TranscriptsDisabled,
|
||||||
YouTubeTranscriptApi,
|
YouTubeTranscriptApi,
|
||||||
)
|
)
|
||||||
|
from youtube_transcript_api.proxies import GenericProxyConfig
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Could not import "youtube_transcript_api" Python package. '
|
'Could not import "youtube_transcript_api" Python package. '
|
||||||
|
|
@ -90,10 +91,9 @@ class YoutubeLoader:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.proxy_url:
|
if self.proxy_url:
|
||||||
youtube_proxies = {
|
youtube_proxies = GenericProxyConfig(
|
||||||
"http": self.proxy_url,
|
http_url=self.proxy_url, https_url=self.proxy_url
|
||||||
"https": self.proxy_url,
|
)
|
||||||
}
|
|
||||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||||
else:
|
else:
|
||||||
youtube_proxies = None
|
youtube_proxies = None
|
||||||
|
|
@ -157,3 +157,10 @@ class YoutubeLoader:
|
||||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||||
)
|
)
|
||||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||||
|
|
||||||
|
async def aload(self) -> Generator[Document, None, None]:
|
||||||
|
"""Asynchronously load YouTube transcripts into `Document` objects."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self.load)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from urllib.parse import quote
|
||||||
|
|
||||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -40,22 +41,17 @@ class ExternalReranker(BaseReranker):
|
||||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||||
log.info(f"ExternalReranker:predict:query {query}")
|
log.info(f"ExternalReranker:predict:query {query}")
|
||||||
|
|
||||||
r = requests.post(
|
headers = {
|
||||||
f"{self.url}",
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
}
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
),
|
headers = include_user_info_headers(headers, user)
|
||||||
},
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{self.url}",
|
||||||
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -200,23 +200,24 @@ class MilvusClient(VectorDBBase):
|
||||||
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
||||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||||
|
|
||||||
# Construct the filter string for querying
|
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace("-", "_")
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
filter_string = " && ".join(
|
|
||||||
[
|
filter_expressions = []
|
||||||
f'metadata["{key}"] == {json.dumps(value)}'
|
for key, value in filter.items():
|
||||||
for key, value in filter.items()
|
if isinstance(value, str):
|
||||||
]
|
filter_expressions.append(f'metadata["{key}"] == "{value}"')
|
||||||
)
|
else:
|
||||||
|
filter_expressions.append(f'metadata["{key}"] == {value}')
|
||||||
|
|
||||||
|
filter_string = " && ".join(filter_expressions)
|
||||||
|
|
||||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
||||||
collection.load()
|
collection.load()
|
||||||
all_results = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -224,24 +225,25 @@ class MilvusClient(VectorDBBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
iterator = collection.query_iterator(
|
iterator = collection.query_iterator(
|
||||||
filter=filter_string,
|
expr=filter_string,
|
||||||
output_fields=[
|
output_fields=[
|
||||||
"id",
|
"id",
|
||||||
"data",
|
"data",
|
||||||
"metadata",
|
"metadata",
|
||||||
],
|
],
|
||||||
limit=limit, # Pass the limit directly; -1 means no limit.
|
limit=limit if limit > 0 else -1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
while True:
|
while True:
|
||||||
result = iterator.next()
|
batch = iterator.next()
|
||||||
if not result:
|
if not batch:
|
||||||
iterator.close()
|
iterator.close()
|
||||||
break
|
break
|
||||||
all_results += result
|
all_results.extend(batch)
|
||||||
|
|
||||||
log.info(f"Total results from query: {len(all_results)}")
|
log.debug(f"Total results from query: {len(all_results)}")
|
||||||
return self._result_to_get_result([all_results])
|
return self._result_to_get_result([all_results] if all_results else [[]])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,6 @@ class MilvusClient(VectorDBBase):
|
||||||
for item in items
|
for item in items
|
||||||
]
|
]
|
||||||
collection.insert(entities)
|
collection.insert(entities)
|
||||||
collection.flush()
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, collection_name: str, vectors: List[List[float]], limit: int
|
self, collection_name: str, vectors: List[List[float]], limit: int
|
||||||
|
|
@ -263,15 +262,23 @@ class MilvusClient(VectorDBBase):
|
||||||
else:
|
else:
|
||||||
expr.append(f"metadata['{key}'] == {value}")
|
expr.append(f"metadata['{key}'] == {value}")
|
||||||
|
|
||||||
results = collection.query(
|
iterator = collection.query_iterator(
|
||||||
expr=" and ".join(expr),
|
expr=" and ".join(expr),
|
||||||
output_fields=["id", "text", "metadata"],
|
output_fields=["id", "text", "metadata"],
|
||||||
limit=limit,
|
limit=limit if limit else -1,
|
||||||
)
|
)
|
||||||
|
|
||||||
ids = [res["id"] for res in results]
|
all_results = []
|
||||||
documents = [res["text"] for res in results]
|
while True:
|
||||||
metadatas = [res["metadata"] for res in results]
|
batch = iterator.next()
|
||||||
|
if not batch:
|
||||||
|
iterator.close()
|
||||||
|
break
|
||||||
|
all_results.extend(batch)
|
||||||
|
|
||||||
|
ids = [res["id"] for res in all_results]
|
||||||
|
documents = [res["text"] for res in all_results]
|
||||||
|
metadatas = [res["metadata"] for res in all_results]
|
||||||
|
|
||||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -717,7 +717,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
limit = limit or 1000
|
limit = 1000 # Hardcoded limit for get operation
|
||||||
|
|
||||||
with self.get_connection() as connection:
|
with self.get_connection() as connection:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
|
@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool
|
||||||
|
|
||||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector, HALFVEC
|
||||||
from sqlalchemy.ext.mutable import MutableDict
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
|
|
@ -44,11 +44,20 @@ from open_webui.config import (
|
||||||
PGVECTOR_POOL_MAX_OVERFLOW,
|
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||||
PGVECTOR_POOL_TIMEOUT,
|
PGVECTOR_POOL_TIMEOUT,
|
||||||
PGVECTOR_POOL_RECYCLE,
|
PGVECTOR_POOL_RECYCLE,
|
||||||
|
PGVECTOR_INDEX_METHOD,
|
||||||
|
PGVECTOR_HNSW_M,
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
||||||
|
PGVECTOR_IVFFLAT_LISTS,
|
||||||
|
PGVECTOR_USE_HALFVEC,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
|
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||||
|
|
||||||
|
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
||||||
|
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -67,7 +76,7 @@ class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
||||||
collection_name = Column(Text, nullable=False)
|
collection_name = Column(Text, nullable=False)
|
||||||
|
|
||||||
if PGVECTOR_PGCRYPTO:
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
|
@ -157,13 +166,9 @@ class PgvectorClient(VectorDBBase):
|
||||||
connection = self.session.connection()
|
connection = self.session.connection()
|
||||||
Base.metadata.create_all(bind=connection)
|
Base.metadata.create_all(bind=connection)
|
||||||
|
|
||||||
# Create an index on the vector column if it doesn't exist
|
index_method, index_options = self._vector_index_configuration()
|
||||||
self.session.execute(
|
self._ensure_vector_index(index_method, index_options)
|
||||||
text(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
|
||||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
text(
|
text(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||||
|
|
@ -177,6 +182,78 @@ class PgvectorClient(VectorDBBase):
|
||||||
log.exception(f"Error during initialization: {e}")
|
log.exception(f"Error during initialization: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
|
||||||
|
if not index_def:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
after_using = index_def.lower().split("using ", 1)[1]
|
||||||
|
return after_using.split()[0]
|
||||||
|
except (IndexError, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _vector_index_configuration(self) -> Tuple[str, str]:
|
||||||
|
if PGVECTOR_INDEX_METHOD:
|
||||||
|
index_method = PGVECTOR_INDEX_METHOD
|
||||||
|
log.info(
|
||||||
|
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
|
||||||
|
index_method,
|
||||||
|
)
|
||||||
|
elif USE_HALFVEC:
|
||||||
|
index_method = "hnsw"
|
||||||
|
log.info(
|
||||||
|
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
||||||
|
VECTOR_LENGTH,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_method = "ivfflat"
|
||||||
|
|
||||||
|
if index_method == "hnsw":
|
||||||
|
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
||||||
|
else:
|
||||||
|
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
||||||
|
|
||||||
|
return index_method, index_options
|
||||||
|
|
||||||
|
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
||||||
|
index_name = "idx_document_chunk_vector"
|
||||||
|
existing_index_def = self.session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT indexdef
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE schemaname = current_schema()
|
||||||
|
AND tablename = 'document_chunk'
|
||||||
|
AND indexname = :index_name
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"index_name": index_name},
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
existing_method = self._extract_index_method(existing_index_def)
|
||||||
|
if existing_method and existing_method != index_method:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
||||||
|
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
||||||
|
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
||||||
|
"and recreate it with the new method before restarting Open WebUI."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_index_def:
|
||||||
|
index_sql = (
|
||||||
|
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
||||||
|
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
||||||
|
)
|
||||||
|
if index_options:
|
||||||
|
index_sql = f"{index_sql} {index_options}"
|
||||||
|
self.session.execute(text(index_sql))
|
||||||
|
log.info(
|
||||||
|
"Ensured vector index '%s' using %s%s.",
|
||||||
|
index_name,
|
||||||
|
index_method,
|
||||||
|
f" {index_options}" if index_options else "",
|
||||||
|
)
|
||||||
|
|
||||||
def check_vector_length(self) -> None:
|
def check_vector_length(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||||
|
|
@ -196,17 +273,20 @@ class PgvectorClient(VectorDBBase):
|
||||||
if "vector" in document_chunk_table.columns:
|
if "vector" in document_chunk_table.columns:
|
||||||
vector_column = document_chunk_table.columns["vector"]
|
vector_column = document_chunk_table.columns["vector"]
|
||||||
vector_type = vector_column.type
|
vector_type = vector_column.type
|
||||||
if isinstance(vector_type, Vector):
|
expected_type = HALFVEC if USE_HALFVEC else Vector
|
||||||
db_vector_length = vector_type.dim
|
|
||||||
if db_vector_length != VECTOR_LENGTH:
|
if not isinstance(vector_type, expected_type):
|
||||||
|
raise Exception(
|
||||||
|
"The 'vector' column type does not match the expected type "
|
||||||
|
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
||||||
|
)
|
||||||
|
|
||||||
|
db_vector_length = getattr(vector_type, "dim", None)
|
||||||
|
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||||
"Cannot change vector size after initialization without migrating the data."
|
"Cannot change vector size after initialization without migrating the data."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"The 'vector' column exists but is not of type 'Vector'."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||||
|
|
@ -360,11 +440,11 @@ class PgvectorClient(VectorDBBase):
|
||||||
num_queries = len(vectors)
|
num_queries = len(vectors)
|
||||||
|
|
||||||
def vector_expr(vector):
|
def vector_expr(vector):
|
||||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||||
|
|
||||||
# Create the values for query vectors
|
# Create the values for query vectors
|
||||||
qid_col = column("qid", Integer)
|
qid_col = column("qid", Integer)
|
||||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||||
query_vectors = (
|
query_vectors = (
|
||||||
values(qid_col, q_vector_col)
|
values(qid_col, q_vector_col)
|
||||||
.data(
|
.data(
|
||||||
|
|
|
||||||
|
|
@ -117,15 +117,16 @@ class S3VectorClient(VectorDBBase):
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a vector index (collection) exists in the S3 vector bucket.
|
Check if a vector index exists using direct lookup.
|
||||||
|
This avoids pagination issues with list_indexes() and is significantly faster.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
self.client.get_index(
|
||||||
indexes = response.get("indexes", [])
|
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||||
return any(idx.get("indexName") == collection_name for idx in indexes)
|
)
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error listing indexes: {e}")
|
log.error(f"Error checking if index '{collection_name}' exists: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
|
|
||||||
340
backend/open_webui/retrieval/vector/dbs/weaviate.py
Normal file
340
backend/open_webui/retrieval/vector/dbs/weaviate.py
Normal file
|
|
@ -0,0 +1,340 @@
|
||||||
|
import weaviate
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
SearchResult,
|
||||||
|
GetResult,
|
||||||
|
)
|
||||||
|
from open_webui.retrieval.vector.utils import process_metadata
|
||||||
|
from open_webui.config import (
|
||||||
|
WEAVIATE_HTTP_HOST,
|
||||||
|
WEAVIATE_HTTP_PORT,
|
||||||
|
WEAVIATE_GRPC_PORT,
|
||||||
|
WEAVIATE_API_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_uuids_to_strings(obj: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Recursively convert UUID objects to strings in nested data structures.
|
||||||
|
|
||||||
|
This function handles:
|
||||||
|
- UUID objects -> string
|
||||||
|
- Dictionaries with UUID values
|
||||||
|
- Lists/Tuples with UUID values
|
||||||
|
- Nested combinations of the above
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Any object that might contain UUIDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same object structure with UUIDs converted to strings
|
||||||
|
"""
|
||||||
|
if isinstance(obj, uuid.UUID):
|
||||||
|
return str(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: _convert_uuids_to_strings(value) for key, value in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return type(obj)(_convert_uuids_to_strings(item) for item in obj)
|
||||||
|
elif isinstance(obj, (str, int, float, bool, type(None))):
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateClient(VectorDBBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.url = WEAVIATE_HTTP_HOST
|
||||||
|
try:
|
||||||
|
# Build connection parameters
|
||||||
|
connection_params = {
|
||||||
|
"host": WEAVIATE_HTTP_HOST,
|
||||||
|
"port": WEAVIATE_HTTP_PORT,
|
||||||
|
"grpc_port": WEAVIATE_GRPC_PORT,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
|
||||||
|
if WEAVIATE_API_KEY:
|
||||||
|
connection_params["auth_credentials"] = (
|
||||||
|
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = weaviate.connect_to_local(**connection_params)
|
||||||
|
self.client.connect()
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
|
||||||
|
|
||||||
|
def _sanitize_collection_name(self, collection_name: str) -> str:
|
||||||
|
"""Sanitize collection name to be a valid Weaviate class name."""
|
||||||
|
if not isinstance(collection_name, str) or not collection_name.strip():
|
||||||
|
raise ValueError("Collection name must be a non-empty string")
|
||||||
|
|
||||||
|
# Requirements for a valid Weaviate class name:
|
||||||
|
# The collection name must begin with a capital letter.
|
||||||
|
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
|
||||||
|
|
||||||
|
# Replace hyphens with underscores and keep only alphanumeric characters
|
||||||
|
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
|
||||||
|
name = name.strip("_")
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not sanitize collection name to be a valid Weaviate class name"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure it starts with a letter and is capitalized
|
||||||
|
if not name[0].isalpha():
|
||||||
|
name = "C" + name
|
||||||
|
|
||||||
|
return name[0].upper() + name[1:]
|
||||||
|
|
||||||
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
return self.client.collections.exists(sane_collection_name)
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if self.client.collections.exists(sane_collection_name):
|
||||||
|
self.client.collections.delete(sane_collection_name)
|
||||||
|
|
||||||
|
def _create_collection(self, collection_name: str) -> None:
|
||||||
|
self.client.collections.create(
|
||||||
|
name=collection_name,
|
||||||
|
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
|
||||||
|
properties=[
|
||||||
|
weaviate.classes.config.Property(
|
||||||
|
name="text", data_type=weaviate.classes.config.DataType.TEXT
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
self._create_collection(sane_collection_name)
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
|
||||||
|
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||||
|
for item in items:
|
||||||
|
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
|
||||||
|
|
||||||
|
properties = {"text": item["text"]}
|
||||||
|
if item["metadata"]:
|
||||||
|
clean_metadata = _convert_uuids_to_strings(
|
||||||
|
process_metadata(item["metadata"])
|
||||||
|
)
|
||||||
|
clean_metadata.pop("text", None)
|
||||||
|
properties.update(clean_metadata)
|
||||||
|
|
||||||
|
batch.add_object(
|
||||||
|
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
self._create_collection(sane_collection_name)
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
|
||||||
|
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||||
|
for item in items:
|
||||||
|
item_uuid = str(item["id"]) if item["id"] else None
|
||||||
|
|
||||||
|
properties = {"text": item["text"]}
|
||||||
|
if item["metadata"]:
|
||||||
|
clean_metadata = _convert_uuids_to_strings(
|
||||||
|
process_metadata(item["metadata"])
|
||||||
|
)
|
||||||
|
clean_metadata.pop("text", None)
|
||||||
|
properties.update(clean_metadata)
|
||||||
|
|
||||||
|
batch.add_object(
|
||||||
|
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
|
||||||
|
result_ids, result_documents, result_metadatas, result_distances = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_embedding in vectors:
|
||||||
|
try:
|
||||||
|
response = collection.query.near_vector(
|
||||||
|
near_vector=vector_embedding,
|
||||||
|
limit=limit,
|
||||||
|
return_metadata=weaviate.classes.query.MetadataQuery(distance=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = [str(obj.uuid) for obj in response.objects]
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
distances = []
|
||||||
|
|
||||||
|
for obj in response.objects:
|
||||||
|
properties = dict(obj.properties) if obj.properties else {}
|
||||||
|
documents.append(properties.pop("text", ""))
|
||||||
|
metadatas.append(_convert_uuids_to_strings(properties))
|
||||||
|
|
||||||
|
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
|
||||||
|
raw_distances = [
|
||||||
|
(
|
||||||
|
obj.metadata.distance
|
||||||
|
if obj.metadata and obj.metadata.distance
|
||||||
|
else 2.0
|
||||||
|
)
|
||||||
|
for obj in response.objects
|
||||||
|
]
|
||||||
|
distances = [(2 - dist) / 2 for dist in raw_distances]
|
||||||
|
|
||||||
|
result_ids.append(ids)
|
||||||
|
result_documents.append(documents)
|
||||||
|
result_metadatas.append(metadatas)
|
||||||
|
result_distances.append(distances)
|
||||||
|
except Exception:
|
||||||
|
result_ids.append([])
|
||||||
|
result_documents.append([])
|
||||||
|
result_metadatas.append([])
|
||||||
|
result_distances.append([])
|
||||||
|
|
||||||
|
return SearchResult(
|
||||||
|
**{
|
||||||
|
"ids": result_ids,
|
||||||
|
"documents": result_documents,
|
||||||
|
"metadatas": result_metadatas,
|
||||||
|
"distances": result_distances,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
|
||||||
|
weaviate_filter = None
|
||||||
|
if filter:
|
||||||
|
for key, value in filter.items():
|
||||||
|
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
|
||||||
|
value
|
||||||
|
)
|
||||||
|
weaviate_filter = (
|
||||||
|
prop_filter
|
||||||
|
if weaviate_filter is None
|
||||||
|
else weaviate.classes.query.Filter.all_of(
|
||||||
|
[weaviate_filter, prop_filter]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = collection.query.fetch_objects(
|
||||||
|
filters=weaviate_filter, limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = [str(obj.uuid) for obj in response.objects]
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for obj in response.objects:
|
||||||
|
properties = dict(obj.properties) if obj.properties else {}
|
||||||
|
documents.append(properties.pop("text", ""))
|
||||||
|
metadatas.append(_convert_uuids_to_strings(properties))
|
||||||
|
|
||||||
|
return GetResult(
|
||||||
|
**{
|
||||||
|
"ids": [ids],
|
||||||
|
"documents": [documents],
|
||||||
|
"metadatas": [metadatas],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
ids, documents, metadatas = [], [], []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in collection.iterator():
|
||||||
|
ids.append(str(item.uuid))
|
||||||
|
properties = dict(item.properties) if item.properties else {}
|
||||||
|
documents.append(properties.pop("text", ""))
|
||||||
|
metadatas.append(_convert_uuids_to_strings(properties))
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return GetResult(
|
||||||
|
**{
|
||||||
|
"ids": [ids],
|
||||||
|
"documents": [documents],
|
||||||
|
"metadatas": [metadatas],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
filter: Optional[Dict] = None,
|
||||||
|
) -> None:
|
||||||
|
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||||
|
if not self.client.collections.exists(sane_collection_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
collection = self.client.collections.get(sane_collection_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ids:
|
||||||
|
for item_id in ids:
|
||||||
|
collection.data.delete_by_id(uuid=item_id)
|
||||||
|
elif filter:
|
||||||
|
weaviate_filter = None
|
||||||
|
for key, value in filter.items():
|
||||||
|
prop_filter = weaviate.classes.query.Filter.by_property(
|
||||||
|
name=key
|
||||||
|
).equal(value)
|
||||||
|
weaviate_filter = (
|
||||||
|
prop_filter
|
||||||
|
if weaviate_filter is None
|
||||||
|
else weaviate.classes.query.Filter.all_of(
|
||||||
|
[weaviate_filter, prop_filter]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if weaviate_filter:
|
||||||
|
collection.data.delete_many(where=weaviate_filter)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
try:
|
||||||
|
for collection_name in self.client.collections.list_all().keys():
|
||||||
|
self.client.collections.delete(collection_name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
@ -67,6 +67,10 @@ class Vector:
|
||||||
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient
|
||||||
|
|
||||||
return Oracle23aiClient()
|
return Oracle23aiClient()
|
||||||
|
case VectorType.WEAVIATE:
|
||||||
|
from open_webui.retrieval.vector.dbs.weaviate import WeaviateClient
|
||||||
|
|
||||||
|
return WeaviateClient()
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,3 +11,4 @@ class VectorType(StrEnum):
|
||||||
PGVECTOR = "pgvector"
|
PGVECTOR = "pgvector"
|
||||||
ORACLE23AI = "oracle23ai"
|
ORACLE23AI = "oracle23ai"
|
||||||
S3VECTOR = "s3vector"
|
S3VECTOR = "s3vector"
|
||||||
|
WEAVIATE = "weaviate"
|
||||||
|
|
|
||||||
128
backend/open_webui/retrieval/web/azure.py
Normal file
128
backend/open_webui/retrieval/web/azure.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
"""
|
||||||
|
Azure AI Search integration for Open WebUI.
|
||||||
|
Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python
|
||||||
|
|
||||||
|
Required package: azure-search-documents
|
||||||
|
Install: pip install azure-search-documents
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def search_azure(
|
||||||
|
api_key: str,
|
||||||
|
endpoint: str,
|
||||||
|
index_name: str,
|
||||||
|
query: str,
|
||||||
|
count: int,
|
||||||
|
filter_list: Optional[list[str]] = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search using Azure AI Search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Azure Search API key (query key or admin key)
|
||||||
|
endpoint: Azure Search service endpoint (e.g., https://myservice.search.windows.net)
|
||||||
|
index_name: Name of the search index to query
|
||||||
|
query: Search query string
|
||||||
|
count: Number of results to return
|
||||||
|
filter_list: Optional list of domains to filter results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with link, title, and snippet
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from azure.search.documents import SearchClient
|
||||||
|
except ImportError:
|
||||||
|
log.error(
|
||||||
|
"azure-search-documents package is not installed. "
|
||||||
|
"Install it with: pip install azure-search-documents"
|
||||||
|
)
|
||||||
|
raise ImportError(
|
||||||
|
"azure-search-documents is required for Azure AI Search. "
|
||||||
|
"Install it with: pip install azure-search-documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create search client with API key authentication
|
||||||
|
credential = AzureKeyCredential(api_key)
|
||||||
|
search_client = SearchClient(
|
||||||
|
endpoint=endpoint, index_name=index_name, credential=credential
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform the search
|
||||||
|
results = search_client.search(search_text=query, top=count)
|
||||||
|
|
||||||
|
# Convert results to list and extract fields
|
||||||
|
search_results = []
|
||||||
|
for result in results:
|
||||||
|
# Azure AI Search returns documents with custom schemas
|
||||||
|
# We need to extract common fields that might represent URL, title, and content
|
||||||
|
# Common field names to look for:
|
||||||
|
result_dict = dict(result)
|
||||||
|
|
||||||
|
# Try to find URL field (common names)
|
||||||
|
link = (
|
||||||
|
result_dict.get("url")
|
||||||
|
or result_dict.get("link")
|
||||||
|
or result_dict.get("uri")
|
||||||
|
or result_dict.get("metadata_storage_path")
|
||||||
|
or ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to find title field (common names)
|
||||||
|
title = (
|
||||||
|
result_dict.get("title")
|
||||||
|
or result_dict.get("name")
|
||||||
|
or result_dict.get("metadata_title")
|
||||||
|
or result_dict.get("metadata_storage_name")
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to find content/snippet field (common names)
|
||||||
|
snippet = (
|
||||||
|
result_dict.get("content")
|
||||||
|
or result_dict.get("snippet")
|
||||||
|
or result_dict.get("description")
|
||||||
|
or result_dict.get("summary")
|
||||||
|
or result_dict.get("text")
|
||||||
|
or None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate snippet if too long
|
||||||
|
if snippet and len(snippet) > 500:
|
||||||
|
snippet = snippet[:497] + "..."
|
||||||
|
|
||||||
|
if link: # Only add if we found a valid link
|
||||||
|
search_results.append(
|
||||||
|
{
|
||||||
|
"link": link,
|
||||||
|
"title": title,
|
||||||
|
"snippet": snippet,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply domain filtering if specified
|
||||||
|
if filter_list:
|
||||||
|
search_results = get_filtered_results(search_results, filter_list)
|
||||||
|
|
||||||
|
# Convert to SearchResult objects
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
link=result["link"],
|
||||||
|
title=result.get("title"),
|
||||||
|
snippet=result.get("snippet"),
|
||||||
|
)
|
||||||
|
for result in search_results
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
log.error(f"Azure AI Search error: {ex}")
|
||||||
|
raise ex
|
||||||
|
|
@ -2,27 +2,42 @@ import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def search_external(
|
def search_external(
|
||||||
|
request: Request,
|
||||||
external_url: str,
|
external_url: str,
|
||||||
external_api_key: str,
|
external_api_key: str,
|
||||||
query: str,
|
query: str,
|
||||||
count: int,
|
count: int,
|
||||||
filter_list: Optional[List[str]] = None,
|
filter_list: Optional[List[str]] = None,
|
||||||
|
user=None,
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
headers = {
|
||||||
external_url,
|
|
||||||
headers={
|
|
||||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||||
"Authorization": f"Bearer {external_api_key}",
|
"Authorization": f"Bearer {external_api_key}",
|
||||||
},
|
}
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
|
chat_id = getattr(request.state, "chat_id", None)
|
||||||
|
if chat_id:
|
||||||
|
headers["X-OpenWebUI-Chat-Id"] = str(chat_id)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
external_url,
|
||||||
|
headers=headers,
|
||||||
json={
|
json={
|
||||||
"query": query,
|
"query": query,
|
||||||
"count": count,
|
"count": count,
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
@ -18,27 +17,20 @@ def search_firecrawl(
|
||||||
filter_list: Optional[List[str]] = None,
|
filter_list: Optional[List[str]] = None,
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
try:
|
try:
|
||||||
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
|
from firecrawl import FirecrawlApp
|
||||||
response = requests.post(
|
|
||||||
firecrawl_search_url,
|
firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url)
|
||||||
headers={
|
response = firecrawl.search(
|
||||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3
|
||||||
"Authorization": f"Bearer {firecrawl_api_key}",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"query": query,
|
|
||||||
"limit": count,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
results = response.web
|
||||||
results = response.json().get("data", [])
|
|
||||||
if filter_list:
|
if filter_list:
|
||||||
results = get_filtered_results(results, filter_list)
|
results = get_filtered_results(results, filter_list)
|
||||||
results = [
|
results = [
|
||||||
SearchResult(
|
SearchResult(
|
||||||
link=result.get("url"),
|
link=result.url,
|
||||||
title=result.get("title"),
|
title=result.title,
|
||||||
snippet=result.get("description"),
|
snippet=result.description,
|
||||||
)
|
)
|
||||||
for result in results[:count]
|
for result in results[:count]
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ def search_google_pse(
|
||||||
query: str,
|
query: str,
|
||||||
count: int,
|
count: int,
|
||||||
filter_list: Optional[list[str]] = None,
|
filter_list: Optional[list[str]] = None,
|
||||||
|
referer: Optional[str] = None,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||||
Handles pagination for counts greater than 10.
|
Handles pagination for counts greater than 10.
|
||||||
|
|
@ -30,7 +31,11 @@ def search_google_pse(
|
||||||
list[SearchResult]: A list of SearchResult objects.
|
list[SearchResult]: A list of SearchResult objects.
|
||||||
"""
|
"""
|
||||||
url = "https://www.googleapis.com/customsearch/v1"
|
url = "https://www.googleapis.com/customsearch/v1"
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if referer:
|
||||||
|
headers["Referer"] = referer
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
start_index = 1 # Google PSE start parameter is 1-based
|
start_index = 1 # Google PSE start parameter is 1-based
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,38 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from open_webui.retrieval.web.utils import resolve_hostname
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_results(results, filter_list):
|
def get_filtered_results(results, filter_list):
|
||||||
if not filter_list:
|
if not filter_list:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
||||||
if not validators.url(url):
|
if not validators.url(url):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
domain = urlparse(url).netloc
|
domain = urlparse(url).netloc
|
||||||
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
if not domain:
|
||||||
|
continue
|
||||||
|
|
||||||
|
hostnames = [domain]
|
||||||
|
|
||||||
|
try:
|
||||||
|
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
|
||||||
|
hostnames.extend(ipv4_addresses)
|
||||||
|
hostnames.extend(ipv6_addresses)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if is_string_allowed(hostnames, filter_list):
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
|
continue
|
||||||
|
|
||||||
return filtered_results
|
return filtered_results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import Optional, Literal
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -15,6 +16,8 @@ def search_perplexity_search(
|
||||||
query: str,
|
query: str,
|
||||||
count: int,
|
count: int,
|
||||||
filter_list: Optional[list[str]] = None,
|
filter_list: Optional[list[str]] = None,
|
||||||
|
api_url: str = "https://api.perplexity.ai/search",
|
||||||
|
user=None,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||||
|
|
||||||
|
|
@ -23,6 +26,8 @@ def search_perplexity_search(
|
||||||
query (str): The query to search for
|
query (str): The query to search for
|
||||||
count (int): Maximum number of results to return
|
count (int): Maximum number of results to return
|
||||||
filter_list (Optional[list[str]]): List of domains to filter results
|
filter_list (Optional[list[str]]): List of domains to filter results
|
||||||
|
api_url (str): Custom API URL (defaults to https://api.perplexity.ai/search)
|
||||||
|
user: Optional user object for forwarding user info headers
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -30,8 +35,11 @@ def search_perplexity_search(
|
||||||
if hasattr(api_key, "__str__"):
|
if hasattr(api_key, "__str__"):
|
||||||
api_key = str(api_key)
|
api_key = str(api_key)
|
||||||
|
|
||||||
|
if hasattr(api_url, "__str__"):
|
||||||
|
api_url = str(api_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = "https://api.perplexity.ai/search"
|
url = api_url
|
||||||
|
|
||||||
# Create payload for the API call
|
# Create payload for the API call
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -44,6 +52,10 @@ def search_perplexity_search(
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Forward user info headers if user is provided
|
||||||
|
if user is not None:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
# Make the API request
|
# Make the API request
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
response = requests.request("POST", url, json=payload, headers=headers)
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import socket
|
||||||
import ssl
|
import ssl
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime, time, timedelta
|
from datetime import datetime, time, timedelta
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -17,13 +16,15 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
Literal,
|
Literal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
import validators
|
import validators
|
||||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||||
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
@ -38,17 +39,46 @@ from open_webui.config import (
|
||||||
TAVILY_EXTRACT_DEPTH,
|
TAVILY_EXTRACT_DEPTH,
|
||||||
EXTERNAL_WEB_LOADER_URL,
|
EXTERNAL_WEB_LOADER_URL,
|
||||||
EXTERNAL_WEB_LOADER_API_KEY,
|
EXTERNAL_WEB_LOADER_API_KEY,
|
||||||
|
WEB_FETCH_FILTER_LIST,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname):
|
||||||
|
# Get address information
|
||||||
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
|
||||||
|
# Extract IP addresses from address information
|
||||||
|
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||||
|
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||||
|
|
||||||
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
def validate_url(url: Union[str, Sequence[str]]):
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
|
||||||
|
# Protocol validation - only allow http/https
|
||||||
|
if parsed_url.scheme not in ["http", "https"]:
|
||||||
|
log.warning(
|
||||||
|
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
|
||||||
|
)
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
|
# Blocklist check using unified filtering logic
|
||||||
|
if WEB_FETCH_FILTER_LIST:
|
||||||
|
if not is_string_allowed(url, WEB_FETCH_FILTER_LIST):
|
||||||
|
log.warning(f"URL blocked by filter list: {url}")
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
|
@ -75,22 +105,12 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||||
try:
|
try:
|
||||||
if validate_url(u):
|
if validate_url(u):
|
||||||
valid_urls.append(u)
|
valid_urls.append(u)
|
||||||
except ValueError:
|
except Exception as e:
|
||||||
|
log.debug(f"Invalid URL {u}: {str(e)}")
|
||||||
continue
|
continue
|
||||||
return valid_urls
|
return valid_urls
|
||||||
|
|
||||||
|
|
||||||
def resolve_hostname(hostname):
|
|
||||||
# Get address information
|
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
|
||||||
|
|
||||||
# Extract IP addresses from address information
|
|
||||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
|
||||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
|
||||||
|
|
||||||
return ipv4_addresses, ipv6_addresses
|
|
||||||
|
|
||||||
|
|
||||||
def extract_metadata(soup, url):
|
def extract_metadata(soup, url):
|
||||||
metadata = {"source": url}
|
metadata = {"source": url}
|
||||||
if title := soup.find("title"):
|
if title := soup.find("title"):
|
||||||
|
|
@ -141,13 +161,13 @@ class RateLimitMixin:
|
||||||
|
|
||||||
|
|
||||||
class URLProcessingMixin:
|
class URLProcessingMixin:
|
||||||
def _verify_ssl_cert(self, url: str) -> bool:
|
async def _verify_ssl_cert(self, url: str) -> bool:
|
||||||
"""Verify SSL certificate for a URL."""
|
"""Verify SSL certificate for a URL."""
|
||||||
return verify_ssl_cert(url)
|
return await run_in_threadpool(verify_ssl_cert, url)
|
||||||
|
|
||||||
async def _safe_process_url(self, url: str) -> bool:
|
async def _safe_process_url(self, url: str) -> bool:
|
||||||
"""Perform safety checks before processing a URL."""
|
"""Perform safety checks before processing a URL."""
|
||||||
if self.verify_ssl and not self._verify_ssl_cert(url):
|
if self.verify_ssl and not await self._verify_ssl_cert(url):
|
||||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||||
await self._wait_for_rate_limit()
|
await self._wait_for_rate_limit()
|
||||||
return True
|
return True
|
||||||
|
|
@ -188,13 +208,12 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||||
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
|
||||||
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
|
||||||
mode: Operation mode selection:
|
mode: Operation mode selection:
|
||||||
- 'crawl': Website crawling mode (default)
|
- 'crawl': Website crawling mode
|
||||||
- 'scrape': Direct page scraping
|
- 'scrape': Direct page scraping (default)
|
||||||
- 'map': Site map generation
|
- 'map': Site map generation
|
||||||
proxy: Proxy override settings for the FireCrawl API.
|
proxy: Proxy override settings for the FireCrawl API.
|
||||||
params: The parameters to pass to the Firecrawl API.
|
params: The parameters to pass to the Firecrawl API.
|
||||||
Examples include crawlerOptions.
|
For more details, visit: https://docs.firecrawl.dev/sdks/python#batch-scrape
|
||||||
For more details, visit: https://github.com/mendableai/firecrawl-py
|
|
||||||
"""
|
"""
|
||||||
proxy_server = proxy.get("server") if proxy else None
|
proxy_server = proxy.get("server") if proxy else None
|
||||||
if trust_env and not proxy_server:
|
if trust_env and not proxy_server:
|
||||||
|
|
@ -214,50 +233,88 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.params = params
|
self.params = params or {}
|
||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
"""Load documents concurrently using FireCrawl."""
|
"""Load documents using FireCrawl batch_scrape."""
|
||||||
for url in self.web_paths:
|
log.debug(
|
||||||
try:
|
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
|
||||||
self._safe_process_url_sync(url)
|
len(self.web_paths),
|
||||||
loader = FireCrawlLoader(
|
self.mode,
|
||||||
url=url,
|
self.params,
|
||||||
api_key=self.api_key,
|
|
||||||
api_url=self.api_url,
|
|
||||||
mode=self.mode,
|
|
||||||
params=self.params,
|
|
||||||
)
|
)
|
||||||
for document in loader.lazy_load():
|
try:
|
||||||
if not document.metadata.get("source"):
|
from firecrawl import FirecrawlApp
|
||||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
|
||||||
yield document
|
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
|
||||||
|
result = firecrawl.batch_scrape(
|
||||||
|
self.web_paths,
|
||||||
|
formats=["markdown"],
|
||||||
|
skip_tls_verification=not self.verify_ssl,
|
||||||
|
ignore_invalid_urls=True,
|
||||||
|
remove_base64_images=True,
|
||||||
|
max_age=300000, # 5 minutes https://docs.firecrawl.dev/features/fast-scraping#common-maxage-values
|
||||||
|
wait_timeout=len(self.web_paths) * 3,
|
||||||
|
**self.params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.status != "completed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FireCrawl batch scrape did not complete successfully. result: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for data in result.data:
|
||||||
|
metadata = data.metadata or {}
|
||||||
|
yield Document(
|
||||||
|
page_content=data.markdown or "",
|
||||||
|
metadata={"source": metadata.url or metadata.source_url or ""},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.continue_on_failure:
|
if self.continue_on_failure:
|
||||||
log.exception(f"Error loading {url}: {e}")
|
log.exception(f"Error extracting content from URLs: {e}")
|
||||||
continue
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def alazy_load(self):
|
async def alazy_load(self):
|
||||||
"""Async version of lazy_load."""
|
"""Async version of lazy_load."""
|
||||||
for url in self.web_paths:
|
log.debug(
|
||||||
try:
|
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
|
||||||
await self._safe_process_url(url)
|
len(self.web_paths),
|
||||||
loader = FireCrawlLoader(
|
self.mode,
|
||||||
url=url,
|
self.params,
|
||||||
api_key=self.api_key,
|
|
||||||
api_url=self.api_url,
|
|
||||||
mode=self.mode,
|
|
||||||
params=self.params,
|
|
||||||
)
|
)
|
||||||
async for document in loader.alazy_load():
|
try:
|
||||||
if not document.metadata.get("source"):
|
from firecrawl import FirecrawlApp
|
||||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
|
||||||
yield document
|
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
|
||||||
|
result = firecrawl.batch_scrape(
|
||||||
|
self.web_paths,
|
||||||
|
formats=["markdown"],
|
||||||
|
skip_tls_verification=not self.verify_ssl,
|
||||||
|
ignore_invalid_urls=True,
|
||||||
|
remove_base64_images=True,
|
||||||
|
max_age=300000, # 5 minutes https://docs.firecrawl.dev/features/fast-scraping#common-maxage-values
|
||||||
|
wait_timeout=len(self.web_paths) * 3,
|
||||||
|
**self.params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.status != "completed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FireCrawl batch scrape did not complete successfully. result: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for data in result.data:
|
||||||
|
metadata = data.metadata or {}
|
||||||
|
yield Document(
|
||||||
|
page_content=data.markdown or "",
|
||||||
|
metadata={"source": metadata.url or metadata.source_url or ""},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.continue_on_failure:
|
if self.continue_on_failure:
|
||||||
log.exception(f"Error loading {url}: {e}")
|
log.exception(f"Error extracting content from URLs: {e}")
|
||||||
continue
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -603,6 +660,10 @@ def get_web_loader(
|
||||||
# Check if the URLs are valid
|
# Check if the URLs are valid
|
||||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||||
|
|
||||||
|
if not safe_urls:
|
||||||
|
log.warning(f"All provided URLs were blocked or invalid: {urls}")
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
web_loader_args = {
|
web_loader_args = {
|
||||||
"web_paths": safe_urls,
|
"web_paths": safe_urls,
|
||||||
"verify_ssl": verify_ssl,
|
"verify_ssl": verify_ssl,
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import html
|
||||||
|
import base64
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from pydub.silence import split_on_silence
|
from pydub.silence import split_on_silence
|
||||||
|
|
@ -14,7 +16,6 @@ import aiohttp
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import requests
|
import requests
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from urllib.parse import urljoin, quote
|
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
|
|
@ -33,18 +34,20 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
WHISPER_MODEL_AUTO_UPDATE,
|
WHISPER_MODEL_AUTO_UPDATE,
|
||||||
WHISPER_MODEL_DIR,
|
WHISPER_MODEL_DIR,
|
||||||
CACHE_DIR,
|
CACHE_DIR,
|
||||||
WHISPER_LANGUAGE,
|
WHISPER_LANGUAGE,
|
||||||
|
ELEVENLABS_API_BASE_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
ENV,
|
||||||
AIOHTTP_CLIENT_SESSION_SSL,
|
AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
AIOHTTP_CLIENT_TIMEOUT,
|
AIOHTTP_CLIENT_TIMEOUT,
|
||||||
ENV,
|
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
DEVICE_TYPE,
|
DEVICE_TYPE,
|
||||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||||
|
|
@ -153,6 +156,7 @@ def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||||
class TTSConfigForm(BaseModel):
|
class TTSConfigForm(BaseModel):
|
||||||
OPENAI_API_BASE_URL: str
|
OPENAI_API_BASE_URL: str
|
||||||
OPENAI_API_KEY: str
|
OPENAI_API_KEY: str
|
||||||
|
OPENAI_PARAMS: Optional[dict] = None
|
||||||
API_KEY: str
|
API_KEY: str
|
||||||
ENGINE: str
|
ENGINE: str
|
||||||
MODEL: str
|
MODEL: str
|
||||||
|
|
@ -176,6 +180,9 @@ class STTConfigForm(BaseModel):
|
||||||
AZURE_LOCALES: str
|
AZURE_LOCALES: str
|
||||||
AZURE_BASE_URL: str
|
AZURE_BASE_URL: str
|
||||||
AZURE_MAX_SPEAKERS: str
|
AZURE_MAX_SPEAKERS: str
|
||||||
|
MISTRAL_API_KEY: str
|
||||||
|
MISTRAL_API_BASE_URL: str
|
||||||
|
MISTRAL_USE_CHAT_COMPLETIONS: bool
|
||||||
|
|
||||||
|
|
||||||
class AudioConfigUpdateForm(BaseModel):
|
class AudioConfigUpdateForm(BaseModel):
|
||||||
|
|
@ -189,6 +196,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"tts": {
|
"tts": {
|
||||||
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||||
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
||||||
|
"OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS,
|
||||||
"API_KEY": request.app.state.config.TTS_API_KEY,
|
"API_KEY": request.app.state.config.TTS_API_KEY,
|
||||||
"ENGINE": request.app.state.config.TTS_ENGINE,
|
"ENGINE": request.app.state.config.TTS_ENGINE,
|
||||||
"MODEL": request.app.state.config.TTS_MODEL,
|
"MODEL": request.app.state.config.TTS_MODEL,
|
||||||
|
|
@ -211,6 +219,9 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||||
|
"MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
||||||
|
"MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||||
|
"MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -221,6 +232,7 @@ async def update_audio_config(
|
||||||
):
|
):
|
||||||
request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
||||||
request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
||||||
|
request.app.state.config.TTS_OPENAI_PARAMS = form_data.tts.OPENAI_PARAMS
|
||||||
request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
||||||
request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
||||||
request.app.state.config.TTS_MODEL = form_data.tts.MODEL
|
request.app.state.config.TTS_MODEL = form_data.tts.MODEL
|
||||||
|
|
@ -251,6 +263,13 @@ async def update_audio_config(
|
||||||
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
|
||||||
form_data.stt.AZURE_MAX_SPEAKERS
|
form_data.stt.AZURE_MAX_SPEAKERS
|
||||||
)
|
)
|
||||||
|
request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY
|
||||||
|
request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = (
|
||||||
|
form_data.stt.MISTRAL_API_BASE_URL
|
||||||
|
)
|
||||||
|
request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = (
|
||||||
|
form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS
|
||||||
|
)
|
||||||
|
|
||||||
if request.app.state.config.STT_ENGINE == "":
|
if request.app.state.config.STT_ENGINE == "":
|
||||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||||
|
|
@ -261,12 +280,13 @@ async def update_audio_config(
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tts": {
|
"tts": {
|
||||||
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
|
||||||
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
|
||||||
"API_KEY": request.app.state.config.TTS_API_KEY,
|
|
||||||
"ENGINE": request.app.state.config.TTS_ENGINE,
|
"ENGINE": request.app.state.config.TTS_ENGINE,
|
||||||
"MODEL": request.app.state.config.TTS_MODEL,
|
"MODEL": request.app.state.config.TTS_MODEL,
|
||||||
"VOICE": request.app.state.config.TTS_VOICE,
|
"VOICE": request.app.state.config.TTS_VOICE,
|
||||||
|
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||||
|
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
||||||
|
"OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS,
|
||||||
|
"API_KEY": request.app.state.config.TTS_API_KEY,
|
||||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||||
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
|
||||||
|
|
@ -285,6 +305,9 @@ async def update_audio_config(
|
||||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||||
|
"MISTRAL_API_KEY": request.app.state.config.AUDIO_STT_MISTRAL_API_KEY,
|
||||||
|
"MISTRAL_API_BASE_URL": request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL,
|
||||||
|
"MISTRAL_USE_CHAT_COMPLETIONS": request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,23 +359,22 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=timeout, trust_env=True
|
timeout=timeout, trust_env=True
|
||||||
) as session:
|
) as session:
|
||||||
|
payload = {
|
||||||
|
**payload,
|
||||||
|
**(request.app.state.config.TTS_OPENAI_PARAMS or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = await session.post(
|
r = await session.post(
|
||||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||||
json=payload,
|
json=payload,
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -403,7 +425,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
timeout=timeout, trust_env=True
|
timeout=timeout, trust_env=True
|
||||||
) as session:
|
) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
f"{ELEVENLABS_API_BASE_URL}/v1/text-to-speech/{voice_id}",
|
||||||
json={
|
json={
|
||||||
"text": payload["input"],
|
"text": payload["input"],
|
||||||
"model_id": request.app.state.config.TTS_MODEL,
|
"model_id": request.app.state.config.TTS_MODEL,
|
||||||
|
|
@ -458,7 +480,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||||
<voice name="{language}">{payload["input"]}</voice>
|
<voice name="{language}">{html.escape(payload["input"])}</voice>
|
||||||
</speak>"""
|
</speak>"""
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
|
|
@ -542,7 +564,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
return FileResponse(file_path)
|
return FileResponse(file_path)
|
||||||
|
|
||||||
|
|
||||||
def transcription_handler(request, file_path, metadata):
|
def transcription_handler(request, file_path, metadata, user=None):
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
file_dir = os.path.dirname(file_path)
|
file_dir = os.path.dirname(file_path)
|
||||||
id = filename.split(".")[0]
|
id = filename.split(".")[0]
|
||||||
|
|
@ -593,11 +615,15 @@ def transcription_handler(request, file_path, metadata):
|
||||||
if language:
|
if language:
|
||||||
payload["language"] = language
|
payload["language"] = language
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||||
|
}
|
||||||
|
if user and ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||||
headers={
|
headers=headers,
|
||||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
|
||||||
},
|
|
||||||
files={"file": (filename, open(file_path, "rb"))},
|
files={"file": (filename, open(file_path, "rb"))},
|
||||||
data=payload,
|
data=payload,
|
||||||
)
|
)
|
||||||
|
|
@ -818,8 +844,190 @@ def transcription_handler(request, file_path, metadata):
|
||||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif request.app.state.config.STT_ENGINE == "mistral":
|
||||||
|
# Check file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise HTTPException(status_code=400, detail="Audio file not found")
|
||||||
|
|
||||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
# Check file size
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB}MB",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY
|
||||||
|
api_base_url = (
|
||||||
|
request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL
|
||||||
|
or "https://api.mistral.ai/v1"
|
||||||
|
)
|
||||||
|
use_chat_completions = (
|
||||||
|
request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Mistral API key is required for Mistral STT",
|
||||||
|
)
|
||||||
|
|
||||||
|
r = None
|
||||||
|
try:
|
||||||
|
# Use voxtral-mini-latest as the default model for transcription
|
||||||
|
model = request.app.state.config.STT_MODEL or "voxtral-mini-latest"
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Mistral STT - model: {model}, "
|
||||||
|
f"method: {'chat_completions' if use_chat_completions else 'transcriptions'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_chat_completions:
|
||||||
|
# Use chat completions API with audio input
|
||||||
|
# This method requires mp3 or wav format
|
||||||
|
audio_file_to_use = file_path
|
||||||
|
|
||||||
|
if is_audio_conversion_required(file_path):
|
||||||
|
log.debug("Converting audio to mp3 for chat completions API")
|
||||||
|
converted_path = convert_audio_to_mp3(file_path)
|
||||||
|
if converted_path:
|
||||||
|
audio_file_to_use = converted_path
|
||||||
|
else:
|
||||||
|
log.error("Audio conversion failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Audio conversion failed. Chat completions API requires mp3 or wav format.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read and encode audio file as base64
|
||||||
|
with open(audio_file_to_use, "rb") as audio_file:
|
||||||
|
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
# Prepare chat completions request
|
||||||
|
url = f"{api_base_url}/chat/completions"
|
||||||
|
|
||||||
|
# Add language instruction if specified
|
||||||
|
language = metadata.get("language", None) if metadata else None
|
||||||
|
if language:
|
||||||
|
text_instruction = f"Transcribe this audio exactly as spoken in {language}. Do not translate it."
|
||||||
|
else:
|
||||||
|
text_instruction = "Transcribe this audio exactly as spoken in its original language. Do not translate it to another language."
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": audio_base64,
|
||||||
|
},
|
||||||
|
{"type": "text", "text": text_instruction},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
url=url,
|
||||||
|
json=payload,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
# Extract transcript from chat completion response
|
||||||
|
transcript = (
|
||||||
|
response.get("choices", [{}])[0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError("Empty transcript in response")
|
||||||
|
|
||||||
|
data = {"text": transcript}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Use dedicated transcriptions API
|
||||||
|
url = f"{api_base_url}/audio/transcriptions"
|
||||||
|
|
||||||
|
# Determine the MIME type
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "audio/webm"
|
||||||
|
|
||||||
|
# Use context manager to ensure file is properly closed
|
||||||
|
with open(file_path, "rb") as audio_file:
|
||||||
|
files = {"file": (filename, audio_file, mime_type)}
|
||||||
|
data_form = {"model": model}
|
||||||
|
|
||||||
|
# Add language if specified in metadata
|
||||||
|
language = metadata.get("language", None) if metadata else None
|
||||||
|
if language:
|
||||||
|
data_form["language"] = language
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
url=url,
|
||||||
|
files=files,
|
||||||
|
data=data_form,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
response = r.json()
|
||||||
|
|
||||||
|
# Extract transcript from response
|
||||||
|
transcript = response.get("text", "").strip()
|
||||||
|
if not transcript:
|
||||||
|
raise ValueError("Empty transcript in response")
|
||||||
|
|
||||||
|
data = {"text": transcript}
|
||||||
|
|
||||||
|
# Save transcript to json file (consistent with other providers)
|
||||||
|
transcript_file = f"{file_dir}/{id}.json"
|
||||||
|
with open(transcript_file, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
log.debug(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
log.exception("Error parsing Mistral response")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to parse Mistral response: {str(e)}",
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
log.exception(e)
|
||||||
|
detail = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if r is not None and r.status_code != 200:
|
||||||
|
res = r.json()
|
||||||
|
if "error" in res:
|
||||||
|
detail = f"External: {res['error'].get('message', '')}"
|
||||||
|
else:
|
||||||
|
detail = f"External: {r.text}"
|
||||||
|
except Exception:
|
||||||
|
detail = f"External: {e}"
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=getattr(r, "status_code", 500) if r else 500,
|
||||||
|
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
request: Request, file_path: str, metadata: Optional[dict] = None, user=None
|
||||||
|
):
|
||||||
log.info(f"transcribe: {file_path} {metadata}")
|
log.info(f"transcribe: {file_path} {metadata}")
|
||||||
|
|
||||||
if is_audio_conversion_required(file_path):
|
if is_audio_conversion_required(file_path):
|
||||||
|
|
@ -846,7 +1054,9 @@ def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
# Submit tasks for each chunk_path
|
# Submit tasks for each chunk_path
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
executor.submit(
|
||||||
|
transcription_handler, request, chunk_path, metadata, user
|
||||||
|
)
|
||||||
for chunk_path in chunk_paths
|
for chunk_path in chunk_paths
|
||||||
]
|
]
|
||||||
# Gather results as they complete
|
# Gather results as they complete
|
||||||
|
|
@ -981,7 +1191,7 @@ def transcription(
|
||||||
if language:
|
if language:
|
||||||
metadata = {"language": language}
|
metadata = {"language": language}
|
||||||
|
|
||||||
result = transcribe(request, file_path, metadata)
|
result = transcribe(request, file_path, metadata, user)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**result,
|
**result,
|
||||||
|
|
@ -1027,7 +1237,7 @@ def get_available_models(request: Request) -> list[dict]:
|
||||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
"https://api.elevenlabs.io/v1/models",
|
f"{ELEVENLABS_API_BASE_URL}/v1/models",
|
||||||
headers={
|
headers={
|
||||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
@ -1131,7 +1341,7 @@ def get_elevenlabs_voices(api_key: str) -> dict:
|
||||||
try:
|
try:
|
||||||
# TODO: Add retries
|
# TODO: Add retries
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
"https://api.elevenlabs.io/v1/voices",
|
f"{ELEVENLABS_API_BASE_URL}/v1/voices",
|
||||||
headers={
|
headers={
|
||||||
"xi-api-key": api_key,
|
"xi-api-key": api_key,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import time
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import (
|
from open_webui.models.auths import (
|
||||||
AddUserForm,
|
AddUserForm,
|
||||||
|
|
@ -15,9 +17,13 @@ from open_webui.models.auths import (
|
||||||
SigninResponse,
|
SigninResponse,
|
||||||
SignupForm,
|
SignupForm,
|
||||||
UpdatePasswordForm,
|
UpdatePasswordForm,
|
||||||
UserResponse,
|
|
||||||
)
|
)
|
||||||
from open_webui.models.users import Users, UpdateProfileForm
|
from open_webui.models.users import (
|
||||||
|
UserProfileImageResponse,
|
||||||
|
Users,
|
||||||
|
UpdateProfileForm,
|
||||||
|
UserStatus,
|
||||||
|
)
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
|
|
@ -35,12 +41,20 @@ from open_webui.env import (
|
||||||
)
|
)
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
||||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
from open_webui.config import (
|
||||||
|
OPENID_PROVIDER_URL,
|
||||||
|
ENABLE_OAUTH_SIGNUP,
|
||||||
|
ENABLE_LDAP,
|
||||||
|
ENABLE_PASSWORD_AUTH,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
|
validate_password,
|
||||||
|
verify_password,
|
||||||
decode_token,
|
decode_token,
|
||||||
|
invalidate_token,
|
||||||
create_api_key,
|
create_api_key,
|
||||||
create_token,
|
create_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
|
|
@ -50,7 +64,12 @@ from open_webui.utils.auth import (
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
)
|
)
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.access_control import get_permissions
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
|
from open_webui.utils.groups import apply_default_group_assignment
|
||||||
|
|
||||||
|
from open_webui.utils.redis import get_redis_client
|
||||||
|
from open_webui.utils.rate_limit import RateLimiter
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
|
@ -64,17 +83,21 @@ router = APIRouter()
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
signin_rate_limiter = RateLimiter(
|
||||||
|
redis_client=get_redis_client(), limit=5 * 3, window=60 * 3
|
||||||
|
)
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetSessionUser
|
# GetSessionUser
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
class SessionUserResponse(Token, UserResponse):
|
class SessionUserResponse(Token, UserProfileImageResponse):
|
||||||
expires_at: Optional[int] = None
|
expires_at: Optional[int] = None
|
||||||
permissions: Optional[dict] = None
|
permissions: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class SessionUserInfoResponse(SessionUserResponse):
|
class SessionUserInfoResponse(SessionUserResponse, UserStatus):
|
||||||
bio: Optional[str] = None
|
bio: Optional[str] = None
|
||||||
gender: Optional[str] = None
|
gender: Optional[str] = None
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
|
@ -131,6 +154,9 @@ async def get_session_user(
|
||||||
"bio": user.bio,
|
"bio": user.bio,
|
||||||
"gender": user.gender,
|
"gender": user.gender,
|
||||||
"date_of_birth": user.date_of_birth,
|
"date_of_birth": user.date_of_birth,
|
||||||
|
"status_emoji": user.status_emoji,
|
||||||
|
"status_message": user.status_message,
|
||||||
|
"status_expires_at": user.status_expires_at,
|
||||||
"permissions": user_permissions,
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -140,7 +166,7 @@ async def get_session_user(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update/profile", response_model=UserResponse)
|
@router.post("/update/profile", response_model=UserProfileImageResponse)
|
||||||
async def update_profile(
|
async def update_profile(
|
||||||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
@ -169,13 +195,19 @@ async def update_password(
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||||
if session_user:
|
if session_user:
|
||||||
user = Auths.authenticate_user(session_user.email, form_data.password)
|
user = Auths.authenticate_user(
|
||||||
|
session_user.email, lambda pw: verify_password(form_data.password, pw)
|
||||||
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
hashed = get_password_hash(form_data.new_password)
|
hashed = get_password_hash(form_data.new_password)
|
||||||
return Auths.update_user_password_by_id(user.id, hashed)
|
return Auths.update_user_password_by_id(user.id, hashed)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
|
@ -185,7 +217,17 @@ async def update_password(
|
||||||
############################
|
############################
|
||||||
@router.post("/ldap", response_model=SessionUserResponse)
|
@router.post("/ldap", response_model=SessionUserResponse)
|
||||||
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
|
# Security checks FIRST - before loading any config
|
||||||
|
if not request.app.state.config.ENABLE_LDAP:
|
||||||
|
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
||||||
|
|
||||||
|
if not ENABLE_PASSWORD_AUTH:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOW load LDAP config variables
|
||||||
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
|
||||||
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
|
LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST
|
||||||
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
|
LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT
|
||||||
|
|
@ -206,9 +248,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
else "ALL"
|
else "ALL"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ENABLE_LDAP:
|
|
||||||
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tls = Tls(
|
tls = Tls(
|
||||||
validate=LDAP_VALIDATE_CERT,
|
validate=LDAP_VALIDATE_CERT,
|
||||||
|
|
@ -379,6 +418,11 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
apply_default_group_assignment(
|
||||||
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
|
user.id,
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
@ -427,7 +471,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
):
|
):
|
||||||
if ENABLE_LDAP_GROUP_CREATION:
|
if ENABLE_LDAP_GROUP_CREATION:
|
||||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
Groups.create_groups_by_group_names(user.id, user_groups)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -463,6 +506,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
|
|
||||||
@router.post("/signin", response_model=SessionUserResponse)
|
@router.post("/signin", response_model=SessionUserResponse)
|
||||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
if not ENABLE_PASSWORD_AUTH:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||||
|
|
@ -472,6 +521,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||||
|
try:
|
||||||
|
name = urllib.parse.unquote(name, encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
if not Users.get_user_by_email(email.lower()):
|
if not Users.get_user_by_email(email.lower()):
|
||||||
await signup(
|
await signup(
|
||||||
|
|
@ -495,7 +548,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
admin_password = "admin"
|
admin_password = "admin"
|
||||||
|
|
||||||
if Users.get_user_by_email(admin_email.lower()):
|
if Users.get_user_by_email(admin_email.lower()):
|
||||||
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
user = Auths.authenticate_user(
|
||||||
|
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if Users.has_users():
|
if Users.has_users():
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||||
|
|
@ -506,9 +561,28 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
SignupForm(email=admin_email, password=admin_password, name="User"),
|
SignupForm(email=admin_email, password=admin_password, name="User"),
|
||||||
)
|
)
|
||||||
|
|
||||||
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
user = Auths.authenticate_user(
|
||||||
|
admin_email.lower(), lambda pw: verify_password(admin_password, pw)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
if signin_rate_limiter.is_limited(form_data.email.lower()):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
|
||||||
|
)
|
||||||
|
|
||||||
|
password_bytes = form_data.password.encode("utf-8")
|
||||||
|
if len(password_bytes) > 72:
|
||||||
|
# TODO: Implement other hashing algorithms that support longer passwords
|
||||||
|
log.info("Password too long, truncating to 72 bytes for bcrypt")
|
||||||
|
password_bytes = password_bytes[:72]
|
||||||
|
|
||||||
|
# decode safely — ignore incomplete UTF-8 sequences
|
||||||
|
form_data.password = password_bytes.decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
user = Auths.authenticate_user(
|
||||||
|
form_data.email.lower(), lambda pw: verify_password(form_data.password, pw)
|
||||||
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
||||||
|
|
@ -590,16 +664,14 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
except Exception as e:
|
||||||
if len(form_data.password.encode("utf-8")) > 72:
|
raise HTTPException(400, detail=str(e))
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
|
||||||
)
|
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
|
|
||||||
|
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
hashed,
|
hashed,
|
||||||
|
|
@ -655,6 +727,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
# Disable signup after the first user is created
|
# Disable signup after the first user is created
|
||||||
request.app.state.config.ENABLE_SIGNUP = False
|
request.app.state.config.ENABLE_SIGNUP = False
|
||||||
|
|
||||||
|
apply_default_group_assignment(
|
||||||
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
|
user.id,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
|
|
@ -675,6 +752,19 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
|
||||||
@router.get("/signout")
|
@router.get("/signout")
|
||||||
async def signout(request: Request, response: Response):
|
async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
|
# get auth token from headers or cookies
|
||||||
|
token = None
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header:
|
||||||
|
auth_cred = get_http_authorization_cred(auth_header)
|
||||||
|
token = auth_cred.credentials
|
||||||
|
else:
|
||||||
|
token = request.cookies.get("token")
|
||||||
|
|
||||||
|
if token:
|
||||||
|
await invalidate_token(request, token)
|
||||||
|
|
||||||
response.delete_cookie("token")
|
response.delete_cookie("token")
|
||||||
response.delete_cookie("oui-session")
|
response.delete_cookie("oui-session")
|
||||||
response.delete_cookie("oauth_id_token")
|
response.delete_cookie("oauth_id_token")
|
||||||
|
|
@ -745,7 +835,9 @@ async def signout(request: Request, response: Response):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/add", response_model=SigninResponse)
|
@router.post("/add", response_model=SigninResponse)
|
||||||
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
async def add_user(
|
||||||
|
request: Request, form_data: AddUserForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
|
|
@ -755,6 +847,11 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
|
|
@ -765,6 +862,11 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
apply_default_group_assignment(
|
||||||
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
|
user.id,
|
||||||
|
)
|
||||||
|
|
||||||
token = create_token(data={"id": user.id})
|
token = create_token(data={"id": user.id})
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
|
@ -826,13 +928,15 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
"ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
|
||||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
||||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
"API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
|
||||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||||
|
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
|
"ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS,
|
||||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||||
|
|
@ -846,13 +950,15 @@ class AdminConfig(BaseModel):
|
||||||
SHOW_ADMIN_DETAILS: bool
|
SHOW_ADMIN_DETAILS: bool
|
||||||
WEBUI_URL: str
|
WEBUI_URL: str
|
||||||
ENABLE_SIGNUP: bool
|
ENABLE_SIGNUP: bool
|
||||||
ENABLE_API_KEY: bool
|
ENABLE_API_KEYS: bool
|
||||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: bool
|
||||||
API_KEY_ALLOWED_ENDPOINTS: str
|
API_KEYS_ALLOWED_ENDPOINTS: str
|
||||||
DEFAULT_USER_ROLE: str
|
DEFAULT_USER_ROLE: str
|
||||||
|
DEFAULT_GROUP_ID: str
|
||||||
JWT_EXPIRES_IN: str
|
JWT_EXPIRES_IN: str
|
||||||
ENABLE_COMMUNITY_SHARING: bool
|
ENABLE_COMMUNITY_SHARING: bool
|
||||||
ENABLE_MESSAGE_RATING: bool
|
ENABLE_MESSAGE_RATING: bool
|
||||||
|
ENABLE_FOLDERS: bool
|
||||||
ENABLE_CHANNELS: bool
|
ENABLE_CHANNELS: bool
|
||||||
ENABLE_NOTES: bool
|
ENABLE_NOTES: bool
|
||||||
ENABLE_USER_WEBHOOKS: bool
|
ENABLE_USER_WEBHOOKS: bool
|
||||||
|
|
@ -869,20 +975,23 @@ async def update_admin_config(
|
||||||
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
||||||
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
||||||
|
|
||||||
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
|
request.app.state.config.ENABLE_API_KEYS = form_data.ENABLE_API_KEYS
|
||||||
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS = (
|
||||||
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
form_data.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS
|
||||||
)
|
)
|
||||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = (
|
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS = (
|
||||||
form_data.API_KEY_ALLOWED_ENDPOINTS
|
form_data.API_KEYS_ALLOWED_ENDPOINTS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.ENABLE_FOLDERS = form_data.ENABLE_FOLDERS
|
||||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||||
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
||||||
|
|
||||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||||
|
|
||||||
|
request.app.state.config.DEFAULT_GROUP_ID = form_data.DEFAULT_GROUP_ID
|
||||||
|
|
||||||
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
||||||
|
|
||||||
# Check if the input string matches the pattern
|
# Check if the input string matches the pattern
|
||||||
|
|
@ -909,13 +1018,15 @@ async def update_admin_config(
|
||||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
"ENABLE_API_KEYS": request.app.state.config.ENABLE_API_KEYS,
|
||||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
"ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS,
|
||||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
"API_KEYS_ALLOWED_ENDPOINTS": request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS,
|
||||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||||
|
"DEFAULT_GROUP_ID": request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
|
"ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS,
|
||||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||||
|
|
@ -1036,9 +1147,11 @@ async def update_ldap_config(
|
||||||
# create api key
|
# create api key
|
||||||
@router.post("/api_key", response_model=ApiKey)
|
@router.post("/api_key", response_model=ApiKey)
|
||||||
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
if not request.app.state.config.ENABLE_API_KEY:
|
if not request.app.state.config.ENABLE_API_KEYS or not has_permission(
|
||||||
|
user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
|
detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1056,8 +1169,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
||||||
# delete api key
|
# delete api key
|
||||||
@router.delete("/api_key", response_model=bool)
|
@router.delete("/api_key", response_model=bool)
|
||||||
async def delete_api_key(user=Depends(get_current_user)):
|
async def delete_api_key(user=Depends(get_current_user)):
|
||||||
success = Users.update_user_api_key_by_id(user.id, None)
|
return Users.delete_user_api_key_by_id(user.id)
|
||||||
return success
|
|
||||||
|
|
||||||
|
|
||||||
# get api key
|
# get api key
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,20 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status, Backgrou
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from open_webui.socket.main import sio, get_user_ids_from_room
|
from open_webui.socket.main import (
|
||||||
from open_webui.models.users import Users, UserNameResponse
|
emit_to_users,
|
||||||
|
enter_room_for_users,
|
||||||
|
sio,
|
||||||
|
get_user_ids_from_room,
|
||||||
|
)
|
||||||
|
from open_webui.models.users import (
|
||||||
|
UserIdNameResponse,
|
||||||
|
UserIdNameStatusResponse,
|
||||||
|
UserListResponse,
|
||||||
|
UserModelResponse,
|
||||||
|
Users,
|
||||||
|
UserNameResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.channels import (
|
from open_webui.models.channels import (
|
||||||
|
|
@ -16,11 +28,13 @@ from open_webui.models.channels import (
|
||||||
ChannelModel,
|
ChannelModel,
|
||||||
ChannelForm,
|
ChannelForm,
|
||||||
ChannelResponse,
|
ChannelResponse,
|
||||||
|
CreateChannelForm,
|
||||||
)
|
)
|
||||||
from open_webui.models.messages import (
|
from open_webui.models.messages import (
|
||||||
Messages,
|
Messages,
|
||||||
MessageModel,
|
MessageModel,
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
MessageWithReactionsResponse,
|
||||||
MessageForm,
|
MessageForm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,7 +52,12 @@ from open_webui.utils.chat import generate_chat_completion
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, get_users_with_access
|
from open_webui.utils.access_control import (
|
||||||
|
has_access,
|
||||||
|
get_users_with_access,
|
||||||
|
get_permitted_group_and_user_ids,
|
||||||
|
has_permission,
|
||||||
|
)
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.channels import extract_mentions, replace_mentions
|
from open_webui.utils.channels import extract_mentions, replace_mentions
|
||||||
|
|
||||||
|
|
@ -52,9 +71,64 @@ router = APIRouter()
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[ChannelModel])
|
class ChannelListItemResponse(ChannelModel):
|
||||||
async def get_channels(user=Depends(get_verified_user)):
|
user_ids: Optional[list[str]] = None # 'dm' channels only
|
||||||
return Channels.get_channels_by_user_id(user.id)
|
users: Optional[list[UserIdNameStatusResponse]] = None # 'dm' channels only
|
||||||
|
|
||||||
|
last_message_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
unread_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=list[ChannelListItemResponse])
|
||||||
|
async def get_channels(request: Request, user=Depends(get_verified_user)):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
channels = Channels.get_channels_by_user_id(user.id)
|
||||||
|
channel_list = []
|
||||||
|
for channel in channels:
|
||||||
|
last_message = Messages.get_last_message_by_channel_id(channel.id)
|
||||||
|
last_message_at = last_message.created_at if last_message else None
|
||||||
|
|
||||||
|
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||||
|
unread_count = (
|
||||||
|
Messages.get_unread_message_count(
|
||||||
|
channel.id, user.id, channel_member.last_read_at
|
||||||
|
)
|
||||||
|
if channel_member
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = None
|
||||||
|
users = None
|
||||||
|
if channel.type == "dm":
|
||||||
|
user_ids = [
|
||||||
|
member.user_id
|
||||||
|
for member in Channels.get_members_by_channel_id(channel.id)
|
||||||
|
]
|
||||||
|
users = [
|
||||||
|
UserIdNameStatusResponse(
|
||||||
|
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
|
||||||
|
)
|
||||||
|
for user in Users.get_users_by_user_ids(user_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
channel_list.append(
|
||||||
|
ChannelListItemResponse(
|
||||||
|
**channel.model_dump(),
|
||||||
|
user_ids=user_ids,
|
||||||
|
users=users,
|
||||||
|
last_message_at=last_message_at,
|
||||||
|
unread_count=unread_count,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel_list
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=list[ChannelModel])
|
@router.get("/list", response_model=list[ChannelModel])
|
||||||
|
|
@ -64,16 +138,141 @@ async def get_all_channels(user=Depends(get_verified_user)):
|
||||||
return Channels.get_channels_by_user_id(user.id)
|
return Channels.get_channels_by_user_id(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetDMChannelByUserId
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}", response_model=Optional[ChannelModel])
|
||||||
|
async def get_dm_channel_by_user_id(
|
||||||
|
request: Request, user_id: str, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id])
|
||||||
|
if existing_channel:
|
||||||
|
participant_ids = [
|
||||||
|
member.user_id
|
||||||
|
for member in Channels.get_members_by_channel_id(existing_channel.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
await emit_to_users(
|
||||||
|
"events:channel",
|
||||||
|
{"data": {"type": "channel:created"}},
|
||||||
|
participant_ids,
|
||||||
|
)
|
||||||
|
await enter_room_for_users(
|
||||||
|
f"channel:{existing_channel.id}", participant_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
Channels.update_member_active_status(existing_channel.id, user.id, True)
|
||||||
|
return ChannelModel(**existing_channel.model_dump())
|
||||||
|
|
||||||
|
channel = Channels.insert_new_channel(
|
||||||
|
CreateChannelForm(
|
||||||
|
type="dm",
|
||||||
|
name="",
|
||||||
|
user_ids=[user_id],
|
||||||
|
),
|
||||||
|
user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
participant_ids = [
|
||||||
|
member.user_id
|
||||||
|
for member in Channels.get_members_by_channel_id(channel.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
await emit_to_users(
|
||||||
|
"events:channel",
|
||||||
|
{"data": {"type": "channel:created"}},
|
||||||
|
participant_ids,
|
||||||
|
)
|
||||||
|
await enter_room_for_users(f"channel:{channel.id}", participant_ids)
|
||||||
|
|
||||||
|
return ChannelModel(**channel.model_dump())
|
||||||
|
else:
|
||||||
|
raise Exception("Error creating channel")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# CreateNewChannel
|
# CreateNewChannel
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[ChannelModel])
|
@router.post("/create", response_model=Optional[ChannelModel])
|
||||||
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
|
async def create_new_channel(
|
||||||
|
request: Request, form_data: CreateChannelForm, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if form_data.type not in ["group", "dm"] and user.role != "admin":
|
||||||
|
# Only admins can create standard channels (joined by default)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = Channels.insert_new_channel(None, form_data, user.id)
|
if form_data.type == "dm":
|
||||||
|
existing_channel = Channels.get_dm_channel_by_user_ids(
|
||||||
|
[user.id, *form_data.user_ids]
|
||||||
|
)
|
||||||
|
if existing_channel:
|
||||||
|
participant_ids = [
|
||||||
|
member.user_id
|
||||||
|
for member in Channels.get_members_by_channel_id(
|
||||||
|
existing_channel.id
|
||||||
|
)
|
||||||
|
]
|
||||||
|
await emit_to_users(
|
||||||
|
"events:channel",
|
||||||
|
{"data": {"type": "channel:created"}},
|
||||||
|
participant_ids,
|
||||||
|
)
|
||||||
|
await enter_room_for_users(
|
||||||
|
f"channel:{existing_channel.id}", participant_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
Channels.update_member_active_status(existing_channel.id, user.id, True)
|
||||||
|
return ChannelModel(**existing_channel.model_dump())
|
||||||
|
|
||||||
|
channel = Channels.insert_new_channel(form_data, user.id)
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
participant_ids = [
|
||||||
|
member.user_id
|
||||||
|
for member in Channels.get_members_by_channel_id(channel.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
await emit_to_users(
|
||||||
|
"events:channel",
|
||||||
|
{"data": {"type": "channel:created"}},
|
||||||
|
participant_ids,
|
||||||
|
)
|
||||||
|
await enter_room_for_users(f"channel:{channel.id}", participant_ids)
|
||||||
|
|
||||||
return ChannelModel(**channel.model_dump())
|
return ChannelModel(**channel.model_dump())
|
||||||
|
else:
|
||||||
|
raise Exception("Error creating channel")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -86,7 +285,15 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Optional[ChannelResponse])
|
class ChannelFullResponse(ChannelResponse):
|
||||||
|
user_ids: Optional[list[str]] = None # 'group'/'dm' channels only
|
||||||
|
users: Optional[list[UserIdNameStatusResponse]] = None # 'group'/'dm' channels only
|
||||||
|
|
||||||
|
last_read_at: Optional[int] = None # timestamp in epoch (time_ns)
|
||||||
|
unread_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}", response_model=Optional[ChannelFullResponse])
|
||||||
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
channel = Channels.get_channel_by_id(id)
|
channel = Channels.get_channel_by_id(id)
|
||||||
if not channel:
|
if not channel:
|
||||||
|
|
@ -94,6 +301,44 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_ids = None
|
||||||
|
users = None
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = [
|
||||||
|
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
users = [
|
||||||
|
UserIdNameStatusResponse(
|
||||||
|
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
|
||||||
|
)
|
||||||
|
for user in Users.get_users_by_user_ids(user_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||||
|
unread_count = Messages.get_unread_message_count(
|
||||||
|
channel.id, user.id, channel_member.last_read_at if channel_member else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChannelFullResponse(
|
||||||
|
**{
|
||||||
|
**channel.model_dump(),
|
||||||
|
"user_ids": user_ids,
|
||||||
|
"users": users,
|
||||||
|
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
|
||||||
|
"write_access": True,
|
||||||
|
"user_count": len(user_ids),
|
||||||
|
"last_read_at": channel_member.last_read_at if channel_member else None,
|
||||||
|
"unread_count": unread_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control
|
||||||
):
|
):
|
||||||
|
|
@ -105,14 +350,240 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChannelResponse(
|
user_count = len(get_users_with_access("read", channel.access_control))
|
||||||
|
|
||||||
|
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||||
|
unread_count = Messages.get_unread_message_count(
|
||||||
|
channel.id, user.id, channel_member.last_read_at if channel_member else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChannelFullResponse(
|
||||||
**{
|
**{
|
||||||
**channel.model_dump(),
|
**channel.model_dump(),
|
||||||
|
"user_ids": user_ids,
|
||||||
|
"users": users,
|
||||||
|
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
|
||||||
"write_access": write_access or user.role == "admin",
|
"write_access": write_access or user.role == "admin",
|
||||||
|
"user_count": user_count,
|
||||||
|
"last_read_at": channel_member.last_read_at if channel_member else None,
|
||||||
|
"unread_count": unread_count,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetChannelMembersById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
PAGE_ITEM_COUNT = 30
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}/members", response_model=UserListResponse)
|
||||||
|
async def get_channel_members_by_id(
|
||||||
|
id: str,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[str] = None,
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
|
page = max(1, page)
|
||||||
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.type == "dm":
|
||||||
|
user_ids = [
|
||||||
|
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
|
||||||
|
]
|
||||||
|
users = Users.get_users_by_user_ids(user_ids)
|
||||||
|
total = len(users)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"users": [
|
||||||
|
UserModelResponse(
|
||||||
|
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||||
|
)
|
||||||
|
for user in users
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
filter = {}
|
||||||
|
|
||||||
|
if query:
|
||||||
|
filter["query"] = query
|
||||||
|
if order_by:
|
||||||
|
filter["order_by"] = order_by
|
||||||
|
if direction:
|
||||||
|
filter["direction"] = direction
|
||||||
|
|
||||||
|
if channel.type == "group":
|
||||||
|
filter["channel_id"] = channel.id
|
||||||
|
else:
|
||||||
|
filter["roles"] = ["!pending"]
|
||||||
|
permitted_ids = get_permitted_group_and_user_ids(
|
||||||
|
"read", channel.access_control
|
||||||
|
)
|
||||||
|
if permitted_ids:
|
||||||
|
filter["user_ids"] = permitted_ids.get("user_ids")
|
||||||
|
filter["group_ids"] = permitted_ids.get("group_ids")
|
||||||
|
|
||||||
|
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
users = result["users"]
|
||||||
|
total = result["total"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"users": [
|
||||||
|
UserModelResponse(
|
||||||
|
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||||
|
)
|
||||||
|
for user in users
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# UpdateIsActiveMemberByIdAndUserId
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateActiveMemberForm(BaseModel):
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/members/active", response_model=bool)
|
||||||
|
async def update_is_active_member_by_id_and_user_id(
|
||||||
|
id: str,
|
||||||
|
form_data: UpdateActiveMemberForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
Channels.update_member_active_status(channel.id, user.id, form_data.is_active)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# AddMembersById
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateMembersForm(BaseModel):
|
||||||
|
user_ids: list[str] = []
|
||||||
|
group_ids: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/update/members/add")
|
||||||
|
async def add_members_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
form_data: UpdateMembersForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.user_id != user.id and user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
memberships = Channels.add_members_to_channel(
|
||||||
|
channel.id, user.id, form_data.user_ids, form_data.group_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return memberships
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
#
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveMembersForm(BaseModel):
|
||||||
|
user_ids: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/update/members/remove")
|
||||||
|
async def remove_members_by_id(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
form_data: RemoveMembersForm,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.user_id != user.id and user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# UpdateChannelById
|
# UpdateChannelById
|
||||||
############################
|
############################
|
||||||
|
|
@ -120,14 +591,27 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/{id}/update", response_model=Optional[ChannelModel])
|
@router.post("/{id}/update", response_model=Optional[ChannelModel])
|
||||||
async def update_channel_by_id(
|
async def update_channel_by_id(
|
||||||
id: str, form_data: ChannelForm, user=Depends(get_admin_user)
|
request: Request, id: str, form_data: ChannelForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
channel = Channels.get_channel_by_id(id)
|
channel = Channels.get_channel_by_id(id)
|
||||||
if not channel:
|
if not channel:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.user_id != user.id and user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
channel = Channels.update_channel_by_id(id, form_data)
|
channel = Channels.update_channel_by_id(id, form_data)
|
||||||
return ChannelModel(**channel.model_dump())
|
return ChannelModel(**channel.model_dump())
|
||||||
|
|
@ -144,13 +628,28 @@ async def update_channel_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}/delete", response_model=bool)
|
@router.delete("/{id}/delete", response_model=bool)
|
||||||
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
|
async def delete_channel_by_id(
|
||||||
|
request: Request, id: str, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
channel = Channels.get_channel_by_id(id)
|
channel = Channels.get_channel_by_id(id)
|
||||||
if not channel:
|
if not channel:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.user_id != user.id and user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Channels.delete_channel_by_id(id)
|
Channels.delete_channel_by_id(id)
|
||||||
return True
|
return True
|
||||||
|
|
@ -180,6 +679,12 @@ async def get_channel_messages(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control
|
||||||
):
|
):
|
||||||
|
|
@ -187,6 +692,10 @@ async def get_channel_messages(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
channel_member = Channels.join_channel(
|
||||||
|
id, user.id
|
||||||
|
) # Ensure user is a member of the channel
|
||||||
|
|
||||||
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
|
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
|
||||||
users = {}
|
users = {}
|
||||||
|
|
||||||
|
|
@ -216,6 +725,62 @@ async def get_channel_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetPinnedChannelMessages
|
||||||
|
############################
|
||||||
|
|
||||||
|
PAGE_ITEM_COUNT_PINNED = 20
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}/messages/pinned", response_model=list[MessageWithReactionsResponse])
|
||||||
|
async def get_pinned_channel_messages(
|
||||||
|
id: str, page: int = 1, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if user.role != "admin" and not has_access(
|
||||||
|
user.id, type="read", access_control=channel.access_control
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
page = max(1, page)
|
||||||
|
skip = (page - 1) * PAGE_ITEM_COUNT_PINNED
|
||||||
|
limit = PAGE_ITEM_COUNT_PINNED
|
||||||
|
|
||||||
|
message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit)
|
||||||
|
users = {}
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for message in message_list:
|
||||||
|
if message.user_id not in users:
|
||||||
|
user = Users.get_user_by_id(message.user_id)
|
||||||
|
users[message.user_id] = user
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
MessageWithReactionsResponse(
|
||||||
|
**{
|
||||||
|
**message.model_dump(),
|
||||||
|
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||||
|
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# PostNewMessage
|
# PostNewMessage
|
||||||
############################
|
############################
|
||||||
|
|
@ -225,7 +790,9 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||||
users = get_users_with_access("read", channel.access_control)
|
users = get_users_with_access("read", channel.access_control)
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
if user.id not in active_user_ids:
|
if (user.id not in active_user_ids) and Channels.is_user_channel_member(
|
||||||
|
channel.id, user.id
|
||||||
|
):
|
||||||
if user.settings:
|
if user.settings:
|
||||||
webhook_url = user.settings.ui.get("notifications", {}).get(
|
webhook_url = user.settings.ui.get("notifications", {}).get(
|
||||||
"webhook_url", None
|
"webhook_url", None
|
||||||
|
|
@ -340,11 +907,12 @@ async def model_response_handler(request, channel, message, user):
|
||||||
if file.get("type", "") == "image":
|
if file.get("type", "") == "image":
|
||||||
images.append(file.get("url", ""))
|
images.append(file.get("url", ""))
|
||||||
|
|
||||||
|
thread_history_string = "\n\n".join(thread_history)
|
||||||
system_message = {
|
system_message = {
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"You are {model.get('name', model_id)}, participating in a threaded conversation. Be concise and conversational."
|
"content": f"You are {model.get('name', model_id)}, participating in a threaded conversation. Be concise and conversational."
|
||||||
+ (
|
+ (
|
||||||
f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally as {model.get('name', model_id)}, addressing the most recent message while being aware of the full context."
|
f"Here's the thread history:\n\n\n{thread_history_string}\n\n\nContinue the conversation naturally as {model.get('name', model_id)}, addressing the most recent message while being aware of the full context."
|
||||||
if thread_history
|
if thread_history
|
||||||
else ""
|
else ""
|
||||||
),
|
),
|
||||||
|
|
@ -384,6 +952,7 @@ async def model_response_handler(request, channel, message, user):
|
||||||
)
|
)
|
||||||
|
|
||||||
if res:
|
if res:
|
||||||
|
if res.get("choices", []) and len(res["choices"]) > 0:
|
||||||
await update_message_by_id(
|
await update_message_by_id(
|
||||||
channel.id,
|
channel.id,
|
||||||
response_message.id,
|
response_message.id,
|
||||||
|
|
@ -397,6 +966,20 @@ async def model_response_handler(request, channel, message, user):
|
||||||
),
|
),
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
elif res.get("error", None):
|
||||||
|
await update_message_by_id(
|
||||||
|
channel.id,
|
||||||
|
response_message.id,
|
||||||
|
MessageForm(
|
||||||
|
**{
|
||||||
|
"content": f"Error: {res['error']}",
|
||||||
|
"meta": {
|
||||||
|
"done": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
user,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.info(e)
|
log.info(e)
|
||||||
pass
|
pass
|
||||||
|
|
@ -413,6 +996,12 @@ async def new_message_handler(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False
|
||||||
):
|
):
|
||||||
|
|
@ -423,20 +1012,28 @@ async def new_message_handler(
|
||||||
try:
|
try:
|
||||||
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
||||||
if message:
|
if message:
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
members = Channels.get_members_by_channel_id(channel.id)
|
||||||
|
for member in members:
|
||||||
|
if not member.is_active:
|
||||||
|
Channels.update_member_active_status(
|
||||||
|
channel.id, member.user_id, True
|
||||||
|
)
|
||||||
|
|
||||||
message = Messages.get_message_by_id(message.id)
|
message = Messages.get_message_by_id(message.id)
|
||||||
event_data = {
|
event_data = {
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": message.id,
|
"message_id": message.id,
|
||||||
"data": {
|
"data": {
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"data": message.model_dump(),
|
"data": {"temp_id": form_data.temp_id, **message.model_dump()},
|
||||||
},
|
},
|
||||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||||
"channel": channel.model_dump(),
|
"channel": channel.model_dump(),
|
||||||
}
|
}
|
||||||
|
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
event_data,
|
event_data,
|
||||||
to=f"channel:{channel.id}",
|
to=f"channel:{channel.id}",
|
||||||
)
|
)
|
||||||
|
|
@ -447,7 +1044,7 @@ async def new_message_handler(
|
||||||
|
|
||||||
if parent_message:
|
if parent_message:
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": parent_message.id,
|
"message_id": parent_message.id,
|
||||||
|
|
@ -521,6 +1118,12 @@ async def get_channel_message(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control
|
||||||
):
|
):
|
||||||
|
|
@ -549,6 +1152,69 @@ async def get_channel_message(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# PinChannelMessage
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class PinMessageForm(BaseModel):
|
||||||
|
is_pinned: bool
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{id}/messages/{message_id}/pin", response_model=Optional[MessageUserResponse]
|
||||||
|
)
|
||||||
|
async def pin_channel_message(
|
||||||
|
id: str, message_id: str, form_data: PinMessageForm, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
channel = Channels.get_channel_by_id(id)
|
||||||
|
if not channel:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if user.role != "admin" and not has_access(
|
||||||
|
user.id, type="read", access_control=channel.access_control
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
message = Messages.get_message_by_id(message_id)
|
||||||
|
if not message:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if message.channel_id != id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id)
|
||||||
|
message = Messages.get_message_by_id(message_id)
|
||||||
|
return MessageUserResponse(
|
||||||
|
**{
|
||||||
|
**message.model_dump(),
|
||||||
|
"user": UserNameResponse(
|
||||||
|
**Users.get_user_by_id(message.user_id).model_dump()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetChannelThreadMessages
|
# GetChannelThreadMessages
|
||||||
############################
|
############################
|
||||||
|
|
@ -570,6 +1236,12 @@ async def get_channel_thread_messages(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control
|
||||||
):
|
):
|
||||||
|
|
@ -629,10 +1301,18 @@ async def update_message_by_id(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if (
|
if (
|
||||||
user.role != "admin"
|
user.role != "admin"
|
||||||
and message.user_id != user.id
|
and message.user_id != user.id
|
||||||
and not has_access(user.id, type="read", access_control=channel.access_control)
|
and not has_access(
|
||||||
|
user.id, type="read", access_control=channel.access_control
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -644,7 +1324,7 @@ async def update_message_by_id(
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": message.id,
|
"message_id": message.id,
|
||||||
|
|
@ -685,6 +1365,12 @@ async def add_reaction_to_message(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False
|
||||||
):
|
):
|
||||||
|
|
@ -708,7 +1394,7 @@ async def add_reaction_to_message(
|
||||||
message = Messages.get_message_by_id(message_id)
|
message = Messages.get_message_by_id(message_id)
|
||||||
|
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": message.id,
|
"message_id": message.id,
|
||||||
|
|
@ -748,6 +1434,12 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False
|
||||||
):
|
):
|
||||||
|
|
@ -774,7 +1466,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
||||||
message = Messages.get_message_by_id(message_id)
|
message = Messages.get_message_by_id(message_id)
|
||||||
|
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": message.id,
|
"message_id": message.id,
|
||||||
|
|
@ -825,11 +1517,20 @@ async def delete_message_by_id(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if channel.type in ["group", "dm"]:
|
||||||
|
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
else:
|
||||||
if (
|
if (
|
||||||
user.role != "admin"
|
user.role != "admin"
|
||||||
and message.user_id != user.id
|
and message.user_id != user.id
|
||||||
and not has_access(
|
and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id,
|
||||||
|
type="write",
|
||||||
|
access_control=channel.access_control,
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -839,7 +1540,7 @@ async def delete_message_by_id(
|
||||||
try:
|
try:
|
||||||
Messages.delete_message_by_id(message_id)
|
Messages.delete_message_by_id(message_id)
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": message.id,
|
"message_id": message.id,
|
||||||
|
|
@ -862,7 +1563,7 @@ async def delete_message_by_id(
|
||||||
|
|
||||||
if parent_message:
|
if parent_message:
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": channel.id,
|
"channel_id": channel.id,
|
||||||
"message_id": parent_message.id,
|
"message_id": parent_message.id,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from open_webui.socket.main import get_event_emitter
|
||||||
from open_webui.models.chats import (
|
from open_webui.models.chats import (
|
||||||
ChatForm,
|
ChatForm,
|
||||||
ChatImportForm,
|
ChatImportForm,
|
||||||
|
ChatsImportForm,
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
Chats,
|
Chats,
|
||||||
ChatTitleIdResponse,
|
ChatTitleIdResponse,
|
||||||
|
|
@ -39,6 +40,7 @@ router = APIRouter()
|
||||||
def get_session_user_chat_list(
|
def get_session_user_chat_list(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
page: Optional[int] = None,
|
page: Optional[int] = None,
|
||||||
|
include_pinned: Optional[bool] = False,
|
||||||
include_folders: Optional[bool] = False,
|
include_folders: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
@ -47,11 +49,15 @@ def get_session_user_chat_list(
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
return Chats.get_chat_title_id_list_by_user_id(
|
return Chats.get_chat_title_id_list_by_user_id(
|
||||||
user.id, include_folders=include_folders, skip=skip, limit=limit
|
user.id,
|
||||||
|
include_folders=include_folders,
|
||||||
|
include_pinned=include_pinned,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return Chats.get_chat_title_id_list_by_user_id(
|
return Chats.get_chat_title_id_list_by_user_id(
|
||||||
user.id, include_folders=include_folders
|
user.id, include_folders=include_folders, include_pinned=include_pinned
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -137,26 +143,15 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# ImportChat
|
# ImportChats
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=Optional[ChatResponse])
|
@router.post("/import", response_model=list[ChatResponse])
|
||||||
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
|
async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)):
|
||||||
try:
|
try:
|
||||||
chat = Chats.import_chat(user.id, form_data)
|
chats = Chats.import_chats(user.id, form_data.chats)
|
||||||
if chat:
|
return chats
|
||||||
tags = chat.meta.get("tags", [])
|
|
||||||
for tag_id in tags:
|
|
||||||
tag_id = tag_id.replace(" ", "_").lower()
|
|
||||||
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
|
||||||
if (
|
|
||||||
tag_id != "none"
|
|
||||||
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
|
||||||
):
|
|
||||||
Tags.insert_new_tag(tag_name, user.id)
|
|
||||||
|
|
||||||
return ChatResponse(**chat.model_dump())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -223,7 +218,7 @@ async def get_chat_list_by_folder_id(
|
||||||
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
limit = 60
|
limit = 10
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -653,8 +648,9 @@ async def clone_chat_by_id(
|
||||||
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.import_chat(
|
chats = Chats.import_chats(
|
||||||
user.id,
|
user.id,
|
||||||
|
[
|
||||||
ChatImportForm(
|
ChatImportForm(
|
||||||
**{
|
**{
|
||||||
"chat": updated_chat,
|
"chat": updated_chat,
|
||||||
|
|
@ -662,10 +658,18 @@ async def clone_chat_by_id(
|
||||||
"pinned": chat.pinned,
|
"pinned": chat.pinned,
|
||||||
"folder_id": chat.folder_id,
|
"folder_id": chat.folder_id,
|
||||||
}
|
}
|
||||||
),
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if chats:
|
||||||
|
chat = chats[0]
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -693,8 +697,9 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
"title": f"Clone of {chat.title}",
|
"title": f"Clone of {chat.title}",
|
||||||
}
|
}
|
||||||
|
|
||||||
chat = Chats.import_chat(
|
chats = Chats.import_chats(
|
||||||
user.id,
|
user.id,
|
||||||
|
[
|
||||||
ChatImportForm(
|
ChatImportForm(
|
||||||
**{
|
**{
|
||||||
"chat": updated_chat,
|
"chat": updated_chat,
|
||||||
|
|
@ -702,9 +707,18 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
"pinned": chat.pinned,
|
"pinned": chat.pinned,
|
||||||
"folder_id": chat.folder_id,
|
"folder_id": chat.folder_id,
|
||||||
}
|
}
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if chats:
|
||||||
|
chat = chats[0]
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
|
||||||
set_tool_servers,
|
set_tool_servers,
|
||||||
)
|
)
|
||||||
from open_webui.utils.mcp.client import MCPClient
|
from open_webui.utils.mcp.client import MCPClient
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -142,6 +144,7 @@ class ToolServerConnection(BaseModel):
|
||||||
path: str
|
path: str
|
||||||
type: Optional[str] = "openapi" # openapi, mcp
|
type: Optional[str] = "openapi" # openapi, mcp
|
||||||
auth_type: Optional[str]
|
auth_type: Optional[str]
|
||||||
|
headers: Optional[dict | str] = None
|
||||||
key: Optional[str]
|
key: Optional[str]
|
||||||
config: Optional[dict]
|
config: Optional[dict]
|
||||||
|
|
||||||
|
|
@ -165,6 +168,21 @@ async def set_tool_servers_config(
|
||||||
form_data: ToolServersConfigForm,
|
form_data: ToolServersConfigForm,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
|
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
|
server_type = connection.get("type", "openapi")
|
||||||
|
auth_type = connection.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "oauth_2.1":
|
||||||
|
# Remove existing OAuth clients for tool servers
|
||||||
|
server_id = connection.get("info", {}).get("id")
|
||||||
|
client_key = f"{server_type}:{server_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Set new tool server connections
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
]
|
]
|
||||||
|
|
@ -176,6 +194,7 @@ async def set_tool_servers_config(
|
||||||
if server_type == "mcp":
|
if server_type == "mcp":
|
||||||
server_id = connection.get("info", {}).get("id")
|
server_id = connection.get("info", {}).get("id")
|
||||||
auth_type = connection.get("auth_type", "none")
|
auth_type = connection.get("auth_type", "none")
|
||||||
|
|
||||||
if auth_type == "oauth_2.1" and server_id:
|
if auth_type == "oauth_2.1" and server_id:
|
||||||
try:
|
try:
|
||||||
oauth_client_info = connection.get("info", {}).get(
|
oauth_client_info = connection.get("info", {}).get(
|
||||||
|
|
@ -183,7 +202,7 @@ async def set_tool_servers_config(
|
||||||
)
|
)
|
||||||
oauth_client_info = decrypt_data(oauth_client_info)
|
oauth_client_info = decrypt_data(oauth_client_info)
|
||||||
|
|
||||||
await request.app.state.oauth_client_manager.add_client(
|
request.app.state.oauth_client_manager.add_client(
|
||||||
f"{server_type}:{server_id}",
|
f"{server_type}:{server_id}",
|
||||||
OAuthClientInformationFull(**oauth_client_info),
|
OAuthClientInformationFull(**oauth_client_info),
|
||||||
)
|
)
|
||||||
|
|
@ -211,9 +230,9 @@ async def verify_tool_servers_config(
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}"
|
f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}"
|
||||||
)
|
)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
discovery_urls[0]
|
discovery_url
|
||||||
) as oauth_server_metadata_response:
|
) as oauth_server_metadata_response:
|
||||||
if oauth_server_metadata_response.status == 200:
|
if oauth_server_metadata_response.status == 200:
|
||||||
try:
|
try:
|
||||||
|
|
@ -234,7 +253,7 @@ async def verify_tool_servers_config(
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}",
|
detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_url}",
|
||||||
)
|
)
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -252,18 +271,26 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
token = request.state.token.credentials
|
token = request.state.token.credentials
|
||||||
elif form_data.auth_type == "system_oauth":
|
elif form_data.auth_type == "system_oauth":
|
||||||
|
oauth_token = None
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get("oauth_session_id", None):
|
||||||
token = await request.app.state.oauth_manager.get_oauth_token(
|
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id,
|
user.id,
|
||||||
request.cookies.get("oauth_session_id", None),
|
request.cookies.get("oauth_session_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if oauth_token:
|
||||||
|
token = oauth_token.get("access_token", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
if form_data.headers and isinstance(form_data.headers, dict):
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
headers.update(form_data.headers)
|
||||||
|
|
||||||
await client.connect(form_data.url, headers=headers)
|
await client.connect(form_data.url, headers=headers)
|
||||||
specs = await client.list_tool_specs()
|
specs = await client.list_tool_specs()
|
||||||
return {
|
return {
|
||||||
|
|
@ -281,6 +308,7 @@ async def verify_tool_servers_config(
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
else: # openapi
|
else: # openapi
|
||||||
token = None
|
token = None
|
||||||
|
headers = None
|
||||||
if form_data.auth_type == "bearer":
|
if form_data.auth_type == "bearer":
|
||||||
token = form_data.key
|
token = form_data.key
|
||||||
elif form_data.auth_type == "session":
|
elif form_data.auth_type == "session":
|
||||||
|
|
@ -288,15 +316,29 @@ async def verify_tool_servers_config(
|
||||||
elif form_data.auth_type == "system_oauth":
|
elif form_data.auth_type == "system_oauth":
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get("oauth_session_id", None):
|
||||||
token = await request.app.state.oauth_manager.get_oauth_token(
|
oauth_token = (
|
||||||
|
await request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id,
|
user.id,
|
||||||
request.cookies.get("oauth_session_id", None),
|
request.cookies.get("oauth_session_id", None),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if oauth_token:
|
||||||
|
token = oauth_token.get("access_token", "")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if token:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
if form_data.headers and isinstance(form_data.headers, dict):
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
headers.update(form_data.headers)
|
||||||
|
|
||||||
url = get_tool_server_url(form_data.url, form_data.path)
|
url = get_tool_server_url(form_data.url, form_data.path)
|
||||||
return await get_tool_server_data(token, url)
|
return await get_tool_server_data(url, headers=headers)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -421,6 +463,7 @@ async def set_code_execution_config(
|
||||||
############################
|
############################
|
||||||
class ModelsConfigForm(BaseModel):
|
class ModelsConfigForm(BaseModel):
|
||||||
DEFAULT_MODELS: Optional[str]
|
DEFAULT_MODELS: Optional[str]
|
||||||
|
DEFAULT_PINNED_MODELS: Optional[str]
|
||||||
MODEL_ORDER_LIST: Optional[list[str]]
|
MODEL_ORDER_LIST: Optional[list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -428,6 +471,7 @@ class ModelsConfigForm(BaseModel):
|
||||||
async def get_models_config(request: Request, user=Depends(get_admin_user)):
|
async def get_models_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||||
|
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
|
||||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -437,9 +481,11 @@ async def set_models_config(
|
||||||
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
|
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
|
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
|
||||||
|
request.app.state.config.DEFAULT_PINNED_MODELS = form_data.DEFAULT_PINNED_MODELS
|
||||||
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
|
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
|
||||||
return {
|
return {
|
||||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||||
|
"DEFAULT_PINNED_MODELS": request.app.state.config.DEFAULT_PINNED_MODELS,
|
||||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from open_webui.models.feedbacks import (
|
||||||
FeedbackModel,
|
FeedbackModel,
|
||||||
FeedbackResponse,
|
FeedbackResponse,
|
||||||
FeedbackForm,
|
FeedbackForm,
|
||||||
|
FeedbackUserResponse,
|
||||||
|
FeedbackListResponse,
|
||||||
Feedbacks,
|
Feedbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -56,35 +58,10 @@ async def update_config(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
@router.get("/feedbacks/all", response_model=list[FeedbackResponse])
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
email: str
|
|
||||||
role: str = "pending"
|
|
||||||
|
|
||||||
last_active_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackUserResponse(FeedbackResponse):
|
|
||||||
user: Optional[UserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
|
|
||||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||||
feedbacks = Feedbacks.get_all_feedbacks()
|
feedbacks = Feedbacks.get_all_feedbacks()
|
||||||
|
return feedbacks
|
||||||
feedback_list = []
|
|
||||||
for feedback in feedbacks:
|
|
||||||
user = Users.get_user_by_id(feedback.user_id)
|
|
||||||
feedback_list.append(
|
|
||||||
FeedbackUserResponse(
|
|
||||||
**feedback.model_dump(),
|
|
||||||
user=UserResponse(**user.model_dump()) if user else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return feedback_list
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/feedbacks/all")
|
@router.delete("/feedbacks/all")
|
||||||
|
|
@ -111,6 +88,31 @@ async def delete_feedbacks(user=Depends(get_verified_user)):
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
PAGE_ITEM_COUNT = 30
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/feedbacks/list", response_model=FeedbackListResponse)
|
||||||
|
async def get_feedbacks(
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[str] = None,
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
user=Depends(get_admin_user),
|
||||||
|
):
|
||||||
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
|
page = max(1, page)
|
||||||
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
|
filter = {}
|
||||||
|
if order_by:
|
||||||
|
filter["order_by"] = order_by
|
||||||
|
if direction:
|
||||||
|
filter["direction"] = direction
|
||||||
|
|
||||||
|
result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/feedback", response_model=FeedbackModel)
|
@router.post("/feedback", response_model=FeedbackModel)
|
||||||
async def create_feedback(
|
async def create_feedback(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
@ -34,12 +35,19 @@ from open_webui.models.files import (
|
||||||
Files,
|
Files,
|
||||||
)
|
)
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
|
|
||||||
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||||
from open_webui.routers.audio import transcribe
|
from open_webui.routers.audio import transcribe
|
||||||
|
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -53,31 +61,37 @@ router = APIRouter()
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
|
||||||
def has_access_to_file(
|
def has_access_to_file(
|
||||||
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
|
||||||
) -> bool:
|
) -> bool:
|
||||||
file = Files.get_file_by_id(file_id)
|
file = Files.get_file_by_id(file_id)
|
||||||
log.debug(f"Checking if user has {access_type} access to file")
|
log.debug(f"Checking if user has {access_type} access to file")
|
||||||
|
|
||||||
if not file:
|
if not file:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_access = False
|
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id)
|
||||||
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||||
|
|
||||||
|
for knowledge_base in knowledge_bases:
|
||||||
|
if knowledge_base.user_id == user.id or has_access(
|
||||||
|
user.id, access_type, knowledge_base.access_control, user_group_ids
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
|
||||||
if knowledge_base_id:
|
if knowledge_base_id:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
|
||||||
user.id, access_type
|
user.id, access_type
|
||||||
)
|
)
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
if knowledge_base.id == knowledge_base_id:
|
if knowledge_base.id == knowledge_base_id:
|
||||||
has_access = True
|
return True
|
||||||
break
|
|
||||||
|
|
||||||
return has_access
|
return False
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -102,7 +116,7 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
file_path = Storage.get_file(file_path)
|
file_path = Storage.get_file(file_path)
|
||||||
result = transcribe(request, file_path, file_metadata)
|
result = transcribe(request, file_path, file_metadata, user)
|
||||||
|
|
||||||
process_file(
|
process_file(
|
||||||
request,
|
request,
|
||||||
|
|
@ -115,6 +129,10 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us
|
||||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||||
):
|
):
|
||||||
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"File type {file.content_type} is not supported for processing"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||||
|
|
|
||||||
|
|
@ -46,11 +46,35 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[FolderNameIdResponse])
|
@router.get("/", response_model=list[FolderNameIdResponse])
|
||||||
async def get_folders(user=Depends(get_verified_user)):
|
async def get_folders(request: Request, user=Depends(get_verified_user)):
|
||||||
|
if request.app.state.config.ENABLE_FOLDERS is False:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id,
|
||||||
|
"features.folders",
|
||||||
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
folders = Folders.get_folders_by_user_id(user.id)
|
folders = Folders.get_folders_by_user_id(user.id)
|
||||||
|
|
||||||
# Verify folder data integrity
|
# Verify folder data integrity
|
||||||
|
folder_list = []
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
|
if folder.parent_id and not Folders.get_folder_by_id_and_user_id(
|
||||||
|
folder.parent_id, user.id
|
||||||
|
):
|
||||||
|
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||||
|
folder.id, user.id, None
|
||||||
|
)
|
||||||
|
|
||||||
if folder.data:
|
if folder.data:
|
||||||
if "files" in folder.data:
|
if "files" in folder.data:
|
||||||
valid_files = []
|
valid_files = []
|
||||||
|
|
@ -74,12 +98,9 @@ async def get_folders(user=Depends(get_verified_user)):
|
||||||
folder.id, user.id, FolderUpdateForm(data=folder.data)
|
folder.id, user.id, FolderUpdateForm(data=folder.data)
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
|
||||||
{
|
|
||||||
**folder.model_dump(),
|
return folder_list
|
||||||
}
|
|
||||||
for folder in folders
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -253,7 +274,10 @@ async def update_folder_is_expanded_by_id(
|
||||||
|
|
||||||
@router.delete("/{id}")
|
@router.delete("/{id}")
|
||||||
async def delete_folder_by_id(
|
async def delete_folder_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
delete_contents: Optional[bool] = True,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
if Chats.count_chats_by_folder_id_and_user_id(id, user.id):
|
||||||
chat_delete_permission = has_permission(
|
chat_delete_permission = has_permission(
|
||||||
|
|
@ -265,12 +289,21 @@ async def delete_folder_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
folders = []
|
||||||
|
folders.append(Folders.get_folder_by_id_and_user_id(id, user.id))
|
||||||
|
while folders:
|
||||||
|
folder = folders.pop()
|
||||||
if folder:
|
if folder:
|
||||||
try:
|
try:
|
||||||
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||||
|
|
||||||
for folder_id in folder_ids:
|
for folder_id in folder_ids:
|
||||||
|
if delete_contents:
|
||||||
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id)
|
||||||
|
else:
|
||||||
|
Chats.move_chats_by_user_id_and_folder_id(
|
||||||
|
user.id, folder_id, None
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -280,6 +313,13 @@ async def delete_folder_by_id(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
# Get all subfolders
|
||||||
|
subfolders = Folders.get_folders_by_parent_id_and_user_id(
|
||||||
|
folder.id, user.id
|
||||||
|
)
|
||||||
|
folders.extend(subfolders)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from open_webui.models.functions import (
|
||||||
FunctionForm,
|
FunctionForm,
|
||||||
FunctionModel,
|
FunctionModel,
|
||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
|
FunctionUserResponse,
|
||||||
FunctionWithValvesModel,
|
FunctionWithValvesModel,
|
||||||
Functions,
|
Functions,
|
||||||
)
|
)
|
||||||
|
|
@ -42,6 +43,11 @@ async def get_functions(user=Depends(get_verified_user)):
|
||||||
return Functions.get_functions()
|
return Functions.get_functions()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/list", response_model=list[FunctionUserResponse])
|
||||||
|
async def get_function_list(user=Depends(get_admin_user)):
|
||||||
|
return Functions.get_function_list()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# ExportFunctions
|
# ExportFunctions
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users, UserInfoResponse
|
||||||
from open_webui.models.groups import (
|
from open_webui.models.groups import (
|
||||||
Groups,
|
Groups,
|
||||||
GroupForm,
|
GroupForm,
|
||||||
|
|
@ -31,11 +31,18 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[GroupResponse])
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
async def get_groups(user=Depends(get_verified_user)):
|
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
|
||||||
return Groups.get_groups()
|
filter = {}
|
||||||
else:
|
if user.role != "admin":
|
||||||
return Groups.get_groups_by_member_id(user.id)
|
filter["member_id"] = user.id
|
||||||
|
|
||||||
|
if share is not None:
|
||||||
|
filter["share"] = share
|
||||||
|
|
||||||
|
groups = Groups.get_groups(filter=filter)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -48,7 +55,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
try:
|
try:
|
||||||
group = Groups.insert_new_group(user.id, form_data)
|
group = Groups.insert_new_group(user.id, form_data)
|
||||||
if group:
|
if group:
|
||||||
return group
|
return GroupResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -71,7 +81,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
group = Groups.get_group_by_id(id)
|
group = Groups.get_group_by_id(id)
|
||||||
if group:
|
if group:
|
||||||
return group
|
return GroupResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -79,6 +92,50 @@ async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# ExportGroupById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
class GroupExportResponse(GroupResponse):
|
||||||
|
user_ids: list[str] = []
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/id/{id}/export", response_model=Optional[GroupExportResponse])
|
||||||
|
async def export_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
group = Groups.get_group_by_id(id)
|
||||||
|
if group:
|
||||||
|
return GroupExportResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
user_ids=Groups.get_group_user_ids_by_id(group.id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetUsersInGroupById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/id/{id}/users", response_model=list[UserInfoResponse])
|
||||||
|
async def get_users_in_group(id: str, user=Depends(get_admin_user)):
|
||||||
|
try:
|
||||||
|
users = Users.get_users_by_group_id(id)
|
||||||
|
return users
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error adding users to group {id}: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# UpdateGroupById
|
# UpdateGroupById
|
||||||
############################
|
############################
|
||||||
|
|
@ -89,12 +146,12 @@ async def update_group_by_id(
|
||||||
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if form_data.user_ids:
|
|
||||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
|
||||||
|
|
||||||
group = Groups.update_group_by_id(id, form_data)
|
group = Groups.update_group_by_id(id, form_data)
|
||||||
if group:
|
if group:
|
||||||
return group
|
return GroupResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -123,7 +180,10 @@ async def add_user_to_group(
|
||||||
|
|
||||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||||
if group:
|
if group:
|
||||||
return group
|
return GroupResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -144,7 +204,10 @@ async def remove_users_from_group(
|
||||||
try:
|
try:
|
||||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||||
if group:
|
if group:
|
||||||
return group
|
return GroupResponse(
|
||||||
|
**group.model_dump(),
|
||||||
|
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from open_webui.models.knowledge import (
|
from open_webui.models.knowledge import (
|
||||||
|
|
@ -41,97 +42,38 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[KnowledgeUserResponse])
|
@router.get("/", response_model=list[KnowledgeUserResponse])
|
||||||
async def get_knowledge(user=Depends(get_verified_user)):
|
async def get_knowledge(user=Depends(get_verified_user)):
|
||||||
|
# Return knowledge bases with read access
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
|
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||||
else:
|
else:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
||||||
|
|
||||||
# Get files for each knowledge base
|
return [
|
||||||
knowledge_with_files = []
|
|
||||||
for knowledge_base in knowledge_bases:
|
|
||||||
files = []
|
|
||||||
if knowledge_base.data:
|
|
||||||
files = Files.get_file_metadatas_by_ids(
|
|
||||||
knowledge_base.data.get("file_ids", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if all files exist
|
|
||||||
if len(files) != len(knowledge_base.data.get("file_ids", [])):
|
|
||||||
missing_files = list(
|
|
||||||
set(knowledge_base.data.get("file_ids", []))
|
|
||||||
- set([file.id for file in files])
|
|
||||||
)
|
|
||||||
if missing_files:
|
|
||||||
data = knowledge_base.data or {}
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
for missing_file in missing_files:
|
|
||||||
file_ids.remove(missing_file)
|
|
||||||
|
|
||||||
data["file_ids"] = file_ids
|
|
||||||
Knowledges.update_knowledge_data_by_id(
|
|
||||||
id=knowledge_base.id, data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
knowledge_with_files.append(
|
|
||||||
KnowledgeUserResponse(
|
KnowledgeUserResponse(
|
||||||
**knowledge_base.model_dump(),
|
**knowledge_base.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
|
||||||
)
|
)
|
||||||
)
|
for knowledge_base in knowledge_bases
|
||||||
|
]
|
||||||
return knowledge_with_files
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=list[KnowledgeUserResponse])
|
@router.get("/list", response_model=list[KnowledgeUserResponse])
|
||||||
async def get_knowledge_list(user=Depends(get_verified_user)):
|
async def get_knowledge_list(user=Depends(get_verified_user)):
|
||||||
|
# Return knowledge bases with write access
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
|
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||||
else:
|
else:
|
||||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write")
|
||||||
|
|
||||||
# Get files for each knowledge base
|
return [
|
||||||
knowledge_with_files = []
|
|
||||||
for knowledge_base in knowledge_bases:
|
|
||||||
files = []
|
|
||||||
if knowledge_base.data:
|
|
||||||
files = Files.get_file_metadatas_by_ids(
|
|
||||||
knowledge_base.data.get("file_ids", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if all files exist
|
|
||||||
if len(files) != len(knowledge_base.data.get("file_ids", [])):
|
|
||||||
missing_files = list(
|
|
||||||
set(knowledge_base.data.get("file_ids", []))
|
|
||||||
- set([file.id for file in files])
|
|
||||||
)
|
|
||||||
if missing_files:
|
|
||||||
data = knowledge_base.data or {}
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
for missing_file in missing_files:
|
|
||||||
file_ids.remove(missing_file)
|
|
||||||
|
|
||||||
data["file_ids"] = file_ids
|
|
||||||
Knowledges.update_knowledge_data_by_id(
|
|
||||||
id=knowledge_base.id, data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
knowledge_with_files.append(
|
|
||||||
KnowledgeUserResponse(
|
KnowledgeUserResponse(
|
||||||
**knowledge_base.model_dump(),
|
**knowledge_base.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
|
||||||
)
|
)
|
||||||
)
|
for knowledge_base in knowledge_bases
|
||||||
return knowledge_with_files
|
]
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -191,26 +133,9 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||||
|
|
||||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||||
|
|
||||||
deleted_knowledge_bases = []
|
|
||||||
|
|
||||||
for knowledge_base in knowledge_bases:
|
for knowledge_base in knowledge_bases:
|
||||||
# -- Robust error handling for missing or invalid data
|
|
||||||
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
|
|
||||||
log.warning(
|
|
||||||
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
|
files = Knowledges.get_files_by_id(knowledge_base.id)
|
||||||
deleted_knowledge_bases.append(knowledge_base.id)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(
|
|
||||||
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_ids = knowledge_base.data.get("file_ids", [])
|
|
||||||
files = Files.get_files_by_ids(file_ids)
|
|
||||||
try:
|
try:
|
||||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||||
VECTOR_DB_CLIENT.delete_collection(
|
VECTOR_DB_CLIENT.delete_collection(
|
||||||
|
|
@ -223,7 +148,8 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||||
failed_files = []
|
failed_files = []
|
||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
process_file(
|
await run_in_threadpool(
|
||||||
|
process_file,
|
||||||
request,
|
request,
|
||||||
ProcessFileForm(
|
ProcessFileForm(
|
||||||
file_id=file.id, collection_name=knowledge_base.id
|
file_id=file.id, collection_name=knowledge_base.id
|
||||||
|
|
@ -249,9 +175,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||||
for failed in failed_files:
|
for failed in failed_files:
|
||||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||||
|
|
||||||
log.info(
|
log.info(f"Reindexing completed.")
|
||||||
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -269,19 +193,15 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user.role == "admin"
|
user.role == "admin"
|
||||||
or knowledge.user_id == user.id
|
or knowledge.user_id == user.id
|
||||||
or has_access(user.id, "read", knowledge.access_control)
|
or has_access(user.id, "read", knowledge.access_control)
|
||||||
):
|
):
|
||||||
|
|
||||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -333,12 +253,9 @@ async def update_knowledge_by_id(
|
||||||
|
|
||||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
||||||
if knowledge:
|
if knowledge:
|
||||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -364,7 +281,6 @@ def add_file_to_knowledge_by_id(
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||||
|
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|
@ -393,6 +309,11 @@ def add_file_to_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.FILE_NOT_PROCESSED,
|
detail=ERROR_MESSAGES.FILE_NOT_PROCESSED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add file to knowledge base
|
||||||
|
Knowledges.add_file_to_knowledge_by_id(
|
||||||
|
knowledge_id=id, file_id=form_data.file_id, user_id=user.id
|
||||||
|
)
|
||||||
|
|
||||||
# Add content to the vector database
|
# Add content to the vector database
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
|
|
@ -408,31 +329,9 @@ def add_file_to_knowledge_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
data = knowledge.data or {}
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
if form_data.file_id not in file_ids:
|
|
||||||
file_ids.append(form_data.file_id)
|
|
||||||
data["file_ids"] = file_ids
|
|
||||||
|
|
||||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
|
||||||
|
|
||||||
if knowledge:
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT("file_id"),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -492,14 +391,9 @@ def update_file_from_knowledge_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
data = knowledge.data or {}
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -544,11 +438,19 @@ def remove_file_from_knowledge_by_id(
|
||||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Knowledges.remove_file_from_knowledge_by_id(
|
||||||
|
knowledge_id=id, file_id=form_data.file_id
|
||||||
|
)
|
||||||
|
|
||||||
# Remove content from the vector database
|
# Remove content from the vector database
|
||||||
try:
|
try:
|
||||||
VECTOR_DB_CLIENT.delete(
|
VECTOR_DB_CLIENT.delete(
|
||||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||||
)
|
) # Remove by file_id first
|
||||||
|
|
||||||
|
VECTOR_DB_CLIENT.delete(
|
||||||
|
collection_name=knowledge.id, filter={"hash": file.hash}
|
||||||
|
) # Remove by hash as well in case of duplicates
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug("This was most likely caused by bypassing embedding processing")
|
log.debug("This was most likely caused by bypassing embedding processing")
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
|
|
@ -569,31 +471,9 @@ def remove_file_from_knowledge_by_id(
|
||||||
Files.delete_file_by_id(form_data.file_id)
|
Files.delete_file_by_id(form_data.file_id)
|
||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
data = knowledge.data or {}
|
|
||||||
file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
if form_data.file_id in file_ids:
|
|
||||||
file_ids.remove(form_data.file_id)
|
|
||||||
data["file_ids"] = file_ids
|
|
||||||
|
|
||||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
|
||||||
|
|
||||||
if knowledge:
|
|
||||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=files,
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT("file_id"),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -695,8 +575,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
knowledge = Knowledges.reset_knowledge_by_id(id=id)
|
||||||
|
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -706,7 +585,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
|
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
|
||||||
def add_files_to_knowledge_batch(
|
async def add_files_to_knowledge_batch(
|
||||||
request: Request,
|
request: Request,
|
||||||
id: str,
|
id: str,
|
||||||
form_data: list[KnowledgeFileIdForm],
|
form_data: list[KnowledgeFileIdForm],
|
||||||
|
|
@ -746,7 +625,7 @@ def add_files_to_knowledge_batch(
|
||||||
|
|
||||||
# Process files
|
# Process files
|
||||||
try:
|
try:
|
||||||
result = process_files_batch(
|
result = await process_files_batch(
|
||||||
request=request,
|
request=request,
|
||||||
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
form_data=BatchProcessFilesForm(files=files, collection_name=id),
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -757,25 +636,19 @@ def add_files_to_knowledge_batch(
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
# Add successful files to knowledge base
|
|
||||||
data = knowledge.data or {}
|
|
||||||
existing_file_ids = data.get("file_ids", [])
|
|
||||||
|
|
||||||
# Only add files that were successfully processed
|
# Only add files that were successfully processed
|
||||||
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||||
for file_id in successful_file_ids:
|
for file_id in successful_file_ids:
|
||||||
if file_id not in existing_file_ids:
|
Knowledges.add_file_to_knowledge_by_id(
|
||||||
existing_file_ids.append(file_id)
|
knowledge_id=id, file_id=file_id, user_id=user.id
|
||||||
|
)
|
||||||
data["file_ids"] = existing_file_ids
|
|
||||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
|
||||||
|
|
||||||
# If there were any errors, include them in the response
|
# If there were any errors, include them in the response
|
||||||
if result.errors:
|
if result.errors:
|
||||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
warnings={
|
warnings={
|
||||||
"message": "Some files failed to process",
|
"message": "Some files failed to process",
|
||||||
"errors": error_details,
|
"errors": error_details,
|
||||||
|
|
@ -784,5 +657,5 @@ def add_files_to_knowledge_batch(
|
||||||
|
|
||||||
return KnowledgeFilesResponse(
|
return KnowledgeFilesResponse(
|
||||||
**knowledge.model_dump(),
|
**knowledge.model_dump(),
|
||||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from open_webui.models.memories import Memories, MemoryModel
|
from open_webui.models.memories import Memories, MemoryModel
|
||||||
|
|
@ -17,7 +18,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/ef")
|
@router.get("/ef")
|
||||||
async def get_embeddings(request: Request):
|
async def get_embeddings(request: Request):
|
||||||
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
|
return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")}
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -51,15 +52,15 @@ async def add_memory(
|
||||||
):
|
):
|
||||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||||
|
|
||||||
|
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||||
|
|
||||||
VECTOR_DB_CLIENT.upsert(
|
VECTOR_DB_CLIENT.upsert(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
items=[
|
items=[
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
"vector": vector,
|
||||||
memory.content, user=user
|
|
||||||
),
|
|
||||||
"metadata": {"created_at": memory.created_at},
|
"metadata": {"created_at": memory.created_at},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -86,9 +87,11 @@ async def query_memory(
|
||||||
if not memories:
|
if not memories:
|
||||||
raise HTTPException(status_code=404, detail="No memories found for user")
|
raise HTTPException(status_code=404, detail="No memories found for user")
|
||||||
|
|
||||||
|
vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)
|
||||||
|
|
||||||
results = VECTOR_DB_CLIENT.search(
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
vectors=[vector],
|
||||||
limit=form_data.k,
|
limit=form_data.k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -105,21 +108,28 @@ async def reset_memory_from_vector_db(
|
||||||
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||||
|
|
||||||
memories = Memories.get_memories_by_user_id(user.id)
|
memories = Memories.get_memories_by_user_id(user.id)
|
||||||
|
|
||||||
|
# Generate vectors in parallel
|
||||||
|
vectors = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||||
|
for memory in memories
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
VECTOR_DB_CLIENT.upsert(
|
VECTOR_DB_CLIENT.upsert(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
items=[
|
items=[
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
"vector": vectors[idx],
|
||||||
memory.content, user=user
|
|
||||||
),
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for memory in memories
|
for idx, memory in enumerate(memories)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -164,15 +174,15 @@ async def update_memory_by_id(
|
||||||
raise HTTPException(status_code=404, detail="Memory not found")
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
|
||||||
if form_data.content is not None:
|
if form_data.content is not None:
|
||||||
|
vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
|
||||||
|
|
||||||
VECTOR_DB_CLIENT.upsert(
|
VECTOR_DB_CLIENT.upsert(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
items=[
|
items=[
|
||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
"vector": vector,
|
||||||
memory.content, user=user
|
|
||||||
),
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,12 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
from open_webui.models.models import (
|
from open_webui.models.models import (
|
||||||
ModelForm,
|
ModelForm,
|
||||||
ModelModel,
|
ModelModel,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelUserResponse,
|
ModelListResponse,
|
||||||
Models,
|
Models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -35,17 +36,56 @@ log = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_model_id(model_id: str) -> bool:
|
||||||
|
return model_id and len(model_id) <= 256
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
# GetModels
|
# GetModels
|
||||||
###########################
|
###########################
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=list[ModelUserResponse])
|
PAGE_ITEM_COUNT = 30
|
||||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
|
||||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
|
||||||
return Models.get_models()
|
@router.get(
|
||||||
else:
|
"/list", response_model=ModelListResponse
|
||||||
return Models.get_models_by_user_id(user.id)
|
) # do NOT use "/" as path, conflicts with main.py
|
||||||
|
async def get_models(
|
||||||
|
query: Optional[str] = None,
|
||||||
|
view_option: Optional[str] = None,
|
||||||
|
tag: Optional[str] = None,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[str] = None,
|
||||||
|
page: Optional[int] = 1,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
|
||||||
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
|
page = max(1, page)
|
||||||
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
|
filter = {}
|
||||||
|
if query:
|
||||||
|
filter["query"] = query
|
||||||
|
if view_option:
|
||||||
|
filter["view_option"] = view_option
|
||||||
|
if tag:
|
||||||
|
filter["tag"] = tag
|
||||||
|
if order_by:
|
||||||
|
filter["order_by"] = order_by
|
||||||
|
if direction:
|
||||||
|
filter["direction"] = direction
|
||||||
|
|
||||||
|
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
|
groups = Groups.get_groups_by_member_id(user.id)
|
||||||
|
if groups:
|
||||||
|
filter["group_ids"] = [group.id for group in groups]
|
||||||
|
|
||||||
|
filter["user_id"] = user.id
|
||||||
|
|
||||||
|
return Models.search_models(user.id, filter=filter, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
###########################
|
###########################
|
||||||
|
|
@ -58,6 +98,30 @@ async def get_base_models(user=Depends(get_admin_user)):
|
||||||
return Models.get_base_models()
|
return Models.get_base_models()
|
||||||
|
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# GetModelTags
|
||||||
|
###########################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tags", response_model=list[str])
|
||||||
|
async def get_model_tags(user=Depends(get_verified_user)):
|
||||||
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
|
models = Models.get_models()
|
||||||
|
else:
|
||||||
|
models = Models.get_models_by_user_id(user.id)
|
||||||
|
|
||||||
|
tags_set = set()
|
||||||
|
for model in models:
|
||||||
|
if model.meta:
|
||||||
|
meta = model.meta.model_dump()
|
||||||
|
for tag in meta.get("tags", []):
|
||||||
|
tags_set.add((tag.get("name")))
|
||||||
|
|
||||||
|
tags = [tag for tag in tags_set]
|
||||||
|
tags.sort()
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# CreateNewModel
|
# CreateNewModel
|
||||||
############################
|
############################
|
||||||
|
|
@ -84,6 +148,12 @@ async def create_new_model(
|
||||||
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_valid_model_id(form_data.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = Models.insert_new_model(form_data, user.id)
|
model = Models.insert_new_model(form_data, user.id)
|
||||||
if model:
|
if model:
|
||||||
|
|
@ -101,8 +171,19 @@ async def create_new_model(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export", response_model=list[ModelModel])
|
@router.get("/export", response_model=list[ModelModel])
|
||||||
async def export_models(user=Depends(get_admin_user)):
|
async def export_models(request: Request, user=Depends(get_verified_user)):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
return Models.get_models()
|
return Models.get_models()
|
||||||
|
else:
|
||||||
|
return Models.get_models_by_user_id(user.id)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -116,15 +197,25 @@ class ModelsImportForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/import", response_model=bool)
|
@router.post("/import", response_model=bool)
|
||||||
async def import_models(
|
async def import_models(
|
||||||
user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...)
|
request: Request,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
form_data: ModelsImportForm = (...),
|
||||||
):
|
):
|
||||||
|
if user.role != "admin" and not has_permission(
|
||||||
|
user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
data = form_data.models
|
data = form_data.models
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
for model_data in data:
|
for model_data in data:
|
||||||
# Here, you can add logic to validate model_data if needed
|
# Here, you can add logic to validate model_data if needed
|
||||||
model_id = model_data.get("id")
|
model_id = model_data.get("id")
|
||||||
if model_id:
|
|
||||||
|
if model_id and is_valid_model_id(model_id):
|
||||||
existing_model = Models.get_model_by_id(model_id)
|
existing_model = Models.get_model_by_id(model_id)
|
||||||
if existing_model:
|
if existing_model:
|
||||||
# Update existing model
|
# Update existing model
|
||||||
|
|
@ -170,6 +261,10 @@ async def sync_models(
|
||||||
###########################
|
###########################
|
||||||
|
|
||||||
|
|
||||||
|
class ModelIdForm(BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
|
||||||
@router.get("/model", response_model=Optional[ModelResponse])
|
@router.get("/model", response_model=Optional[ModelResponse])
|
||||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
@ -216,6 +311,7 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
||||||
else:
|
else:
|
||||||
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
||||||
|
|
@ -263,12 +359,10 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.post("/model/update", response_model=Optional[ModelModel])
|
@router.post("/model/update", response_model=Optional[ModelModel])
|
||||||
async def update_model_by_id(
|
async def update_model_by_id(
|
||||||
id: str,
|
|
||||||
form_data: ModelForm,
|
form_data: ModelForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
model = Models.get_model_by_id(id)
|
model = Models.get_model_by_id(form_data.id)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -285,7 +379,7 @@ async def update_model_by_id(
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Models.update_model_by_id(id, form_data)
|
model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -294,9 +388,9 @@ async def update_model_by_id(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/model/delete", response_model=bool)
|
@router.post("/model/delete", response_model=bool)
|
||||||
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)):
|
||||||
model = Models.get_model_by_id(id)
|
model = Models.get_model_by_id(form_data.id)
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -313,7 +407,7 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = Models.delete_model_by_id(id)
|
result = Models.delete_model_by_id(form_data.id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ from urllib.parse import urlparse
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
import requests
|
import requests
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
|
|
@ -82,22 +82,17 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
headers = {
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
}
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
),
|
headers = include_user_info_headers(headers, user)
|
||||||
},
|
|
||||||
|
async with session.get(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
@ -133,28 +128,20 @@ async def send_post_request(
|
||||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
if metadata and metadata.get("chat_id"):
|
||||||
|
headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id")
|
||||||
|
|
||||||
r = await session.post(
|
r = await session.post(
|
||||||
url,
|
url,
|
||||||
data=payload,
|
data=payload,
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
**(
|
|
||||||
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
|
||||||
if metadata and metadata.get("chat_id")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -246,21 +233,16 @@ async def verify_connection(
|
||||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{url}/api/version",
|
f"{url}/api/version",
|
||||||
headers={
|
headers=headers,
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as r:
|
) as r:
|
||||||
if r.status != 200:
|
if r.status != 200:
|
||||||
|
|
@ -469,22 +451,17 @@ async def get_ollama_tags(
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="GET",
|
method="GET",
|
||||||
url=f"{url}/api/tags",
|
url=f"{url}/api/tags",
|
||||||
headers={
|
headers=headers,
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -838,23 +815,18 @@ async def copy_model(
|
||||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/api/copy",
|
url=f"{url}/api/copy",
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -907,25 +879,21 @@ async def delete_model(
|
||||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
|
|
||||||
|
r = None
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="DELETE",
|
method="DELETE",
|
||||||
url=f"{url}/api/delete",
|
url=f"{url}/api/delete",
|
||||||
data=json.dumps(form_data).encode(),
|
headers=headers,
|
||||||
headers={
|
json=form_data,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -973,24 +941,16 @@ async def show_model_info(
|
||||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = requests.request(
|
headers = {
|
||||||
method="POST",
|
|
||||||
url=f"{url}/api/show",
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
}
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
),
|
headers = include_user_info_headers(headers, user)
|
||||||
},
|
|
||||||
data=json.dumps(form_data).encode(),
|
r = requests.request(
|
||||||
|
method="POST", url=f"{url}/api/show", headers=headers, json=form_data
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
@ -1020,6 +980,10 @@ class GenerateEmbedForm(BaseModel):
|
||||||
options: Optional[dict] = None
|
options: Optional[dict] = None
|
||||||
keep_alive: Optional[Union[int, str]] = None
|
keep_alive: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
extra="allow",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/embed")
|
@router.post("/api/embed")
|
||||||
@router.post("/api/embed/{url_idx}")
|
@router.post("/api/embed/{url_idx}")
|
||||||
|
|
@ -1060,23 +1024,18 @@ async def embed(
|
||||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/api/embed",
|
url=f"{url}/api/embed",
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
@ -1147,23 +1106,18 @@ async def embeddings(
|
||||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/api/embeddings",
|
url=f"{url}/api/embeddings",
|
||||||
headers={
|
headers=headers,
|
||||||
"Content-Type": "application/json",
|
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from typing import Optional
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
import requests
|
import requests
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||||
|
|
||||||
|
|
@ -45,10 +44,12 @@ from open_webui.utils.payload import (
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
convert_logit_bias_input_to_json,
|
convert_logit_bias_input_to_json,
|
||||||
|
stream_chunks_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
from open_webui.utils.headers import include_user_info_headers
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -66,21 +67,16 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
|
headers = {
|
||||||
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url,
|
url,
|
||||||
headers={
|
headers=headers,
|
||||||
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as response:
|
) as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
@ -140,23 +136,13 @@ async def get_headers_and_cookies(
|
||||||
if "openrouter.ai" in url
|
if "openrouter.ai" in url
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
**(
|
|
||||||
{
|
|
||||||
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
|
||||||
"X-OpenWebUI-User-Id": user.id,
|
|
||||||
"X-OpenWebUI-User-Email": user.email,
|
|
||||||
"X-OpenWebUI-User-Role": user.role,
|
|
||||||
**(
|
|
||||||
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
|
|
||||||
if metadata and metadata.get("chat_id")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
|
headers = include_user_info_headers(headers, user)
|
||||||
|
if metadata and metadata.get("chat_id"):
|
||||||
|
headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id")
|
||||||
|
|
||||||
token = None
|
token = None
|
||||||
auth_type = config.get("auth_type")
|
auth_type = config.get("auth_type")
|
||||||
|
|
||||||
|
|
@ -190,6 +176,9 @@ async def get_headers_and_cookies(
|
||||||
if token:
|
if token:
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
if config.get("headers") and isinstance(config.get("headers"), dict):
|
||||||
|
headers = {**headers, **config.get("headers")}
|
||||||
|
|
||||||
return headers, cookies
|
return headers, cookies
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -498,30 +487,9 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||||
return response
|
return response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def merge_models_lists(model_lists):
|
def is_supported_openai_models(model_id):
|
||||||
log.debug(f"merge_models_lists {model_lists}")
|
if any(
|
||||||
merged_list = []
|
name in model_id
|
||||||
|
|
||||||
for idx, models in enumerate(model_lists):
|
|
||||||
if models is not None and "error" not in models:
|
|
||||||
|
|
||||||
merged_list.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
**model,
|
|
||||||
"name": model.get("name", model["id"]),
|
|
||||||
"owned_by": "openai",
|
|
||||||
"openai": model,
|
|
||||||
"connection_type": model.get("connection_type", "external"),
|
|
||||||
"urlIdx": idx,
|
|
||||||
}
|
|
||||||
for model in models
|
|
||||||
if (model.get("id") or model.get("name"))
|
|
||||||
and (
|
|
||||||
"api.openai.com"
|
|
||||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
||||||
or not any(
|
|
||||||
name in model["id"]
|
|
||||||
for name in [
|
for name in [
|
||||||
"babbage",
|
"babbage",
|
||||||
"dall-e",
|
"dall-e",
|
||||||
|
|
@ -530,18 +498,44 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||||
"tts",
|
"tts",
|
||||||
"whisper",
|
"whisper",
|
||||||
]
|
]
|
||||||
)
|
):
|
||||||
)
|
return False
|
||||||
]
|
return True
|
||||||
)
|
|
||||||
|
|
||||||
return merged_list
|
def get_merged_models(model_lists):
|
||||||
|
log.debug(f"merge_models_lists {model_lists}")
|
||||||
|
models = {}
|
||||||
|
|
||||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
for idx, model_list in enumerate(model_lists):
|
||||||
|
if model_list is not None and "error" not in model_list:
|
||||||
|
for model in model_list:
|
||||||
|
model_id = model.get("id") or model.get("name")
|
||||||
|
|
||||||
|
if (
|
||||||
|
"api.openai.com"
|
||||||
|
in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
|
and not is_supported_openai_models(model_id)
|
||||||
|
):
|
||||||
|
# Skip unwanted OpenAI models
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model_id and model_id not in models:
|
||||||
|
models[model_id] = {
|
||||||
|
**model,
|
||||||
|
"name": model.get("name", model_id),
|
||||||
|
"owned_by": "openai",
|
||||||
|
"openai": model,
|
||||||
|
"connection_type": model.get("connection_type", "external"),
|
||||||
|
"urlIdx": idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
models = get_merged_models(map(extract_data, responses))
|
||||||
log.debug(f"models: {models}")
|
log.debug(f"models: {models}")
|
||||||
|
|
||||||
request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
|
request.app.state.OPENAI_MODELS = models
|
||||||
return models
|
return {"data": list(models.values())}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
|
|
@ -754,6 +748,7 @@ def get_azure_allowed_params(api_version: str) -> set[str]:
|
||||||
"response_format",
|
"response_format",
|
||||||
"seed",
|
"seed",
|
||||||
"max_completion_tokens",
|
"max_completion_tokens",
|
||||||
|
"reasoning_effort",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -944,7 +939,7 @@ async def generate_chat_completion(
|
||||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||||
streaming = True
|
streaming = True
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
r.content,
|
stream_chunks_handler(r.content),
|
||||||
status_code=r.status,
|
status_code=r.status,
|
||||||
headers=dict(r.headers),
|
headers=dict(r.headers),
|
||||||
background=BackgroundTask(
|
background=BackgroundTask(
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,15 @@ async def get_prompt_list(user=Depends(get_verified_user)):
|
||||||
async def create_new_prompt(
|
async def create_new_prompt(
|
||||||
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
|
request: Request, form_data: PromptForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not (
|
||||||
|
has_permission(
|
||||||
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
or has_permission(
|
||||||
|
user.id,
|
||||||
|
"workspace.prompts_import",
|
||||||
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -31,7 +32,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSpl
|
||||||
from langchain_text_splitters import MarkdownHeaderTextSplitter
|
from langchain_text_splitters import MarkdownHeaderTextSplitter
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.models.files import FileModel, Files
|
from open_webui.models.files import FileModel, FileUpdateForm, Files
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
from open_webui.storage.provider import Storage
|
from open_webui.storage.provider import Storage
|
||||||
|
|
||||||
|
|
@ -63,6 +64,7 @@ from open_webui.retrieval.web.serply import search_serply
|
||||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||||
from open_webui.retrieval.web.tavily import search_tavily
|
from open_webui.retrieval.web.tavily import search_tavily
|
||||||
from open_webui.retrieval.web.bing import search_bing
|
from open_webui.retrieval.web.bing import search_bing
|
||||||
|
from open_webui.retrieval.web.azure import search_azure
|
||||||
from open_webui.retrieval.web.exa import search_exa
|
from open_webui.retrieval.web.exa import search_exa
|
||||||
from open_webui.retrieval.web.perplexity import search_perplexity
|
from open_webui.retrieval.web.perplexity import search_perplexity
|
||||||
from open_webui.retrieval.web.sougou import search_sougou
|
from open_webui.retrieval.web.sougou import search_sougou
|
||||||
|
|
@ -70,6 +72,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
|
||||||
from open_webui.retrieval.web.external import search_external
|
from open_webui.retrieval.web.external import search_external
|
||||||
|
|
||||||
from open_webui.retrieval.utils import (
|
from open_webui.retrieval.utils import (
|
||||||
|
get_content_from_url,
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
get_reranking_function,
|
get_reranking_function,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
|
|
@ -120,7 +123,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
def get_ef(
|
def get_ef(
|
||||||
engine: str,
|
engine: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
auto_update: bool = False,
|
auto_update: bool = RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
):
|
):
|
||||||
ef = None
|
ef = None
|
||||||
if embedding_model and engine == "":
|
if embedding_model and engine == "":
|
||||||
|
|
@ -145,7 +148,7 @@ def get_rf(
|
||||||
reranking_model: Optional[str] = None,
|
reranking_model: Optional[str] = None,
|
||||||
external_reranker_url: str = "",
|
external_reranker_url: str = "",
|
||||||
external_reranker_api_key: str = "",
|
external_reranker_api_key: str = "",
|
||||||
auto_update: bool = False,
|
auto_update: bool = RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
):
|
):
|
||||||
rf = None
|
rf = None
|
||||||
if reranking_model:
|
if reranking_model:
|
||||||
|
|
@ -189,6 +192,26 @@ def get_rf(
|
||||||
log.error(f"CrossEncoder: {e}")
|
log.error(f"CrossEncoder: {e}")
|
||||||
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
||||||
|
|
||||||
|
# Safely adjust pad_token_id if missing as some models do not have this in config
|
||||||
|
try:
|
||||||
|
model_cfg = getattr(rf, "model", None)
|
||||||
|
if model_cfg and hasattr(model_cfg, "config"):
|
||||||
|
cfg = model_cfg.config
|
||||||
|
if getattr(cfg, "pad_token_id", None) is None:
|
||||||
|
# Fallback to eos_token_id when available
|
||||||
|
eos = getattr(cfg, "eos_token_id", None)
|
||||||
|
if eos is not None:
|
||||||
|
cfg.pad_token_id = eos
|
||||||
|
log.debug(
|
||||||
|
f"Missing pad_token_id detected; set to eos_token_id={eos}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"Neither pad_token_id nor eos_token_id present in model config"
|
||||||
|
)
|
||||||
|
except Exception as e2:
|
||||||
|
log.warning(f"Failed to adjust pad_token_id on CrossEncoder: {e2}")
|
||||||
|
|
||||||
return rf
|
return rf
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -218,13 +241,14 @@ class SearchForm(BaseModel):
|
||||||
async def get_status(request: Request):
|
async def get_status(request: Request):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"chunk_size": request.app.state.config.CHUNK_SIZE,
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||||||
"chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
|
"CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP,
|
||||||
"template": request.app.state.config.RAG_TEMPLATE,
|
"RAG_TEMPLATE": request.app.state.config.RAG_TEMPLATE,
|
||||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
"RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
"RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
"ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -232,9 +256,10 @@ async def get_status(request: Request):
|
||||||
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
"RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
"RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
"ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||||
|
|
@ -271,18 +296,13 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||||
openai_config: Optional[OpenAIConfigForm] = None
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
ollama_config: Optional[OllamaConfigForm] = None
|
ollama_config: Optional[OllamaConfigForm] = None
|
||||||
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
||||||
embedding_engine: str
|
RAG_EMBEDDING_ENGINE: str
|
||||||
embedding_model: str
|
RAG_EMBEDDING_MODEL: str
|
||||||
embedding_batch_size: Optional[int] = 1
|
RAG_EMBEDDING_BATCH_SIZE: Optional[int] = 1
|
||||||
|
ENABLE_ASYNC_EMBEDDING: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
@router.post("/embedding/update")
|
def unload_embedding_model(request: Request):
|
||||||
async def update_embedding_config(
|
|
||||||
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
|
||||||
):
|
|
||||||
log.info(
|
|
||||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
||||||
)
|
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
||||||
# unloads current internal embedding model and clears VRAM cache
|
# unloads current internal embedding model and clears VRAM cache
|
||||||
request.app.state.ef = None
|
request.app.state.ef = None
|
||||||
|
|
@ -295,9 +315,25 @@ async def update_embedding_config(
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embedding/update")
|
||||||
|
async def update_embedding_config(
|
||||||
|
request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
log.info(
|
||||||
|
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.RAG_EMBEDDING_MODEL}"
|
||||||
|
)
|
||||||
|
unload_embedding_model(request)
|
||||||
try:
|
try:
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.RAG_EMBEDDING_ENGINE
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.RAG_EMBEDDING_MODEL
|
||||||
|
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||||
|
form_data.RAG_EMBEDDING_BATCH_SIZE
|
||||||
|
)
|
||||||
|
request.app.state.config.ENABLE_ASYNC_EMBEDDING = (
|
||||||
|
form_data.ENABLE_ASYNC_EMBEDDING
|
||||||
|
)
|
||||||
|
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||||||
"ollama",
|
"ollama",
|
||||||
|
|
@ -331,10 +367,6 @@ async def update_embedding_config(
|
||||||
form_data.azure_openai_config.version
|
form_data.azure_openai_config.version
|
||||||
)
|
)
|
||||||
|
|
||||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
|
||||||
form_data.embedding_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
request.app.state.ef = get_ef(
|
request.app.state.ef = get_ef(
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
|
@ -368,13 +400,15 @@ async def update_embedding_config(
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
enable_async=request.app.state.config.ENABLE_ASYNC_EMBEDDING,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
"RAG_EMBEDDING_ENGINE": request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
|
"RAG_EMBEDDING_MODEL": request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
"embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
"RAG_EMBEDDING_BATCH_SIZE": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
|
"ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING,
|
||||||
"openai_config": {
|
"openai_config": {
|
||||||
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
"url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
|
||||||
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
"key": request.app.state.config.RAG_OPENAI_API_KEY,
|
||||||
|
|
@ -408,6 +442,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
|
||||||
# Hybrid search settings
|
# Hybrid search settings
|
||||||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
|
"ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
|
||||||
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
||||||
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
||||||
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||||
|
|
@ -429,20 +464,18 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||||||
"DOCLING_DO_OCR": request.app.state.config.DOCLING_DO_OCR,
|
"DOCLING_API_KEY": request.app.state.config.DOCLING_API_KEY,
|
||||||
"DOCLING_FORCE_OCR": request.app.state.config.DOCLING_FORCE_OCR,
|
"DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS,
|
||||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
|
||||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
|
||||||
"DOCLING_PDF_BACKEND": request.app.state.config.DOCLING_PDF_BACKEND,
|
|
||||||
"DOCLING_TABLE_MODE": request.app.state.config.DOCLING_TABLE_MODE,
|
|
||||||
"DOCLING_PIPELINE": request.app.state.config.DOCLING_PIPELINE,
|
|
||||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
|
||||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
"DOCUMENT_INTELLIGENCE_MODEL": request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL,
|
||||||
|
"MISTRAL_OCR_API_BASE_URL": request.app.state.config.MISTRAL_OCR_API_BASE_URL,
|
||||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||||
|
# MinerU settings
|
||||||
|
"MINERU_API_MODE": request.app.state.config.MINERU_API_MODE,
|
||||||
|
"MINERU_API_URL": request.app.state.config.MINERU_API_URL,
|
||||||
|
"MINERU_API_KEY": request.app.state.config.MINERU_API_KEY,
|
||||||
|
"MINERU_PARAMS": request.app.state.config.MINERU_PARAMS,
|
||||||
# Reranking settings
|
# Reranking settings
|
||||||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
|
@ -499,6 +532,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||||
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
||||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||||
|
"PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL,
|
||||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||||
|
|
@ -556,6 +590,7 @@ class WebConfig(BaseModel):
|
||||||
PERPLEXITY_API_KEY: Optional[str] = None
|
PERPLEXITY_API_KEY: Optional[str] = None
|
||||||
PERPLEXITY_MODEL: Optional[str] = None
|
PERPLEXITY_MODEL: Optional[str] = None
|
||||||
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
|
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
|
||||||
|
PERPLEXITY_SEARCH_API_URL: Optional[str] = None
|
||||||
SOUGOU_API_SID: Optional[str] = None
|
SOUGOU_API_SID: Optional[str] = None
|
||||||
SOUGOU_API_SK: Optional[str] = None
|
SOUGOU_API_SK: Optional[str] = None
|
||||||
WEB_LOADER_ENGINE: Optional[str] = None
|
WEB_LOADER_ENGINE: Optional[str] = None
|
||||||
|
|
@ -583,6 +618,7 @@ class ConfigForm(BaseModel):
|
||||||
|
|
||||||
# Hybrid search settings
|
# Hybrid search settings
|
||||||
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
||||||
|
ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS: Optional[bool] = None
|
||||||
TOP_K_RERANKER: Optional[int] = None
|
TOP_K_RERANKER: Optional[int] = None
|
||||||
RELEVANCE_THRESHOLD: Optional[float] = None
|
RELEVANCE_THRESHOLD: Optional[float] = None
|
||||||
HYBRID_BM25_WEIGHT: Optional[float] = None
|
HYBRID_BM25_WEIGHT: Optional[float] = None
|
||||||
|
|
@ -590,6 +626,7 @@ class ConfigForm(BaseModel):
|
||||||
# Content extraction settings
|
# Content extraction settings
|
||||||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||||||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||||||
|
|
||||||
DATALAB_MARKER_API_KEY: Optional[str] = None
|
DATALAB_MARKER_API_KEY: Optional[str] = None
|
||||||
DATALAB_MARKER_API_BASE_URL: Optional[str] = None
|
DATALAB_MARKER_API_BASE_URL: Optional[str] = None
|
||||||
DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None
|
DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None
|
||||||
|
|
@ -601,26 +638,26 @@ class ConfigForm(BaseModel):
|
||||||
DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None
|
DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None
|
||||||
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
||||||
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
||||||
|
|
||||||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||||||
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
TIKA_SERVER_URL: Optional[str] = None
|
TIKA_SERVER_URL: Optional[str] = None
|
||||||
DOCLING_SERVER_URL: Optional[str] = None
|
DOCLING_SERVER_URL: Optional[str] = None
|
||||||
DOCLING_DO_OCR: Optional[bool] = None
|
DOCLING_API_KEY: Optional[str] = None
|
||||||
DOCLING_FORCE_OCR: Optional[bool] = None
|
DOCLING_PARAMS: Optional[dict] = None
|
||||||
DOCLING_OCR_ENGINE: Optional[str] = None
|
|
||||||
DOCLING_OCR_LANG: Optional[str] = None
|
|
||||||
DOCLING_PDF_BACKEND: Optional[str] = None
|
|
||||||
DOCLING_TABLE_MODE: Optional[str] = None
|
|
||||||
DOCLING_PIPELINE: Optional[str] = None
|
|
||||||
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
|
|
||||||
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
|
|
||||||
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
||||||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||||||
|
DOCUMENT_INTELLIGENCE_MODEL: Optional[str] = None
|
||||||
|
MISTRAL_OCR_API_BASE_URL: Optional[str] = None
|
||||||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
|
# MinerU settings
|
||||||
|
MINERU_API_MODE: Optional[str] = None
|
||||||
|
MINERU_API_URL: Optional[str] = None
|
||||||
|
MINERU_API_KEY: Optional[str] = None
|
||||||
|
MINERU_PARAMS: Optional[dict] = None
|
||||||
|
|
||||||
# Reranking settings
|
# Reranking settings
|
||||||
RAG_RERANKING_MODEL: Optional[str] = None
|
RAG_RERANKING_MODEL: Optional[str] = None
|
||||||
RAG_RERANKING_ENGINE: Optional[str] = None
|
RAG_RERANKING_ENGINE: Optional[str] = None
|
||||||
|
|
@ -679,6 +716,11 @@ async def update_rag_config(
|
||||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||||
)
|
)
|
||||||
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS = (
|
||||||
|
form_data.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
|
||||||
|
if form_data.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS is not None
|
||||||
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.TOP_K_RERANKER = (
|
request.app.state.config.TOP_K_RERANKER = (
|
||||||
form_data.TOP_K_RERANKER
|
form_data.TOP_K_RERANKER
|
||||||
|
|
@ -782,63 +824,16 @@ async def update_rag_config(
|
||||||
if form_data.DOCLING_SERVER_URL is not None
|
if form_data.DOCLING_SERVER_URL is not None
|
||||||
else request.app.state.config.DOCLING_SERVER_URL
|
else request.app.state.config.DOCLING_SERVER_URL
|
||||||
)
|
)
|
||||||
request.app.state.config.DOCLING_DO_OCR = (
|
request.app.state.config.DOCLING_API_KEY = (
|
||||||
form_data.DOCLING_DO_OCR
|
form_data.DOCLING_API_KEY
|
||||||
if form_data.DOCLING_DO_OCR is not None
|
if form_data.DOCLING_API_KEY is not None
|
||||||
else request.app.state.config.DOCLING_DO_OCR
|
else request.app.state.config.DOCLING_API_KEY
|
||||||
)
|
)
|
||||||
request.app.state.config.DOCLING_FORCE_OCR = (
|
request.app.state.config.DOCLING_PARAMS = (
|
||||||
form_data.DOCLING_FORCE_OCR
|
form_data.DOCLING_PARAMS
|
||||||
if form_data.DOCLING_FORCE_OCR is not None
|
if form_data.DOCLING_PARAMS is not None
|
||||||
else request.app.state.config.DOCLING_FORCE_OCR
|
else request.app.state.config.DOCLING_PARAMS
|
||||||
)
|
)
|
||||||
request.app.state.config.DOCLING_OCR_ENGINE = (
|
|
||||||
form_data.DOCLING_OCR_ENGINE
|
|
||||||
if form_data.DOCLING_OCR_ENGINE is not None
|
|
||||||
else request.app.state.config.DOCLING_OCR_ENGINE
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_OCR_LANG = (
|
|
||||||
form_data.DOCLING_OCR_LANG
|
|
||||||
if form_data.DOCLING_OCR_LANG is not None
|
|
||||||
else request.app.state.config.DOCLING_OCR_LANG
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_PDF_BACKEND = (
|
|
||||||
form_data.DOCLING_PDF_BACKEND
|
|
||||||
if form_data.DOCLING_PDF_BACKEND is not None
|
|
||||||
else request.app.state.config.DOCLING_PDF_BACKEND
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_TABLE_MODE = (
|
|
||||||
form_data.DOCLING_TABLE_MODE
|
|
||||||
if form_data.DOCLING_TABLE_MODE is not None
|
|
||||||
else request.app.state.config.DOCLING_TABLE_MODE
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_PIPELINE = (
|
|
||||||
form_data.DOCLING_PIPELINE
|
|
||||||
if form_data.DOCLING_PIPELINE is not None
|
|
||||||
else request.app.state.config.DOCLING_PIPELINE
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = (
|
|
||||||
form_data.DOCLING_DO_PICTURE_DESCRIPTION
|
|
||||||
if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None
|
|
||||||
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
|
||||||
)
|
|
||||||
|
|
||||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
|
|
||||||
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
|
|
||||||
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
|
|
||||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
|
|
||||||
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
||||||
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
|
|
||||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
|
||||||
)
|
|
||||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
|
|
||||||
form_data.DOCLING_PICTURE_DESCRIPTION_API
|
|
||||||
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
|
|
||||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
|
||||||
)
|
|
||||||
|
|
||||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||||
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||||
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
||||||
|
|
@ -849,12 +844,45 @@ async def update_rag_config(
|
||||||
if form_data.DOCUMENT_INTELLIGENCE_KEY is not None
|
if form_data.DOCUMENT_INTELLIGENCE_KEY is not None
|
||||||
else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
else request.app.state.config.DOCUMENT_INTELLIGENCE_KEY
|
||||||
)
|
)
|
||||||
|
request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL = (
|
||||||
|
form_data.DOCUMENT_INTELLIGENCE_MODEL
|
||||||
|
if form_data.DOCUMENT_INTELLIGENCE_MODEL is not None
|
||||||
|
else request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.MISTRAL_OCR_API_BASE_URL = (
|
||||||
|
form_data.MISTRAL_OCR_API_BASE_URL
|
||||||
|
if form_data.MISTRAL_OCR_API_BASE_URL is not None
|
||||||
|
else request.app.state.config.MISTRAL_OCR_API_BASE_URL
|
||||||
|
)
|
||||||
request.app.state.config.MISTRAL_OCR_API_KEY = (
|
request.app.state.config.MISTRAL_OCR_API_KEY = (
|
||||||
form_data.MISTRAL_OCR_API_KEY
|
form_data.MISTRAL_OCR_API_KEY
|
||||||
if form_data.MISTRAL_OCR_API_KEY is not None
|
if form_data.MISTRAL_OCR_API_KEY is not None
|
||||||
else request.app.state.config.MISTRAL_OCR_API_KEY
|
else request.app.state.config.MISTRAL_OCR_API_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MinerU settings
|
||||||
|
request.app.state.config.MINERU_API_MODE = (
|
||||||
|
form_data.MINERU_API_MODE
|
||||||
|
if form_data.MINERU_API_MODE is not None
|
||||||
|
else request.app.state.config.MINERU_API_MODE
|
||||||
|
)
|
||||||
|
request.app.state.config.MINERU_API_URL = (
|
||||||
|
form_data.MINERU_API_URL
|
||||||
|
if form_data.MINERU_API_URL is not None
|
||||||
|
else request.app.state.config.MINERU_API_URL
|
||||||
|
)
|
||||||
|
request.app.state.config.MINERU_API_KEY = (
|
||||||
|
form_data.MINERU_API_KEY
|
||||||
|
if form_data.MINERU_API_KEY is not None
|
||||||
|
else request.app.state.config.MINERU_API_KEY
|
||||||
|
)
|
||||||
|
request.app.state.config.MINERU_PARAMS = (
|
||||||
|
form_data.MINERU_PARAMS
|
||||||
|
if form_data.MINERU_PARAMS is not None
|
||||||
|
else request.app.state.config.MINERU_PARAMS
|
||||||
|
)
|
||||||
|
|
||||||
# Reranking settings
|
# Reranking settings
|
||||||
if request.app.state.config.RAG_RERANKING_ENGINE == "":
|
if request.app.state.config.RAG_RERANKING_ENGINE == "":
|
||||||
# Unloading the internal reranker and clear VRAM memory
|
# Unloading the internal reranker and clear VRAM memory
|
||||||
|
|
@ -906,7 +934,6 @@ async def update_rag_config(
|
||||||
request.app.state.config.RAG_RERANKING_MODEL,
|
request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||||
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
||||||
|
|
@ -1036,6 +1063,9 @@ async def update_rag_config(
|
||||||
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
|
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
|
||||||
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
|
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||||
)
|
)
|
||||||
|
request.app.state.config.PERPLEXITY_SEARCH_API_URL = (
|
||||||
|
form_data.web.PERPLEXITY_SEARCH_API_URL
|
||||||
|
)
|
||||||
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
||||||
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
||||||
|
|
||||||
|
|
@ -1104,20 +1134,18 @@ async def update_rag_config(
|
||||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||||||
"DOCLING_DO_OCR": request.app.state.config.DOCLING_DO_OCR,
|
"DOCLING_API_KEY": request.app.state.config.DOCLING_API_KEY,
|
||||||
"DOCLING_FORCE_OCR": request.app.state.config.DOCLING_FORCE_OCR,
|
"DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS,
|
||||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
|
||||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
|
||||||
"DOCLING_PDF_BACKEND": request.app.state.config.DOCLING_PDF_BACKEND,
|
|
||||||
"DOCLING_TABLE_MODE": request.app.state.config.DOCLING_TABLE_MODE,
|
|
||||||
"DOCLING_PIPELINE": request.app.state.config.DOCLING_PIPELINE,
|
|
||||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
|
||||||
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
|
||||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
"DOCUMENT_INTELLIGENCE_MODEL": request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL,
|
||||||
|
"MISTRAL_OCR_API_BASE_URL": request.app.state.config.MISTRAL_OCR_API_BASE_URL,
|
||||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||||
|
# MinerU settings
|
||||||
|
"MINERU_API_MODE": request.app.state.config.MINERU_API_MODE,
|
||||||
|
"MINERU_API_URL": request.app.state.config.MINERU_API_URL,
|
||||||
|
"MINERU_API_KEY": request.app.state.config.MINERU_API_KEY,
|
||||||
|
"MINERU_PARAMS": request.app.state.config.MINERU_PARAMS,
|
||||||
# Reranking settings
|
# Reranking settings
|
||||||
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
|
@ -1174,6 +1202,7 @@ async def update_rag_config(
|
||||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||||
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
||||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||||
|
"PERPLEXITY_SEARCH_API_URL": request.app.state.config.PERPLEXITY_SEARCH_API_URL,
|
||||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||||
|
|
@ -1227,7 +1256,7 @@ def save_docs_to_vector_db(
|
||||||
|
|
||||||
return ", ".join(docs_info)
|
return ", ".join(docs_info)
|
||||||
|
|
||||||
log.info(
|
log.debug(
|
||||||
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1374,11 +1403,14 @@ def save_docs_to_vector_db(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = embedding_function(
|
# Run async embedding in sync context
|
||||||
|
embeddings = asyncio.run(
|
||||||
|
embedding_function(
|
||||||
list(map(lambda x: x.replace("\n", " "), texts)),
|
list(map(lambda x: x.replace("\n", " "), texts)),
|
||||||
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
|
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
|
||||||
|
|
||||||
items = [
|
items = [
|
||||||
|
|
@ -1416,6 +1448,9 @@ def process_file(
|
||||||
form_data: ProcessFileForm,
|
form_data: ProcessFileForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Process a file and save its content to the vector database.
|
||||||
|
"""
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
file = Files.get_file_by_id(form_data.file_id)
|
file = Files.get_file_by_id(form_data.file_id)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1495,6 +1530,7 @@ def process_file(
|
||||||
file_path = Storage.get_file(file_path)
|
file_path = Storage.get_file(file_path)
|
||||||
loader = Loader(
|
loader = Loader(
|
||||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||||
|
user=user,
|
||||||
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||||
DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
|
||||||
DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
|
||||||
|
|
@ -1510,23 +1546,18 @@ def process_file(
|
||||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||||
DOCLING_PARAMS={
|
DOCLING_API_KEY=request.app.state.config.DOCLING_API_KEY,
|
||||||
"do_ocr": request.app.state.config.DOCLING_DO_OCR,
|
DOCLING_PARAMS=request.app.state.config.DOCLING_PARAMS,
|
||||||
"force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
|
|
||||||
"ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
|
|
||||||
"ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
|
|
||||||
"pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
|
|
||||||
"table_mode": request.app.state.config.DOCLING_TABLE_MODE,
|
|
||||||
"pipeline": request.app.state.config.DOCLING_PIPELINE,
|
|
||||||
"do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
|
||||||
"picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
|
||||||
"picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
|
||||||
"picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
|
||||||
},
|
|
||||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
|
DOCUMENT_INTELLIGENCE_MODEL=request.app.state.config.DOCUMENT_INTELLIGENCE_MODEL,
|
||||||
|
MISTRAL_OCR_API_BASE_URL=request.app.state.config.MISTRAL_OCR_API_BASE_URL,
|
||||||
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
|
MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||||
|
MINERU_API_MODE=request.app.state.config.MINERU_API_MODE,
|
||||||
|
MINERU_API_URL=request.app.state.config.MINERU_API_URL,
|
||||||
|
MINERU_API_KEY=request.app.state.config.MINERU_API_KEY,
|
||||||
|
MINERU_PARAMS=request.app.state.config.MINERU_PARAMS,
|
||||||
)
|
)
|
||||||
docs = loader.load(
|
docs = loader.load(
|
||||||
file.filename, file.meta.get("content_type"), file_path
|
file.filename, file.meta.get("content_type"), file_path
|
||||||
|
|
@ -1647,7 +1678,7 @@ class ProcessTextForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/text")
|
@router.post("/process/text")
|
||||||
def process_text(
|
async def process_text(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: ProcessTextForm,
|
form_data: ProcessTextForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
|
@ -1665,7 +1696,9 @@ def process_text(
|
||||||
text_content = form_data.content
|
text_content = form_data.content
|
||||||
log.debug(f"text_content: {text_content}")
|
log.debug(f"text_content: {text_content}")
|
||||||
|
|
||||||
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
result = await run_in_threadpool(
|
||||||
|
save_docs_to_vector_db, request, docs, collection_name, user=user
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
|
@ -1680,51 +1713,8 @@ def process_text(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/youtube")
|
@router.post("/process/youtube")
|
||||||
def process_youtube_video(
|
|
||||||
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
collection_name = form_data.collection_name
|
|
||||||
if not collection_name:
|
|
||||||
collection_name = calculate_sha256_string(form_data.url)[:63]
|
|
||||||
|
|
||||||
loader = YoutubeLoader(
|
|
||||||
form_data.url,
|
|
||||||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
|
||||||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = loader.load()
|
|
||||||
content = " ".join([doc.page_content for doc in docs])
|
|
||||||
log.debug(f"text_content: {content}")
|
|
||||||
|
|
||||||
save_docs_to_vector_db(
|
|
||||||
request, docs, collection_name, overwrite=True, user=user
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": True,
|
|
||||||
"collection_name": collection_name,
|
|
||||||
"filename": form_data.url,
|
|
||||||
"file": {
|
|
||||||
"data": {
|
|
||||||
"content": content,
|
|
||||||
},
|
|
||||||
"meta": {
|
|
||||||
"name": form_data.url,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/web")
|
@router.post("/process/web")
|
||||||
def process_web(
|
async def process_web(
|
||||||
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
@ -1732,19 +1722,19 @@ def process_web(
|
||||||
if not collection_name:
|
if not collection_name:
|
||||||
collection_name = calculate_sha256_string(form_data.url)[:63]
|
collection_name = calculate_sha256_string(form_data.url)[:63]
|
||||||
|
|
||||||
loader = get_web_loader(
|
content, docs = await run_in_threadpool(
|
||||||
form_data.url,
|
get_content_from_url, request, form_data.url
|
||||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
||||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
|
||||||
)
|
)
|
||||||
docs = loader.load()
|
|
||||||
content = " ".join([doc.page_content for doc in docs])
|
|
||||||
|
|
||||||
log.debug(f"text_content: {content}")
|
log.debug(f"text_content: {content}")
|
||||||
|
|
||||||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||||
save_docs_to_vector_db(
|
await run_in_threadpool(
|
||||||
request, docs, collection_name, overwrite=True, user=user
|
save_docs_to_vector_db,
|
||||||
|
request,
|
||||||
|
docs,
|
||||||
|
collection_name,
|
||||||
|
overwrite=True,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
collection_name = None
|
collection_name = None
|
||||||
|
|
@ -1771,7 +1761,9 @@ def process_web(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
def search_web(
|
||||||
|
request: Request, engine: str, query: str, user=None
|
||||||
|
) -> list[SearchResult]:
|
||||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||||
Will look for a search engine API key in environment variables in the following order:
|
Will look for a search engine API key in environment variables in the following order:
|
||||||
- SEARXNG_QUERY_URL
|
- SEARXNG_QUERY_URL
|
||||||
|
|
@ -1810,6 +1802,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
query,
|
query,
|
||||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
request.app.state.config.PERPLEXITY_SEARCH_API_URL,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No PERPLEXITY_API_KEY found in environment variables")
|
raise Exception("No PERPLEXITY_API_KEY found in environment variables")
|
||||||
|
|
@ -1846,6 +1840,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
query,
|
query,
|
||||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
referer=request.app.state.config.WEBUI_URL,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
@ -1986,6 +1981,24 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
)
|
)
|
||||||
|
elif engine == "azure":
|
||||||
|
if (
|
||||||
|
request.app.state.config.AZURE_AI_SEARCH_API_KEY
|
||||||
|
and request.app.state.config.AZURE_AI_SEARCH_ENDPOINT
|
||||||
|
and request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME
|
||||||
|
):
|
||||||
|
return search_azure(
|
||||||
|
request.app.state.config.AZURE_AI_SEARCH_API_KEY,
|
||||||
|
request.app.state.config.AZURE_AI_SEARCH_ENDPOINT,
|
||||||
|
request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME,
|
||||||
|
query,
|
||||||
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"AZURE_AI_SEARCH_API_KEY, AZURE_AI_SEARCH_ENDPOINT, and AZURE_AI_SEARCH_INDEX_NAME are required for Azure AI Search"
|
||||||
|
)
|
||||||
elif engine == "exa":
|
elif engine == "exa":
|
||||||
return search_exa(
|
return search_exa(
|
||||||
request.app.state.config.EXA_API_KEY,
|
request.app.state.config.EXA_API_KEY,
|
||||||
|
|
@ -2028,11 +2041,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||||
)
|
)
|
||||||
elif engine == "external":
|
elif engine == "external":
|
||||||
return search_external(
|
return search_external(
|
||||||
|
request,
|
||||||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||||
query,
|
query,
|
||||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("No search engine API key found in environment variables")
|
raise Exception("No search engine API key found in environment variables")
|
||||||
|
|
@ -2047,7 +2062,7 @@ async def process_web_search(
|
||||||
result_items = []
|
result_items = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info(
|
logging.debug(
|
||||||
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}"
|
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2057,6 +2072,7 @@ async def process_web_search(
|
||||||
request,
|
request,
|
||||||
request.app.state.config.WEB_SEARCH_ENGINE,
|
request.app.state.config.WEB_SEARCH_ENGINE,
|
||||||
query,
|
query,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
for query in form_data.queries
|
for query in form_data.queries
|
||||||
]
|
]
|
||||||
|
|
@ -2081,6 +2097,12 @@ async def process_web_search(
|
||||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(urls) == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT("No results found from web search"),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
||||||
search_results = [
|
search_results = [
|
||||||
|
|
@ -2176,7 +2198,7 @@ class QueryDocForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/query/doc")
|
@router.post("/query/doc")
|
||||||
def query_doc_handler(
|
async def query_doc_handler(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: QueryDocForm,
|
form_data: QueryDocForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
|
@ -2189,7 +2211,7 @@ def query_doc_handler(
|
||||||
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
|
||||||
collection_name=form_data.collection_name
|
collection_name=form_data.collection_name
|
||||||
)
|
)
|
||||||
return query_doc_with_hybrid_search(
|
return await query_doc_with_hybrid_search(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
collection_result=collection_results[form_data.collection_name],
|
collection_result=collection_results[form_data.collection_name],
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
|
|
@ -2199,8 +2221,8 @@ def query_doc_handler(
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=(
|
reranking_function=(
|
||||||
(
|
(
|
||||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||||
sentences, user=user
|
query, documents, user=user
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if request.app.state.RERANKING_FUNCTION
|
if request.app.state.RERANKING_FUNCTION
|
||||||
|
|
@ -2221,11 +2243,12 @@ def query_doc_handler(
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
query_embedding = await request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
||||||
|
)
|
||||||
return query_doc(
|
return query_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(
|
query_embedding=query_embedding,
|
||||||
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
|
|
||||||
),
|
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
@ -2245,10 +2268,11 @@ class QueryCollectionsForm(BaseModel):
|
||||||
r: Optional[float] = None
|
r: Optional[float] = None
|
||||||
hybrid: Optional[bool] = None
|
hybrid: Optional[bool] = None
|
||||||
hybrid_bm25_weight: Optional[float] = None
|
hybrid_bm25_weight: Optional[float] = None
|
||||||
|
enable_enriched_texts: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/query/collection")
|
@router.post("/query/collection")
|
||||||
def query_collection_handler(
|
async def query_collection_handler(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: QueryCollectionsForm,
|
form_data: QueryCollectionsForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
|
@ -2257,7 +2281,7 @@ def query_collection_handler(
|
||||||
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
|
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
|
||||||
form_data.hybrid is None or form_data.hybrid
|
form_data.hybrid is None or form_data.hybrid
|
||||||
):
|
):
|
||||||
return query_collection_with_hybrid_search(
|
return await query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
|
@ -2266,8 +2290,8 @@ def query_collection_handler(
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=(
|
reranking_function=(
|
||||||
(
|
(
|
||||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||||
sentences, user=user
|
query, documents, user=user
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if request.app.state.RERANKING_FUNCTION
|
if request.app.state.RERANKING_FUNCTION
|
||||||
|
|
@ -2285,9 +2309,14 @@ def query_collection_handler(
|
||||||
if form_data.hybrid_bm25_weight
|
if form_data.hybrid_bm25_weight
|
||||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||||
),
|
),
|
||||||
|
enable_enriched_texts=(
|
||||||
|
form_data.enable_enriched_texts
|
||||||
|
if form_data.enable_enriched_texts is not None
|
||||||
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_collection(
|
return await query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
|
@ -2369,7 +2398,7 @@ if ENV == "dev":
|
||||||
@router.get("/ef/{text}")
|
@router.get("/ef/{text}")
|
||||||
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
||||||
return {
|
return {
|
||||||
"result": request.app.state.EMBEDDING_FUNCTION(
|
"result": await request.app.state.EMBEDDING_FUNCTION(
|
||||||
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -2392,7 +2421,7 @@ class BatchProcessFilesResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/files/batch")
|
@router.post("/process/files/batch")
|
||||||
def process_files_batch(
|
async def process_files_batch(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: BatchProcessFilesForm,
|
form_data: BatchProcessFilesForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
|
|
@ -2400,16 +2429,19 @@ def process_files_batch(
|
||||||
"""
|
"""
|
||||||
Process a batch of files and save them to the vector database.
|
Process a batch of files and save them to the vector database.
|
||||||
"""
|
"""
|
||||||
results: List[BatchProcessFilesResult] = []
|
|
||||||
errors: List[BatchProcessFilesResult] = []
|
|
||||||
collection_name = form_data.collection_name
|
collection_name = form_data.collection_name
|
||||||
|
|
||||||
|
file_results: List[BatchProcessFilesResult] = []
|
||||||
|
file_errors: List[BatchProcessFilesResult] = []
|
||||||
|
file_updates: List[FileUpdateForm] = []
|
||||||
|
|
||||||
# Prepare all documents first
|
# Prepare all documents first
|
||||||
all_docs: List[Document] = []
|
all_docs: List[Document] = []
|
||||||
|
|
||||||
for file in form_data.files:
|
for file in form_data.files:
|
||||||
try:
|
try:
|
||||||
text_content = file.data.get("content", "")
|
text_content = file.data.get("content", "")
|
||||||
|
|
||||||
docs: List[Document] = [
|
docs: List[Document] = [
|
||||||
Document(
|
Document(
|
||||||
page_content=text_content.replace("<br/>", "\n"),
|
page_content=text_content.replace("<br/>", "\n"),
|
||||||
|
|
@ -2423,45 +2455,49 @@ def process_files_batch(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
hash = calculate_sha256_string(text_content)
|
|
||||||
Files.update_file_hash_by_id(file.id, hash)
|
|
||||||
Files.update_file_data_by_id(file.id, {"content": text_content})
|
|
||||||
|
|
||||||
all_docs.extend(docs)
|
all_docs.extend(docs)
|
||||||
results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
|
|
||||||
|
file_updates.append(
|
||||||
|
FileUpdateForm(
|
||||||
|
hash=calculate_sha256_string(text_content),
|
||||||
|
data={"content": text_content},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
file_results.append(
|
||||||
|
BatchProcessFilesResult(file_id=file.id, status="prepared")
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
|
||||||
errors.append(
|
file_errors.append(
|
||||||
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save all documents in one batch
|
# Save all documents in one batch
|
||||||
if all_docs:
|
if all_docs:
|
||||||
try:
|
try:
|
||||||
save_docs_to_vector_db(
|
await run_in_threadpool(
|
||||||
request=request,
|
save_docs_to_vector_db,
|
||||||
docs=all_docs,
|
request,
|
||||||
collection_name=collection_name,
|
all_docs,
|
||||||
|
collection_name,
|
||||||
add=True,
|
add=True,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update all files with collection name
|
# Update all files with collection name
|
||||||
for result in results:
|
for file_update, file_result in zip(file_updates, file_results):
|
||||||
Files.update_file_metadata_by_id(
|
Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
|
||||||
result.file_id, {"collection_name": collection_name}
|
file_result.status = "completed"
|
||||||
)
|
|
||||||
result.status = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(
|
log.error(
|
||||||
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
f"process_files_batch: Error saving documents to vector DB: {str(e)}"
|
||||||
)
|
)
|
||||||
for result in results:
|
for file_result in file_results:
|
||||||
result.status = "failed"
|
file_result.status = "failed"
|
||||||
errors.append(
|
file_errors.append(
|
||||||
BatchProcessFilesResult(file_id=result.file_id, error=str(e))
|
BatchProcessFilesResult(file_id=file_result.file_id, error=str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
return BatchProcessFilesResponse(results=results, errors=errors)
|
return BatchProcessFilesResponse(results=file_results, errors=file_errors)
|
||||||
|
|
|
||||||
|
|
@ -256,15 +256,16 @@ def get_scim_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if SCIM is enabled
|
# Check if SCIM is enabled
|
||||||
scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False)
|
enable_scim = getattr(request.app.state, "ENABLE_SCIM", False)
|
||||||
log.info(
|
log.info(
|
||||||
f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}"
|
f"SCIM auth check - raw ENABLE_SCIM: {enable_scim}, type: {type(enable_scim)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle both PersistentConfig and direct value
|
# Handle both PersistentConfig and direct value
|
||||||
if hasattr(scim_enabled, "value"):
|
if hasattr(enable_scim, "value"):
|
||||||
scim_enabled = scim_enabled.value
|
enable_scim = enable_scim.value
|
||||||
log.info(f"SCIM enabled status after conversion: {scim_enabled}")
|
|
||||||
if not scim_enabled:
|
if not enable_scim:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="SCIM is not enabled",
|
detail="SCIM is not enabled",
|
||||||
|
|
@ -348,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||||
|
|
||||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||||
"""Convert internal Group model to SCIM Group"""
|
"""Convert internal Group model to SCIM Group"""
|
||||||
|
member_ids = Groups.get_group_user_ids_by_id(group.id)
|
||||||
members = []
|
members = []
|
||||||
for user_id in group.user_ids:
|
|
||||||
|
for user_id in member_ids:
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id)
|
||||||
if user:
|
if user:
|
||||||
members.append(
|
members.append(
|
||||||
|
|
@ -716,7 +719,7 @@ async def get_groups(
|
||||||
):
|
):
|
||||||
"""List SCIM Groups"""
|
"""List SCIM Groups"""
|
||||||
# Get all groups
|
# Get all groups
|
||||||
groups_list = Groups.get_groups()
|
groups_list = Groups.get_all_groups()
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
total = len(groups_list)
|
total = len(groups_list)
|
||||||
|
|
@ -795,9 +798,11 @@ async def create_group(
|
||||||
update_form = GroupUpdateForm(
|
update_form = GroupUpdateForm(
|
||||||
name=new_group.name,
|
name=new_group.name,
|
||||||
description=new_group.description,
|
description=new_group.description,
|
||||||
user_ids=member_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
Groups.update_group_by_id(new_group.id, update_form)
|
Groups.update_group_by_id(new_group.id, update_form)
|
||||||
|
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
|
||||||
|
|
||||||
new_group = Groups.get_group_by_id(new_group.id)
|
new_group = Groups.get_group_by_id(new_group.id)
|
||||||
|
|
||||||
return group_to_scim(new_group, request)
|
return group_to_scim(new_group, request)
|
||||||
|
|
@ -829,7 +834,7 @@ async def update_group(
|
||||||
# Handle members if provided
|
# Handle members if provided
|
||||||
if group_data.members is not None:
|
if group_data.members is not None:
|
||||||
member_ids = [member.value for member in group_data.members]
|
member_ids = [member.value for member in group_data.members]
|
||||||
update_form.user_ids = member_ids
|
Groups.set_group_user_ids_by_id(group_id, member_ids)
|
||||||
|
|
||||||
# Update group
|
# Update group
|
||||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||||
|
|
@ -862,7 +867,6 @@ async def patch_group(
|
||||||
update_form = GroupUpdateForm(
|
update_form = GroupUpdateForm(
|
||||||
name=group.name,
|
name=group.name,
|
||||||
description=group.description,
|
description=group.description,
|
||||||
user_ids=group.user_ids.copy() if group.user_ids else [],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for operation in patch_data.Operations:
|
for operation in patch_data.Operations:
|
||||||
|
|
@ -875,21 +879,22 @@ async def patch_group(
|
||||||
update_form.name = value
|
update_form.name = value
|
||||||
elif path == "members":
|
elif path == "members":
|
||||||
# Replace all members
|
# Replace all members
|
||||||
update_form.user_ids = [member["value"] for member in value]
|
Groups.set_group_user_ids_by_id(
|
||||||
|
group_id, [member["value"] for member in value]
|
||||||
|
)
|
||||||
|
|
||||||
elif op == "add":
|
elif op == "add":
|
||||||
if path == "members":
|
if path == "members":
|
||||||
# Add members
|
# Add members
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
for member in value:
|
for member in value:
|
||||||
if isinstance(member, dict) and "value" in member:
|
if isinstance(member, dict) and "value" in member:
|
||||||
if member["value"] not in update_form.user_ids:
|
Groups.add_users_to_group(group_id, [member["value"]])
|
||||||
update_form.user_ids.append(member["value"])
|
|
||||||
elif op == "remove":
|
elif op == "remove":
|
||||||
if path and path.startswith("members[value eq"):
|
if path and path.startswith("members[value eq"):
|
||||||
# Remove specific member
|
# Remove specific member
|
||||||
member_id = path.split('"')[1]
|
member_id = path.split('"')[1]
|
||||||
if member_id in update_form.user_ids:
|
Groups.remove_users_from_group(group_id, [member_id])
|
||||||
update_form.user_ids.remove(member_id)
|
|
||||||
|
|
||||||
# Update group
|
# Update group
|
||||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from open_webui.config import (
|
||||||
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
|
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
|
||||||
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
|
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -68,6 +69,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -87,6 +89,7 @@ class TaskConfigForm(BaseModel):
|
||||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||||
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||||
|
VOICE_MODE_PROMPT_TEMPLATE: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/update")
|
@router.post("/config/update")
|
||||||
|
|
@ -136,6 +139,10 @@ async def update_task_config(
|
||||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = (
|
||||||
|
form_data.VOICE_MODE_PROMPT_TEMPLATE
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
"TASK_MODEL": request.app.state.config.TASK_MODEL,
|
||||||
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||||
|
|
@ -152,6 +159,7 @@ async def update_task_config(
|
||||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
"VOICE_MODE_PROMPT_TEMPLATE": request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -247,9 +247,19 @@ async def load_tool_from_url(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/export", response_model=list[ToolModel])
|
@router.get("/export", response_model=list[ToolModel])
|
||||||
async def export_tools(user=Depends(get_admin_user)):
|
async def export_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
tools = Tools.get_tools()
|
if user.role != "admin" and not has_permission(
|
||||||
return tools
|
user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||||
|
return Tools.get_tools()
|
||||||
|
else:
|
||||||
|
return Tools.get_tools_by_user_id(user.id, "read")
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -263,8 +273,13 @@ async def create_new_tools(
|
||||||
form_data: ToolForm,
|
form_data: ToolForm,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if user.role != "admin" and not has_permission(
|
if user.role != "admin" and not (
|
||||||
|
has_permission(
|
||||||
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
|
user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
or has_permission(
|
||||||
|
user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import io
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from fastapi.responses import Response, StreamingResponse, FileResponse
|
from fastapi.responses import Response, StreamingResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.auths import Auths
|
from open_webui.models.auths import Auths
|
||||||
|
|
@ -16,26 +16,27 @@ from open_webui.models.groups import Groups
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.users import (
|
from open_webui.models.users import (
|
||||||
UserModel,
|
UserModel,
|
||||||
UserListResponse,
|
UserGroupIdsModel,
|
||||||
|
UserGroupIdsListResponse,
|
||||||
|
UserInfoListResponse,
|
||||||
UserInfoListResponse,
|
UserInfoListResponse,
|
||||||
UserIdNameListResponse,
|
|
||||||
UserRoleUpdateForm,
|
UserRoleUpdateForm,
|
||||||
|
UserStatus,
|
||||||
Users,
|
Users,
|
||||||
UserSettings,
|
UserSettings,
|
||||||
UserUpdateForm,
|
UserUpdateForm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from open_webui.socket.main import (
|
|
||||||
get_active_status_by_user_id,
|
|
||||||
get_active_user_ids,
|
|
||||||
get_user_active_status,
|
|
||||||
)
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
from open_webui.utils.auth import (
|
||||||
|
get_admin_user,
|
||||||
|
get_password_hash,
|
||||||
|
get_verified_user,
|
||||||
|
validate_password,
|
||||||
|
)
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,23 +46,6 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
# GetActiveUsers
|
|
||||||
############################
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/active")
|
|
||||||
async def get_active_users(
|
|
||||||
user=Depends(get_verified_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a list of active users.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"user_ids": get_active_user_ids(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetUsers
|
# GetUsers
|
||||||
############################
|
############################
|
||||||
|
|
@ -70,7 +54,7 @@ async def get_active_users(
|
||||||
PAGE_ITEM_COUNT = 30
|
PAGE_ITEM_COUNT = 30
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=UserListResponse)
|
@router.get("/", response_model=UserGroupIdsListResponse)
|
||||||
async def get_users(
|
async def get_users(
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
order_by: Optional[str] = None,
|
order_by: Optional[str] = None,
|
||||||
|
|
@ -91,7 +75,25 @@ async def get_users(
|
||||||
if direction:
|
if direction:
|
||||||
filter["direction"] = direction
|
filter["direction"] = direction
|
||||||
|
|
||||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
users = result["users"]
|
||||||
|
total = result["total"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"users": [
|
||||||
|
UserGroupIdsModel(
|
||||||
|
**{
|
||||||
|
**user.model_dump(),
|
||||||
|
"group_ids": [
|
||||||
|
group.id for group in Groups.get_groups_by_member_id(user.id)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for user in users
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all", response_model=UserInfoListResponse)
|
@router.get("/all", response_model=UserInfoListResponse)
|
||||||
|
|
@ -101,20 +103,31 @@ async def get_all_users(
|
||||||
return Users.get_users()
|
return Users.get_users()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=UserIdNameListResponse)
|
@router.get("/search", response_model=UserInfoListResponse)
|
||||||
async def search_users(
|
async def search_users(
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[str] = None,
|
||||||
|
page: Optional[int] = 1,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
limit = PAGE_ITEM_COUNT
|
limit = PAGE_ITEM_COUNT
|
||||||
|
|
||||||
page = 1 # Always return the first page for search
|
page = max(1, page)
|
||||||
skip = (page - 1) * limit
|
skip = (page - 1) * limit
|
||||||
|
|
||||||
filter = {}
|
filter = {}
|
||||||
if query:
|
if query:
|
||||||
filter["query"] = query
|
filter["query"] = query
|
||||||
|
|
||||||
|
filter = {}
|
||||||
|
if query:
|
||||||
|
filter["query"] = query
|
||||||
|
if order_by:
|
||||||
|
filter["order_by"] = order_by
|
||||||
|
if direction:
|
||||||
|
filter["direction"] = direction
|
||||||
|
|
||||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -150,13 +163,24 @@ class WorkspacePermissions(BaseModel):
|
||||||
knowledge: bool = False
|
knowledge: bool = False
|
||||||
prompts: bool = False
|
prompts: bool = False
|
||||||
tools: bool = False
|
tools: bool = False
|
||||||
|
models_import: bool = False
|
||||||
|
models_export: bool = False
|
||||||
|
prompts_import: bool = False
|
||||||
|
prompts_export: bool = False
|
||||||
|
tools_import: bool = False
|
||||||
|
tools_export: bool = False
|
||||||
|
|
||||||
|
|
||||||
class SharingPermissions(BaseModel):
|
class SharingPermissions(BaseModel):
|
||||||
public_models: bool = True
|
models: bool = False
|
||||||
public_knowledge: bool = True
|
public_models: bool = False
|
||||||
public_prompts: bool = True
|
knowledge: bool = False
|
||||||
|
public_knowledge: bool = False
|
||||||
|
prompts: bool = False
|
||||||
|
public_prompts: bool = False
|
||||||
|
tools: bool = False
|
||||||
public_tools: bool = True
|
public_tools: bool = True
|
||||||
|
notes: bool = False
|
||||||
public_notes: bool = True
|
public_notes: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,11 +207,15 @@ class ChatPermissions(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class FeaturesPermissions(BaseModel):
|
class FeaturesPermissions(BaseModel):
|
||||||
|
api_keys: bool = False
|
||||||
|
notes: bool = True
|
||||||
|
channels: bool = True
|
||||||
|
folders: bool = True
|
||||||
direct_tool_servers: bool = False
|
direct_tool_servers: bool = False
|
||||||
|
|
||||||
web_search: bool = True
|
web_search: bool = True
|
||||||
image_generation: bool = True
|
image_generation: bool = True
|
||||||
code_interpreter: bool = True
|
code_interpreter: bool = True
|
||||||
notes: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class UserPermissions(BaseModel):
|
class UserPermissions(BaseModel):
|
||||||
|
|
@ -272,6 +300,43 @@ async def update_user_settings_by_session_user(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetUserStatusBySessionUser
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/status")
|
||||||
|
async def get_user_status_by_session_user(user=Depends(get_verified_user)):
|
||||||
|
user = Users.get_user_by_id(user.id)
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# UpdateUserStatusBySessionUser
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/user/status/update")
|
||||||
|
async def update_user_status_by_session_user(
|
||||||
|
form_data: UserStatus, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
user = Users.get_user_by_id(user.id)
|
||||||
|
if user:
|
||||||
|
user = Users.update_user_status_by_id(user.id, form_data)
|
||||||
|
return user
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetUserInfoBySessionUser
|
# GetUserInfoBySessionUser
|
||||||
############################
|
############################
|
||||||
|
|
@ -323,13 +388,15 @@ async def update_user_info_by_session_user(
|
||||||
############################
|
############################
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserActiveResponse(UserStatus):
|
||||||
name: str
|
name: str
|
||||||
profile_image_url: str
|
profile_image_url: Optional[str] = None
|
||||||
active: Optional[bool] = None
|
|
||||||
|
is_active: bool
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserResponse)
|
@router.get("/{user_id}", response_model=UserActiveResponse)
|
||||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
# Check if user_id is a shared chat
|
# Check if user_id is a shared chat
|
||||||
# If it is, get the user_id from the chat
|
# If it is, get the user_id from the chat
|
||||||
|
|
@ -347,11 +414,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
return UserResponse(
|
return UserActiveResponse(
|
||||||
**{
|
**{
|
||||||
"name": user.name,
|
**user.model_dump(),
|
||||||
"profile_image_url": user.profile_image_url,
|
"is_active": Users.is_user_active(user_id),
|
||||||
"active": get_active_status_by_user_id(user_id),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -361,7 +427,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}/oauth/sessions", response_model=Optional[dict])
|
@router.get("/{user_id}/oauth/sessions")
|
||||||
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
|
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||||
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
sessions = OAuthSessions.get_sessions_by_user_id(user_id)
|
||||||
if sessions and len(sessions) > 0:
|
if sessions and len(sessions) > 0:
|
||||||
|
|
@ -418,7 +484,7 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u
|
||||||
@router.get("/{user_id}/active", response_model=dict)
|
@router.get("/{user_id}/active", response_model=dict)
|
||||||
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
|
async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||||
return {
|
return {
|
||||||
"active": get_user_active_status(user_id),
|
"active": Users.is_user_active(user_id),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -471,8 +537,12 @@ async def update_user_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
if form_data.password:
|
if form_data.password:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
log.debug(f"hashed: {hashed}")
|
|
||||||
Auths.update_user_password_by_id(user_id, hashed)
|
Auths.update_user_password_by_id(user_id, hashed)
|
||||||
|
|
||||||
Auths.update_email_by_id(user_id, form_data.email.lower())
|
Auths.update_email_by_id(user_id, form_data.email.lower())
|
||||||
|
|
|
||||||
|
|
@ -124,12 +124,3 @@ async def download_db(user=Depends(get_admin_user)):
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
filename="webui.db",
|
filename="webui.db",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/litellm/config")
|
|
||||||
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
|
|
||||||
return FileResponse(
|
|
||||||
f"{DATA_DIR}/litellm/config.yaml",
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
filename="config.yaml",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,12 @@ from open_webui.utils.redis import (
|
||||||
get_sentinel_url_from_env,
|
get_sentinel_url_from_env,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from open_webui.config import (
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
VERSION,
|
||||||
ENABLE_WEBSOCKET_SUPPORT,
|
ENABLE_WEBSOCKET_SUPPORT,
|
||||||
WEBSOCKET_MANAGER,
|
WEBSOCKET_MANAGER,
|
||||||
WEBSOCKET_REDIS_URL,
|
WEBSOCKET_REDIS_URL,
|
||||||
|
|
@ -27,6 +32,11 @@ from open_webui.env import (
|
||||||
WEBSOCKET_SENTINEL_PORT,
|
WEBSOCKET_SENTINEL_PORT,
|
||||||
WEBSOCKET_SENTINEL_HOSTS,
|
WEBSOCKET_SENTINEL_HOSTS,
|
||||||
REDIS_KEY_PREFIX,
|
REDIS_KEY_PREFIX,
|
||||||
|
WEBSOCKET_REDIS_OPTIONS,
|
||||||
|
WEBSOCKET_SERVER_PING_TIMEOUT,
|
||||||
|
WEBSOCKET_SERVER_PING_INTERVAL,
|
||||||
|
WEBSOCKET_SERVER_LOGGING,
|
||||||
|
WEBSOCKET_SERVER_ENGINEIO_LOGGING,
|
||||||
)
|
)
|
||||||
from open_webui.utils.auth import decode_token
|
from open_webui.utils.auth import decode_token
|
||||||
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
|
from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
|
||||||
|
|
@ -48,30 +58,44 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
||||||
|
|
||||||
REDIS = None
|
REDIS = None
|
||||||
|
|
||||||
|
# Configure CORS for Socket.IO
|
||||||
|
SOCKETIO_CORS_ORIGINS = "*" if CORS_ALLOW_ORIGIN == ["*"] else CORS_ALLOW_ORIGIN
|
||||||
|
|
||||||
if WEBSOCKET_MANAGER == "redis":
|
if WEBSOCKET_MANAGER == "redis":
|
||||||
if WEBSOCKET_SENTINEL_HOSTS:
|
if WEBSOCKET_SENTINEL_HOSTS:
|
||||||
mgr = socketio.AsyncRedisManager(
|
mgr = socketio.AsyncRedisManager(
|
||||||
get_sentinel_url_from_env(
|
get_sentinel_url_from_env(
|
||||||
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||||
)
|
),
|
||||||
|
redis_options=WEBSOCKET_REDIS_OPTIONS,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
mgr = socketio.AsyncRedisManager(
|
||||||
|
WEBSOCKET_REDIS_URL, redis_options=WEBSOCKET_REDIS_OPTIONS
|
||||||
|
)
|
||||||
sio = socketio.AsyncServer(
|
sio = socketio.AsyncServer(
|
||||||
cors_allowed_origins=[],
|
cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
|
||||||
async_mode="asgi",
|
async_mode="asgi",
|
||||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||||
always_connect=True,
|
always_connect=True,
|
||||||
client_manager=mgr,
|
client_manager=mgr,
|
||||||
|
logger=WEBSOCKET_SERVER_LOGGING,
|
||||||
|
ping_interval=WEBSOCKET_SERVER_PING_INTERVAL,
|
||||||
|
ping_timeout=WEBSOCKET_SERVER_PING_TIMEOUT,
|
||||||
|
engineio_logger=WEBSOCKET_SERVER_ENGINEIO_LOGGING,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sio = socketio.AsyncServer(
|
sio = socketio.AsyncServer(
|
||||||
cors_allowed_origins=[],
|
cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
|
||||||
async_mode="asgi",
|
async_mode="asgi",
|
||||||
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||||
always_connect=True,
|
always_connect=True,
|
||||||
|
logger=WEBSOCKET_SERVER_LOGGING,
|
||||||
|
ping_interval=WEBSOCKET_SERVER_PING_INTERVAL,
|
||||||
|
ping_timeout=WEBSOCKET_SERVER_PING_TIMEOUT,
|
||||||
|
engineio_logger=WEBSOCKET_SERVER_ENGINEIO_LOGGING,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,14 +118,16 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
redis_sentinels = get_sentinels_from_env(
|
redis_sentinels = get_sentinels_from_env(
|
||||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||||
)
|
)
|
||||||
SESSION_POOL = RedisDict(
|
|
||||||
f"{REDIS_KEY_PREFIX}:session_pool",
|
MODELS = RedisDict(
|
||||||
|
f"{REDIS_KEY_PREFIX}:models",
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
)
|
)
|
||||||
USER_POOL = RedisDict(
|
|
||||||
f"{REDIS_KEY_PREFIX}:user_pool",
|
SESSION_POOL = RedisDict(
|
||||||
|
f"{REDIS_KEY_PREFIX}:session_pool",
|
||||||
redis_url=WEBSOCKET_REDIS_URL,
|
redis_url=WEBSOCKET_REDIS_URL,
|
||||||
redis_sentinels=redis_sentinels,
|
redis_sentinels=redis_sentinels,
|
||||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||||
|
|
@ -124,8 +150,9 @@ if WEBSOCKET_MANAGER == "redis":
|
||||||
renew_func = clean_up_lock.renew_lock
|
renew_func = clean_up_lock.renew_lock
|
||||||
release_func = clean_up_lock.release_lock
|
release_func = clean_up_lock.release_lock
|
||||||
else:
|
else:
|
||||||
|
MODELS = {}
|
||||||
|
|
||||||
SESSION_POOL = {}
|
SESSION_POOL = {}
|
||||||
USER_POOL = {}
|
|
||||||
USAGE_POOL = {}
|
USAGE_POOL = {}
|
||||||
|
|
||||||
aquire_func = release_func = renew_func = lambda: True
|
aquire_func = release_func = renew_func = lambda: True
|
||||||
|
|
@ -201,16 +228,6 @@ def get_models_in_use():
|
||||||
return models_in_use
|
return models_in_use
|
||||||
|
|
||||||
|
|
||||||
def get_active_user_ids():
|
|
||||||
"""Get the list of active user IDs."""
|
|
||||||
return list(USER_POOL.keys())
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_active_status(user_id):
|
|
||||||
"""Check if a user is currently active."""
|
|
||||||
return user_id in USER_POOL
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_id_from_session_pool(sid):
|
def get_user_id_from_session_pool(sid):
|
||||||
user = SESSION_POOL.get(sid)
|
user = SESSION_POOL.get(sid)
|
||||||
if user:
|
if user:
|
||||||
|
|
@ -236,10 +253,36 @@ def get_user_ids_from_room(room):
|
||||||
return active_user_ids
|
return active_user_ids
|
||||||
|
|
||||||
|
|
||||||
def get_active_status_by_user_id(user_id):
|
async def emit_to_users(event: str, data: dict, user_ids: list[str]):
|
||||||
if user_id in USER_POOL:
|
"""
|
||||||
return True
|
Send a message to specific users using their user:{id} rooms.
|
||||||
return False
|
|
||||||
|
Args:
|
||||||
|
event (str): The event name to emit.
|
||||||
|
data (dict): The payload/data to send.
|
||||||
|
user_ids (list[str]): The target users' IDs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for user_id in user_ids:
|
||||||
|
await sio.emit(event, data, room=f"user:{user_id}")
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to emit event {event} to users {user_ids}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def enter_room_for_users(room: str, user_ids: list[str]):
|
||||||
|
"""
|
||||||
|
Make all sessions of a user join a specific room.
|
||||||
|
Args:
|
||||||
|
room (str): The room to join.
|
||||||
|
user_ids (list[str]): The target user's IDs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for user_id in user_ids:
|
||||||
|
session_ids = get_session_ids_from_room(f"user:{user_id}")
|
||||||
|
for sid in session_ids:
|
||||||
|
await sio.enter_room(sid, room)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to make users {user_ids} join room {room}: {e}")
|
||||||
|
|
||||||
|
|
||||||
@sio.on("usage")
|
@sio.on("usage")
|
||||||
|
|
@ -269,10 +312,7 @@ async def connect(sid, environ, auth):
|
||||||
SESSION_POOL[sid] = user.model_dump(
|
SESSION_POOL[sid] = user.model_dump(
|
||||||
exclude=["date_of_birth", "bio", "gender"]
|
exclude=["date_of_birth", "bio", "gender"]
|
||||||
)
|
)
|
||||||
if user.id in USER_POOL:
|
await sio.enter_room(sid, f"user:{user.id}")
|
||||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
|
||||||
else:
|
|
||||||
USER_POOL[user.id] = [sid]
|
|
||||||
|
|
||||||
|
|
||||||
@sio.on("user-join")
|
@sio.on("user-join")
|
||||||
|
|
@ -290,20 +330,34 @@ async def user_join(sid, data):
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
|
|
||||||
SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"])
|
SESSION_POOL[sid] = user.model_dump(
|
||||||
if user.id in USER_POOL:
|
exclude=[
|
||||||
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
|
"profile_image_url",
|
||||||
else:
|
"profile_banner_image_url",
|
||||||
USER_POOL[user.id] = [sid]
|
"date_of_birth",
|
||||||
|
"bio",
|
||||||
|
"gender",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await sio.enter_room(sid, f"user:{user.id}")
|
||||||
|
|
||||||
# Join all the channels
|
# Join all the channels
|
||||||
channels = Channels.get_channels_by_user_id(user.id)
|
channels = Channels.get_channels_by_user_id(user.id)
|
||||||
log.debug(f"{channels=}")
|
log.debug(f"{channels=}")
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
await sio.enter_room(sid, f"channel:{channel.id}")
|
await sio.enter_room(sid, f"channel:{channel.id}")
|
||||||
|
|
||||||
return {"id": user.id, "name": user.name}
|
return {"id": user.id, "name": user.name}
|
||||||
|
|
||||||
|
|
||||||
|
@sio.on("heartbeat")
|
||||||
|
async def heartbeat(sid, data):
|
||||||
|
user = SESSION_POOL.get(sid)
|
||||||
|
if user:
|
||||||
|
Users.update_last_active_by_id(user["id"])
|
||||||
|
|
||||||
|
|
||||||
@sio.on("join-channels")
|
@sio.on("join-channels")
|
||||||
async def join_channel(sid, data):
|
async def join_channel(sid, data):
|
||||||
auth = data["auth"] if "auth" in data else None
|
auth = data["auth"] if "auth" in data else None
|
||||||
|
|
@ -356,7 +410,7 @@ async def join_note(sid, data):
|
||||||
await sio.enter_room(sid, f"note:{note.id}")
|
await sio.enter_room(sid, f"note:{note.id}")
|
||||||
|
|
||||||
|
|
||||||
@sio.on("channel-events")
|
@sio.on("events:channel")
|
||||||
async def channel_events(sid, data):
|
async def channel_events(sid, data):
|
||||||
room = f"channel:{data['channel_id']}"
|
room = f"channel:{data['channel_id']}"
|
||||||
participants = sio.manager.get_participants(
|
participants = sio.manager.get_participants(
|
||||||
|
|
@ -371,17 +425,24 @@ async def channel_events(sid, data):
|
||||||
event_data = data["data"]
|
event_data = data["data"]
|
||||||
event_type = event_data["type"]
|
event_type = event_data["type"]
|
||||||
|
|
||||||
|
user = SESSION_POOL.get(sid)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return
|
||||||
|
|
||||||
if event_type == "typing":
|
if event_type == "typing":
|
||||||
await sio.emit(
|
await sio.emit(
|
||||||
"channel-events",
|
"events:channel",
|
||||||
{
|
{
|
||||||
"channel_id": data["channel_id"],
|
"channel_id": data["channel_id"],
|
||||||
"message_id": data.get("message_id", None),
|
"message_id": data.get("message_id", None),
|
||||||
"data": event_data,
|
"data": event_data,
|
||||||
"user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
|
"user": UserNameResponse(**user).model_dump(),
|
||||||
},
|
},
|
||||||
room=room,
|
room=room,
|
||||||
)
|
)
|
||||||
|
elif event_type == "last_read_at":
|
||||||
|
Channels.update_member_last_read_at(data["channel_id"], user["id"])
|
||||||
|
|
||||||
|
|
||||||
@sio.on("ydoc:document:join")
|
@sio.on("ydoc:document:join")
|
||||||
|
|
@ -625,13 +686,6 @@ async def disconnect(sid):
|
||||||
if sid in SESSION_POOL:
|
if sid in SESSION_POOL:
|
||||||
user = SESSION_POOL[sid]
|
user = SESSION_POOL[sid]
|
||||||
del SESSION_POOL[sid]
|
del SESSION_POOL[sid]
|
||||||
|
|
||||||
user_id = user["id"]
|
|
||||||
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
|
||||||
|
|
||||||
if len(USER_POOL[user_id]) == 0:
|
|
||||||
del USER_POOL[user_id]
|
|
||||||
|
|
||||||
await YDOC_MANAGER.remove_user_from_all_documents(sid)
|
await YDOC_MANAGER.remove_user_from_all_documents(sid)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
@ -641,34 +695,24 @@ async def disconnect(sid):
|
||||||
def get_event_emitter(request_info, update_db=True):
|
def get_event_emitter(request_info, update_db=True):
|
||||||
async def __event_emitter__(event_data):
|
async def __event_emitter__(event_data):
|
||||||
user_id = request_info["user_id"]
|
user_id = request_info["user_id"]
|
||||||
|
chat_id = request_info["chat_id"]
|
||||||
|
message_id = request_info["message_id"]
|
||||||
|
|
||||||
session_ids = list(
|
await sio.emit(
|
||||||
set(
|
"events",
|
||||||
USER_POOL.get(user_id, [])
|
|
||||||
+ (
|
|
||||||
[request_info.get("session_id")]
|
|
||||||
if request_info.get("session_id")
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
emit_tasks = [
|
|
||||||
sio.emit(
|
|
||||||
"chat-events",
|
|
||||||
{
|
{
|
||||||
"chat_id": request_info.get("chat_id", None),
|
"chat_id": chat_id,
|
||||||
"message_id": request_info.get("message_id", None),
|
"message_id": message_id,
|
||||||
"data": event_data,
|
"data": event_data,
|
||||||
},
|
},
|
||||||
to=session_id,
|
room=f"user:{user_id}",
|
||||||
)
|
)
|
||||||
for session_id in session_ids
|
if (
|
||||||
]
|
update_db
|
||||||
|
and message_id
|
||||||
|
and not request_info.get("chat_id", "").startswith("local:")
|
||||||
|
):
|
||||||
|
|
||||||
await asyncio.gather(*emit_tasks)
|
|
||||||
|
|
||||||
if update_db:
|
|
||||||
if "type" in event_data and event_data["type"] == "status":
|
if "type" in event_data and event_data["type"] == "status":
|
||||||
Chats.add_message_status_to_chat_by_id_and_message_id(
|
Chats.add_message_status_to_chat_by_id_and_message_id(
|
||||||
request_info["chat_id"],
|
request_info["chat_id"],
|
||||||
|
|
@ -758,13 +802,20 @@ def get_event_emitter(request_info, update_db=True):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"user_id" in request_info
|
||||||
|
and "chat_id" in request_info
|
||||||
|
and "message_id" in request_info
|
||||||
|
):
|
||||||
return __event_emitter__
|
return __event_emitter__
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_event_call(request_info):
|
def get_event_call(request_info):
|
||||||
async def __event_caller__(event_data):
|
async def __event_caller__(event_data):
|
||||||
response = await sio.call(
|
response = await sio.call(
|
||||||
"chat-events",
|
"events",
|
||||||
{
|
{
|
||||||
"chat_id": request_info.get("chat_id", None),
|
"chat_id": request_info.get("chat_id", None),
|
||||||
"message_id": request_info.get("message_id", None),
|
"message_id": request_info.get("message_id", None),
|
||||||
|
|
@ -774,7 +825,14 @@ def get_event_call(request_info):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
if (
|
||||||
|
"session_id" in request_info
|
||||||
|
and "chat_id" in request_info
|
||||||
|
and "message_id" in request_info
|
||||||
|
):
|
||||||
return __event_caller__
|
return __event_caller__
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
get_event_caller = get_event_call
|
get_event_caller = get_event_call
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,15 @@ class RedisDict:
|
||||||
def items(self):
|
def items(self):
|
||||||
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
|
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
|
||||||
|
|
||||||
|
def set(self, mapping: dict):
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
|
||||||
|
pipe.delete(self.name)
|
||||||
|
if mapping:
|
||||||
|
pipe.hset(self.name, mapping={k: json.dumps(v) for k, v in mapping.items()})
|
||||||
|
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,10 @@ async def stop_task(redis, task_id: str):
|
||||||
# Task successfully canceled
|
# Task successfully canceled
|
||||||
return {"status": True, "message": f"Task {task_id} successfully stopped."}
|
return {"status": True, "message": f"Task {task_id} successfully stopped."}
|
||||||
|
|
||||||
return {"status": False, "message": f"Failed to stop task {task_id}."}
|
if task.cancelled() or task.done():
|
||||||
|
return {"status": True, "message": f"Task {task_id} successfully cancelled."}
|
||||||
|
|
||||||
|
return {"status": True, "message": f"Cancellation requested for {task_id}."}
|
||||||
|
|
||||||
|
|
||||||
async def stop_item_tasks(redis: Redis, item_id: str):
|
async def stop_item_tasks(redis: Redis, item_id: str):
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,22 @@ def has_permission(
|
||||||
return get_permission(default_permissions, permission_hierarchy)
|
return get_permission(default_permissions, permission_hierarchy)
|
||||||
|
|
||||||
|
|
||||||
|
def get_permitted_group_and_user_ids(
|
||||||
|
type: str = "write", access_control: Optional[dict] = None
|
||||||
|
) -> Union[Dict[str, List[str]], None]:
|
||||||
|
if access_control is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
permission_access = access_control.get(type, {})
|
||||||
|
permitted_group_ids = permission_access.get("group_ids", [])
|
||||||
|
permitted_user_ids = permission_access.get("user_ids", [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"group_ids": permitted_group_ids,
|
||||||
|
"user_ids": permitted_user_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def has_access(
|
def has_access(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
type: str = "write",
|
type: str = "write",
|
||||||
|
|
@ -122,9 +138,12 @@ def has_access(
|
||||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||||
user_group_ids = {group.id for group in user_groups}
|
user_group_ids = {group.id for group in user_groups}
|
||||||
|
|
||||||
permission_access = access_control.get(type, {})
|
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||||
permitted_group_ids = permission_access.get("group_ids", [])
|
if permitted_ids is None:
|
||||||
permitted_user_ids = permission_access.get("user_ids", [])
|
return False
|
||||||
|
|
||||||
|
permitted_group_ids = permitted_ids.get("group_ids", [])
|
||||||
|
permitted_user_ids = permitted_ids.get("user_ids", [])
|
||||||
|
|
||||||
return user_id in permitted_user_ids or any(
|
return user_id in permitted_user_ids or any(
|
||||||
group_id in permitted_group_ids for group_id in user_group_ids
|
group_id in permitted_group_ids for group_id in user_group_ids
|
||||||
|
|
@ -136,18 +155,20 @@ def get_users_with_access(
|
||||||
type: str = "write", access_control: Optional[dict] = None
|
type: str = "write", access_control: Optional[dict] = None
|
||||||
) -> list[UserModel]:
|
) -> list[UserModel]:
|
||||||
if access_control is None:
|
if access_control is None:
|
||||||
result = Users.get_users()
|
result = Users.get_users(filter={"roles": ["!pending"]})
|
||||||
return result.get("users", [])
|
return result.get("users", [])
|
||||||
|
|
||||||
permission_access = access_control.get(type, {})
|
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||||
permitted_group_ids = permission_access.get("group_ids", [])
|
if permitted_ids is None:
|
||||||
permitted_user_ids = permission_access.get("user_ids", [])
|
return []
|
||||||
|
|
||||||
|
permitted_group_ids = permitted_ids.get("group_ids", [])
|
||||||
|
permitted_user_ids = permitted_ids.get("user_ids", [])
|
||||||
|
|
||||||
user_ids_with_access = set(permitted_user_ids)
|
user_ids_with_access = set(permitted_user_ids)
|
||||||
|
|
||||||
for group_id in permitted_group_ids:
|
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
|
||||||
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
|
for user_ids in group_user_ids_map.values():
|
||||||
if group_user_ids:
|
user_ids_with_access.update(user_ids)
|
||||||
user_ids_with_access.update(group_user_ids)
|
|
||||||
|
|
||||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ class AuditLoggingMiddleware:
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = get_current_user(
|
user = await get_current_user(
|
||||||
request, None, None, get_http_authorization_cred(auth_header)
|
request, None, None, get_http_authorization_cred(auth_header)
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
|
@ -21,13 +21,18 @@ from typing import Optional, Union, List, Dict
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.access_control import has_permission
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
ENABLE_PASSWORD_VALIDATION,
|
||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
LICENSE_BLOB,
|
LICENSE_BLOB,
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN,
|
||||||
|
REDIS_KEY_PREFIX,
|
||||||
pk,
|
pk,
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
TRUSTED_SIGNATURE_KEY,
|
TRUSTED_SIGNATURE_KEY,
|
||||||
|
|
@ -38,11 +43,8 @@ from open_webui.env import (
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from passlib.context import CryptContext
|
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
|
||||||
|
|
||||||
|
|
@ -155,17 +157,37 @@ def get_license_data(app, key):
|
||||||
|
|
||||||
|
|
||||||
bearer_security = HTTPBearer(auto_error=False)
|
bearer_security = HTTPBearer(auto_error=False)
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password, hashed_password):
|
def get_password_hash(password: str) -> str:
|
||||||
return (
|
"""Hash a password using bcrypt"""
|
||||||
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password(password: str) -> bool:
|
||||||
|
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||||
|
if len(password.encode("utf-8")) > 72:
|
||||||
|
raise Exception(
|
||||||
|
ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ENABLE_PASSWORD_VALIDATION:
|
||||||
|
if not PASSWORD_VALIDATION_REGEX_PATTERN.match(password):
|
||||||
|
raise Exception(ERROR_MESSAGES.INVALID_PASSWORD())
|
||||||
|
|
||||||
def get_password_hash(password):
|
return True
|
||||||
return pwd_context.hash(password)
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify a password against its hash"""
|
||||||
|
return (
|
||||||
|
bcrypt.checkpw(
|
||||||
|
plain_password.encode("utf-8"),
|
||||||
|
hashed_password.encode("utf-8"),
|
||||||
|
)
|
||||||
|
if hashed_password
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||||
|
|
@ -175,6 +197,9 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
|
||||||
expire = datetime.now(UTC) + expires_delta
|
expire = datetime.now(UTC) + expires_delta
|
||||||
payload.update({"exp": expire})
|
payload.update({"exp": expire})
|
||||||
|
|
||||||
|
jti = str(uuid.uuid4())
|
||||||
|
payload.update({"jti": jti})
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
@ -187,6 +212,43 @@ def decode_token(token: str) -> Optional[dict]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def is_valid_token(request, decoded) -> bool:
|
||||||
|
# Require Redis to check revoked tokens
|
||||||
|
if request.app.state.redis:
|
||||||
|
jti = decoded.get("jti")
|
||||||
|
|
||||||
|
if jti:
|
||||||
|
revoked = await request.app.state.redis.get(
|
||||||
|
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
|
||||||
|
)
|
||||||
|
if revoked:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_token(request, token):
|
||||||
|
decoded = decode_token(token)
|
||||||
|
|
||||||
|
# Require Redis to store revoked tokens
|
||||||
|
if request.app.state.redis:
|
||||||
|
jti = decoded.get("jti")
|
||||||
|
exp = decoded.get("exp")
|
||||||
|
|
||||||
|
if jti and exp:
|
||||||
|
ttl = exp - int(
|
||||||
|
datetime.now(UTC).timestamp()
|
||||||
|
) # Calculate time-to-live for the token
|
||||||
|
|
||||||
|
if ttl > 0:
|
||||||
|
# Store the revoked token in Redis with an expiration time
|
||||||
|
await request.app.state.redis.set(
|
||||||
|
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
|
||||||
|
"1",
|
||||||
|
ex=ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_token_from_auth_header(auth_header: str):
|
def extract_token_from_auth_header(auth_header: str):
|
||||||
return auth_header[len("Bearer ") :]
|
return auth_header[len("Bearer ") :]
|
||||||
|
|
||||||
|
|
@ -206,7 +268,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
async def get_current_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
|
|
@ -225,30 +287,7 @@ def get_current_user(
|
||||||
|
|
||||||
# auth by api key
|
# auth by api key
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
if not request.state.enable_api_key:
|
user = get_current_user_by_api_key(request, token)
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
|
|
||||||
allowed_paths = [
|
|
||||||
path.strip()
|
|
||||||
for path in str(
|
|
||||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
|
|
||||||
).split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if the request path matches any allowed endpoint.
|
|
||||||
if not any(
|
|
||||||
request.url.path == allowed
|
|
||||||
or request.url.path.startswith(allowed + "/")
|
|
||||||
for allowed in allowed_paths
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
|
||||||
)
|
|
||||||
|
|
||||||
user = get_current_user_by_api_key(token)
|
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
current_span = trace.get_current_span()
|
current_span = trace.get_current_span()
|
||||||
|
|
@ -261,7 +300,6 @@ def get_current_user(
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# auth by jwt token
|
# auth by jwt token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
data = decode_token(token)
|
data = decode_token(token)
|
||||||
|
|
@ -272,6 +310,12 @@ def get_current_user(
|
||||||
)
|
)
|
||||||
|
|
||||||
if data is not None and "id" in data:
|
if data is not None and "id" in data:
|
||||||
|
if data.get("jti") and not await is_valid_token(request, data):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token",
|
||||||
|
)
|
||||||
|
|
||||||
user = Users.get_user_by_id(data["id"])
|
user = Users.get_user_by_id(data["id"])
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -300,9 +344,7 @@ def get_current_user(
|
||||||
# Refresh the user's last active timestamp asynchronously
|
# Refresh the user's last active timestamp asynchronously
|
||||||
# to prevent blocking the request
|
# to prevent blocking the request
|
||||||
if background_tasks:
|
if background_tasks:
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(Users.update_last_active_by_id, user.id)
|
||||||
Users.update_user_last_active_by_id, user.id
|
|
||||||
)
|
|
||||||
return user
|
return user
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -324,7 +366,7 @@ def get_current_user(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_by_api_key(api_key: str):
|
def get_current_user_by_api_key(request, api_key: str):
|
||||||
user = Users.get_user_by_api_key(api_key)
|
user = Users.get_user_by_api_key(api_key)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
|
|
@ -332,7 +374,19 @@ def get_current_user_by_api_key(api_key: str):
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if not request.state.enable_api_keys or (
|
||||||
|
user.role != "admin"
|
||||||
|
and not has_permission(
|
||||||
|
user.id,
|
||||||
|
"features.api_keys",
|
||||||
|
request.app.state.config.USER_PERMISSIONS,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||||
|
)
|
||||||
|
|
||||||
# Add user info to current span
|
# Add user info to current span
|
||||||
current_span = trace.get_current_span()
|
current_span = trace.get_current_span()
|
||||||
if current_span:
|
if current_span:
|
||||||
|
|
@ -341,8 +395,7 @@ def get_current_user_by_api_key(api_key: str):
|
||||||
current_span.set_attribute("client.user.role", user.role)
|
current_span.set_attribute("client.user.role", user.role)
|
||||||
current_span.set_attribute("client.auth.type", "api_key")
|
current_span.set_attribute("client.auth.type", "api_key")
|
||||||
|
|
||||||
Users.update_user_last_active_by_id(user.id)
|
Users.update_last_active_by_id(user.id)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ async def generate_direct_chat_completion(
|
||||||
event_caller = get_event_call(metadata)
|
event_caller = get_event_call(metadata)
|
||||||
|
|
||||||
channel = f"{user_id}:{session_id}:{request_id}"
|
channel = f"{user_id}:{session_id}:{request_id}"
|
||||||
|
logging.info(f"WebSocket channel: {channel}")
|
||||||
|
|
||||||
if form_data.get("stream"):
|
if form_data.get("stream"):
|
||||||
q = asyncio.Queue()
|
q = asyncio.Queue()
|
||||||
|
|
@ -121,7 +122,10 @@ async def generate_direct_chat_completion(
|
||||||
|
|
||||||
yield f"data: {json.dumps(data)}\n\n"
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
elif isinstance(data, str):
|
elif isinstance(data, str):
|
||||||
yield data
|
if "data:" in data:
|
||||||
|
yield f"{data}\n\n"
|
||||||
|
else:
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error in event generator: {e}")
|
log.debug(f"Error in event generator: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from open_webui.routers.images import (
|
from open_webui.routers.images import (
|
||||||
load_b64_image_data,
|
get_image_data,
|
||||||
upload_image,
|
upload_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -16,13 +16,18 @@ from open_webui.routers.files import upload_file_handler
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
BASE64_IMAGE_URL_PREFIX = re.compile(r"data:image/\w+;base64,", re.IGNORECASE)
|
||||||
|
MARKDOWN_IMAGE_URL_PATTERN = re.compile(r"!\[(.*?)\]\((.+?)\)", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_base64(request, base64_image_string, metadata, user):
|
def get_image_url_from_base64(request, base64_image_string, metadata, user):
|
||||||
if "data:image/png;base64" in base64_image_string:
|
if BASE64_IMAGE_URL_PREFIX.match(base64_image_string):
|
||||||
image_url = ""
|
image_url = ""
|
||||||
# Extract base64 image data from the line
|
# Extract base64 image data from the line
|
||||||
image_data, content_type = load_b64_image_data(base64_image_string)
|
image_data, content_type = get_image_data(base64_image_string)
|
||||||
if image_data is not None:
|
if image_data is not None:
|
||||||
image_url = upload_image(
|
image_url = upload_image(
|
||||||
request,
|
request,
|
||||||
|
|
@ -35,6 +40,19 @@ def get_image_url_from_base64(request, base64_image_string, metadata, user):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_markdown_base64_images(request, content: str, metadata, user):
|
||||||
|
def replace(match):
|
||||||
|
base64_string = match.group(2)
|
||||||
|
MIN_REPLACEMENT_URL_LENGTH = 1024
|
||||||
|
if len(base64_string) > MIN_REPLACEMENT_URL_LENGTH:
|
||||||
|
url = get_image_url_from_base64(request, base64_string, metadata, user)
|
||||||
|
if url:
|
||||||
|
return f""
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return MARKDOWN_IMAGE_URL_PATTERN.sub(replace, content)
|
||||||
|
|
||||||
|
|
||||||
def load_b64_audio_data(b64_str):
|
def load_b64_audio_data(b64_str):
|
||||||
try:
|
try:
|
||||||
if "," in b64_str:
|
if "," in b64_str:
|
||||||
|
|
|
||||||
24
backend/open_webui/utils/groups.py
Normal file
24
backend/open_webui/utils/groups.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import logging
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_default_group_assignment(
|
||||||
|
default_group_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply default group assignment to a user if default_group_id is provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_group_id: ID of the default group to add the user to
|
||||||
|
user_id: ID of the user to add to the default group
|
||||||
|
"""
|
||||||
|
if default_group_id:
|
||||||
|
try:
|
||||||
|
Groups.add_users_to_group(default_group_id, [user_id])
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to add user {user_id} to default group {default_group_id}: {e}"
|
||||||
|
)
|
||||||
11
backend/open_webui/utils/headers.py
Normal file
11
backend/open_webui/utils/headers.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
|
||||||
|
def include_user_info_headers(headers, user):
|
||||||
|
return {
|
||||||
|
**headers,
|
||||||
|
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
|
@ -2,6 +2,8 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import requests
|
||||||
|
import aiohttp
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -91,6 +93,25 @@ def get_images(ws, prompt, client_id, base_url, api_key):
|
||||||
return {"data": output_images}
|
return {"data": output_images}
|
||||||
|
|
||||||
|
|
||||||
|
async def comfyui_upload_image(image_file_item, base_url, api_key):
|
||||||
|
url = f"{base_url}/api/upload/image"
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
_, (filename, file_bytes, mime_type) = image_file_item
|
||||||
|
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("image", file_bytes, filename=filename, content_type=mime_type)
|
||||||
|
form.add_field("type", "input") # required by ComfyUI
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, data=form, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
|
||||||
class ComfyUINodeInput(BaseModel):
|
class ComfyUINodeInput(BaseModel):
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
node_ids: list[str] = []
|
node_ids: list[str] = []
|
||||||
|
|
@ -103,7 +124,7 @@ class ComfyUIWorkflow(BaseModel):
|
||||||
nodes: list[ComfyUINodeInput]
|
nodes: list[ComfyUINodeInput]
|
||||||
|
|
||||||
|
|
||||||
class ComfyUIGenerateImageForm(BaseModel):
|
class ComfyUICreateImageForm(BaseModel):
|
||||||
workflow: ComfyUIWorkflow
|
workflow: ComfyUIWorkflow
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
@ -116,8 +137,8 @@ class ComfyUIGenerateImageForm(BaseModel):
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
async def comfyui_generate_image(
|
async def comfyui_create_image(
|
||||||
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url, api_key
|
model: str, payload: ComfyUICreateImageForm, client_id, base_url, api_key
|
||||||
):
|
):
|
||||||
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||||
workflow = json.loads(payload.workflow.workflow)
|
workflow = json.loads(payload.workflow.workflow)
|
||||||
|
|
@ -191,3 +212,102 @@ async def comfyui_generate_image(
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIEditImageForm(BaseModel):
|
||||||
|
workflow: ComfyUIWorkflow
|
||||||
|
|
||||||
|
image: str | list[str]
|
||||||
|
prompt: str
|
||||||
|
width: Optional[int] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
|
||||||
|
steps: Optional[int] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def comfyui_edit_image(
|
||||||
|
model: str, payload: ComfyUIEditImageForm, client_id, base_url, api_key
|
||||||
|
):
|
||||||
|
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||||
|
workflow = json.loads(payload.workflow.workflow)
|
||||||
|
|
||||||
|
for node in payload.workflow.nodes:
|
||||||
|
if node.type:
|
||||||
|
if node.type == "model":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][node.key] = model
|
||||||
|
elif node.type == "image":
|
||||||
|
if isinstance(payload.image, list):
|
||||||
|
# check if multiple images are provided
|
||||||
|
for idx, node_id in enumerate(node.node_ids):
|
||||||
|
if idx < len(payload.image):
|
||||||
|
workflow[node_id]["inputs"][node.key] = payload.image[idx]
|
||||||
|
else:
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][node.key] = payload.image
|
||||||
|
elif node.type == "prompt":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "text"
|
||||||
|
] = payload.prompt
|
||||||
|
elif node.type == "negative_prompt":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "text"
|
||||||
|
] = payload.negative_prompt
|
||||||
|
elif node.type == "width":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "width"
|
||||||
|
] = payload.width
|
||||||
|
elif node.type == "height":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "height"
|
||||||
|
] = payload.height
|
||||||
|
elif node.type == "n":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "batch_size"
|
||||||
|
] = payload.n
|
||||||
|
elif node.type == "steps":
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][
|
||||||
|
node.key if node.key else "steps"
|
||||||
|
] = payload.steps
|
||||||
|
elif node.type == "seed":
|
||||||
|
seed = (
|
||||||
|
payload.seed
|
||||||
|
if payload.seed
|
||||||
|
else random.randint(0, 1125899906842624)
|
||||||
|
)
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][node.key] = seed
|
||||||
|
else:
|
||||||
|
for node_id in node.node_ids:
|
||||||
|
workflow[node_id]["inputs"][node.key] = node.value
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
|
||||||
|
log.info("WebSocket connection established.")
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Failed to connect to WebSocket server: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info("Sending workflow to WebSocket server.")
|
||||||
|
log.info(f"Workflow: {workflow}")
|
||||||
|
images = await asyncio.to_thread(
|
||||||
|
get_images, ws, workflow, client_id, base_url, api_key
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error while receiving images: {e}")
|
||||||
|
images = None
|
||||||
|
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
@ -11,25 +13,28 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAu
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.session: Optional[ClientSession] = None
|
self.session: Optional[ClientSession] = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = None
|
||||||
|
|
||||||
async def connect(self, url: str, headers: Optional[dict] = None):
|
async def connect(self, url: str, headers: Optional[dict] = None):
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
self._streams_context = streamablehttp_client(url, headers=headers)
|
self._streams_context = streamablehttp_client(url, headers=headers)
|
||||||
|
|
||||||
transport = await self.exit_stack.enter_async_context(self._streams_context)
|
transport = await exit_stack.enter_async_context(self._streams_context)
|
||||||
read_stream, write_stream, _ = transport
|
read_stream, write_stream, _ = transport
|
||||||
|
|
||||||
self._session_context = ClientSession(
|
self._session_context = ClientSession(
|
||||||
read_stream, write_stream
|
read_stream, write_stream
|
||||||
) # pylint: disable=W0201
|
) # pylint: disable=W0201
|
||||||
|
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await exit_stack.enter_async_context(
|
||||||
self._session_context
|
self._session_context
|
||||||
)
|
)
|
||||||
|
with anyio.fail_after(10):
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
|
self.exit_stack = exit_stack.pop_all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.disconnect()
|
await asyncio.shield(self.disconnect())
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def list_tool_specs(self) -> Optional[dict]:
|
async def list_tool_specs(self) -> Optional[dict]:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
|
||||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.folders import Folders
|
from open_webui.models.folders import Folders
|
||||||
|
|
@ -31,7 +32,6 @@ from open_webui.models.users import Users
|
||||||
from open_webui.socket.main import (
|
from open_webui.socket.main import (
|
||||||
get_event_call,
|
get_event_call,
|
||||||
get_event_emitter,
|
get_event_emitter,
|
||||||
get_active_status_by_user_id,
|
|
||||||
)
|
)
|
||||||
from open_webui.routers.tasks import (
|
from open_webui.routers.tasks import (
|
||||||
generate_queries,
|
generate_queries,
|
||||||
|
|
@ -40,12 +40,15 @@ from open_webui.routers.tasks import (
|
||||||
generate_image_prompt,
|
generate_image_prompt,
|
||||||
generate_chat_tags,
|
generate_chat_tags,
|
||||||
)
|
)
|
||||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
from open_webui.routers.retrieval import (
|
||||||
|
process_web_search,
|
||||||
|
SearchForm,
|
||||||
|
)
|
||||||
from open_webui.routers.images import (
|
from open_webui.routers.images import (
|
||||||
load_b64_image_data,
|
|
||||||
image_generations,
|
image_generations,
|
||||||
GenerateImageForm,
|
CreateImageForm,
|
||||||
upload_image,
|
image_edits,
|
||||||
|
EditImageForm,
|
||||||
)
|
)
|
||||||
from open_webui.routers.pipelines import (
|
from open_webui.routers.pipelines import (
|
||||||
process_pipeline_inlet_filter,
|
process_pipeline_inlet_filter,
|
||||||
|
|
@ -55,7 +58,7 @@ from open_webui.routers.memories import query_memory, QueryMemoryForm
|
||||||
|
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
from open_webui.utils.files import (
|
from open_webui.utils.files import (
|
||||||
get_audio_url_from_base64,
|
convert_markdown_base64_images,
|
||||||
get_file_url_from_base64,
|
get_file_url_from_base64,
|
||||||
get_image_url_from_base64,
|
get_image_url_from_base64,
|
||||||
)
|
)
|
||||||
|
|
@ -76,16 +79,19 @@ from open_webui.utils.task import (
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
deep_update,
|
deep_update,
|
||||||
|
extract_urls,
|
||||||
get_message_list,
|
get_message_list,
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
add_or_update_user_message,
|
add_or_update_user_message,
|
||||||
get_last_user_message,
|
get_last_user_message,
|
||||||
|
get_last_user_message_item,
|
||||||
get_last_assistant_message,
|
get_last_assistant_message,
|
||||||
get_system_message,
|
get_system_message,
|
||||||
prepend_to_first_user_message_content,
|
prepend_to_first_user_message_content,
|
||||||
convert_logit_bias_input_to_json,
|
convert_logit_bias_input_to_json,
|
||||||
|
get_content_from_message,
|
||||||
)
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools, get_updated_tool_function
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
from open_webui.utils.filter import (
|
from open_webui.utils.filter import (
|
||||||
get_sorted_filter_ids,
|
get_sorted_filter_ids,
|
||||||
|
|
@ -98,6 +104,7 @@ from open_webui.utils.mcp.client import MCPClient
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
CACHE_DIR,
|
CACHE_DIR,
|
||||||
|
DEFAULT_VOICE_MODE_PROMPT_TEMPLATE,
|
||||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
DEFAULT_CODE_INTERPRETER_PROMPT,
|
DEFAULT_CODE_INTERPRETER_PROMPT,
|
||||||
CODE_INTERPRETER_BLOCKED_MODULES,
|
CODE_INTERPRETER_BLOCKED_MODULES,
|
||||||
|
|
@ -105,6 +112,7 @@ from open_webui.config import (
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
GLOBAL_LOG_LEVEL,
|
GLOBAL_LOG_LEVEL,
|
||||||
|
ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION,
|
||||||
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE,
|
||||||
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES,
|
CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES,
|
||||||
BYPASS_MODEL_ACCESS_CONTROL,
|
BYPASS_MODEL_ACCESS_CONTROL,
|
||||||
|
|
@ -147,7 +155,7 @@ def process_tool_result(
|
||||||
if isinstance(tool_result, HTMLResponse):
|
if isinstance(tool_result, HTMLResponse):
|
||||||
content_disposition = tool_result.headers.get("Content-Disposition", "")
|
content_disposition = tool_result.headers.get("Content-Disposition", "")
|
||||||
if "inline" in content_disposition:
|
if "inline" in content_disposition:
|
||||||
content = tool_result.body.decode("utf-8")
|
content = tool_result.body.decode("utf-8", "replace")
|
||||||
tool_result_embeds.append(content)
|
tool_result_embeds.append(content)
|
||||||
|
|
||||||
if 200 <= tool_result.status_code < 300:
|
if 200 <= tool_result.status_code < 300:
|
||||||
|
|
@ -175,7 +183,7 @@ def process_tool_result(
|
||||||
"message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.",
|
"message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
tool_result = tool_result.body.decode("utf-8")
|
tool_result = tool_result.body.decode("utf-8", "replace")
|
||||||
|
|
||||||
elif (tool_type == "external" and isinstance(tool_result, tuple)) or (
|
elif (tool_type == "external" and isinstance(tool_result, tuple)) or (
|
||||||
direct_tool and isinstance(tool_result, list) and len(tool_result) == 2
|
direct_tool and isinstance(tool_result, list) and len(tool_result) == 2
|
||||||
|
|
@ -283,7 +291,7 @@ async def chat_completion_tools_handler(
|
||||||
content = None
|
content = None
|
||||||
if hasattr(response, "body_iterator"):
|
if hasattr(response, "body_iterator"):
|
||||||
async for chunk in response.body_iterator:
|
async for chunk in response.body_iterator:
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
data = json.loads(chunk.decode("utf-8", "replace"))
|
||||||
content = data["choices"][0]["message"]["content"]
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
# Cleanup any remaining background tasks if necessary
|
# Cleanup any remaining background tasks if necessary
|
||||||
|
|
@ -296,19 +304,27 @@ async def chat_completion_tools_handler(
|
||||||
def get_tools_function_calling_payload(messages, task_model_id, content):
|
def get_tools_function_calling_payload(messages, task_model_id, content):
|
||||||
user_message = get_last_user_message(messages)
|
user_message = get_last_user_message(messages)
|
||||||
|
|
||||||
|
if user_message and messages and messages[-1]["role"] == "user":
|
||||||
|
# Remove the last user message to avoid duplication
|
||||||
|
messages = messages[:-1]
|
||||||
|
|
||||||
recent_messages = messages[-4:] if len(messages) > 4 else messages
|
recent_messages = messages[-4:] if len(messages) > 4 else messages
|
||||||
chat_history = "\n".join(
|
chat_history = "\n".join(
|
||||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
f"{message['role'].upper()}: \"\"\"{get_content_from_message(message)}\"\"\""
|
||||||
for message in recent_messages
|
for message in recent_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"History:\n{chat_history}\nQuery: {user_message}"
|
prompt = (
|
||||||
|
f"History:\n{chat_history}\nQuery: {user_message}"
|
||||||
|
if chat_history
|
||||||
|
else f"Query: {user_message}"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": task_model_id,
|
"model": task_model_id,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": content},
|
{"role": "system", "content": content},
|
||||||
{"role": "user", "content": f"Query: {prompt}"},
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
||||||
|
|
@ -442,12 +458,6 @@ async def chat_completion_tools_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
|
||||||
f"Tool {tool_function_name} result: {tool_result}",
|
|
||||||
tool_result_files,
|
|
||||||
tool_result_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_result:
|
if tool_result:
|
||||||
tool = tools[tool_function_name]
|
tool = tools[tool_function_name]
|
||||||
tool_id = tool.get("tool_id", "")
|
tool_id = tool.get("tool_id", "")
|
||||||
|
|
@ -475,12 +485,6 @@ async def chat_completion_tools_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Citation is not enabled for this tool
|
|
||||||
body["messages"] = add_or_update_user_message(
|
|
||||||
f"\nTool `{tool_name}` Output: {tool_result}",
|
|
||||||
body["messages"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tools[tool_function_name]
|
tools[tool_function_name]
|
||||||
.get("metadata", {})
|
.get("metadata", {})
|
||||||
|
|
@ -712,10 +716,56 @@ async def chat_web_search_handler(
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_images(message_list):
|
||||||
|
images = []
|
||||||
|
for message in reversed(message_list):
|
||||||
|
images_flag = False
|
||||||
|
for file in message.get("files", []):
|
||||||
|
if file.get("type") == "image":
|
||||||
|
images.append(file.get("url"))
|
||||||
|
images_flag = True
|
||||||
|
|
||||||
|
if images_flag:
|
||||||
|
break
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_urls(delta_images, request, metadata, user) -> list[str]:
|
||||||
|
if not isinstance(delta_images, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
image_urls = []
|
||||||
|
for img in delta_images:
|
||||||
|
if not isinstance(img, dict) or img.get("type") != "image_url":
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = img.get("image_url", {}).get("url")
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if url.startswith("data:image/png;base64"):
|
||||||
|
url = get_image_url_from_base64(request, url, metadata, user)
|
||||||
|
|
||||||
|
image_urls.append(url)
|
||||||
|
|
||||||
|
return image_urls
|
||||||
|
|
||||||
|
|
||||||
async def chat_image_generation_handler(
|
async def chat_image_generation_handler(
|
||||||
request: Request, form_data: dict, extra_params: dict, user
|
request: Request, form_data: dict, extra_params: dict, user
|
||||||
):
|
):
|
||||||
|
metadata = extra_params.get("__metadata__", {})
|
||||||
|
chat_id = metadata.get("chat_id", None)
|
||||||
|
if not chat_id:
|
||||||
|
return form_data
|
||||||
|
|
||||||
__event_emitter__ = extra_params["__event_emitter__"]
|
__event_emitter__ = extra_params["__event_emitter__"]
|
||||||
|
|
||||||
|
if chat_id.startswith("local:"):
|
||||||
|
message_list = form_data.get("messages", [])
|
||||||
|
else:
|
||||||
|
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -723,48 +773,23 @@ async def chat_image_generation_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = form_data["messages"]
|
messages_map = chat.chat.get("history", {}).get("messages", {})
|
||||||
user_message = get_last_user_message(messages)
|
message_id = chat.chat.get("history", {}).get("currentId")
|
||||||
|
message_list = get_message_list(messages_map, message_id)
|
||||||
|
|
||||||
|
user_message = get_last_user_message(message_list)
|
||||||
|
|
||||||
prompt = user_message
|
prompt = user_message
|
||||||
negative_prompt = ""
|
input_images = get_last_images(message_list)
|
||||||
|
|
||||||
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
|
||||||
try:
|
|
||||||
res = await generate_image_prompt(
|
|
||||||
request,
|
|
||||||
{
|
|
||||||
"model": form_data["model"],
|
|
||||||
"messages": messages,
|
|
||||||
},
|
|
||||||
user,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = res["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
bracket_start = response.find("{")
|
|
||||||
bracket_end = response.rfind("}") + 1
|
|
||||||
|
|
||||||
if bracket_start == -1 or bracket_end == -1:
|
|
||||||
raise Exception("No JSON object found in the response")
|
|
||||||
|
|
||||||
response = response[bracket_start:bracket_end]
|
|
||||||
response = json.loads(response)
|
|
||||||
prompt = response.get("prompt", [])
|
|
||||||
except Exception as e:
|
|
||||||
prompt = user_message
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(e)
|
|
||||||
prompt = user_message
|
|
||||||
|
|
||||||
system_message_content = ""
|
system_message_content = ""
|
||||||
|
|
||||||
|
if len(input_images) > 0 and request.app.state.config.ENABLE_IMAGE_EDIT:
|
||||||
|
# Edit image(s)
|
||||||
try:
|
try:
|
||||||
images = await image_generations(
|
images = await image_edits(
|
||||||
request=request,
|
request=request,
|
||||||
form_data=GenerateImageForm(**{"prompt": prompt}),
|
form_data=EditImageForm(**{"prompt": prompt, "image": input_images}),
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -790,9 +815,17 @@ async def chat_image_generation_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
|
system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.debug(e)
|
||||||
|
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
if e.detail and isinstance(e.detail, dict):
|
||||||
|
error_message = e.detail.get("message", str(e.detail))
|
||||||
|
else:
|
||||||
|
error_message = str(e.detail)
|
||||||
|
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -803,7 +836,91 @@ async def chat_image_generation_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
|
system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}</context>"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create image(s)
|
||||||
|
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
|
||||||
|
try:
|
||||||
|
res = await generate_image_prompt(
|
||||||
|
request,
|
||||||
|
{
|
||||||
|
"model": form_data["model"],
|
||||||
|
"messages": form_data["messages"],
|
||||||
|
},
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = res["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
bracket_start = response.find("{")
|
||||||
|
bracket_end = response.rfind("}") + 1
|
||||||
|
|
||||||
|
if bracket_start == -1 or bracket_end == -1:
|
||||||
|
raise Exception("No JSON object found in the response")
|
||||||
|
|
||||||
|
response = response[bracket_start:bracket_end]
|
||||||
|
response = json.loads(response)
|
||||||
|
prompt = response.get("prompt", [])
|
||||||
|
except Exception as e:
|
||||||
|
prompt = user_message
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
prompt = user_message
|
||||||
|
|
||||||
|
try:
|
||||||
|
images = await image_generations(
|
||||||
|
request=request,
|
||||||
|
form_data=CreateImageForm(**{"prompt": prompt}),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
await __event_emitter__(
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"data": {"description": "Image created", "done": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await __event_emitter__(
|
||||||
|
{
|
||||||
|
"type": "files",
|
||||||
|
"data": {
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": image["url"],
|
||||||
|
}
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
system_message_content = "<context>The requested image has been created by the system successfully and is now being shown to the user. Let the user know that the image they requested has been generated and is now shown in the chat.</context>"
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(e)
|
||||||
|
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
if e.detail and isinstance(e.detail, dict):
|
||||||
|
error_message = e.detail.get("message", str(e.detail))
|
||||||
|
else:
|
||||||
|
error_message = str(e.detail)
|
||||||
|
|
||||||
|
await __event_emitter__(
|
||||||
|
{
|
||||||
|
"type": "status",
|
||||||
|
"data": {
|
||||||
|
"description": f"An error occurred while generating an image",
|
||||||
|
"done": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
system_message_content = f"<context>Image generation was attempted but failed because of an error. The system is currently unable to generate the image. Tell the user that the following error occurred: {error_message}</context>"
|
||||||
|
|
||||||
if system_message_content:
|
if system_message_content:
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data["messages"] = add_or_update_system_message(
|
||||||
|
|
@ -853,10 +970,6 @@ async def chat_completion_files_handler(
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if len(queries) == 0:
|
|
||||||
queries = [get_last_user_message(body["messages"])]
|
|
||||||
|
|
||||||
if not all_full_context:
|
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -868,13 +981,12 @@ async def chat_completion_files_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(queries) == 0:
|
||||||
|
queries = [get_last_user_message(body["messages"])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Offload get_sources_from_items to a separate thread
|
# Directly await async get_sources_from_items (no thread needed - fully async now)
|
||||||
loop = asyncio.get_running_loop()
|
sources = await get_sources_from_items(
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
sources = await loop.run_in_executor(
|
|
||||||
executor,
|
|
||||||
lambda: get_sources_from_items(
|
|
||||||
request=request,
|
request=request,
|
||||||
items=files,
|
items=files,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
|
|
@ -884,8 +996,8 @@ async def chat_completion_files_handler(
|
||||||
k=request.app.state.config.TOP_K,
|
k=request.app.state.config.TOP_K,
|
||||||
reranking_function=(
|
reranking_function=(
|
||||||
(
|
(
|
||||||
lambda sentences: request.app.state.RERANKING_FUNCTION(
|
lambda query, documents: request.app.state.RERANKING_FUNCTION(
|
||||||
sentences, user=user
|
query, documents, user=user
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if request.app.state.RERANKING_FUNCTION
|
if request.app.state.RERANKING_FUNCTION
|
||||||
|
|
@ -898,7 +1010,6 @@ async def chat_completion_files_handler(
|
||||||
full_context=all_full_context
|
full_context=all_full_context
|
||||||
or request.app.state.config.RAG_FULL_CONTEXT,
|
or request.app.state.config.RAG_FULL_CONTEXT,
|
||||||
user=user,
|
user=user,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -906,7 +1017,6 @@ async def chat_completion_files_handler(
|
||||||
log.debug(f"rag_contexts:sources: {sources}")
|
log.debug(f"rag_contexts:sources: {sources}")
|
||||||
|
|
||||||
unique_ids = set()
|
unique_ids = set()
|
||||||
|
|
||||||
for source in sources or []:
|
for source in sources or []:
|
||||||
if not source or len(source.keys()) == 0:
|
if not source or len(source.keys()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
@ -925,7 +1035,6 @@ async def chat_completion_files_handler(
|
||||||
unique_ids.add(_id)
|
unique_ids.add(_id)
|
||||||
|
|
||||||
sources_count = len(unique_ids)
|
sources_count = len(unique_ids)
|
||||||
|
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -999,16 +1108,16 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
log.debug(f"form_data: {form_data}")
|
log.debug(f"form_data: {form_data}")
|
||||||
|
|
||||||
system_message = get_system_message(form_data.get("messages", []))
|
system_message = get_system_message(form_data.get("messages", []))
|
||||||
if system_message:
|
if system_message: # Chat Controls/User Settings
|
||||||
try:
|
try:
|
||||||
form_data = apply_system_prompt_to_body(
|
form_data = apply_system_prompt_to_body(
|
||||||
system_message.get("content"), form_data, metadata, user
|
system_message.get("content"), form_data, metadata, user, replace=True
|
||||||
)
|
) # Required to handle system prompt variables
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
event_call = get_event_call(metadata)
|
event_caller = get_event_call(metadata)
|
||||||
|
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
try:
|
try:
|
||||||
|
|
@ -1022,14 +1131,13 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__event_emitter__": event_emitter,
|
"__event_emitter__": event_emitter,
|
||||||
"__event_call__": event_call,
|
"__event_call__": event_caller,
|
||||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||||
"__metadata__": metadata,
|
"__metadata__": metadata,
|
||||||
|
"__oauth_token__": oauth_token,
|
||||||
"__request__": request,
|
"__request__": request,
|
||||||
"__model__": model,
|
"__model__": model,
|
||||||
"__oauth_token__": oauth_token,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize events to store additional event to be sent to the client
|
# Initialize events to store additional event to be sent to the client
|
||||||
# Initialize contexts and citation
|
# Initialize contexts and citation
|
||||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||||
|
|
@ -1140,6 +1248,18 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
features = form_data.pop("features", None)
|
features = form_data.pop("features", None)
|
||||||
if features:
|
if features:
|
||||||
|
if "voice" in features and features["voice"]:
|
||||||
|
if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != None:
|
||||||
|
if request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE != "":
|
||||||
|
template = request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE
|
||||||
|
else:
|
||||||
|
template = DEFAULT_VOICE_MODE_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
form_data["messages"] = add_or_update_system_message(
|
||||||
|
template,
|
||||||
|
form_data["messages"],
|
||||||
|
)
|
||||||
|
|
||||||
if "memory" in features and features["memory"]:
|
if "memory" in features and features["memory"]:
|
||||||
form_data = await chat_memory_handler(
|
form_data = await chat_memory_handler(
|
||||||
request, form_data, extra_params, user
|
request, form_data, extra_params, user
|
||||||
|
|
@ -1168,8 +1288,28 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
tool_ids = form_data.pop("tool_ids", None)
|
tool_ids = form_data.pop("tool_ids", None)
|
||||||
files = form_data.pop("files", None)
|
files = form_data.pop("files", None)
|
||||||
|
|
||||||
# Remove files duplicates
|
prompt = get_last_user_message(form_data["messages"])
|
||||||
|
# TODO: re-enable URL extraction from prompt
|
||||||
|
# urls = []
|
||||||
|
# if prompt and len(prompt or "") < 500 and (not files or len(files) == 0):
|
||||||
|
# urls = extract_urls(prompt)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
|
if not files:
|
||||||
|
files = []
|
||||||
|
|
||||||
|
for file_item in files:
|
||||||
|
if file_item.get("type", "file") == "folder":
|
||||||
|
# Get folder files
|
||||||
|
folder_id = file_item.get("id", None)
|
||||||
|
if folder_id:
|
||||||
|
folder = Folders.get_folder_by_id_and_user_id(folder_id, user.id)
|
||||||
|
if folder and folder.data and "files" in folder.data:
|
||||||
|
files = [f for f in files if f.get("id", None) != folder_id]
|
||||||
|
files = [*files, *folder.data["files"]]
|
||||||
|
|
||||||
|
# files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
|
||||||
|
# Remove duplicate files based on their content
|
||||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
|
@ -1214,7 +1354,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
auth_type = mcp_server_connection.get("auth_type", "")
|
auth_type = mcp_server_connection.get("auth_type", "")
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
|
|
@ -1250,20 +1389,29 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
log.error(f"Error getting OAuth token: {e}")
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
|
|
||||||
|
connection_headers = mcp_server_connection.get("headers", None)
|
||||||
|
if connection_headers and isinstance(connection_headers, dict):
|
||||||
|
for key, value in connection_headers.items():
|
||||||
|
headers[key] = value
|
||||||
|
|
||||||
mcp_clients[server_id] = MCPClient()
|
mcp_clients[server_id] = MCPClient()
|
||||||
await mcp_clients[server_id].connect(
|
await mcp_clients[server_id].connect(
|
||||||
url=mcp_server_connection.get("url", ""),
|
url=mcp_server_connection.get("url", ""),
|
||||||
headers=headers if headers else None,
|
headers=headers if headers else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
function_name_filter_list = mcp_server_connection.get(
|
||||||
|
"config", {}
|
||||||
|
).get("function_name_filter_list", "")
|
||||||
|
|
||||||
|
if isinstance(function_name_filter_list, str):
|
||||||
|
function_name_filter_list = function_name_filter_list.split(",")
|
||||||
|
|
||||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||||
for tool_spec in tool_specs:
|
for tool_spec in tool_specs:
|
||||||
|
|
||||||
def make_tool_function(client, function_name):
|
def make_tool_function(client, function_name):
|
||||||
async def tool_function(**kwargs):
|
async def tool_function(**kwargs):
|
||||||
print(kwargs)
|
|
||||||
print(client)
|
|
||||||
print(await client.list_tool_specs())
|
|
||||||
return await client.call_tool(
|
return await client.call_tool(
|
||||||
function_name,
|
function_name,
|
||||||
function_args=kwargs,
|
function_args=kwargs,
|
||||||
|
|
@ -1271,6 +1419,13 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
return tool_function
|
return tool_function
|
||||||
|
|
||||||
|
if function_name_filter_list:
|
||||||
|
if not is_string_allowed(
|
||||||
|
tool_spec["name"], function_name_filter_list
|
||||||
|
):
|
||||||
|
# Skip this function
|
||||||
|
continue
|
||||||
|
|
||||||
tool_function = make_tool_function(
|
tool_function = make_tool_function(
|
||||||
mcp_clients[server_id], tool_spec["name"]
|
mcp_clients[server_id], tool_spec["name"]
|
||||||
)
|
)
|
||||||
|
|
@ -1287,6 +1442,17 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
|
if event_emitter:
|
||||||
|
await event_emitter(
|
||||||
|
{
|
||||||
|
"type": "chat:message:error",
|
||||||
|
"data": {
|
||||||
|
"error": {
|
||||||
|
"content": f"Failed to connect to MCP server '{server_id}'"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tools_dict = await get_tools(
|
tools_dict = await get_tools(
|
||||||
|
|
@ -1300,6 +1466,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
"__files__": metadata.get("files", []),
|
"__files__": metadata.get("files", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if mcp_tools_dict:
|
if mcp_tools_dict:
|
||||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||||
|
|
||||||
|
|
@ -1370,8 +1537,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
)
|
)
|
||||||
|
|
||||||
context_string = context_string.strip()
|
context_string = context_string.strip()
|
||||||
|
|
||||||
prompt = get_last_user_message(form_data["messages"])
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
raise Exception("No user message found")
|
raise Exception("No user message found")
|
||||||
|
|
||||||
|
|
@ -1410,10 +1575,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Final form_data:", form_data)
|
|
||||||
print("Final metadata:", metadata)
|
|
||||||
print("Final events:", events)
|
|
||||||
|
|
||||||
return form_data, metadata, events
|
return form_data, metadata, events
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1421,10 +1582,13 @@ async def process_chat_response(
|
||||||
request, response, form_data, user, metadata, model, events, tasks
|
request, response, form_data, user, metadata, model, events, tasks
|
||||||
):
|
):
|
||||||
async def background_tasks_handler():
|
async def background_tasks_handler():
|
||||||
|
message = None
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"):
|
||||||
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
|
messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"])
|
||||||
message = messages_map.get(metadata["message_id"]) if messages_map else None
|
message = messages_map.get(metadata["message_id"]) if messages_map else None
|
||||||
|
|
||||||
if message:
|
|
||||||
message_list = get_message_list(messages_map, metadata["message_id"])
|
message_list = get_message_list(messages_map, metadata["message_id"])
|
||||||
|
|
||||||
# Remove details tags and files from the messages.
|
# Remove details tags and files from the messages.
|
||||||
|
|
@ -1457,7 +1621,14 @@ async def process_chat_response(
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Local temp chat, get the model and message from the form_data
|
||||||
|
message = get_last_user_message_item(form_data.get("messages", []))
|
||||||
|
messages = form_data.get("messages", [])
|
||||||
|
if message:
|
||||||
|
message["model"] = form_data.get("model")
|
||||||
|
|
||||||
|
if message and "model" in message:
|
||||||
if tasks and messages:
|
if tasks and messages:
|
||||||
if (
|
if (
|
||||||
TASKS.FOLLOW_UP_GENERATION in tasks
|
TASKS.FOLLOW_UP_GENERATION in tasks
|
||||||
|
|
@ -1476,11 +1647,13 @@ async def process_chat_response(
|
||||||
|
|
||||||
if res and isinstance(res, dict):
|
if res and isinstance(res, dict):
|
||||||
if len(res.get("choices", [])) == 1:
|
if len(res.get("choices", [])) == 1:
|
||||||
follow_ups_string = (
|
response_message = res.get("choices", [])[0].get(
|
||||||
res.get("choices", [])[0]
|
"message", {}
|
||||||
.get("message", {})
|
|
||||||
.get("content", "")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
follow_ups_string = response_message.get(
|
||||||
|
"content"
|
||||||
|
) or response_message.get("reasoning_content", "")
|
||||||
else:
|
else:
|
||||||
follow_ups_string = ""
|
follow_ups_string = ""
|
||||||
|
|
||||||
|
|
@ -1493,15 +1666,6 @@ async def process_chat_response(
|
||||||
follow_ups = json.loads(follow_ups_string).get(
|
follow_ups = json.loads(follow_ups_string).get(
|
||||||
"follow_ups", []
|
"follow_ups", []
|
||||||
)
|
)
|
||||||
|
|
||||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
|
||||||
metadata["chat_id"],
|
|
||||||
metadata["message_id"],
|
|
||||||
{
|
|
||||||
"followUps": follow_ups,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:message:follow_ups",
|
"type": "chat:message:follow_ups",
|
||||||
|
|
@ -1510,16 +1674,29 @@ async def process_chat_response(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not metadata.get("chat_id", "").startswith("local:"):
|
||||||
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
|
metadata["chat_id"],
|
||||||
|
metadata["message_id"],
|
||||||
|
{
|
||||||
|
"followUps": follow_ups,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if not metadata.get("chat_id", "").startswith(
|
||||||
|
"local:"
|
||||||
|
): # Only update titles and tags for non-temp chats
|
||||||
if TASKS.TITLE_GENERATION in tasks:
|
if TASKS.TITLE_GENERATION in tasks:
|
||||||
user_message = get_last_user_message(messages)
|
user_message = get_last_user_message(messages)
|
||||||
if user_message and len(user_message) > 100:
|
if user_message and len(user_message) > 100:
|
||||||
user_message = user_message[:100] + "..."
|
user_message = user_message[:100] + "..."
|
||||||
|
|
||||||
|
title = None
|
||||||
if tasks[TASKS.TITLE_GENERATION]:
|
if tasks[TASKS.TITLE_GENERATION]:
|
||||||
|
|
||||||
res = await generate_title(
|
res = await generate_title(
|
||||||
request,
|
request,
|
||||||
{
|
{
|
||||||
|
|
@ -1532,12 +1709,16 @@ async def process_chat_response(
|
||||||
|
|
||||||
if res and isinstance(res, dict):
|
if res and isinstance(res, dict):
|
||||||
if len(res.get("choices", [])) == 1:
|
if len(res.get("choices", [])) == 1:
|
||||||
title_string = (
|
response_message = res.get("choices", [])[0].get(
|
||||||
res.get("choices", [])[0]
|
"message", {}
|
||||||
.get("message", {})
|
|
||||||
.get(
|
|
||||||
"content", message.get("content", user_message)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
title_string = (
|
||||||
|
response_message.get("content")
|
||||||
|
or response_message.get(
|
||||||
|
"reasoning_content",
|
||||||
|
)
|
||||||
|
or message.get("content", user_message)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
title_string = ""
|
title_string = ""
|
||||||
|
|
@ -1556,7 +1737,9 @@ async def process_chat_response(
|
||||||
if not title:
|
if not title:
|
||||||
title = messages[0].get("content", user_message)
|
title = messages[0].get("content", user_message)
|
||||||
|
|
||||||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
Chats.update_chat_title_by_id(
|
||||||
|
metadata["chat_id"], title
|
||||||
|
)
|
||||||
|
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
|
|
@ -1564,7 +1747,8 @@ async def process_chat_response(
|
||||||
"data": title,
|
"data": title,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif len(messages) == 2:
|
|
||||||
|
if title == None and len(messages) == 2:
|
||||||
title = messages[0].get("content", user_message)
|
title = messages[0].get("content", user_message)
|
||||||
|
|
||||||
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
Chats.update_chat_title_by_id(metadata["chat_id"], title)
|
||||||
|
|
@ -1589,11 +1773,13 @@ async def process_chat_response(
|
||||||
|
|
||||||
if res and isinstance(res, dict):
|
if res and isinstance(res, dict):
|
||||||
if len(res.get("choices", [])) == 1:
|
if len(res.get("choices", [])) == 1:
|
||||||
tags_string = (
|
response_message = res.get("choices", [])[0].get(
|
||||||
res.get("choices", [])[0]
|
"message", {}
|
||||||
.get("message", {})
|
|
||||||
.get("content", "")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tags_string = response_message.get(
|
||||||
|
"content"
|
||||||
|
) or response_message.get("reasoning_content", "")
|
||||||
else:
|
else:
|
||||||
tags_string = ""
|
tags_string = ""
|
||||||
|
|
||||||
|
|
@ -1642,7 +1828,9 @@ async def process_chat_response(
|
||||||
response.body, bytes
|
response.body, bytes
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
response_data = json.loads(response.body.decode("utf-8"))
|
response_data = json.loads(
|
||||||
|
response.body.decode("utf-8", "replace")
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
response_data = {
|
response_data = {
|
||||||
"error": {"detail": "Invalid JSON response"}
|
"error": {"detail": "Invalid JSON response"}
|
||||||
|
|
@ -1718,7 +1906,7 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send a webhook notification if the user is not active
|
# Send a webhook notification if the user is not active
|
||||||
if not get_active_status_by_user_id(user.id):
|
if not Users.is_user_active(user.id):
|
||||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
await post_webhook(
|
await post_webhook(
|
||||||
|
|
@ -1896,10 +2084,12 @@ async def process_chat_response(
|
||||||
content = f"{content}{tool_calls_display_content}"
|
content = f"{content}{tool_calls_display_content}"
|
||||||
|
|
||||||
elif block["type"] == "reasoning":
|
elif block["type"] == "reasoning":
|
||||||
reasoning_display_content = "\n".join(
|
reasoning_display_content = html.escape(
|
||||||
|
"\n".join(
|
||||||
(f"> {line}" if not line.startswith(">") else line)
|
(f"> {line}" if not line.startswith(">") else line)
|
||||||
for line in block["content"].splitlines()
|
for line in block["content"].splitlines()
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
reasoning_duration = block.get("duration", None)
|
reasoning_duration = block.get("duration", None)
|
||||||
|
|
||||||
|
|
@ -2276,7 +2466,11 @@ async def process_chat_response(
|
||||||
last_delta_data = None
|
last_delta_data = None
|
||||||
|
|
||||||
async for line in response.body_iterator:
|
async for line in response.body_iterator:
|
||||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
line = (
|
||||||
|
line.decode("utf-8", "replace")
|
||||||
|
if isinstance(line, bytes)
|
||||||
|
else line
|
||||||
|
)
|
||||||
data = line
|
data = line
|
||||||
|
|
||||||
# Skip empty lines
|
# Skip empty lines
|
||||||
|
|
@ -2302,7 +2496,9 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
if "event" in data:
|
if "event" in data and not getattr(
|
||||||
|
request.state, "direct", False
|
||||||
|
):
|
||||||
await event_emitter(data.get("event", {}))
|
await event_emitter(data.get("event", {}))
|
||||||
|
|
||||||
if "selected_model_id" in data:
|
if "selected_model_id" in data:
|
||||||
|
|
@ -2410,6 +2606,26 @@ async def process_chat_response(
|
||||||
"arguments"
|
"arguments"
|
||||||
] += delta_arguments
|
] += delta_arguments
|
||||||
|
|
||||||
|
image_urls = get_image_urls(
|
||||||
|
delta.get("images", []), request, metadata, user
|
||||||
|
)
|
||||||
|
if image_urls:
|
||||||
|
message_files = Chats.add_message_files_by_id_and_message_id(
|
||||||
|
metadata["chat_id"],
|
||||||
|
metadata["message_id"],
|
||||||
|
[
|
||||||
|
{"type": "image", "url": url}
|
||||||
|
for url in image_urls
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
await event_emitter(
|
||||||
|
{
|
||||||
|
"type": "files",
|
||||||
|
"data": {"files": message_files},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
value = delta.get("content")
|
value = delta.get("content")
|
||||||
|
|
||||||
reasoning_content = (
|
reasoning_content = (
|
||||||
|
|
@ -2468,6 +2684,11 @@ async def process_chat_response(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION:
|
||||||
|
value = convert_markdown_base64_images(
|
||||||
|
request, value, metadata, user
|
||||||
|
)
|
||||||
|
|
||||||
content = f"{content}{value}"
|
content = f"{content}{value}"
|
||||||
if not content_blocks:
|
if not content_blocks:
|
||||||
content_blocks.append(
|
content_blocks.append(
|
||||||
|
|
@ -2619,8 +2840,6 @@ async def process_chat_response(
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for tool_call in response_tool_calls:
|
for tool_call in response_tool_calls:
|
||||||
|
|
||||||
print("tool_call", tool_call)
|
|
||||||
tool_call_id = tool_call.get("id", "")
|
tool_call_id = tool_call.get("id", "")
|
||||||
tool_function_name = tool_call.get("function", {}).get(
|
tool_function_name = tool_call.get("function", {}).get(
|
||||||
"name", ""
|
"name", ""
|
||||||
|
|
@ -2695,7 +2914,16 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tool_function = tool["callable"]
|
tool_function = get_updated_tool_function(
|
||||||
|
function=tool["callable"],
|
||||||
|
extra_params={
|
||||||
|
"__messages__": form_data.get(
|
||||||
|
"messages", []
|
||||||
|
),
|
||||||
|
"__files__": metadata.get("files", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tool_result = await tool_function(
|
tool_result = await tool_function(
|
||||||
**tool_function_params
|
**tool_function_params
|
||||||
)
|
)
|
||||||
|
|
@ -2751,9 +2979,9 @@ async def process_chat_response(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_form_data = {
|
new_form_data = {
|
||||||
|
**form_data,
|
||||||
"model": model_id,
|
"model": model_id,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"tools": form_data["tools"],
|
|
||||||
"messages": [
|
"messages": [
|
||||||
*form_data["messages"],
|
*form_data["messages"],
|
||||||
*convert_content_blocks_to_messages(
|
*convert_content_blocks_to_messages(
|
||||||
|
|
@ -2927,6 +3155,7 @@ async def process_chat_response(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_form_data = {
|
new_form_data = {
|
||||||
|
**form_data,
|
||||||
"model": model_id,
|
"model": model_id,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": [
|
"messages": [
|
||||||
|
|
@ -2972,7 +3201,7 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send a webhook notification if the user is not active
|
# Send a webhook notification if the user is not active
|
||||||
if not get_active_status_by_user_id(user.id):
|
if not Users.is_user_active(user.id):
|
||||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
await post_webhook(
|
await post_webhook(
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,13 @@ import uuid
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Sequence, Union
|
||||||
import json
|
import json
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS, CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
@ -26,6 +27,49 @@ def deep_update(d, u):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def get_allow_block_lists(filter_list):
|
||||||
|
allow_list = []
|
||||||
|
block_list = []
|
||||||
|
|
||||||
|
if filter_list:
|
||||||
|
for d in filter_list:
|
||||||
|
if d.startswith("!"):
|
||||||
|
# Domains starting with "!" → blocked
|
||||||
|
block_list.append(d[1:].strip())
|
||||||
|
else:
|
||||||
|
# Domains starting without "!" → allowed
|
||||||
|
allow_list.append(d.strip())
|
||||||
|
|
||||||
|
return allow_list, block_list
|
||||||
|
|
||||||
|
|
||||||
|
def is_string_allowed(
|
||||||
|
string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a string is allowed based on the provided filter list.
|
||||||
|
:param string: The string or sequence of strings to check (e.g., domain or hostname).
|
||||||
|
:param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked.
|
||||||
|
:return: True if the string or sequence of strings is allowed, False otherwise.
|
||||||
|
"""
|
||||||
|
if not filter_list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
allow_list, block_list = get_allow_block_lists(filter_list)
|
||||||
|
strings = [string] if isinstance(string, str) else list(string)
|
||||||
|
|
||||||
|
# If allow list is non-empty, require domain to match one of them
|
||||||
|
if allow_list:
|
||||||
|
if not any(s.endswith(allowed) for s in strings for allowed in allow_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Block list always removes matches
|
||||||
|
if any(s.endswith(blocked) for s in strings for blocked in block_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_message_list(messages_map, message_id):
|
def get_message_list(messages_map, message_id):
|
||||||
"""
|
"""
|
||||||
Reconstructs a list of messages in order up to the specified message_id.
|
Reconstructs a list of messages in order up to the specified message_id.
|
||||||
|
|
@ -136,6 +180,14 @@ def update_message_content(message: dict, content: str, append: bool = True) ->
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def replace_system_message_content(content: str, messages: list[dict]) -> dict:
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
message["content"] = content
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def add_or_update_system_message(
|
def add_or_update_system_message(
|
||||||
content: str, messages: list[dict], append: bool = False
|
content: str, messages: list[dict], append: bool = False
|
||||||
):
|
):
|
||||||
|
|
@ -523,3 +575,76 @@ def throttle(interval: float = 10.0):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def extract_urls(text: str) -> list[str]:
|
||||||
|
# Regex pattern to match URLs
|
||||||
|
url_pattern = re.compile(
|
||||||
|
r"(https?://[^\s]+)", re.IGNORECASE
|
||||||
|
) # Matches http and https URLs
|
||||||
|
return url_pattern.findall(text)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chunks_handler(stream: aiohttp.StreamReader):
|
||||||
|
"""
|
||||||
|
Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit.
|
||||||
|
When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data
|
||||||
|
until encountering normally sized data.
|
||||||
|
|
||||||
|
:param stream: The stream reader to handle.
|
||||||
|
:return: An async generator that yields the stream data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE
|
||||||
|
if max_buffer_size is None or max_buffer_size <= 0:
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def yield_safe_stream_chunks():
|
||||||
|
buffer = b""
|
||||||
|
skip_mode = False
|
||||||
|
|
||||||
|
async for data, _ in stream.iter_chunks():
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line)
|
||||||
|
if skip_mode and len(buffer) > max_buffer_size:
|
||||||
|
buffer = b""
|
||||||
|
|
||||||
|
lines = (buffer + data).split(b"\n")
|
||||||
|
|
||||||
|
# Process complete lines (except the last possibly incomplete fragment)
|
||||||
|
for i in range(len(lines) - 1):
|
||||||
|
line = lines[i]
|
||||||
|
|
||||||
|
if skip_mode:
|
||||||
|
# Skip mode: check if current line is small enough to exit skip mode
|
||||||
|
if len(line) <= max_buffer_size:
|
||||||
|
skip_mode = False
|
||||||
|
yield line
|
||||||
|
else:
|
||||||
|
yield b"data: {}"
|
||||||
|
else:
|
||||||
|
# Normal mode: check if line exceeds limit
|
||||||
|
if len(line) > max_buffer_size:
|
||||||
|
skip_mode = True
|
||||||
|
yield b"data: {}"
|
||||||
|
log.info(f"Skip mode triggered, line size: {len(line)}")
|
||||||
|
else:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
# Save the last incomplete fragment
|
||||||
|
buffer = lines[-1]
|
||||||
|
|
||||||
|
# Check if buffer exceeds limit
|
||||||
|
if not skip_mode and len(buffer) > max_buffer_size:
|
||||||
|
skip_mode = True
|
||||||
|
log.info(f"Skip mode triggered, buffer size: {len(buffer)}")
|
||||||
|
# Clear oversized buffer to prevent unlimited growth
|
||||||
|
buffer = b""
|
||||||
|
|
||||||
|
# Process remaining buffer data
|
||||||
|
if buffer and not skip_mode:
|
||||||
|
yield buffer
|
||||||
|
|
||||||
|
return yield_safe_stream_chunks()
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,14 @@ import sys
|
||||||
from aiocache import cached
|
from aiocache import cached
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from open_webui.socket.utils import RedisDict
|
||||||
from open_webui.routers import openai, ollama
|
from open_webui.routers import openai, ollama
|
||||||
from open_webui.functions import get_function_models
|
from open_webui.functions import get_function_models
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.plugin import (
|
from open_webui.utils.plugin import (
|
||||||
|
|
@ -166,7 +168,8 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
action_ids = []
|
action_ids = []
|
||||||
filter_ids = []
|
filter_ids = []
|
||||||
|
|
||||||
if "info" in model and "meta" in model["info"]:
|
if "info" in model:
|
||||||
|
if "meta" in model["info"]:
|
||||||
action_ids.extend(
|
action_ids.extend(
|
||||||
model["info"]["meta"].get("actionIds", [])
|
model["info"]["meta"].get("actionIds", [])
|
||||||
)
|
)
|
||||||
|
|
@ -174,6 +177,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
model["info"]["meta"].get("filterIds", [])
|
model["info"]["meta"].get("filterIds", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "params" in model["info"]:
|
||||||
|
# Remove params to avoid exposing sensitive info
|
||||||
|
del model["info"]["params"]
|
||||||
|
|
||||||
model["action_ids"] = action_ids
|
model["action_ids"] = action_ids
|
||||||
model["filter_ids"] = filter_ids
|
model["filter_ids"] = filter_ids
|
||||||
else:
|
else:
|
||||||
|
|
@ -182,22 +189,45 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
elif custom_model.is_active and (
|
elif custom_model.is_active and (
|
||||||
custom_model.id not in [model["id"] for model in models]
|
custom_model.id not in [model["id"] for model in models]
|
||||||
):
|
):
|
||||||
|
# Custom model based on a base model
|
||||||
owned_by = "openai"
|
owned_by = "openai"
|
||||||
|
connection_type = None
|
||||||
|
|
||||||
pipe = None
|
pipe = None
|
||||||
|
|
||||||
|
for m in models:
|
||||||
|
if (
|
||||||
|
custom_model.base_model_id == m["id"]
|
||||||
|
or custom_model.base_model_id == m["id"].split(":")[0]
|
||||||
|
):
|
||||||
|
owned_by = m.get("owned_by", "unknown")
|
||||||
|
if "pipe" in m:
|
||||||
|
pipe = m["pipe"]
|
||||||
|
|
||||||
|
connection_type = m.get("connection_type", None)
|
||||||
|
break
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"id": f"{custom_model.id}",
|
||||||
|
"name": custom_model.name,
|
||||||
|
"object": "model",
|
||||||
|
"created": custom_model.created_at,
|
||||||
|
"owned_by": owned_by,
|
||||||
|
"connection_type": connection_type,
|
||||||
|
"preset": True,
|
||||||
|
**({"pipe": pipe} if pipe is not None else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
info = custom_model.model_dump()
|
||||||
|
if "params" in info:
|
||||||
|
# Remove params to avoid exposing sensitive info
|
||||||
|
del info["params"]
|
||||||
|
|
||||||
|
model["info"] = info
|
||||||
|
|
||||||
action_ids = []
|
action_ids = []
|
||||||
filter_ids = []
|
filter_ids = []
|
||||||
|
|
||||||
for model in models:
|
|
||||||
if (
|
|
||||||
custom_model.base_model_id == model["id"]
|
|
||||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
|
||||||
):
|
|
||||||
owned_by = model.get("owned_by", "unknown owner")
|
|
||||||
if "pipe" in model:
|
|
||||||
pipe = model["pipe"]
|
|
||||||
break
|
|
||||||
|
|
||||||
if custom_model.meta:
|
if custom_model.meta:
|
||||||
meta = custom_model.meta.model_dump()
|
meta = custom_model.meta.model_dump()
|
||||||
|
|
||||||
|
|
@ -207,20 +237,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
if "filterIds" in meta:
|
if "filterIds" in meta:
|
||||||
filter_ids.extend(meta["filterIds"])
|
filter_ids.extend(meta["filterIds"])
|
||||||
|
|
||||||
models.append(
|
model["action_ids"] = action_ids
|
||||||
{
|
model["filter_ids"] = filter_ids
|
||||||
"id": f"{custom_model.id}",
|
|
||||||
"name": custom_model.name,
|
models.append(model)
|
||||||
"object": "model",
|
|
||||||
"created": custom_model.created_at,
|
|
||||||
"owned_by": owned_by,
|
|
||||||
"info": custom_model.model_dump(),
|
|
||||||
"preset": True,
|
|
||||||
**({"pipe": pipe} if pipe is not None else {}),
|
|
||||||
"action_ids": action_ids,
|
|
||||||
"filter_ids": filter_ids,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process action_ids to get the actions
|
# Process action_ids to get the actions
|
||||||
def get_action_items_from_module(function, module):
|
def get_action_items_from_module(function, module):
|
||||||
|
|
@ -309,7 +329,12 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
|
|
||||||
log.debug(f"get_all_models() returned {len(models)} models")
|
log.debug(f"get_all_models() returned {len(models)} models")
|
||||||
|
|
||||||
request.app.state.MODELS = {model["id"]: model for model in models}
|
models_dict = {model["id"]: model for model in models}
|
||||||
|
if isinstance(request.app.state.MODELS, RedisDict):
|
||||||
|
request.app.state.MODELS.set(models_dict)
|
||||||
|
else:
|
||||||
|
request.app.state.MODELS = models_dict
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -343,6 +368,7 @@ def get_filtered_models(models, user):
|
||||||
or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
|
or (user.role == "admin" and not BYPASS_ADMIN_ACCESS_CONTROL)
|
||||||
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.get("arena"):
|
if model.get("arena"):
|
||||||
if has_access(
|
if has_access(
|
||||||
|
|
@ -351,6 +377,7 @@ def get_filtered_models(models, user):
|
||||||
access_control=model.get("info", {})
|
access_control=model.get("info", {})
|
||||||
.get("meta", {})
|
.get("meta", {})
|
||||||
.get("access_control", {}),
|
.get("access_control", {}),
|
||||||
|
user_group_ids=user_group_ids,
|
||||||
):
|
):
|
||||||
filtered_models.append(model)
|
filtered_models.append(model)
|
||||||
continue
|
continue
|
||||||
|
|
@ -364,6 +391,7 @@ def get_filtered_models(models, user):
|
||||||
user.id,
|
user.id,
|
||||||
type="read",
|
type="read",
|
||||||
access_control=model_info.access_control,
|
access_control=model_info.access_control,
|
||||||
|
user_group_ids=user_group_ids,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
filtered_models.append(model)
|
filtered_models.append(model)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
@ -13,7 +14,7 @@ import fnmatch
|
||||||
import time
|
import time
|
||||||
import secrets
|
import secrets
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from authlib.integrations.starlette_client import OAuth
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
|
@ -41,6 +42,8 @@ from open_webui.config import (
|
||||||
ENABLE_OAUTH_GROUP_MANAGEMENT,
|
ENABLE_OAUTH_GROUP_MANAGEMENT,
|
||||||
ENABLE_OAUTH_GROUP_CREATION,
|
ENABLE_OAUTH_GROUP_CREATION,
|
||||||
OAUTH_BLOCKED_GROUPS,
|
OAUTH_BLOCKED_GROUPS,
|
||||||
|
OAUTH_GROUPS_SEPARATOR,
|
||||||
|
OAUTH_ROLES_SEPARATOR,
|
||||||
OAUTH_ROLES_CLAIM,
|
OAUTH_ROLES_CLAIM,
|
||||||
OAUTH_SUB_CLAIM,
|
OAUTH_SUB_CLAIM,
|
||||||
OAUTH_GROUPS_CLAIM,
|
OAUTH_GROUPS_CLAIM,
|
||||||
|
|
@ -51,6 +54,7 @@ from open_webui.config import (
|
||||||
OAUTH_ADMIN_ROLES,
|
OAUTH_ADMIN_ROLES,
|
||||||
OAUTH_ALLOWED_DOMAINS,
|
OAUTH_ALLOWED_DOMAINS,
|
||||||
OAUTH_UPDATE_PICTURE_ON_LOGIN,
|
OAUTH_UPDATE_PICTURE_ON_LOGIN,
|
||||||
|
OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID,
|
||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
JWT_EXPIRES_IN,
|
JWT_EXPIRES_IN,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
|
|
@ -62,17 +66,28 @@ from open_webui.env import (
|
||||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
WEBUI_AUTH_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
ENABLE_OAUTH_ID_TOKEN_COOKIE,
|
||||||
|
ENABLE_OAUTH_EMAIL_FALLBACK,
|
||||||
OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
|
OAUTH_CLIENT_INFO_ENCRYPTION_KEY,
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import parse_duration
|
from open_webui.utils.misc import parse_duration
|
||||||
from open_webui.utils.auth import get_password_hash, create_token
|
from open_webui.utils.auth import get_password_hash, create_token
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
from open_webui.utils.groups import apply_default_group_assignment
|
||||||
|
|
||||||
from mcp.shared.auth import (
|
from mcp.shared.auth import (
|
||||||
OAuthClientMetadata,
|
OAuthClientMetadata as MCPOAuthClientMetadata,
|
||||||
OAuthMetadata,
|
OAuthMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from authlib.oauth2.rfc6749.errors import OAuth2Error
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientMetadata(MCPOAuthClientMetadata):
|
||||||
|
token_endpoint_auth_method: Literal[
|
||||||
|
"none", "client_secret_basic", "client_secret_post"
|
||||||
|
] = "client_secret_post"
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientInformationFull(OAuthClientMetadata):
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
||||||
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
||||||
|
|
@ -82,6 +97,8 @@ class OAuthClientInformationFull(OAuthClientMetadata):
|
||||||
client_id_issued_at: int | None = None
|
client_id_issued_at: int | None = None
|
||||||
client_secret_expires_at: int | None = None
|
client_secret_expires_at: int | None = None
|
||||||
|
|
||||||
|
server_metadata: Optional[OAuthMetadata] = None # Fetched from the OAuth server
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||||
|
|
||||||
|
|
@ -147,6 +164,37 @@ def decrypt_data(data: str):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _build_oauth_callback_error_message(e: Exception) -> str:
|
||||||
|
"""
|
||||||
|
Produce a user-facing callback error string with actionable context.
|
||||||
|
Keeps the message short and strips newlines for safe redirect usage.
|
||||||
|
"""
|
||||||
|
if isinstance(e, OAuth2Error):
|
||||||
|
parts = [p for p in [e.error, e.description] if p]
|
||||||
|
detail = " - ".join(parts)
|
||||||
|
elif isinstance(e, HTTPException):
|
||||||
|
detail = e.detail if isinstance(e.detail, str) else str(e.detail)
|
||||||
|
elif isinstance(e, aiohttp.ClientResponseError):
|
||||||
|
detail = f"Upstream provider returned {e.status}: {e.message}"
|
||||||
|
elif isinstance(e, aiohttp.ClientError):
|
||||||
|
detail = str(e)
|
||||||
|
elif isinstance(e, KeyError):
|
||||||
|
missing = str(e).strip("'")
|
||||||
|
if missing.lower() == "state":
|
||||||
|
detail = "Missing state parameter in callback (session may have expired)"
|
||||||
|
else:
|
||||||
|
detail = f"Missing expected key '{missing}' in OAuth response"
|
||||||
|
else:
|
||||||
|
detail = str(e)
|
||||||
|
|
||||||
|
detail = detail.replace("\n", " ").strip()
|
||||||
|
if not detail:
|
||||||
|
detail = e.__class__.__name__
|
||||||
|
|
||||||
|
message = f"OAuth callback failed: {detail}"
|
||||||
|
return message[:197] + "..." if len(message) > 200 else message
|
||||||
|
|
||||||
|
|
||||||
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a group name matches any blocked pattern.
|
Check if a group name matches any blocked pattern.
|
||||||
|
|
@ -200,22 +248,31 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
|
||||||
def get_discovery_urls(server_url) -> list[str]:
|
def get_discovery_urls(server_url) -> list[str]:
|
||||||
parsed, base_url = get_parsed_and_base_url(server_url)
|
parsed, base_url = get_parsed_and_base_url(server_url)
|
||||||
|
|
||||||
urls = [
|
urls = []
|
||||||
|
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
|
||||||
|
tenant = parsed.path.rstrip("/")
|
||||||
|
urls.extend(
|
||||||
|
[
|
||||||
|
urllib.parse.urljoin(
|
||||||
|
base_url,
|
||||||
|
f"/.well-known/oauth-authorization-server{tenant}",
|
||||||
|
),
|
||||||
|
urllib.parse.urljoin(
|
||||||
|
base_url, f"/.well-known/openid-configuration{tenant}"
|
||||||
|
),
|
||||||
|
urllib.parse.urljoin(
|
||||||
|
base_url, f"{tenant}/.well-known/openid-configuration"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
urls.extend(
|
||||||
|
[
|
||||||
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
|
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
|
||||||
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
|
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
|
||||||
]
|
]
|
||||||
|
|
||||||
if parsed.path and parsed.path != "/":
|
|
||||||
urls.append(
|
|
||||||
urllib.parse.urljoin(
|
|
||||||
base_url,
|
|
||||||
f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
urls.append(
|
|
||||||
urllib.parse.urljoin(
|
|
||||||
base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
@ -242,13 +299,12 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
response_types=["code"],
|
response_types=["code"],
|
||||||
token_endpoint_auth_method="client_secret_post",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
||||||
discovery_urls = get_discovery_urls(oauth_server_url)
|
discovery_urls = get_discovery_urls(oauth_server_url)
|
||||||
for url in discovery_urls:
|
for url in discovery_urls:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
url, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||||
) as oauth_server_metadata_response:
|
) as oauth_server_metadata_response:
|
||||||
|
|
@ -265,6 +321,17 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
oauth_client_metadata.scope = " ".join(
|
oauth_client_metadata.scope = " ".join(
|
||||||
oauth_server_metadata.scopes_supported
|
oauth_server_metadata.scopes_supported
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
oauth_server_metadata.token_endpoint_auth_methods_supported
|
||||||
|
and oauth_client_metadata.token_endpoint_auth_method
|
||||||
|
not in oauth_server_metadata.token_endpoint_auth_methods_supported
|
||||||
|
):
|
||||||
|
# Pick the first supported method from the server
|
||||||
|
oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error parsing OAuth metadata from {url}: {e}")
|
log.error(f"Error parsing OAuth metadata from {url}: {e}")
|
||||||
|
|
@ -284,7 +351,7 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform dynamic client registration and return client info
|
# Perform dynamic client registration and return client info
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||||
) as oauth_client_registration_response:
|
) as oauth_client_registration_response:
|
||||||
|
|
@ -292,10 +359,18 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
registration_response_json = (
|
registration_response_json = (
|
||||||
await oauth_client_registration_response.json()
|
await oauth_client_registration_response.json()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The mcp package requires optional unset values to be None. If an empty string is passed, it gets validated and fails.
|
||||||
|
# This replaces all empty strings with None.
|
||||||
|
registration_response_json = {
|
||||||
|
k: (None if v == "" else v)
|
||||||
|
for k, v in registration_response_json.items()
|
||||||
|
}
|
||||||
oauth_client_info = OAuthClientInformationFull.model_validate(
|
oauth_client_info = OAuthClientInformationFull.model_validate(
|
||||||
{
|
{
|
||||||
**registration_response_json,
|
**registration_response_json,
|
||||||
**{"issuer": oauth_server_metadata_url},
|
**{"issuer": oauth_server_metadata_url},
|
||||||
|
**{"server_metadata": oauth_server_metadata},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -331,20 +406,45 @@ class OAuthClientManager:
|
||||||
self.clients = {}
|
self.clients = {}
|
||||||
|
|
||||||
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull):
|
||||||
self.clients[client_id] = {
|
kwargs = {
|
||||||
"client": self.oauth.register(
|
"name": client_id,
|
||||||
name=client_id,
|
"client_id": oauth_client_info.client_id,
|
||||||
client_id=oauth_client_info.client_id,
|
"client_secret": oauth_client_info.client_secret,
|
||||||
client_secret=oauth_client_info.client_secret,
|
"client_kwargs": {
|
||||||
client_kwargs=(
|
**(
|
||||||
{"scope": oauth_client_info.scope}
|
{"scope": oauth_client_info.scope}
|
||||||
if oauth_client_info.scope
|
if oauth_client_info.scope
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
server_metadata_url=(
|
**(
|
||||||
|
{
|
||||||
|
"token_endpoint_auth_method": oauth_client_info.token_endpoint_auth_method
|
||||||
|
}
|
||||||
|
if oauth_client_info.token_endpoint_auth_method
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"server_metadata_url": (
|
||||||
oauth_client_info.issuer if oauth_client_info.issuer else None
|
oauth_client_info.issuer if oauth_client_info.issuer else None
|
||||||
),
|
),
|
||||||
),
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
oauth_client_info.server_metadata
|
||||||
|
and oauth_client_info.server_metadata.code_challenge_methods_supported
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
isinstance(
|
||||||
|
oauth_client_info.server_metadata.code_challenge_methods_supported,
|
||||||
|
list,
|
||||||
|
)
|
||||||
|
and "S256"
|
||||||
|
in oauth_client_info.server_metadata.code_challenge_methods_supported
|
||||||
|
):
|
||||||
|
kwargs["code_challenge_method"] = "S256"
|
||||||
|
|
||||||
|
self.clients[client_id] = {
|
||||||
|
"client": self.oauth.register(**kwargs),
|
||||||
"client_info": oauth_client_info,
|
"client_info": oauth_client_info,
|
||||||
}
|
}
|
||||||
return self.clients[client_id]
|
return self.clients[client_id]
|
||||||
|
|
@ -353,6 +453,82 @@ class OAuthClientManager:
|
||||||
if client_id in self.clients:
|
if client_id in self.clients:
|
||||||
del self.clients[client_id]
|
del self.clients[client_id]
|
||||||
log.info(f"Removed OAuth client {client_id}")
|
log.info(f"Removed OAuth client {client_id}")
|
||||||
|
|
||||||
|
if hasattr(self.oauth, "_clients"):
|
||||||
|
if client_id in self.oauth._clients:
|
||||||
|
self.oauth._clients.pop(client_id, None)
|
||||||
|
|
||||||
|
if hasattr(self.oauth, "_registry"):
|
||||||
|
if client_id in self.oauth._registry:
|
||||||
|
self.oauth._registry.pop(client_id, None)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _preflight_authorization_url(
|
||||||
|
self, client, client_info: OAuthClientInformationFull
|
||||||
|
) -> bool:
|
||||||
|
# TODO: Replace this logic with a more robust OAuth client registration validation
|
||||||
|
# Only perform preflight checks for Starlette OAuth clients
|
||||||
|
if not hasattr(client, "create_authorization_url"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
redirect_uri = None
|
||||||
|
if client_info.redirect_uris:
|
||||||
|
redirect_uri = str(client_info.redirect_uris[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
||||||
|
authorization_url = auth_data.get("url")
|
||||||
|
|
||||||
|
if not authorization_url:
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(
|
||||||
|
f"Skipping OAuth preflight for client {client_info.client_id}: {e}",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
|
async with session.get(
|
||||||
|
authorization_url,
|
||||||
|
allow_redirects=False,
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as resp:
|
||||||
|
if resp.status < 400:
|
||||||
|
return True
|
||||||
|
response_text = await resp.text()
|
||||||
|
|
||||||
|
error = None
|
||||||
|
error_description = ""
|
||||||
|
|
||||||
|
content_type = resp.headers.get("content-type", "")
|
||||||
|
if "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
payload = json.loads(response_text)
|
||||||
|
error = payload.get("error")
|
||||||
|
error_description = payload.get("error_description", "")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
error_description = response_text
|
||||||
|
|
||||||
|
error_message = f"{error or ''} {error_description or ''}".lower()
|
||||||
|
|
||||||
|
if any(
|
||||||
|
keyword in error_message
|
||||||
|
for keyword in ("invalid_client", "invalid client", "client id")
|
||||||
|
):
|
||||||
|
log.warning(
|
||||||
|
f"OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(
|
||||||
|
f"Skipping OAuth preflight network check for client {client_info.client_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_client(self, client_id):
|
def get_client(self, client_id):
|
||||||
|
|
@ -367,8 +543,8 @@ class OAuthClientManager:
|
||||||
if client_id in self.clients:
|
if client_id in self.clients:
|
||||||
client = self.clients[client_id]
|
client = self.clients[client_id]
|
||||||
return (
|
return (
|
||||||
client.server_metadata_url
|
client._server_metadata_url
|
||||||
if hasattr(client, "server_metadata_url")
|
if hasattr(client, "_server_metadata_url")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
@ -543,7 +719,6 @@ class OAuthClientManager:
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
||||||
client_info = self.get_client_info(client_id)
|
client_info = self.get_client_info(client_id)
|
||||||
if client_info is None:
|
if client_info is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
@ -551,7 +726,8 @@ class OAuthClientManager:
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
||||||
)
|
)
|
||||||
return await client.authorize_redirect(request, str(redirect_uri))
|
redirect_uri_str = str(redirect_uri) if redirect_uri else None
|
||||||
|
return await client.authorize_redirect(request, redirect_uri_str)
|
||||||
|
|
||||||
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
|
|
@ -560,7 +736,18 @@ class OAuthClientManager:
|
||||||
|
|
||||||
error_message = None
|
error_message = None
|
||||||
try:
|
try:
|
||||||
token = await client.authorize_access_token(request)
|
client_info = self.get_client_info(client_id)
|
||||||
|
|
||||||
|
auth_params = {}
|
||||||
|
if (
|
||||||
|
client_info
|
||||||
|
and hasattr(client_info, "client_id")
|
||||||
|
and hasattr(client_info, "client_secret")
|
||||||
|
):
|
||||||
|
auth_params["client_id"] = client_info.client_id
|
||||||
|
auth_params["client_secret"] = client_info.client_secret
|
||||||
|
|
||||||
|
token = await client.authorize_access_token(request, **auth_params)
|
||||||
if token:
|
if token:
|
||||||
try:
|
try:
|
||||||
# Add timestamp for tracking
|
# Add timestamp for tracking
|
||||||
|
|
@ -593,8 +780,14 @@ class OAuthClientManager:
|
||||||
error_message = "Failed to obtain OAuth token"
|
error_message = "Failed to obtain OAuth token"
|
||||||
log.warning(error_message)
|
log.warning(error_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = "OAuth callback error"
|
error_message = _build_oauth_callback_error_message(e)
|
||||||
log.warning(f"OAuth callback error: {e}")
|
log.warning(
|
||||||
|
"OAuth callback error for user_id=%s client_id=%s: %s",
|
||||||
|
user_id,
|
||||||
|
client_id,
|
||||||
|
error_message,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
redirect_url = (
|
redirect_url = (
|
||||||
str(request.app.state.config.WEBUI_URL or request.base_url)
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||||
|
|
@ -602,7 +795,9 @@ class OAuthClientManager:
|
||||||
|
|
||||||
if error_message:
|
if error_message:
|
||||||
log.debug(error_message)
|
log.debug(error_message)
|
||||||
redirect_url = f"{redirect_url}/?error={error_message}"
|
redirect_url = (
|
||||||
|
f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}"
|
||||||
|
)
|
||||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
||||||
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
@ -615,8 +810,14 @@ class OAuthManager:
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
self._clients = {}
|
self._clients = {}
|
||||||
for _, provider_config in OAUTH_PROVIDERS.items():
|
|
||||||
provider_config["register"](self.oauth)
|
for name, provider_config in OAUTH_PROVIDERS.items():
|
||||||
|
if "register" not in provider_config:
|
||||||
|
log.error(f"OAuth provider {name} missing register function")
|
||||||
|
continue
|
||||||
|
|
||||||
|
client = provider_config["register"](self.oauth)
|
||||||
|
self._clients[name] = client
|
||||||
|
|
||||||
def get_client(self, provider_name):
|
def get_client(self, provider_name):
|
||||||
if provider_name not in self._clients:
|
if provider_name not in self._clients:
|
||||||
|
|
@ -627,8 +828,8 @@ class OAuthManager:
|
||||||
if provider_name in self._clients:
|
if provider_name in self._clients:
|
||||||
client = self._clients[provider_name]
|
client = self._clients[provider_name]
|
||||||
return (
|
return (
|
||||||
client.server_metadata_url
|
client._server_metadata_url
|
||||||
if hasattr(client, "server_metadata_url")
|
if hasattr(client, "_server_metadata_url")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
@ -825,11 +1026,21 @@ class OAuthManager:
|
||||||
for nested_claim in nested_claims:
|
for nested_claim in nested_claims:
|
||||||
claim_data = claim_data.get(nested_claim, {})
|
claim_data = claim_data.get(nested_claim, {})
|
||||||
|
|
||||||
|
# Try flat claim structure as alternative
|
||||||
|
if not claim_data:
|
||||||
|
claim_data = user_data.get(oauth_claim, {})
|
||||||
|
|
||||||
oauth_roles = []
|
oauth_roles = []
|
||||||
|
|
||||||
if isinstance(claim_data, list):
|
if isinstance(claim_data, list):
|
||||||
oauth_roles = claim_data
|
oauth_roles = claim_data
|
||||||
if isinstance(claim_data, str) or isinstance(claim_data, int):
|
elif isinstance(claim_data, str):
|
||||||
|
# Split by the configured separator if present
|
||||||
|
if OAUTH_ROLES_SEPARATOR and OAUTH_ROLES_SEPARATOR in claim_data:
|
||||||
|
oauth_roles = claim_data.split(OAUTH_ROLES_SEPARATOR)
|
||||||
|
else:
|
||||||
|
oauth_roles = [claim_data]
|
||||||
|
elif isinstance(claim_data, int):
|
||||||
oauth_roles = [str(claim_data)]
|
oauth_roles = [str(claim_data)]
|
||||||
|
|
||||||
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
log.debug(f"Oauth Roles claim: {oauth_claim}")
|
||||||
|
|
@ -883,12 +1094,16 @@ class OAuthManager:
|
||||||
if isinstance(claim_data, list):
|
if isinstance(claim_data, list):
|
||||||
user_oauth_groups = claim_data
|
user_oauth_groups = claim_data
|
||||||
elif isinstance(claim_data, str):
|
elif isinstance(claim_data, str):
|
||||||
|
# Split by the configured separator if present
|
||||||
|
if OAUTH_GROUPS_SEPARATOR in claim_data:
|
||||||
|
user_oauth_groups = claim_data.split(OAUTH_GROUPS_SEPARATOR)
|
||||||
|
else:
|
||||||
user_oauth_groups = [claim_data]
|
user_oauth_groups = [claim_data]
|
||||||
else:
|
else:
|
||||||
user_oauth_groups = []
|
user_oauth_groups = []
|
||||||
|
|
||||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
all_available_groups: list[GroupModel] = Groups.get_all_groups()
|
||||||
|
|
||||||
# Create groups if they don't exist and creation is enabled
|
# Create groups if they don't exist and creation is enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||||
|
|
@ -932,7 +1147,7 @@ class OAuthManager:
|
||||||
|
|
||||||
# Refresh the list of all available groups if any were created
|
# Refresh the list of all available groups if any were created
|
||||||
if groups_created:
|
if groups_created:
|
||||||
all_available_groups = Groups.get_groups()
|
all_available_groups = Groups.get_all_groups()
|
||||||
log.debug("Refreshed list of all available groups after creation.")
|
log.debug("Refreshed list of all available groups after creation.")
|
||||||
|
|
||||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||||
|
|
@ -953,23 +1168,21 @@ class OAuthManager:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||||
)
|
)
|
||||||
|
Groups.remove_users_from_group(group_model.id, [user.id])
|
||||||
user_ids = group_model.user_ids
|
|
||||||
user_ids = [i for i in user_ids if i != user.id]
|
|
||||||
|
|
||||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||||
group_permissions = group_model.permissions
|
group_permissions = group_model.permissions
|
||||||
if not group_permissions:
|
if not group_permissions:
|
||||||
group_permissions = default_permissions
|
group_permissions = default_permissions
|
||||||
|
|
||||||
update_form = GroupUpdateForm(
|
Groups.update_group_by_id(
|
||||||
|
id=group_model.id,
|
||||||
|
form_data=GroupUpdateForm(
|
||||||
name=group_model.name,
|
name=group_model.name,
|
||||||
description=group_model.description,
|
description=group_model.description,
|
||||||
permissions=group_permissions,
|
permissions=group_permissions,
|
||||||
user_ids=user_ids,
|
),
|
||||||
)
|
overwrite=False,
|
||||||
Groups.update_group_by_id(
|
|
||||||
id=group_model.id, form_data=update_form, overwrite=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user to new groups
|
# Add user to new groups
|
||||||
|
|
@ -985,22 +1198,21 @@ class OAuthManager:
|
||||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ids = group_model.user_ids
|
Groups.add_users_to_group(group_model.id, [user.id])
|
||||||
user_ids.append(user.id)
|
|
||||||
|
|
||||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||||
group_permissions = group_model.permissions
|
group_permissions = group_model.permissions
|
||||||
if not group_permissions:
|
if not group_permissions:
|
||||||
group_permissions = default_permissions
|
group_permissions = default_permissions
|
||||||
|
|
||||||
update_form = GroupUpdateForm(
|
Groups.update_group_by_id(
|
||||||
|
id=group_model.id,
|
||||||
|
form_data=GroupUpdateForm(
|
||||||
name=group_model.name,
|
name=group_model.name,
|
||||||
description=group_model.description,
|
description=group_model.description,
|
||||||
permissions=group_permissions,
|
permissions=group_permissions,
|
||||||
user_ids=user_ids,
|
),
|
||||||
)
|
overwrite=False,
|
||||||
Groups.update_group_by_id(
|
|
||||||
id=group_model.id, form_data=update_form, overwrite=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_picture_url(
|
async def _process_picture_url(
|
||||||
|
|
@ -1067,10 +1279,26 @@ class OAuthManager:
|
||||||
error_message = None
|
error_message = None
|
||||||
try:
|
try:
|
||||||
client = self.get_client(provider)
|
client = self.get_client(provider)
|
||||||
|
|
||||||
|
auth_params = {}
|
||||||
|
|
||||||
|
if client:
|
||||||
|
if (
|
||||||
|
hasattr(client, "client_id")
|
||||||
|
and OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID
|
||||||
|
):
|
||||||
|
auth_params["client_id"] = client.client_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = await client.authorize_access_token(request)
|
token = await client.authorize_access_token(request, **auth_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"OAuth callback error: {e}")
|
detailed_error = _build_oauth_callback_error_message(e)
|
||||||
|
log.warning(
|
||||||
|
"OAuth callback error during authorize_access_token for provider %s: %s",
|
||||||
|
provider,
|
||||||
|
detailed_error,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
# Try to get userinfo from the token first, some providers include it there
|
# Try to get userinfo from the token first, some providers include it there
|
||||||
|
|
@ -1101,7 +1329,10 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
oauth_data = {}
|
||||||
|
oauth_data[provider] = {
|
||||||
|
"sub": sub,
|
||||||
|
}
|
||||||
|
|
||||||
# Email extraction
|
# Email extraction
|
||||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
|
|
@ -1147,11 +1378,13 @@ class OAuthManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Error fetching GitHub email: {e}")
|
log.warning(f"Error fetching GitHub email: {e}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
elif ENABLE_OAUTH_EMAIL_FALLBACK:
|
||||||
|
email = f"{provider}@{sub}.local"
|
||||||
else:
|
else:
|
||||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
email = email.lower()
|
|
||||||
|
|
||||||
|
email = email.lower()
|
||||||
# If allowed domains are configured, check if the email domain is in the list
|
# If allowed domains are configured, check if the email domain is in the list
|
||||||
if (
|
if (
|
||||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||||
|
|
@ -1164,7 +1397,7 @@ class OAuthManager:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
# Check if the user exists
|
# Check if the user exists
|
||||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
user = Users.get_user_by_oauth_sub(provider, sub)
|
||||||
if not user:
|
if not user:
|
||||||
# If the user does not exist, check if merging is enabled
|
# If the user does not exist, check if merging is enabled
|
||||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||||
|
|
@ -1172,12 +1405,15 @@ class OAuthManager:
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
if user:
|
if user:
|
||||||
# Update the user with the new oauth sub
|
# Update the user with the new oauth sub
|
||||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
Users.update_user_oauth_by_id(user.id, provider, sub)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
determined_role = self.get_user_role(user, user_data)
|
determined_role = self.get_user_role(user, user_data)
|
||||||
if user.role != determined_role:
|
if user.role != determined_role:
|
||||||
Users.update_user_role_by_id(user.id, determined_role)
|
Users.update_user_role_by_id(user.id, determined_role)
|
||||||
|
# Update the user object in memory as well,
|
||||||
|
# to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below
|
||||||
|
user.role = determined_role
|
||||||
# Update profile picture if enabled and different from current
|
# Update profile picture if enabled and different from current
|
||||||
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
||||||
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
||||||
|
|
@ -1228,7 +1464,7 @@ class OAuthManager:
|
||||||
name=name,
|
name=name,
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=self.get_user_role(None, user_data),
|
role=self.get_user_role(None, user_data),
|
||||||
oauth_sub=provider_sub,
|
oauth=oauth_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if auth_manager_config.WEBHOOK_URL:
|
if auth_manager_config.WEBHOOK_URL:
|
||||||
|
|
@ -1242,6 +1478,12 @@ class OAuthManager:
|
||||||
"user": user.model_dump_json(exclude_none=True),
|
"user": user.model_dump_json(exclude_none=True),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
apply_default_group_assignment(
|
||||||
|
request.app.state.config.DEFAULT_GROUP_ID,
|
||||||
|
user.id,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_403_FORBIDDEN,
|
status.HTTP_403_FORBIDDEN,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from open_webui.utils.task import prompt_template, prompt_variables_template
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
deep_update,
|
deep_update,
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
|
replace_system_message_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
@ -10,7 +11,11 @@ import json
|
||||||
|
|
||||||
# inplace function: form_data is modified
|
# inplace function: form_data is modified
|
||||||
def apply_system_prompt_to_body(
|
def apply_system_prompt_to_body(
|
||||||
system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
|
system: Optional[str],
|
||||||
|
form_data: dict,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
user=None,
|
||||||
|
replace: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not system:
|
if not system:
|
||||||
return form_data
|
return form_data
|
||||||
|
|
@ -24,9 +29,15 @@ def apply_system_prompt_to_body(
|
||||||
# Legacy (API Usage)
|
# Legacy (API Usage)
|
||||||
system = prompt_template(system, user)
|
system = prompt_template(system, user)
|
||||||
|
|
||||||
|
if replace:
|
||||||
|
form_data["messages"] = replace_system_message_content(
|
||||||
|
system, form_data.get("messages", [])
|
||||||
|
)
|
||||||
|
else:
|
||||||
form_data["messages"] = add_or_update_system_message(
|
form_data["messages"] = add_or_update_system_message(
|
||||||
system, form_data.get("messages", [])
|
system, form_data.get("messages", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -286,6 +297,10 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||||
if "tools" in openai_payload:
|
if "tools" in openai_payload:
|
||||||
ollama_payload["tools"] = openai_payload["tools"]
|
ollama_payload["tools"] = openai_payload["tools"]
|
||||||
|
|
||||||
|
if "max_tokens" in openai_payload:
|
||||||
|
ollama_payload["num_predict"] = openai_payload["max_tokens"]
|
||||||
|
del openai_payload["max_tokens"]
|
||||||
|
|
||||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||||
if openai_payload.get("options"):
|
if openai_payload.get("options"):
|
||||||
ollama_payload["options"] = openai_payload["options"]
|
ollama_payload["options"] = openai_payload["options"]
|
||||||
|
|
|
||||||
139
backend/open_webui/utils/rate_limit.py
Normal file
139
backend/open_webui/utils/rate_limit.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from open_webui.env import REDIS_KEY_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
General-purpose rate limiter using Redis with a rolling window strategy.
|
||||||
|
Falls back to in-memory storage if Redis is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# In-memory fallback storage
|
||||||
|
_memory_store: Dict[str, Dict[int, int]] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client,
|
||||||
|
limit: int,
|
||||||
|
window: int,
|
||||||
|
bucket_size: int = 60,
|
||||||
|
enabled: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param redis_client: Redis client instance or None
|
||||||
|
:param limit: Max allowed events in the window
|
||||||
|
:param window: Time window in seconds
|
||||||
|
:param bucket_size: Bucket resolution
|
||||||
|
:param enabled: Turn on/off rate limiting globally
|
||||||
|
"""
|
||||||
|
self.r = redis_client
|
||||||
|
self.limit = limit
|
||||||
|
self.window = window
|
||||||
|
self.bucket_size = bucket_size
|
||||||
|
self.num_buckets = window // bucket_size
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
def _bucket_key(self, key: str, bucket_index: int) -> str:
|
||||||
|
return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}"
|
||||||
|
|
||||||
|
def _current_bucket(self) -> int:
|
||||||
|
return int(time.time()) // self.bucket_size
|
||||||
|
|
||||||
|
def _redis_available(self) -> bool:
|
||||||
|
return self.r is not None
|
||||||
|
|
||||||
|
def is_limited(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Main rate-limit check.
|
||||||
|
Gracefully handles missing or failing Redis.
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._redis_available():
|
||||||
|
try:
|
||||||
|
return self._is_limited_redis(key)
|
||||||
|
except Exception:
|
||||||
|
return self._is_limited_memory(key)
|
||||||
|
else:
|
||||||
|
return self._is_limited_memory(key)
|
||||||
|
|
||||||
|
def get_count(self, key: str) -> int:
|
||||||
|
if not self.enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if self._redis_available():
|
||||||
|
try:
|
||||||
|
return self._get_count_redis(key)
|
||||||
|
except Exception:
|
||||||
|
return self._get_count_memory(key)
|
||||||
|
else:
|
||||||
|
return self._get_count_memory(key)
|
||||||
|
|
||||||
|
def remaining(self, key: str) -> int:
|
||||||
|
used = self.get_count(key)
|
||||||
|
return max(0, self.limit - used)
|
||||||
|
|
||||||
|
def _is_limited_redis(self, key: str) -> bool:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
bucket_key = self._bucket_key(key, now_bucket)
|
||||||
|
|
||||||
|
attempts = self.r.incr(bucket_key)
|
||||||
|
if attempts == 1:
|
||||||
|
self.r.expire(bucket_key, self.window + self.bucket_size)
|
||||||
|
|
||||||
|
# Collect buckets
|
||||||
|
buckets = [
|
||||||
|
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
counts = self.r.mget(buckets)
|
||||||
|
total = sum(int(c) for c in counts if c)
|
||||||
|
|
||||||
|
return total > self.limit
|
||||||
|
|
||||||
|
def _get_count_redis(self, key: str) -> int:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
buckets = [
|
||||||
|
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||||
|
]
|
||||||
|
counts = self.r.mget(buckets)
|
||||||
|
return sum(int(c) for c in counts if c)
|
||||||
|
|
||||||
|
def _is_limited_memory(self, key: str) -> bool:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
|
||||||
|
# Init storage
|
||||||
|
if key not in self._memory_store:
|
||||||
|
self._memory_store[key] = {}
|
||||||
|
|
||||||
|
store = self._memory_store[key]
|
||||||
|
|
||||||
|
# Increment bucket
|
||||||
|
store[now_bucket] = store.get(now_bucket, 0) + 1
|
||||||
|
|
||||||
|
# Drop expired buckets
|
||||||
|
min_bucket = now_bucket - self.num_buckets
|
||||||
|
expired = [b for b in store if b < min_bucket]
|
||||||
|
for b in expired:
|
||||||
|
del store[b]
|
||||||
|
|
||||||
|
# Count totals
|
||||||
|
total = sum(store.values())
|
||||||
|
return total > self.limit
|
||||||
|
|
||||||
|
def _get_count_memory(self, key: str) -> int:
|
||||||
|
now_bucket = self._current_bucket()
|
||||||
|
if key not in self._memory_store:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
store = self._memory_store[key]
|
||||||
|
min_bucket = now_bucket - self.num_buckets
|
||||||
|
|
||||||
|
# Remove expired
|
||||||
|
expired = [b for b in store if b < min_bucket]
|
||||||
|
for b in expired:
|
||||||
|
del store[b]
|
||||||
|
|
||||||
|
return sum(store.values())
|
||||||
|
|
@ -5,7 +5,13 @@ import logging
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
|
from open_webui.env import (
|
||||||
|
REDIS_CLUSTER,
|
||||||
|
REDIS_SENTINEL_HOSTS,
|
||||||
|
REDIS_SENTINEL_MAX_RETRY_COUNT,
|
||||||
|
REDIS_SENTINEL_PORT,
|
||||||
|
REDIS_URL,
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -108,6 +114,21 @@ def parse_redis_service_url(redis_url):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client(async_mode=False):
|
||||||
|
try:
|
||||||
|
return get_redis_connection(
|
||||||
|
redis_url=REDIS_URL,
|
||||||
|
redis_sentinels=get_sentinels_from_env(
|
||||||
|
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||||
|
),
|
||||||
|
redis_cluster=REDIS_CLUSTER,
|
||||||
|
async_mode=async_mode,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to get Redis client: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_redis_connection(
|
def get_redis_connection(
|
||||||
redis_url,
|
redis_url,
|
||||||
redis_sentinels,
|
redis_sentinels,
|
||||||
|
|
|
||||||
|
|
@ -208,20 +208,21 @@ def rag_template(template: str, context: str, query: str):
|
||||||
if "[query]" in context:
|
if "[query]" in context:
|
||||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||||
template = template.replace("[query]", query_placeholder)
|
template = template.replace("[query]", query_placeholder)
|
||||||
query_placeholders.append(query_placeholder)
|
query_placeholders.append((query_placeholder, "[query]"))
|
||||||
|
|
||||||
if "{{QUERY}}" in context:
|
if "{{QUERY}}" in context:
|
||||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||||
template = template.replace("{{QUERY}}", query_placeholder)
|
template = template.replace("{{QUERY}}", query_placeholder)
|
||||||
query_placeholders.append(query_placeholder)
|
query_placeholders.append((query_placeholder, "{{QUERY}}"))
|
||||||
|
|
||||||
template = template.replace("[context]", context)
|
template = template.replace("[context]", context)
|
||||||
template = template.replace("{{CONTEXT}}", context)
|
template = template.replace("{{CONTEXT}}", context)
|
||||||
|
|
||||||
template = template.replace("[query]", query)
|
template = template.replace("[query]", query)
|
||||||
template = template.replace("{{QUERY}}", query)
|
template = template.replace("{{QUERY}}", query)
|
||||||
|
|
||||||
for query_placeholder in query_placeholders:
|
for query_placeholder, original_placeholder in query_placeholders:
|
||||||
template = template.replace(query_placeholder, query)
|
template = template.replace(query_placeholder, original_placeholder)
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,6 @@ from open_webui.env import (
|
||||||
OTEL_METRICS_OTLP_SPAN_EXPORTER,
|
OTEL_METRICS_OTLP_SPAN_EXPORTER,
|
||||||
OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
OTEL_METRICS_EXPORTER_OTLP_INSECURE,
|
||||||
)
|
)
|
||||||
from open_webui.socket.main import get_active_user_ids
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
|
||||||
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
|
||||||
|
|
@ -99,6 +98,9 @@ def _build_meter_provider(resource: Resource) -> MeterProvider:
|
||||||
View(
|
View(
|
||||||
instrument_name="webui.users.active",
|
instrument_name="webui.users.active",
|
||||||
),
|
),
|
||||||
|
View(
|
||||||
|
instrument_name="webui.users.active.today",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
provider = MeterProvider(
|
provider = MeterProvider(
|
||||||
|
|
@ -132,7 +134,7 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None:
|
||||||
) -> Sequence[metrics.Observation]:
|
) -> Sequence[metrics.Observation]:
|
||||||
return [
|
return [
|
||||||
metrics.Observation(
|
metrics.Observation(
|
||||||
value=len(get_active_user_ids()),
|
value=Users.get_active_user_count(),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -159,6 +161,18 @@ def setup_metrics(app: FastAPI, resource: Resource) -> None:
|
||||||
callbacks=[observe_active_users],
|
callbacks=[observe_active_users],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def observe_users_active_today(
|
||||||
|
options: metrics.CallbackOptions,
|
||||||
|
) -> Sequence[metrics.Observation]:
|
||||||
|
return [metrics.Observation(value=Users.get_num_users_active_today())]
|
||||||
|
|
||||||
|
meter.create_observable_gauge(
|
||||||
|
name="webui.users.active.today",
|
||||||
|
description="Number of users active since midnight today",
|
||||||
|
unit="users",
|
||||||
|
callbacks=[observe_users_active_today],
|
||||||
|
)
|
||||||
|
|
||||||
# FastAPI middleware
|
# FastAPI middleware
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def _metrics_middleware(request: Request, call_next):
|
async def _metrics_middleware(request: Request, call_next):
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ from langchain_core.utils.function_calling import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
from open_webui.models.tools import Tools
|
from open_webui.models.tools import Tools
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id
|
from open_webui.utils.plugin import load_tool_module_by_id
|
||||||
|
|
@ -85,9 +86,26 @@ def get_async_tool_function_and_apply_extra_params(
|
||||||
update_wrapper(new_function, function)
|
update_wrapper(new_function, function)
|
||||||
new_function.__signature__ = new_sig
|
new_function.__signature__ = new_sig
|
||||||
|
|
||||||
|
new_function.__function__ = function # type: ignore
|
||||||
|
new_function.__extra_params__ = extra_params # type: ignore
|
||||||
|
|
||||||
return new_function
|
return new_function
|
||||||
|
|
||||||
|
|
||||||
|
def get_updated_tool_function(function: Callable, extra_params: dict):
|
||||||
|
# Get the original function and merge updated params
|
||||||
|
__function__ = getattr(function, "__function__", None)
|
||||||
|
__extra_params__ = getattr(function, "__extra_params__", None)
|
||||||
|
|
||||||
|
if __function__ is not None and __extra_params__ is not None:
|
||||||
|
return get_async_tool_function_and_apply_extra_params(
|
||||||
|
__function__,
|
||||||
|
{**__extra_params__, **extra_params},
|
||||||
|
)
|
||||||
|
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
async def get_tools(
|
async def get_tools(
|
||||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
|
|
@ -132,13 +150,28 @@ async def get_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
specs = tool_server_data.get("specs", [])
|
specs = tool_server_data.get("specs", [])
|
||||||
|
function_name_filter_list = tool_server_connection.get(
|
||||||
|
"config", {}
|
||||||
|
).get("function_name_filter_list", "")
|
||||||
|
|
||||||
|
if isinstance(function_name_filter_list, str):
|
||||||
|
function_name_filter_list = function_name_filter_list.split(",")
|
||||||
|
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
function_name = spec["name"]
|
function_name = spec["name"]
|
||||||
|
if function_name_filter_list:
|
||||||
|
if not is_string_allowed(
|
||||||
|
function_name, function_name_filter_list
|
||||||
|
):
|
||||||
|
# Skip this function
|
||||||
|
continue
|
||||||
|
|
||||||
auth_type = tool_server_connection.get("auth_type", "bearer")
|
auth_type = tool_server_connection.get("auth_type", "bearer")
|
||||||
|
|
||||||
cookies = {}
|
cookies = {}
|
||||||
headers = {}
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
if auth_type == "bearer":
|
if auth_type == "bearer":
|
||||||
headers["Authorization"] = (
|
headers["Authorization"] = (
|
||||||
|
|
@ -160,7 +193,10 @@ async def get_tools(
|
||||||
f"Bearer {oauth_token.get('access_token', '')}"
|
f"Bearer {oauth_token.get('access_token', '')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
headers["Content-Type"] = "application/json"
|
connection_headers = tool_server_connection.get("headers", None)
|
||||||
|
if connection_headers and isinstance(connection_headers, dict):
|
||||||
|
for key, value in connection_headers.items():
|
||||||
|
headers[key] = value
|
||||||
|
|
||||||
def make_tool_function(
|
def make_tool_function(
|
||||||
function_name, tool_server_data, headers
|
function_name, tool_server_data, headers
|
||||||
|
|
@ -215,14 +251,16 @@ async def get_tools(
|
||||||
module, _ = load_tool_module_by_id(tool_id)
|
module, _ = load_tool_module_by_id(tool_id)
|
||||||
request.app.state.TOOLS[tool_id] = module
|
request.app.state.TOOLS[tool_id] = module
|
||||||
|
|
||||||
extra_params["__id__"] = tool_id
|
__user__ = {
|
||||||
|
**extra_params["__user__"],
|
||||||
|
}
|
||||||
|
|
||||||
# Set valves for the tool
|
# Set valves for the tool
|
||||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||||
module.valves = module.Valves(**valves)
|
module.valves = module.Valves(**valves)
|
||||||
if hasattr(module, "UserValves"):
|
if hasattr(module, "UserValves"):
|
||||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
__user__["valves"] = module.UserValves( # type: ignore
|
||||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -244,7 +282,12 @@ async def get_tools(
|
||||||
function_name = spec["name"]
|
function_name = spec["name"]
|
||||||
tool_function = getattr(module, function_name)
|
tool_function = getattr(module, function_name)
|
||||||
callable = get_async_tool_function_and_apply_extra_params(
|
callable = get_async_tool_function_and_apply_extra_params(
|
||||||
tool_function, extra_params
|
tool_function,
|
||||||
|
{
|
||||||
|
**extra_params,
|
||||||
|
"__id__": tool_id,
|
||||||
|
"__user__": __user__,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Support Pydantic models as parameters
|
# TODO: Support Pydantic models as parameters
|
||||||
|
|
@ -544,20 +587,21 @@ async def get_tool_servers(request: Request):
|
||||||
return tool_servers
|
return tool_servers
|
||||||
|
|
||||||
|
|
||||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]:
|
||||||
headers = {
|
_headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
if token:
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
if headers:
|
||||||
|
_headers.update(headers)
|
||||||
|
|
||||||
error = None
|
error = None
|
||||||
try:
|
try:
|
||||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_body = await response.json()
|
error_body = await response.json()
|
||||||
|
|
@ -627,7 +671,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
|
||||||
openapi_path = server.get("path", "openapi.json")
|
openapi_path = server.get("path", "openapi.json")
|
||||||
spec_url = get_tool_server_url(server_url, openapi_path)
|
spec_url = get_tool_server_url(server_url, openapi_path)
|
||||||
# Fetch from URL
|
# Fetch from URL
|
||||||
task = get_tool_server_data(token, spec_url)
|
task = get_tool_server_data(
|
||||||
|
spec_url,
|
||||||
|
{"Authorization": f"Bearer {token}"} if token else None,
|
||||||
|
)
|
||||||
elif spec_type == "json" and server.get("spec", ""):
|
elif spec_type == "json" and server.get("spec", ""):
|
||||||
# Use provided JSON spec
|
# Use provided JSON spec
|
||||||
spec_json = None
|
spec_json = None
|
||||||
|
|
@ -748,17 +795,13 @@ async def execute_tool_server(
|
||||||
if operation.get("requestBody", {}).get("content"):
|
if operation.get("requestBody", {}).get("content"):
|
||||||
if params:
|
if params:
|
||||||
body_params = params
|
body_params = params
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Request body expected for operation '{name}' but none found."
|
|
||||||
)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||||
) as session:
|
) as session:
|
||||||
request_method = getattr(session, http_method.lower())
|
request_method = getattr(session, http_method.lower())
|
||||||
|
|
||||||
if http_method in ["post", "put", "patch"]:
|
if http_method in ["post", "put", "patch", "delete"]:
|
||||||
async with request_method(
|
async with request_method(
|
||||||
final_url,
|
final_url,
|
||||||
json=body_params,
|
json=body_params,
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ async def post_webhook(name: str, url: str, message: str, event_data: dict) -> b
|
||||||
payload = {**event_data}
|
payload = {**event_data}
|
||||||
|
|
||||||
log.debug(f"payload: {payload}")
|
log.debug(f"payload: {payload}")
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.post(url, json=payload) as r:
|
async with session.post(url, json=payload) as r:
|
||||||
r_text = await r.text()
|
r_text = await r.text()
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue