Skip to content

Commit

Permalink
Merge pull request #613 from samiamjidkhan/dmg-backend
Browse files Browse the repository at this point in the history
image and text mode fix
  • Loading branch information
AlexCheema authored Jan 21, 2025
2 parents 819ec76 + 5c4ce53 commit 410d901
Showing 1 changed file with 49 additions and 20 deletions.
69 changes: 49 additions & 20 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,23 @@ def remap_messages(messages: List[Message]) -> List[Message]:
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools: chat_template_args["tools"] = tools

prompt = tokenizer.apply_chat_template(**chat_template_args)
print(f"!!! Prompt: {prompt}")
return prompt
if tools:
chat_template_args["tools"] = tools

try:
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
return prompt
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
return prompt


def parse_message(data: dict):
Expand Down Expand Up @@ -213,11 +225,16 @@ def __init__(
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})

# Add static routes
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')

# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
self.app.router.add_static('/images/', self.images_dir, name='static_images')

self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
Expand Down Expand Up @@ -509,20 +526,32 @@ async def stream_image(_request_id: str, result, is_finished: bool):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')

elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder/image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)

await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
try:
im = Image.fromarray(np.array(result))
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = self.images_dir/image_filename
im.save(image_path)

# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)

await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')

if is_finished:
await response.write_eof()

except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')

stream_task = None

Expand Down

0 comments on commit 410d901

Please sign in to comment.