diff --git a/.gitignore b/.gitignore index 44bbb13..10073a9 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ .vscode/ +test.py diff --git a/langchain_g4f/G4FLLM.py b/langchain_g4f/G4FLLM.py index d110905..cf9a894 100644 --- a/langchain_g4f/G4FLLM.py +++ b/langchain_g4f/G4FLLM.py @@ -8,7 +8,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackM from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens - +MAX_TRIES = 5 class G4FLLM(LLM): model: Union[Model, str] provider: Optional[type[BaseProvider]] = None @@ -33,16 +33,25 @@ class G4FLLM(LLM): if self.auth is not None: create_kwargs["auth"] = self.auth - text = ChatCompletion.create( - messages=[{"role": "user", "content": prompt}], - **create_kwargs, - ) + for i in range(MAX_TRIES): + try: + text = ChatCompletion.create( + messages=[{"role": "user", "content": prompt}], + **create_kwargs, + ) + + # Generator -> str + text = text if type(text) is str else "".join(text) + if stop is not None: + text = enforce_stop_tokens(text, stop) + if text: + return text + print(f"Empty response, trying {i+1} of {MAX_TRIES}") + except Exception as e: + print(f"Error in G4FLLM._call: {e}, trying {i+1} of {MAX_TRIES}") + return "" + - # Generator -> str - text = text if type(text) is str else "".join(text) - if stop is not None: - text = enforce_stop_tokens(text, stop) - return text async def _acall(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any) -> str: create_kwargs = {} if self.create_kwargs is None else self.create_kwargs.copy() create_kwargs["model"] = self.model