159 lines
5.7 KiB
Python
159 lines
5.7 KiB
Python
import time
|
|
from gradio_client import Client, handle_file
|
|
import json
|
|
import re
|
|
import os
|
|
|
|
def save_to_json(data, filename):
|
|
"""将数据保存为JSON文件"""
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True) # 确保目录存在
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
|
|
def extract_invoice_info(markdown_text):
|
|
try:
|
|
# 提取发票号码
|
|
invoice_match = re.search(r'发票号码:\s*(\d+)', markdown_text)
|
|
if not invoice_match:
|
|
raise ValueError("未找到发票号码信息")
|
|
invoice_number = invoice_match.group(1)
|
|
# print("invoice_number:", invoice_number)
|
|
# 提取销售方名称
|
|
seller_section = markdown_text.split('销售方信息')
|
|
if len(seller_section) < 2:
|
|
raise ValueError("未找到销售方信息部分")
|
|
seller_match = re.search(r'名称:\s*(.*?)\n', seller_section[-1])
|
|
if not seller_match:
|
|
raise ValueError("未找到销售方名称")
|
|
seller_name = seller_match.group(1).strip()
|
|
# print("seller_name:", seller_name)
|
|
# 修正金额正则表达式(移除$符号)
|
|
# amount_match = re.search(r'小写\s*¥(\d+\.\d+)', markdown_text)
|
|
amount_match = re.search(r'\(小写\)\s*¥(\d+\.\d+)', markdown_text)
|
|
|
|
if not amount_match:
|
|
raise ValueError("未找到金额信息")
|
|
amount = amount_match.group(1)
|
|
|
|
# 构建基础数据
|
|
invoice_data = {
|
|
"invoice_number": invoice_number,
|
|
"seller_name": seller_name,
|
|
"total_amount": amount,
|
|
"items": []
|
|
}
|
|
# print("amount:", amount)
|
|
# 提取商品明细
|
|
item_section = markdown_text.split('<tbody>')
|
|
if len(item_section) < 2:
|
|
raise ValueError("未找到商品明细部分")
|
|
print("发票号码:", invoice_number)
|
|
print("销售方名称:", seller_name)
|
|
print("金额:", amount)
|
|
# 修正表格解析逻辑
|
|
table_rows = re.findall(r'<tr>.*?</tr>', item_section[1],re.DOTALL)
|
|
if len(table_rows) < 3:
|
|
raise ValueError("商品明细数据不完整")
|
|
for row in table_rows[:-2]:
|
|
# print("row:", row)
|
|
item_name_match = re.search(r'<td>(.*?)</td>', row)
|
|
model_match = re.search(r'<td[^>]*>.*?</td>\s*<td[^>]*>(.*?)</td>', row)
|
|
quantity_match = re.search(r'<td>(\d+)</td>', row)
|
|
|
|
if item_name_match is not None:
|
|
print("项目名称:", item_name_match.group(1))
|
|
item_name = item_name_match.group(1)
|
|
else:
|
|
print("项目名称:", "无")
|
|
item_name = "无"
|
|
if model_match is not None:
|
|
print("规格型号:", model_match.group(1))
|
|
model = model_match.group(1)
|
|
else:
|
|
print("规格型号:", "无")
|
|
model = "无"
|
|
if quantity_match is not None:
|
|
print("数量:", quantity_match.group(1))
|
|
quantity = quantity_match.group(1)
|
|
else:
|
|
print("数量:", "无")
|
|
quantity = "无"
|
|
item_data = {
|
|
"name": item_name,
|
|
"model": model,
|
|
"quantity": quantity
|
|
}
|
|
invoice_data["items"].append(item_data)
|
|
return invoice_data
|
|
except Exception as e:
|
|
print(f"解析发票信息时出错: {str(e)}")
|
|
return None
|
|
|
|
def convert_pdf_to_markdown(
|
|
file_paths: list[str],
|
|
client
|
|
):
|
|
"""
|
|
Convert PDF/images to markdown using the API
|
|
|
|
Args:
|
|
client_url: URL of the docext server
|
|
username: Authentication username
|
|
password: Authentication password
|
|
file_paths: List of file paths to convert
|
|
model_name: Model to use for conversion
|
|
|
|
Returns:
|
|
str: Converted markdown content
|
|
"""
|
|
|
|
# Prepare file inputs
|
|
file_inputs = [{"image": handle_file(file_path)} for file_path in file_paths]
|
|
|
|
# Convert to markdown (non-streaming)
|
|
result = client.predict(
|
|
images=file_inputs,
|
|
api_name="/process_markdown_streaming"
|
|
)
|
|
return result
|
|
|
|
def get_pdf_files(directory):
|
|
pdf_files = []
|
|
for root, dirs, files in os.walk(directory):
|
|
for file in files:
|
|
if file.lower().endswith('.pdf'):
|
|
pdf_files.append(os.path.join(root, file))
|
|
return pdf_files
|
|
|
|
if __name__ == "__main__":
|
|
# # test extract_invoice_info function
|
|
# info = extract_invoice_info(markdown_text)
|
|
# print("Extracted invoice info:", info)
|
|
|
|
# Example usage
|
|
# client url can be the local host or the public url like `https://6986bdd23daef6f7eb.gradio.live/`
|
|
CLIENT_URL = "http://172.29.57.6:9998/"
|
|
client = Client(CLIENT_URL, auth=("admin", "admin"))
|
|
pdf_directory = "pdfs"
|
|
output_dir = "output"
|
|
pdf_files = get_pdf_files(pdf_directory)
|
|
print(pdf_files)
|
|
for pdf_file in pdf_files:
|
|
print(f"Found PDF file: {pdf_file}")
|
|
# Single image conversion
|
|
markdown_content = convert_pdf_to_markdown([pdf_file],client)
|
|
# print(markdown_content)
|
|
invoice_info = extract_invoice_info(markdown_content)
|
|
print(f"Extracted invoice info: {invoice_info}")
|
|
if invoice_info:
|
|
# 生成输出文件名
|
|
base_name = os.path.splitext(os.path.basename(pdf_file))[0]
|
|
print(f"Base name: {base_name}")
|
|
json_file = os.path.join(output_dir, f"{base_name}.json")
|
|
print(f"JSON file path: {json_file}")
|
|
|
|
# 保存为JSON
|
|
save_to_json(invoice_info, json_file)
|
|
print(f"发票信息已保存到: {json_file}")
|