Files

72 lines
2.0 KiB
Python

from typing import List
import requests
from llama_index.core.embeddings import BaseEmbedding
from pydantic import Field
class OpenAICompatibleEmbedding(BaseEmbedding):
model: str = Field(...)
api_key: str = Field(...)
api_base: str = Field(...)
timeout: int = Field(default=60)
def __init__(
self,
model: str,
api_key: str,
api_base: str,
timeout: int = 60,
):
self.model = model
self.api_key = api_key
self.api_base = api_base.rstrip("/")
self.timeout = timeout
# ---------- low-level call ----------
def _embed(self, texts: List[str]) -> List[List[float]]:
url = f"{self.api_base}/embeddings"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": texts,
}
resp = requests.post(
url,
headers=headers,
json=payload,
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
return [item["embedding"] for item in data["data"]]
# ---------- document embeddings ----------
def _get_text_embedding(self, text: str) -> List[float]:
return self._embed([text])[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._get_text_embeddings(texts)
# ---------- query embeddings (REQUIRED) ----------
def _get_query_embedding(self, query: str) -> List[float]:
# bge-m3 uses same embedding for query & doc
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)