import logging
import os
from typing import Optional, Sequence, Callable, Union
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts.prompts import SycamorePrompt, RenderedPrompt
from sycamore.plan_nodes import Node
from sycamore.transforms.map import MapBatch
from sycamore.data import Document, Element
from sycamore.utils.threading import run_coros_threadsafe
import asyncio
logger = logging.getLogger(__name__)
async def _infer_prompts_async(prompts: list[RenderedPrompt], llm: LLM) -> list[str]:
el = asyncio.get_running_loop()
awaitables = [llm.generate_async(prompt=p, llm_kwargs={}) for p in prompts]
tasks = [el.create_task(aw) for aw in awaitables]
return await asyncio.gather(*tasks)
def _run_new_thread(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
def _infer_prompts(
prompts: list[RenderedPrompt],
llm: LLM,
llm_mode: LLMMode,
) -> list[str]:
if llm_mode == LLMMode.SYNC:
res = []
for p in prompts:
if len(p.messages) == 0:
res.append("")
continue
try:
s = llm.generate(prompt=p)
res.append(s)
if all_prompt_dir := os.environ.get("LLM_DEBUG_DIR"):
from datetime import datetime
from pathlib import Path
now = datetime.now().isoformat()
path = Path(all_prompt_dir) / f"{now}.txt"
logger.info(f"Saving prompt and result to {path}")
with open(path, "w") as f:
f.write(p.to_human_readable())
f.write("\n\n--------------------------------------------\n\n")
f.write(s)
except Exception:
bad_prompt_path = os.environ.get("BAD_PROMPT_PATH", "/tmp/bad_prompt.txt")
with open(bad_prompt_path, "w") as f:
f.write(p.to_human_readable())
logger.error(f"Error generating prompt. Wrote failing prompt to $BAD_PROMPT_PATH:{bad_prompt_path}")
raise
return res
elif llm_mode == LLMMode.ASYNC:
nonempty = [(i, p) for i, p in enumerate(prompts) if len(p.messages) > 0]
res = [""] * len(prompts)
coroutines = [llm.generate_async(prompt=p, llm_kwargs={}) for _, p in nonempty]
responses = run_coros_threadsafe(coroutines)
for (i, _), rs in zip(nonempty, responses):
res[i] = rs
return res
elif llm_mode == LLMMode.BATCH:
return llm.generate_batch(prompts=prompts)
else:
raise NotImplementedError("Unknown LLM Mode")
[docs]
class LLMMap(MapBatch):
"""The LLMMap transform renders each Document in a docset into
a prompt for an LLM, calls the LLM, and attaches the output to
the document.
Args:
child: Child node in the sycamore execution graph
prompt: The SycamorePrompt to use to render each document.
Must implement the ``render_document`` method.
output_field: The name of the field in doc.properties in which
to store the llm output
llm: The llm to use for inference.
llm_mode: How to call the llm - sync/async/batch. All LLMs do not
necessarily implement all options.
iteration_var: Name of the document property to increment with every
invalid response. Default is None, which means no re-try.
validate: Function to determine whether an LLM response is valid.
Default is 'everything is valid'
max_tries: Hard limit on the number of LLM calls per document. Default
is 5
Example:
.. code-block:: python
prompt = EntityExtractorZeroShotGuidancePrompt.set(entity="title")
docset.llm_map(
prompt=prompt,
output_field="title",
llm=OpenAI(OpenAIModels.GPT_4O_MINI)
)
"""
def __init__(
self,
child: Optional[Node],
prompt: SycamorePrompt,
output_field: str,
llm: LLM,
llm_mode: Optional[LLMMode] = None,
iteration_var: Optional[str] = None,
validate: Callable[[Document], bool] = lambda d: True,
max_tries: int = 5,
filter: Callable[[Document], bool] = lambda d: True,
**kwargs,
):
self._prompt = prompt
self._validate_prompt()
self._output_field = output_field
self._llm = llm
self._llm_mode = llm_mode if llm_mode is not None else llm.default_mode()
self._iteration_var = iteration_var
self._validate = validate
self._max_tries = max_tries
self._filter = filter
super().__init__(child, f=self.llm_map, **kwargs)
def llm_map(self, documents: list[Document]) -> list[Document]:
if self._iteration_var is not None:
for d in documents:
d.properties[self._iteration_var] = 0
skips = [not self._filter(d) for d in documents]
tries = 0
while not all(skips) and tries < self._max_tries:
tries += 1
rendered_and_index = [
(self._prompt.render_document(d), i) for sk, d, i in zip(skips, documents, range(len(skips))) if not sk
]
rendered = []
for r, i in rendered_and_index:
if len(r.messages) == 0:
skips[i] = True
else:
rendered.append(r)
if len(rendered) == 0:
break
results = _infer_prompts(rendered, self._llm, self._llm_mode)
ri = 0
for i in range(len(documents)):
if skips[i]:
continue
documents[i].properties[self._output_field] = results[ri]
skips[i] = self._validate(documents[i])
ri += 1
if self._iteration_var is not None and not skips[i]:
documents[i].properties[self._iteration_var] += 1
if self._iteration_var is None:
break
return documents
def _validate_prompt(self):
doc = Document()
try:
_ = self._prompt.render_document(doc)
except NotImplementedError as e:
raise e
except Exception:
pass
[docs]
class LLMMapElements(MapBatch):
"""The LLMMapElements transform renders each Element for each
Document in a docset into a prompt for an LLM, calls the LLM,
and attaches the output to the element.
Args:
child: Child node in the sycamore execution graph
prompt: The SycamorePrompt to use to render each element.
Must implement the ``render_element`` method.
output_field: The name of the field in elt.properties in which
to store the llm output.
llm: The llm to use for inference.
llm_mode: How to call the llm - sync/async/batch. All LLMs do not
necessarily implement all options.
iteration_var: Name of the element property to increment with every
invalid response. Default is None, which means no re-try.
validate: Function to determine whether an LLM response is valid.
Default is 'everything is valid'
max_tries: Hard limit on the number of LLM calls per element. Default
is 5
Example:
.. code-block:: python
prompt = TextSummarizerGuidancePrompt
docset.llm_map_elements(
prompt = prompt,
output_field = "summary",
llm = OpenAI(OpenAIModels.GPT_4O)
"""
def __init__(
self,
child: Optional[Node],
prompt: SycamorePrompt,
output_field: str,
llm: LLM,
llm_mode: Optional[LLMMode] = None,
iteration_var: Optional[str] = None,
validate: Callable[[Element], bool] = lambda e: True,
max_tries: int = 5,
filter: Callable[[Element], bool] = lambda e: True,
**kwargs,
):
self._prompt = prompt
self._validate_prompt()
self._output_field = output_field
self._llm = llm
self._llm_mode = llm_mode if llm_mode is not None else llm.default_mode()
self._iteration_var = iteration_var
self._validate = validate
self._max_tries = max_tries
self._filter = filter
super().__init__(child, f=self.llm_map_elements, **kwargs)
def llm_map_elements(self, documents: list[Document]) -> list[Document]:
elt_doc_pairs = [(e, d) for d in documents for e in d.elements]
if self._iteration_var is not None:
for e, _ in elt_doc_pairs:
e.properties[self._iteration_var] = 0
skips = [not self._filter(e) for e, _ in elt_doc_pairs]
tries = 0
while not all(skips) and tries < self._max_tries:
tries += 1
rendered_and_index = [
(self._prompt.render_element(e, d), i)
for sk, (e, d), i in zip(skips, elt_doc_pairs, range(len(skips)))
if not sk
]
rendered = []
for r, i in rendered_and_index:
if len(r.messages) == 0:
skips[i] = True
else:
rendered.append(r)
if len(rendered) == 0:
break
results = _infer_prompts(rendered, self._llm, self._llm_mode)
ri = 0
for i in range(len(elt_doc_pairs)):
if skips[i]:
continue
elt, doc = elt_doc_pairs[i]
elt.properties[self._output_field] = results[ri]
skips[i] = self._validate(elt)
ri += 1
if self._iteration_var is not None:
elt.properties[self._iteration_var] += 1
if self._iteration_var is None:
break
last_doc = None
new_elts = []
for e, d in elt_doc_pairs:
if last_doc is not None and last_doc.doc_id != d.doc_id:
last_doc.elements = new_elts
new_elts = []
new_elts.append(e)
last_doc = d
if last_doc is not None:
last_doc.elements = new_elts
return documents
def _validate_prompt(self):
doc = Document()
elt = Element()
try:
_ = self._prompt.render_element(elt, doc)
except NotImplementedError as e:
raise e
except Exception:
pass
def _as_sequences(ls: list[Union[RenderedPrompt, Sequence[RenderedPrompt]]]) -> list[Sequence[RenderedPrompt]]:
return [[p] if isinstance(p, RenderedPrompt) else p for p in ls]