diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 7eef50e47..2d0611fe5 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -19,6 +19,7 @@ import asyncio import json import uuid +import os from typing import List, Optional import requests @@ -300,6 +301,13 @@ def create_huggingface_api_worker(): default=None, help="Overwrite the random seed for each generation.", ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) args = parser.parse_args() with open(args.model_info_file, "r", encoding="UTF-8") as f: @@ -388,4 +396,14 @@ def create_huggingface_api_worker(): if __name__ == "__main__": args, workers = create_huggingface_api_worker() - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index c18f0aa9e..d69c8cd14 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -291,6 +291,13 @@ def create_model_worker(): parser.add_argument( "--debug", type=bool, default=False, help="Print debugging messages" ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) args = parser.parse_args() logger.info(f"args: {args}") @@ -359,4 +366,14 @@ def create_model_worker(): if __name__ == "__main__": args, worker = create_model_worker() - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py index f77ff4447..aafb7fbb8 100644 --- a/fastchat/serve/multi_model_worker.py +++ b/fastchat/serve/multi_model_worker.py @@ -190,6 +190,13 @@ def create_multi_model_worker(): parser.add_argument("--limit-worker-concurrency", type=int, default=5) parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) args = parser.parse_args() logger.info(f"args: {args}") @@ -279,4 +286,14 @@ def create_multi_model_worker(): if __name__ == "__main__": args, workers = create_multi_model_worker() - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info")