diff --git a/langchain_g4f/G4FLLM.py b/langchain_g4f/G4FLLM.py index 2568229..d110905 100644 --- a/langchain_g4f/G4FLLM.py +++ b/langchain_g4f/G4FLLM.py @@ -1,9 +1,10 @@ from typing import Any, List, Mapping, Optional, Union +from functools import partial from g4f import ChatCompletion from g4f.models import Model from g4f.Provider.base_provider import BaseProvider -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -42,6 +43,24 @@ class G4FLLM(LLM): if stop is not None: text = enforce_stop_tokens(text, stop) return text + async def _acall(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any) -> str: + create_kwargs = {} if self.create_kwargs is None else self.create_kwargs.copy() + create_kwargs["model"] = self.model + if self.provider is not None: + create_kwargs["provider"] = self.provider + if self.auth is not None: + create_kwargs["auth"] = self.auth + + text_callback = None + if run_manager: + text_callback = partial(run_manager.on_llm_new_token) + + text = "" + for token in ChatCompletion.create(messages=[{"role": "user", "content": prompt}], stream=True, **create_kwargs): + if text_callback: + await text_callback(token) + text += token + return text @property def _identifying_params(self) -> Mapping[str, Any]: