diff --git a/Dockerfile b/Dockerfile index 8104c72729..1e8361fd25 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,6 +22,9 @@ ARG OLLAMA_API_BASE_URL='/ollama/api' ENV ENV=prod ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL +ENV WEBUI_AUTH "" +ENV WEBUI_DB_URL "" +ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY" WORKDIR /app COPY --from=build /app/build /app/build diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index bb211cfe43..53daefac0a 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,4 +1,4 @@ -from flask import Flask, request, Response +from flask import Flask, request, Response, jsonify from flask_cors import CORS @@ -6,7 +6,10 @@ import requests import json -from config import OLLAMA_API_BASE_URL +from apps.web.models.users import Users +from constants import ERROR_MESSAGES +from utils.utils import extract_token_from_auth_header +from config import OLLAMA_API_BASE_URL, WEBUI_AUTH app = Flask(__name__) CORS( @@ -22,12 +25,40 @@ TARGET_SERVER_URL = OLLAMA_API_BASE_URL def proxy(path): # Combine the base URL of the target server with the requested path target_url = f"{TARGET_SERVER_URL}/{path}" - print(target_url) + print(path) # Get data from the original request data = request.get_data() headers = dict(request.headers) + # Basic RBAC support + if WEBUI_AUTH: + if "Authorization" in headers: + token = extract_token_from_auth_header(headers["Authorization"]) + user = Users.get_user_by_token(token) + if user: + # Only user and admin roles can access + if user.role in ["user", "admin"]: + if path in ["pull", "delete", "push", "copy", "create"]: + # Only admin role can perform actions above + if user.role == "admin": + pass + else: + return ( + jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), + 401, + ) + else: + pass + else: + return jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401 + else: + return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 + else: + return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 + else: + pass + # Make a request to the target server target_response = requests.request( method=request.method, diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py new file mode 100644 index 0000000000..0380f0c34e --- /dev/null +++ b/backend/apps/web/main.py @@ -0,0 +1,26 @@ +from fastapi import FastAPI, Request, Depends, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +from apps.web.routers import auths, users +from config import WEBUI_VERSION, WEBUI_AUTH + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +app.include_router(auths.router, prefix="/auths", tags=["auths"]) +app.include_router(users.router, prefix="/users", tags=["users"]) + + +@app.get("/") +async def get_status(): + return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH} diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py new file mode 100644 index 0000000000..41c82efdbe --- /dev/null +++ b/backend/apps/web/models/auths.py @@ -0,0 +1,103 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +import time +import uuid + + +from apps.web.models.users import UserModel, Users +from utils.utils import ( + verify_password, + get_password_hash, + bearer_scheme, + create_token, +) + +import config + +DB = config.DB + +#################### +# DB MODEL +#################### + + +class AuthModel(BaseModel): + id: str + email: str + password: str + active: bool = True + + +#################### +# Forms +#################### + + +class Token(BaseModel): + token: str + token_type: str + + +class UserResponse(BaseModel): + id: str + email: str + name: str + role: str + profile_image_url: str + + +class SigninResponse(Token, UserResponse): + pass + + +class SigninForm(BaseModel): + email: str + password: str + + +class SignupForm(BaseModel): + name: str + email: str + password: str + + +class AuthsTable: + def __init__(self, db): + self.db = db + self.table = db.auths + + def insert_new_auth( + self, email: str, password: str, name: str, role: str = "pending" + ) -> Optional[UserModel]: + print("insert_new_auth") + + id = str(uuid.uuid4()) + + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = self.table.insert_one(auth.model_dump()) + user = Users.insert_new_user(id, name, email, role) + + print(result, user) + if result and user: + return user + else: + return None + + def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: + print("authenticate_user") + + auth = self.table.find_one({"email": email, "active": True}) + + if auth: + if verify_password(password, auth["password"]): + user = self.db.users.find_one({"id": auth["id"]}) + return UserModel(**user) + else: + return None + else: + return None + + +Auths = AuthsTable(DB) diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py new file mode 100644 index 0000000000..4dc3fc7a6d --- /dev/null +++ b/backend/apps/web/models/users.py @@ -0,0 +1,97 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +from pymongo import ReturnDocument +import time + +from utils.utils import decode_token +from utils.misc import get_gravatar_url + +from config import DB + +#################### +# User DB Schema +#################### + + +class UserModel(BaseModel): + id: str + name: str + email: str + role: str = "pending" + profile_image_url: str = "/user.png" + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class UserRoleUpdateForm(BaseModel): + id: str + role: str + + +class UsersTable: + def __init__(self, db): + self.db = db + self.table = db.users + + def insert_new_user( + self, id: str, name: str, email: str, role: str = "pending" + ) -> Optional[UserModel]: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": get_gravatar_url(email), + "created_at": int(time.time()), + } + ) + result = self.table.insert_one(user.model_dump()) + + if result: + return user + else: + return None + + def get_user_by_email(self, email: str) -> Optional[UserModel]: + user = self.table.find_one({"email": email}, {"_id": False}) + + if user: + return UserModel(**user) + else: + return None + + def get_user_by_token(self, token: str) -> Optional[UserModel]: + data = decode_token(token) + + if data != None and "email" in data: + return self.get_user_by_email(data["email"]) + else: + return None + + def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: + return [ + UserModel(**user) + for user in list( + self.table.find({}, {"_id": False}).skip(skip).limit(limit) + ) + ] + + def get_num_users(self) -> Optional[int]: + return self.table.count_documents({}) + + def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + user = self.table.find_one_and_update( + {"id": id}, {"$set": updated}, return_document=ReturnDocument.AFTER + ) + return UserModel(**user) + + def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: + return self.update_user_by_id(id, {"role": role}) + + +Users = UsersTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py new file mode 100644 index 0000000000..023e1914d2 --- /dev/null +++ b/backend/apps/web/routers/auths.py @@ -0,0 +1,111 @@ +from fastapi import Response +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union + +from fastapi import APIRouter +from pydantic import BaseModel +import time +import uuid + +from apps.web.models.auths import ( + SigninForm, + SignupForm, + UserResponse, + SigninResponse, + Auths, +) +from apps.web.models.users import Users + + +from utils.utils import ( + get_password_hash, + bearer_scheme, + create_token, +) +from utils.misc import get_gravatar_url +from constants import ERROR_MESSAGES + + +router = APIRouter() + +############################ +# GetSessionUser +############################ + + +@router.get("/", response_model=UserResponse) +async def get_session_user(cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + if user: + return { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + + +############################ +# SignIn +############################ + + +@router.post("/signin", response_model=SigninResponse) +async def signin(form_data: SigninForm): + user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + if user: + token = create_token(data={"email": user.email}) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + +############################ +# SignUp +############################ + + +@router.post("/signup", response_model=SigninResponse) +async def signup(form_data: SignupForm): + if not Users.get_user_by_email(form_data.email.lower()): + try: + role = "admin" if Users.get_num_users() == 0 else "pending" + hashed = get_password_hash(form_data.password) + user = Auths.insert_new_auth(form_data.email, hashed, form_data.name, role) + + if user: + token = create_token(data={"email": user.email}) + # response.set_cookie(key='token', value=token, httponly=True) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + } + else: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + except Exception as err: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT()) diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py new file mode 100644 index 0000000000..08437bd34b --- /dev/null +++ b/backend/apps/web/routers/users.py @@ -0,0 +1,75 @@ +from fastapi import Response +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import time +import uuid + +from apps.web.models.users import UserModel, UserRoleUpdateForm, Users + +from utils.utils import ( + get_password_hash, + bearer_scheme, + create_token, +) +from constants import ERROR_MESSAGES + +router = APIRouter() + +############################ +# GetUsers +############################ + + +@router.get("/", response_model=List[UserModel]) +async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + if user.role == "admin": + return Users.get_users(skip, limit) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + + +############################ +# UpdateUserRole +############################ + + +@router.post("/update/role", response_model=Optional[UserModel]) +async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + if user.role == "admin": + if user.id != form_data.id: + return Users.update_user_role_by_id(form_data.id, form_data.role) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) diff --git a/backend/config.py b/backend/config.py index 2a33818d9a..33e5a25be0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,11 +1,24 @@ -import sys -import os from dotenv import load_dotenv, find_dotenv +from pymongo import MongoClient +from constants import ERROR_MESSAGES + +from secrets import token_bytes +from base64 import b64encode +import os + load_dotenv(find_dotenv()) +#################################### +# ENV (dev,test,prod) +#################################### + ENV = os.environ.get("ENV", "dev") +#################################### +# OLLAMA_API_BASE_URL +#################################### + OLLAMA_API_BASE_URL = os.environ.get( "OLLAMA_API_BASE_URL", "http://localhost:11434/api" ) @@ -13,3 +26,41 @@ OLLAMA_API_BASE_URL = os.environ.get( if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" + +#################################### +# WEBUI_VERSION +#################################### + +WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.11") + +#################################### +# WEBUI_AUTH +#################################### + + +WEBUI_AUTH = True if os.environ.get("WEBUI_AUTH", "TRUE") == "TRUE" else False + + +#################################### +# WEBUI_DB +#################################### + + +WEBUI_DB_URL = os.environ.get("WEBUI_DB_URL", "mongodb://root:root@localhost:27017/") + +if WEBUI_AUTH and WEBUI_DB_URL == "": + raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) + + +DB_CLIENT = MongoClient(f"{WEBUI_DB_URL}?authSource=admin") +DB = DB_CLIENT["ollama-webui"] + + +#################################### +# WEBUI_JWT_SECRET_KEY +#################################### + +WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") + +if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": + raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) diff --git a/backend/constants.py b/backend/constants.py new file mode 100644 index 0000000000..b383957b34 --- /dev/null +++ b/backend/constants.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class MESSAGES(str, Enum): + DEFAULT = lambda msg="": f"{msg if msg else ''}" + + +class ERROR_MESSAGES(str, Enum): + def __str__(self) -> str: + return super().__str__() + + DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" + ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." + INVALID_TOKEN = ( + "Your session has expired or the token is invalid. Please sign in again." + ) + INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." + UNAUTHORIZED = "401 Unauthorized" + ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." + ACTION_PROHIBITED = ( + "The requested action has been restricted as a security measure." + ) + USER_NOT_FOUND = "We could not find what you're looking for :/" + MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/backend/main.py b/backend/main.py index 2851df20f0..24bad0c92f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,16 +1,14 @@ -import time -import sys - from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles - from fastapi import HTTPException -from starlette.exceptions import HTTPException as StarletteHTTPException - from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware +from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app +from apps.web.main import app as webui_app + +import time class SPAStaticFiles(StaticFiles): @@ -47,5 +45,6 @@ async def check_url(request: Request, call_next): return response +app.mount("/api/v1", webui_app) app.mount("/ollama/api", WSGIMiddleware(ollama_app)) app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/backend/utils/misc.py b/backend/utils/misc.py new file mode 100644 index 0000000000..e4011b34c9 --- /dev/null +++ b/backend/utils/misc.py @@ -0,0 +1,15 @@ +import hashlib + + +def get_gravatar_url(email): + # Trim leading and trailing whitespace from + # an email address and force all characters + # to lower case + address = str(email).strip().lower() + + # Create a SHA256 hash of the final string + hash_object = hashlib.sha256(address.encode()) + hash_hex = hash_object.hexdigest() + + # Grab the actual image URL + return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp" diff --git a/backend/utils/utils.py b/backend/utils/utils.py new file mode 100644 index 0000000000..62e6958fb2 --- /dev/null +++ b/backend/utils/utils.py @@ -0,0 +1,68 @@ +from fastapi.security import HTTPBasicCredentials, HTTPBearer +from pydantic import BaseModel +from typing import Union, Optional + +from passlib.context import CryptContext +from datetime import datetime, timedelta +import requests +import jwt + +import config + +JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY +ALGORITHM = "HS256" + +############## +# Auth Utils +############## + +bearer_scheme = HTTPBearer() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password, hashed_password): + return ( + pwd_context.verify(plain_password, hashed_password) if hashed_password else None + ) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: + payload = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + payload.update({"exp": expire}) + + encoded_jwt = jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> Optional[dict]: + try: + decoded = jwt.decode(token, JWT_SECRET_KEY, options={"verify_signature": False}) + return decoded + except Exception as e: + return None + + +def extract_token_from_auth_header(auth_header: str): + return auth_header[len("Bearer ") :] + + +def verify_token(request): + try: + bearer = request.headers["authorization"] + if bearer: + token = bearer[len("Bearer ") :] + decoded = jwt.decode( + token, JWT_SECRET_KEY, options={"verify_signature": False} + ) + return decoded + else: + return None + except Exception as e: + return None diff --git a/compose.yaml b/compose.yaml index b503635429..39366c6e75 100644 --- a/compose.yaml +++ b/compose.yaml @@ -22,6 +22,17 @@ services: restart: unless-stopped image: ollama/ollama:latest + + # Uncomment below for WIP: Auth support + # ollama-webui-db: + # image: mongo + # container_name: ollama-webui-db + # restart: always + # # Make sure to change the username/password! + # environment: + # MONGO_INITDB_ROOT_USERNAME: root + # MONGO_INITDB_ROOT_PASSWORD: example + ollama-webui: build: context: . @@ -32,10 +43,16 @@ services: container_name: ollama-webui depends_on: - ollama + # Uncomment below for WIP: Auth support + # - ollama-webui-db ports: - 3000:8080 environment: - "OLLAMA_API_BASE_URL=http://ollama:11434/api" + # Uncomment below for WIP: Auth support + # - "WEBUI_AUTH=TRUE" + # - "WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/" + # - "WEBUI_JWT_SECRET_KEY=SECRET_KEY" extra_hosts: - host.docker.internal:host-gateway restart: unless-stopped diff --git a/src/app.css b/src/app.css index 7afbeafe8b..27d2d7ef5c 100644 --- a/src/app.css +++ b/src/app.css @@ -4,8 +4,13 @@ font-display: swap; } +@font-face { + font-family: 'Mona Sans'; + src: url('/assets/fonts/Mona-Sans.woff2'); + font-display: swap; +} + html { - @apply bg-gray-800; word-break: break-word; } diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte new file mode 100644 index 0000000000..dcd464291d --- /dev/null +++ b/src/lib/components/chat/MessageInput.svelte @@ -0,0 +1,282 @@ + + +
+
-
+