Source code for sycamore.transforms.extract_table
import io
import json
import logging
from abc import abstractmethod, ABC
from collections import OrderedDict
from typing import Optional, Any
import pypdf
from botocore.exceptions import ClientError
from textractor import Textractor
from textractor.data.constants import TextractFeatures
from textractor.parsers import response_parser
from sycamore.data import BoundingBox, Document, Element
from sycamore.data.document import DocumentPropertyTypes
logger = logging.getLogger("sycamore")
[docs]
class MissingS3UploadPath(Exception):
"Raised when an S3 upload path is needed but one wasn't provided"
pass
class TableExtractor(ABC):
@abstractmethod
def extract_tables(self, document: Document) -> Document:
pass
[docs]
class TextractTableExtractor(TableExtractor):
"""
TextractTableExtractor utilizes Amazon Textract to extract tables from documents.
This class inherits from TableExtractor and is designed for extracting tables from documents using Amazon Textract,
a cloud-based document text and data extraction service from AWS.
Args:
profile_name: The AWS profile name to use for authentication. Default is None.
region_name: The AWS region name where the Textract service is available.
kms_key_id: The AWS Key Management Service (KMS) key ID for encryption.
Example:
.. code-block:: python
table_extractor = TextractTableExtractor(profile_name="my-profile", region_name="us-east-1")
context = sycamore.init()
pdf_docset = context.read.binary(paths, binary_format="pdf")
.partition(partitioner=ArynPartitioner(), table_extractor=table_extractor)
"""
def __init__(
self,
profile_name: Optional[str] = None,
region_name: Optional[str] = None,
kms_key_id: str = "",
s3_upload_root: str = "",
):
self._profile_name = profile_name
self._region_name = region_name
self._kms_key_id: str = kms_key_id
self._s3_upload_root: str = s3_upload_root
def get_textract_result(self, document: Document):
extractor = Textractor(self._profile_name, self._region_name, self._kms_key_id)
path = document.properties["path"]
if path.startswith("s3://"): # if document is already in s3, don't upload it again
result = extractor.start_document_analysis(document.properties["path"], TextractFeatures.TABLES)
elif not self._s3_upload_root.startswith("s3://"):
raise MissingS3UploadPath()
else:
# TODO: https://github.com/aryn-ai/sycamore/issues/173 - implement content-hash uploading
# If we manually upload based on a hash, we can avoid repeated uploads and storage
# of the same file.
tmp_path = path
if not tmp_path.startswith("/"):
tmp_path = "/" + path
# os.path.join("s3://foo", "/abc") -> "/abc"; which is not what we want.
dest = self._s3_upload_root + tmp_path
result = extractor.start_document_analysis(
document.properties["path"], TextractFeatures.TABLES, s3_upload_path=dest
)
return result
@staticmethod
def get_tables_from_textract_result(result):
# https://docs.aws.amazon.com/textract/latest/dg/API_BoundingBox.html
def bbox_to_coord(bbox):
return bbox.x, bbox.y, bbox.x + bbox.width, bbox.y + bbox.height
all_tables = []
for table in result.tables:
element = Element()
element.type = "Table"
element.properties["boxes"] = []
element.properties["id"] = table.id
element.properties[DocumentPropertyTypes.PAGE_NUMBER] = table.page
if table.title:
element.text_representation = table.title.text + "\n"
else:
element.text_representation = ""
element.text_representation = element.text_representation + table.to_csv() + "\n"
element.bbox = BoundingBox(*bbox_to_coord(table.bbox))
for footer in table.footers:
element.text_representation = element.text_representation + footer.text + "\n"
all_tables.append(element)
return all_tables
def extract_tables(self, document: Document) -> Document:
textract_result = self.get_textract_result(document)
tables = self.get_tables_from_textract_result(textract_result)
document.elements = document.elements + tables
return document
[docs]
class CachedTextractTableExtractor(TextractTableExtractor):
"""
Extends TextractTableExtractor with S3 based cache support for raw Textract results.
CachedTextractTableExtractor overrides the 'get_textract_result' method by doing the following:
1. if cache hit for current document, get from cache and return, otherwise continue
2. if run_full_textract is enabled, call textractor on the whole document and go to step 4
3. else clip pages which contain tables and run table extraction using textractor
5. update cache accordingly based on textractor result and return result
"""
def __init__(
self,
s3_cache_location,
run_full_textract: bool = False,
s3_textract_upload_path: str = "",
profile_name: Optional[str] = None,
region_name: Optional[str] = None,
kms_key_id: str = "",
):
super().__init__(profile_name, region_name, kms_key_id)
self._s3_cache_location = s3_cache_location
self._run_full_textract = run_full_textract
self._s3_textract_upload_path = s3_textract_upload_path
self._profile_name = profile_name
self._region_name = region_name
self._kms_key_id = kms_key_id
def _get_cached_textract_result(self, s3, cache_id: str):
"""Get cache from S3"""
try:
parts = self._s3_cache_location.replace("s3://", "").strip("/").split("/", 1)
bucket = parts[0]
key = "/".join([parts[1], cache_id]) if len(parts) == 2 else cache_id
response = s3.get_object(Bucket=bucket, Key=key)
parsed_response = json.loads(response["Body"].read())
return response_parser.parse(parsed_response["textract_result"]), parsed_response["document_page_mapping"]
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return None, None
else:
raise
def _cache_textract_result(self, s3, cache_id: str, result: Any, document_page_mapping: list):
"""Put table into S3"""
parts = self._s3_cache_location.replace("s3://", "").strip("/").split("/", 1)
bucket = parts[0]
key = "/".join([parts[1], cache_id]) if len(parts) == 2 else cache_id
json_str = json.dumps({"document_page_mapping": document_page_mapping, "textract_result": result.response})
s3.put_object(Body=json_str, Bucket=bucket, Key=key)
@staticmethod
def _cache_id(s3, object_path: str) -> str:
parts = object_path.replace("s3://", "").split("/", 1)
response = s3.head_object(Bucket=parts[0], Key=parts[1])
cache_id = response["ETag"].replace('"', "")
return cache_id
def get_textract_result(self, document: Document):
import boto3
s3 = boto3.client("s3")
cache_id = self._cache_id(s3, document.properties["path"])
try:
textract_result, document_page_mapping = self._get_cached_textract_result(s3, cache_id)
if textract_result:
logger.info(f"Textract cache hit for {document.properties['path']}")
return textract_result, document_page_mapping
except Exception as e:
logger.exception("Error in reading from cache %s", str(e))
# cache miss
document_page_mapping = list(
OrderedDict.fromkeys(
[
element.properties[DocumentPropertyTypes.PAGE_NUMBER]
for element in document.elements
if element.type == "Table"
]
)
)
# no pages with tables found and no full execution
if not self._run_full_textract and not document_page_mapping:
return None, None
logger.info(f"Textract cache miss for {document.properties['path']}")
extractor = Textractor(self._profile_name, self._region_name, self._kms_key_id)
if self._run_full_textract:
textract_result = extractor.start_document_analysis(
document.properties["path"],
TextractFeatures.TABLES,
s3_upload_path=self._s3_textract_upload_path,
save_image=False,
)
else:
# When no s3 upload path exists, it's assumed textractor won't run even cache miss
if not self._s3_textract_upload_path:
raise RuntimeError("Missing textract upload path")
# Clip the pages which have tables into a new tmp pdf and upload for textract
binary = io.BytesIO(document.data["binary_representation"])
pdf_reader = pypdf.PdfReader(binary)
pdf_writer = pypdf.PdfWriter()
for page_number in document_page_mapping:
page = pdf_reader.pages[page_number - 1] # Page numbers start from 0
pdf_writer.add_page(page)
output_pdf_stream = io.BytesIO()
pdf_writer.write(output_pdf_stream)
# Do textract
textract_result = extractor.start_document_analysis(
output_pdf_stream.getvalue(),
TextractFeatures.TABLES,
s3_upload_path=self._s3_textract_upload_path,
save_image=False,
)
self._cache_textract_result(s3, cache_id, textract_result, document_page_mapping)
return textract_result, document_page_mapping
def extract_tables(self, document: Document) -> Document:
textract_result, document_page_mapping = self.get_textract_result(document)
if textract_result:
tables = self.get_tables_from_textract_result(textract_result)
# put back actual page numbers
for table in tables:
table.properties[DocumentPropertyTypes.PAGE_NUMBER] = (
document_page_mapping[table.properties[DocumentPropertyTypes.PAGE_NUMBER] - 1]
if document_page_mapping
else table.properties[DocumentPropertyTypes.PAGE_NUMBER]
)
document.elements = document.elements + tables
return document