Files
hf-test/docext-test/docext_api_test.py
2025-07-20 10:52:08 +08:00

127 lines
3.9 KiB
Python

import time
from gradio_client import Client, handle_file
import json
import re
import os
# def extract_invoice_info(markdown_text):
# try:
# # 提取发票号码
# invoice_number = re.search(r'发票号码:\s*(\d+)', markdown_text)
# if not invoice_number:
# raise ValueError("无法提取发票号码")
# # 提取销售方名称
# seller_section = markdown_text.split('销售方信息')[-1]
# seller_name = re.search(r'名称:\s*(.*?)\n', seller_section)
# if not seller_name:
# raise ValueError("无法提取销售方名称")
# # 提取小写金额
# amount = re.search(r'\(小写\)\s*¥(\d+\.\d+)', markdown_text)
# if not amount:
# raise ValueError("无法提取小写金额")
# return {
# "发票号码": invoice_number.group(1),
# "销售方名称": seller_name.group(1).strip(),
# "金额": amount.group(1)
# }
# except Exception as e:
# print(f"提取信息时出错: {e}")
# return None
def extract_invoice_info(markdown_text):
try:
# 提取发票号码
invoice_match = re.search(r'发票号码:\s*(\d+)', markdown_text)
if not invoice_match:
raise ValueError("未找到发票号码信息")
invoice_number = invoice_match.group(1)
# 提取销售方名称
seller_section = markdown_text.split('销售方信息')
if len(seller_section) < 2:
raise ValueError("未找到销售方信息部分")
seller_match = re.search(r'名称:\s*(.*?)\n', seller_section[-1])
if not seller_match:
raise ValueError("未找到销售方名称")
seller_name = seller_match.group(1).strip()
# 提取小写金额
amount_match = re.search(r'\(小写\)\s*¥(\d+\.\d+)', markdown_text)
if not amount_match:
raise ValueError("未找到金额信息")
amount = amount_match.group(1)
return {
"发票号码": invoice_number,
"销售方名称": seller_name,
"金额": amount
}
except Exception as e:
print(f"解析发票信息时出错: {str(e)}")
return None
def convert_pdf_to_markdown(
file_paths: list[str],
client
):
"""
Convert PDF/images to markdown using the API
Args:
client_url: URL of the docext server
username: Authentication username
password: Authentication password
file_paths: List of file paths to convert
model_name: Model to use for conversion
Returns:
str: Converted markdown content
"""
# Prepare file inputs
file_inputs = [{"image": handle_file(file_path)} for file_path in file_paths]
# Convert to markdown (non-streaming)
result = client.predict(
images=file_inputs,
api_name="/process_markdown_streaming"
)
return result
def get_pdf_files(directory):
pdf_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.lower().endswith('.pdf'):
pdf_files.append(os.path.join(root, file))
return pdf_files
if __name__ == "__main__":
# Example usage
# client url can be the local host or the public url like `https://6986bdd23daef6f7eb.gradio.live/`
CLIENT_URL = "https://61d79ea57016de2c8d.gradio.live/"
client = Client(CLIENT_URL, auth=("admin", "admin"))
pdf_directory = "pdfs"
pdf_files = get_pdf_files(pdf_directory)
for pdf_file in pdf_files:
print(f"Found PDF file: {pdf_file}")
# Single image conversion
markdown_content = convert_pdf_to_markdown(
[pdf_file],client
)
# print(markdown_content)
invoice_info = extract_invoice_info(markdown_content)
print(f"Extracted invoice info: {invoice_info}")