mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
Merge pull request #1527 from microsoft/xuyang1/add_openai_api_support
add openai interface support
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import fire
|
||||
from qlib.finco.task import WorkflowManager
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
def main(prompt):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
wm = WorkflowManager()
|
||||
wm.run(prompt)
|
||||
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
# TODO: use pydantic for other modules in Qlib
|
||||
from pydantic import BaseSettings
|
||||
from pydantic import (BaseSettings)
|
||||
|
||||
import os
|
||||
|
||||
class Conf(BaseSettings):
|
||||
"""module specific settings."""
|
||||
class Config():
|
||||
_instance = None
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
|
||||
self.max_tokens = 8000 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
|
||||
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.azure_api_base = os.getenv("AZURE_API_BASE")
|
||||
self.azure_api_version = os.getenv("AZURE_API_VERSION")
|
||||
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
|
||||
|
||||
...
|
||||
self.max_retry = os.getenv("MAX_RETRY")
|
||||
@@ -1,4 +1,6 @@
|
||||
import openai
|
||||
from typing import Optional
|
||||
from qlib.finco.conf import Config
|
||||
|
||||
|
||||
def example():
|
||||
@@ -12,3 +14,49 @@ def example():
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
|
||||
def try_create_chat_completion(max_retry=10, **kwargs):
|
||||
cfg = Config()
|
||||
max_retry = cfg.max_retry if cfg.max_retry is not None else max_retry
|
||||
for i in range(max_retry):
|
||||
try:
|
||||
response = create_chat_completion(**kwargs)
|
||||
return response
|
||||
except openai.error.RateLimitError as e:
|
||||
print(e)
|
||||
print(f"Retrying {i+1}th time...")
|
||||
continue
|
||||
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
|
||||
|
||||
def create_chat_completion(
|
||||
messages,
|
||||
model = None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
cfg = Config()
|
||||
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = cfg.max_tokens
|
||||
|
||||
openai.api_key = cfg.openai_api_key
|
||||
if cfg.use_azure:
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = cfg.azure_api_base
|
||||
openai.api_version = cfg.azure_api_version
|
||||
response = openai.ChatCompletion.create(
|
||||
engine=cfg.model,
|
||||
messages=messages,
|
||||
max_tokens=cfg.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=cfg.model,
|
||||
messages=messages,
|
||||
)
|
||||
return response
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_chat_completion()
|
||||
Reference in New Issue
Block a user